TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

2値分類で中間の値も学習する(続き)

昨日書いた2値分類で中間の値も学習するコードは、損失の計算で計算グラフを構築して、backward()時の微分はChainerに任せていた。

しかし、交差エントロピー微分は、以下のように引き算で表すことができるため、計算グラフを構築しなくてもよい。
\displaystyle
H'(p, t) = t - p

交差エントロピーの独自実装

Chainerで微分を独自実装する場合、カスタムFunctionクラスを実装する。
交差エントロピーは、以下のように実装できる(SigmoidCrossEntropyのコードを改造している)。

import numpy

from chainer import cuda
from chainer import function
from chainer.functions.activation import sigmoid
from chainer import utils
from chainer.utils import type_check

class SigmoidCrossEntropy2(function.Function):

    def __init__(self):
        pass

    def check_type_forward(self, in_types):
        type_check.expect(in_types.size() == 2)

        x_type, t_type = in_types
        type_check.expect(
            x_type.dtype == numpy.float32,
            t_type.dtype == numpy.float32,
            x_type.shape == t_type.shape
        )

    def forward(self, inputs):
        xp = cuda.get_array_module(*inputs)
        x, t = inputs

        # stable computation of the cross entropy.
        log1p_ex = xp.log1p(xp.exp(x))
        loss = t * (log1p_ex - x) + (1 - t) * log1p_ex

        self.count = len(x)

        return utils.force_array(
            xp.divide(xp.sum(loss), self.count, dtype=x.dtype)),

    def backward(self, inputs, grad_outputs):
        xp = cuda.get_array_module(*inputs)
        x, t = inputs
        gloss = grad_outputs[0]
        y, = sigmoid.Sigmoid().forward((x,))
        gx = xp.divide(
            gloss * (y - t), self.count,
            dtype=y.dtype)
        return gx, None


def sigmoid_cross_entropy2(x, t):
    return SigmoidCrossEntropy2()(x, t)

検証

昨日のコードの損失をsigmoid_cross_entropy2に変更して、MNISTデータセットを学習できるか検証した。

差分箇所
        ...
        # 損失計算
        loss = sigmoid_cross_entropy2(y, t)
        ...
        # 損失計算
        loss_test = sigmoid_cross_entropy2(y_test, t_test)
        ...
結果
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 20
epoch=1, train loss=0.46885148, test loss=0.33832973
epoch=2, train loss=0.30149257, test loss=0.2579055
epoch=3, train loss=0.23578753, test loss=0.2050258
epoch=4, train loss=0.19029677, test loss=0.1703787
epoch=5, train loss=0.15914916, test loss=0.14760743
epoch=6, train loss=0.13843748, test loss=0.13353598
epoch=7, train loss=0.12384777, test loss=0.123705104
epoch=8, train loss=0.11282795, test loss=0.11471477
epoch=9, train loss=0.104268566, test loss=0.10847781
epoch=10, train loss=0.09730684, test loss=0.103607066
epoch=11, train loss=0.091508955, test loss=0.09841835
epoch=12, train loss=0.08627852, test loss=0.0952349
epoch=13, train loss=0.08196059, test loss=0.095587894
epoch=14, train loss=0.078234516, test loss=0.089423485
epoch=15, train loss=0.0742373, test loss=0.087859996
epoch=16, train loss=0.07113662, test loss=0.08602422
epoch=17, train loss=0.06793152, test loss=0.08308805
epoch=18, train loss=0.06521834, test loss=0.08113079
epoch=19, train loss=0.062456965, test loss=0.0854297
epoch=20, train loss=0.060211364, test loss=0.07905889

小さな誤差があるが、ほぼ同じ結果になった。