昨日書いた2値分類で中間の値も学習するコードは、損失の計算で計算グラフを構築して、backward()時の微分はChainerに任せていた。
しかし、交差エントロピーの微分は、以下のように引き算で表すことができるため、計算グラフを構築しなくてもよい。
交差エントロピーの独自実装
Chainerで微分を独自実装する場合、カスタムFunctionクラスを実装する。
交差エントロピーは、以下のように実装できる(SigmoidCrossEntropyのコードを改造している)。
import numpy from chainer import cuda from chainer import function from chainer.functions.activation import sigmoid from chainer import utils from chainer.utils import type_check class SigmoidCrossEntropy2(function.Function): def __init__(self): pass def check_type_forward(self, in_types): type_check.expect(in_types.size() == 2) x_type, t_type = in_types type_check.expect( x_type.dtype == numpy.float32, t_type.dtype == numpy.float32, x_type.shape == t_type.shape ) def forward(self, inputs): xp = cuda.get_array_module(*inputs) x, t = inputs # stable computation of the cross entropy. log1p_ex = xp.log1p(xp.exp(x)) loss = t * (log1p_ex - x) + (1 - t) * log1p_ex self.count = len(x) return utils.force_array( xp.divide(xp.sum(loss), self.count, dtype=x.dtype)), def backward(self, inputs, grad_outputs): xp = cuda.get_array_module(*inputs) x, t = inputs gloss = grad_outputs[0] y, = sigmoid.Sigmoid().forward((x,)) gx = xp.divide( gloss * (y - t), self.count, dtype=y.dtype) return gx, None def sigmoid_cross_entropy2(x, t): return SigmoidCrossEntropy2()(x, t)
検証
昨日のコードの損失をsigmoid_cross_entropy2に変更して、MNISTデータセットを学習できるか検証した。
差分箇所
... # 損失計算 loss = sigmoid_cross_entropy2(y, t) ... # 損失計算 loss_test = sigmoid_cross_entropy2(y_test, t_test) ...
結果
GPU: 0 # unit: 1000 # Minibatch-size: 100 # epoch: 20 epoch=1, train loss=0.46885148, test loss=0.33832973 epoch=2, train loss=0.30149257, test loss=0.2579055 epoch=3, train loss=0.23578753, test loss=0.2050258 epoch=4, train loss=0.19029677, test loss=0.1703787 epoch=5, train loss=0.15914916, test loss=0.14760743 epoch=6, train loss=0.13843748, test loss=0.13353598 epoch=7, train loss=0.12384777, test loss=0.123705104 epoch=8, train loss=0.11282795, test loss=0.11471477 epoch=9, train loss=0.104268566, test loss=0.10847781 epoch=10, train loss=0.09730684, test loss=0.103607066 epoch=11, train loss=0.091508955, test loss=0.09841835 epoch=12, train loss=0.08627852, test loss=0.0952349 epoch=13, train loss=0.08196059, test loss=0.095587894 epoch=14, train loss=0.078234516, test loss=0.089423485 epoch=15, train loss=0.0742373, test loss=0.087859996 epoch=16, train loss=0.07113662, test loss=0.08602422 epoch=17, train loss=0.06793152, test loss=0.08308805 epoch=18, train loss=0.06521834, test loss=0.08113079 epoch=19, train loss=0.062456965, test loss=0.0854297 epoch=20, train loss=0.060211364, test loss=0.07905889
小さな誤差があるが、ほぼ同じ結果になった。