分類問題を機械学習で解く

今回の目的

理論とか細かいことは後回し、取り敢えず実行してみて、機械学習がどのようなものか感じをつかむ。

分類問題とは

データを複数のクラスに分類すること。

今回は、機械学習の手法の一つである教師あり学習を使って分類する。教師あり学習とは、予め用意された問題と正解の傾向を学習することで、未知の問題に対する正解を推測する手法をいう。

必要なもの

  • データセット
  • 分類器 (データを分類する機械学習モデル)
    • k-近傍法(k-NN)
    • 決定木
    • サポートベクターマシン(SVM)
    • ロジスティック回帰など

データの傾向を学習させる必要があるため、目的に合わせたデータセットを事前に用意する。分類器はライブラリとして既に実装されているものを利用する。

機械学習を試すとき、まずはデータを準備することが一つのハードルになるが、scikit-learnにはいくつかの標準的なデータセットが付属しているので、自分で用意しなくても試すことが出来る。

今回はsklearn.datasetsのload_iris(アヤメの花のデータセット)を使用する。がくや花びらの大きさとアヤメの種類がデータに含まれていて、がくや花びらの大きさからアヤメの種類を予測する。

Pythonで実装

from sklearn.datasets import load_iris

iris = load_iris()
# データの内容を確認する
iris
{'DESCR': '(データの説明は省略)',
 'data': array([[5.1, 3.5, 1.4, 0.2],
        [4.9, 3. , 1.4, 0.2],
        [4.7, 3.2, 1.3, 0.2],
        [4.6, 3.1, 1.5, 0.2],
        [5. , 3.6, 1.4, 0.2],
        (中略)
        [5.9, 3. , 5.1, 1.8]]),
 'feature_names': ['sepal length (cm)', # がくの長さ
  'sepal width (cm)',                   # がくの幅
  'petal length (cm)',                  # 花びらの長さ
  'petal width (cm)'],                  # 花びらの幅
 'filename': '/usr/local/lib/python3.6/dist-packages/sklearn/datasets/data/iris.csv',
 # アヤメの種類 0: 'setosa', 1: 'versicolor', 2: 'virginica'
 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10')} 

この後のデータの解析を行いやすくするため、pandasのDataFrameに変換する。

import pandas as pd

# DataFrameに変換する
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target_names[iris.target]

# 先頭5行を表示する
df.head()
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)target
05.13.51.40.2setosa
14.93.01.40.2setosa
24.73.21.30.2setosa
34.63.11.50.2setosa
45.03.61.40.2setosa

ざっとデータの内容を確認する。欠損値もないしそのまま使えるように既に整えられている。

df.info()

RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   sepal length (cm)  150 non-null    float64
 1   sepal width (cm)   150 non-null    float64
 2   petal length (cm)  150 non-null    float64
 3   petal width (cm)   150 non-null    float64
 4   target             150 non-null    object 
dtypes: float64(4), object(1)
memory usage: 6.0+ KB
df.describe()
sepal length (cm)sepal width (cm)petal length (cm)petal width (cm)
count150150150150
mean5.8433333.0573333.7581.199333
std0.8280660.4358661.7652980.762238
min4.3210.1
25%5.12.81.60.3
50%5.834.351.3
75%6.43.35.11.8
max7.94.46.92.5
df['target'].value_counts()

virginica     50
setosa        50
versicolor    50
Name: target, dtype: int64

次に説明変数と目的変数に分ける。
説明変数とは目的変数を説明する変数のこと。これをもとに予測する。
目的変数とは予測したい変数のこと。

# 説明変数
X = df.drop('target', axis=1)
# 目的変数
y = df['target']

さらに、トレーニング用データセットとテスト用データセットに分ける。未知の値について予測する性能をテストするため、テスト用のデータはトレーニングで使用してはいけない。

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

分類器にはサポートベクターマシンを使用する。scikit-learnサポートベクターマシンを実装したライブラリがあるので、それをそのまま使用する。

from sklearn.svm import SVC

svc = SVC()
# 問題と正解の傾向を学習させ、学習済みモデルを作成する
svc.fit(X_train, y_train)

学習済みモデルが作成出来たら、テストデータを使ってアヤメの種類を予測してみる。

svc.predict(X_test)

array(['virginica', 'versicolor', 'setosa', 'virginica', 'setosa',
       'virginica', 'setosa', 'versicolor', 'versicolor', 'versicolor',
       'virginica', 'versicolor', 'versicolor', 'versicolor',
       'versicolor', 'setosa', 'versicolor', 'versicolor', 'setosa',
       'setosa', 'virginica', 'versicolor', 'setosa', 'setosa',
       'virginica', 'setosa', 'setosa', 'versicolor', 'versicolor',
       'setosa', 'virginica', 'versicolor', 'setosa', 'virginica',
       'virginica', 'versicolor', 'setosa', 'virginica'], dtype=object)

どのくらい正解しているのかは下記で確認出来る。

svc.score(X_test, y_test)

0.9736842105263158

97%正しいアヤメの種類を予測出来ている。

まとめ

上記はかなり単純な例です。このように高い率で予測出来ているのは、使用しやすく既にデータが整備されていたためです。本来データには欠損値であったり、ノイズであったり、そもそも予測するための説明変数が不足していたりしていますので、まず学習で使用するデータの作成に時間がかかります。
また、分類器も今回はサポートベクターマシンを使用しましたが、最適な分類器ではないかもしれません。どの分類器が最適か探すこともあるでしょうし、そのパラメーターのチューニングも必要になるかもしれません。
これらを詰めていき正解率を上げていく作業はなかなか楽しいです。興味があればさらに詳しく調べてみてください。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください