TadaoYamaokaの日記

山岡忠夫 Home で公開しているプログラムの開発ネタを中心に書いていきます。

Chainerで3層パーセプトロンの誤差逆伝播を実装してみた

前回の記事で、numpyを使って行列演算で3層パーセプトロンを実装しましたが、同じことをChainerを使って実装してみます。

import numpy as np
import chainer
from chainer import Function, Variable, optimizers
from chainer import Link, Chain
import chainer.functions as F
import chainer.links as L

model = Chain(layer1=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)),
        layer2=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)))

optimizer = optimizers.SGD()
optimizer.setup(model)

# 入力値と教師データ
x_data = np.array([[1.0, 0.5]], dtype=np.float32)
t_data = np.array([0]) # 教師データ(クラス)
x = Variable(x_data)
t = Variable(t_data)

# 順伝播
u1 = model.layer1(x)
z1 = F.sigmoid(u1)
u2 = model.layer2(z1)
y = F.softmax(u2)

# 損失関数
loss = F.softmax_cross_entropy(u2, t)

# 誤差逆伝播
optimizer.zero_grads()
loss.backward()
optimizer.weight_decay(0.005)
optimizer.update()


クラス分類を行う場合は、教師データは、1次元の整数になります。
2値のクラスの場合は、0か1です。


Chainerで扱うデータ型は、すべてVariable型になります。
numpyのarrayから変換が可能です。
浮動小数点は、float32のみ対応しています。

model = Chain(layer1=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)),
        layer2=L.Linear(2, 2, initialW=np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)))

の部分で、ニューラルネットワークの層数、ユニット数、ネットワークのパラメータの初期値を定義しています。
ネットワークのパラメータの初期値を省略した場合は、デフォルトでランダムな値が使用されます。

optimizer = optimizers.SGD()

で、誤差逆伝播の勾配を計算する方法を指定しています。最も単純な、確率的勾配降下法SGD)としています。

u1 = model.layer1(x)
z1 = F.sigmoid(u1)
u2 = model.layer2(z1)
y = F.softmax(u2)

の部分で、順伝播を計算しています。活性化関数の関数として隠れ層にsigmoid、出力層にsoftmaxを使用しています。

loss = F.softmax_cross_entropy(u2, t)

で、使用する損失関数を指定しています。引数は、出力層の入力と教師データです。
出力層の出力ではないので注意してください。
教師データは、上記で説明した通りクラスを示す1次元の整数です。

loss.backward()

で勾配を計算して、

optimizer.update()

で、ネットワークのパラメータを更新しています。

optimizer.zero_grads()

で、勾配を初期化しています。

optimizer.weight_decay(0.005)

は、学習係数の減衰率を設定しています。(最新のChainerでは記述方法が変わっています。公式のリファレンスを参照してください。)



GUIで教師データをマウスクリックで与えてネットワークパラメータの変化を観測できるツールを作成してみました。
左クリックでクラス0(赤)、右クリックでクラス1(青)の教師データを入力します。

f:id:TadaoYamaoka:20160410221633p:plain


Githubでソースを公開しています。

github.com

BackpropagationChainerTest.py 3層パーセプトロン誤差逆伝播(Chainer版)
BackpropagationTest.py 3層パーセプトロン誤差逆伝播(numpy版)
BackpropagationChainer.py GUIツール(Chainer版)
Backpropagation.py GUIツール(numpy版)