Pythonで機械学習体験(scikit-learn)

scikit-learnとは

scikit-learnは、Pythonで機械学習を簡単に実装するための代表的なライブラリです。

※読み方はサイキットラーン

公式ドキュメント : https://scikit-learn.org/stable/index.html

Pythonで機械学習を行うときの定番中の定番です。

できること

分類 数字認識、迷惑メール判定
回帰 売上予測、価格予測
クラスタリング 顧客分類
データ分割 学習用・テスト用分離
評価 正解率計算
学習済みデータセット load_digits(手書き数字画像)... etc

今回は手書き数字画像の学習済みデータセットを使用して、手書き画像認識の機械学習をやってみます。

load_digitsは1797枚の手書きの数字データと、それの正解ラベル(0〜9)が含まれるデータセットです。

ここで重要なのは「AIは最初から数字を理解しているわけではない」という点です。

大量の「問題と答え」を見て学習しています。

サンプルスクリプト

スクリプトの仕様

  1. キャンバスを起動して数字を描く
  2. 描かれた画像データを学習済みデータセットを使用して認識する
  3. 精度が低い、または間違っている場合は、教師あり学習データとして利用できる

使用するライブラリ

  • scikit-learn
    機械学習
  • opencv-python
    画像処理・画面操作
  • numpy
    数値データ処理

スクリプト

# pip install scikit-learn opencv-python numpy

import cv2
import numpy as np
from sklearn.datasets import load_digits
from sklearn.svm import SVC

# 1. 学習データを読み込み
digits = load_digits()
X = digits.data
y = digits.target

# 2. モデルを学習
# probability=True を追加
model = SVC(gamma=0.001, probability=True)
model.fit(X, y)

# 3. お絵描きキャンバス
canvas = np.ones((400, 400), dtype=np.uint8) * 255
drawing = False

def draw(event, x, y, flags, param):
    global drawing, canvas

    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True

    elif event == cv2.EVENT_MOUSEMOVE:
        if drawing:
            cv2.circle(canvas, (x, y), 18, 0, -1)

    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False

cv2.namedWindow("Draw Digit")
cv2.setMouseCallback("Draw Digit", draw)

print("マウスで数字を書いてください")
print("p: 判定 / c: クリア / q: 終了")

while True:
    display = canvas.copy()

    cv2.putText(display,
                "p: predict  c: clear  q: quit",
                (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                0,
                2)

    cv2.imshow("Draw Digit", display)

    key = cv2.waitKey(1) & 0xFF

    if key == ord("p"):

        # 白黒反転
        img = 255 - canvas

        coords = cv2.findNonZero(img)

        if coords is None:
            print("数字が書かれていません")
            continue

        x, y, w, h = cv2.boundingRect(coords)
        digit = img[y:y+h, x:x+w]

        # 正方形化
        size = max(w, h)
        square = np.zeros((size, size), dtype=np.uint8)

        x_offset = (size - w) // 2
        y_offset = (size - h) // 2

        square[y_offset:y_offset+h,
               x_offset:x_offset+w] = digit

        # 8x8へ縮小
        resized = cv2.resize(
            square,
            (8, 8),
            interpolation=cv2.INTER_AREA
        )

        # 0〜16に変換
        input_data = resized / 255 * 16
        input_data = input_data.reshape(1, -1)

        # 予測
        prediction = model.predict(input_data)[0]

        # 確率取得
        probabilities = model.predict_proba(input_data)[0]

        confidence = probabilities[prediction] * 100

        print(f"AIの予測: {prediction}")
        print(f"信頼度: {confidence:.2f}%")

        # 全確率表示
        print("各数字の確率")
        for i, prob in enumerate(probabilities):
            print(f"{i}: {prob*100:.2f}%")

        # 画面表示
        result_text = f"Prediction: {prediction}"
        conf_text = f"Confidence: {confidence:.1f}%"

        cv2.putText(display,
                    result_text,
                    (70, 180),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1.2,
                    0,
                    3)

        cv2.putText(display,
                    conf_text,
                    (70, 240),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    1.0,
                    0,
                    2)

        cv2.imshow("Draw Digit", display)

        cv2.waitKey(2000)

    elif key == ord("c"):
        canvas[:] = 255

    elif key == ord("q"):
        break

cv2.destroyAllWindows()

実行結果

起動直後

※Warningはフォントのエラーなので無視

「2」を描いて判定

結果はあってますが、完全な「2」が描いたつもりが67%は低い気がします。

描いた「2」を教師あり学習データとして利用

再度「2」を描いて判定

94%まで上がりました。(そもそも最初より上手く描けてる気もしますが…

まとめ

機械学習とは、大量のデータから特徴を学び、未知のデータを予測する技術です。

また、機械学習においては、人間のように理解しているわけではなく、画像や文字を数値として扱い確率的に判断しているということを理解しておく必要があります。

この仕組みは、現在の生成AIや画像認識AIなど、さまざまなAI技術の基礎になっています。

コメント

このブログの人気の投稿

docker-compose up で proxyconnect tcp: dial tcp: lookup proxy.example.com: no such host

[Java] JDBCドライバでMySQL接続するまでの手順

[Azure]キーコンテナー (Azure Key Vault)