Pythonで機械学習体験(scikit-learn)
scikit-learnとは
scikit-learnは、Pythonで機械学習を簡単に実装するための代表的なライブラリです。
※読み方はサイキットラーン
公式ドキュメント : https://scikit-learn.org/stable/index.html
Pythonで機械学習を行うときの定番中の定番です。
できること
| 分類 | 数字認識、迷惑メール判定 |
|---|---|
| 回帰 | 売上予測、価格予測 |
| クラスタリング | 顧客分類 |
| データ分割 | 学習用・テスト用分離 |
| 評価 | 正解率計算 |
| 学習済みデータセット | load_digits(手書き数字画像)... etc |
今回は手書き数字画像の学習済みデータセットを使用して、手書き画像認識の機械学習をやってみます。
load_digitsは1797枚の手書きの数字データと、それの正解ラベル(0〜9)が含まれるデータセットです。
ここで重要なのは「AIは最初から数字を理解しているわけではない」という点です。
大量の「問題と答え」を見て学習しています。
サンプルスクリプト
スクリプトの仕様
- キャンバスを起動して数字を描く
- 描かれた画像データを学習済みデータセットを使用して認識する
- 精度が低い、または間違っている場合は、教師あり学習データとして利用できる
使用するライブラリ
- scikit-learn
機械学習 - opencv-python
画像処理・画面操作 - numpy
数値データ処理
スクリプト
# pip install scikit-learn opencv-python numpy
import cv2
import numpy as np
from sklearn.datasets import load_digits
from sklearn.svm import SVC
# 1. 学習データを読み込み
digits = load_digits()
X = digits.data
y = digits.target
# 2. モデルを学習
# probability=True を追加
model = SVC(gamma=0.001, probability=True)
model.fit(X, y)
# 3. お絵描きキャンバス
canvas = np.ones((400, 400), dtype=np.uint8) * 255
drawing = False
def draw(event, x, y, flags, param):
global drawing, canvas
if event == cv2.EVENT_LBUTTONDOWN:
drawing = True
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
cv2.circle(canvas, (x, y), 18, 0, -1)
elif event == cv2.EVENT_LBUTTONUP:
drawing = False
cv2.namedWindow("Draw Digit")
cv2.setMouseCallback("Draw Digit", draw)
print("マウスで数字を書いてください")
print("p: 判定 / c: クリア / q: 終了")
while True:
display = canvas.copy()
cv2.putText(display,
"p: predict c: clear q: quit",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
0,
2)
cv2.imshow("Draw Digit", display)
key = cv2.waitKey(1) & 0xFF
if key == ord("p"):
# 白黒反転
img = 255 - canvas
coords = cv2.findNonZero(img)
if coords is None:
print("数字が書かれていません")
continue
x, y, w, h = cv2.boundingRect(coords)
digit = img[y:y+h, x:x+w]
# 正方形化
size = max(w, h)
square = np.zeros((size, size), dtype=np.uint8)
x_offset = (size - w) // 2
y_offset = (size - h) // 2
square[y_offset:y_offset+h,
x_offset:x_offset+w] = digit
# 8x8へ縮小
resized = cv2.resize(
square,
(8, 8),
interpolation=cv2.INTER_AREA
)
# 0〜16に変換
input_data = resized / 255 * 16
input_data = input_data.reshape(1, -1)
# 予測
prediction = model.predict(input_data)[0]
# 確率取得
probabilities = model.predict_proba(input_data)[0]
confidence = probabilities[prediction] * 100
print(f"AIの予測: {prediction}")
print(f"信頼度: {confidence:.2f}%")
# 全確率表示
print("各数字の確率")
for i, prob in enumerate(probabilities):
print(f"{i}: {prob*100:.2f}%")
# 画面表示
result_text = f"Prediction: {prediction}"
conf_text = f"Confidence: {confidence:.1f}%"
cv2.putText(display,
result_text,
(70, 180),
cv2.FONT_HERSHEY_SIMPLEX,
1.2,
0,
3)
cv2.putText(display,
conf_text,
(70, 240),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
0,
2)
cv2.imshow("Draw Digit", display)
cv2.waitKey(2000)
elif key == ord("c"):
canvas[:] = 255
elif key == ord("q"):
break
cv2.destroyAllWindows()
実行結果
起動直後
※Warningはフォントのエラーなので無視
「2」を描いて判定
結果はあってますが、完全な「2」が描いたつもりが67%は低い気がします。
描いた「2」を教師あり学習データとして利用
再度「2」を描いて判定
94%まで上がりました。(そもそも最初より上手く描けてる気もしますが…
まとめ
機械学習とは、大量のデータから特徴を学び、未知のデータを予測する技術です。
また、機械学習においては、人間のように理解しているわけではなく、画像や文字を数値として扱い確率的に判断しているということを理解しておく必要があります。
この仕組みは、現在の生成AIや画像認識AIなど、さまざまなAI技術の基礎になっています。




コメント
コメントを投稿