ここでは混合ガウスモデル(GMM:Gaussian Mixture Models)と呼ぶクラスタリングモデルを学び、手描き数字のサンプルから新たな手描き数字を作り出すというプログラムを紹介します。
まずscikit-learnモジュールに手描き数字のサンプルがあるため、それをロードします。以下のような手書きの数字になっています。
この手書き数字の次元削減を行うために、データを可逆次元削減アルゴリズムに通します。ここでは簡単なPCA(主成分分析)を使用して、分散の99%を保存する程度までデータの次元を削減します。そしてPCAオブジェクトの逆変換を使用して、新しい数字を作成します。コードは以下の通りです。
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
digits=load_digits()
pca=PCA(0.99,whiten=True)
data=pca.fit_transform(digits.data)
n_components=np.arange(50,210,10)
models=[GaussianMixture(n,covariance_type='full',random_state=0) for n in n_components]
aics=[model.fit(data).aic(data) for model in models]
gmm=GaussianMixture(150,covariance_type='full',random_state=0)
gmm.fit(data)
data_new,label_new=gmm.sample(100)
digits_new=pca.inverse_transform(data_new)
def plot_digits(data):
fig,ax=plt.subplots(5,10,figsize=(8,4),subplot_kw=dict(xticks=[],yticks=[]))
fig.subplots_adjust(hspace=0.05,wspace=0.05)
for i,axi in enumerate(ax.flat):
im=axi.imshow(data[i].reshape(8,8),cmap='binary')
im.set_clim(0,16)
plot_digits(digits_new)
plt.grid()
plt.show()
このコードを実行すると、以下のような新しいランダムな数字ができあがります。
最初に示した数字のサンプルと同じように、素晴らしい手書き数字ができあがっていることがわかります。