前回の記事で、numpyを使って行列演算で3層パーセプトロンを実装しましたが、同じことをChainerを使って実装してみます。
import numpy as np import chainer from chainer import Function, Variable, optimizers from chainer import Link, Chain import chainer.functions as F import chainer.links as L model = Chain(layer1=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)), layer2=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32))) optimizer = optimizers.SGD() optimizer.setup(model) # 入力値と教師データ x_data = np.array([[1.0, 0.5]], dtype=np.float32) t_data = np.array([0]) # 教師データ(クラス) x = Variable(x_data) t = Variable(t_data) # 順伝播 u1 = model.layer1(x) z1 = F.sigmoid(u1) u2 = model.layer2(z1) y = F.softmax(u2) # 損失関数 loss = F.softmax_cross_entropy(u2, t) # 誤差逆伝播 optimizer.zero_grads() loss.backward() optimizer.weight_decay(0.005) optimizer.update()
クラス分類を行う場合は、教師データは、1次元の整数になります。
2値のクラスの場合は、0か1です。
Chainerで扱うデータ型は、すべてVariable型になります。
numpyのarrayから変換が可能です。
浮動小数点は、float32のみ対応しています。
model = Chain(layer1=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)), layer2=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)))
の部分で、ニューラルネットワークの層数、ユニット数、ネットワークのパラメータの初期値を定義しています。
ネットワークのパラメータの初期値を省略した場合は、デフォルトでランダムな値が使用されます。
optimizer = optimizers.SGD()
で、誤差逆伝播の勾配を計算する方法を指定しています。最も単純な、確率的勾配降下法(SGD)としています。
u1 = model.layer1(x) z1 = F.sigmoid(u1) u2 = model.layer2(z1) y = F.softmax(u2)
の部分で、順伝播を計算しています。活性化関数の関数として隠れ層にsigmoid、出力層にsoftmaxを使用しています。
loss = F.softmax_cross_entropy(u2, t)
で、使用する損失関数を指定しています。引数は、出力層の入力と教師データです。
出力層の出力ではないので注意してください。
教師データは、上記で説明した通りクラスを示す1次元の整数です。
loss.backward()
で勾配を計算して、
optimizer.update()
で、ネットワークのパラメータを更新しています。
optimizer.zero_grads()
で、勾配を初期化しています。
optimizer.weight_decay(0.005)
は、学習係数の減衰率を設定しています。(最新のChainerでは記述方法が変わっています。公式のリファレンスを参照してください。)
GUIで教師データをマウスクリックで与えてネットワークパラメータの変化を観測できるツールを作成してみました。
左クリックでクラス0(赤)、右クリックでクラス1(青)の教師データを入力します。
Githubでソースを公開しています。
BackpropagationChainerTest.py | 3層パーセプトロンの誤差逆伝播(Chainer版) |
BackpropagationTest.py | 3層パーセプトロンの誤差逆伝播(numpy版) |
BackpropagationChainer.py | GUIツール(Chainer版) |
Backpropagation.py | GUIツール(numpy版) |