TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

畳み込み層とBatchNormalizationのレイヤー融合をChainerで試してみた

畳み込み層のフィルタは行列で表すことができる。
BatchNormalizationも、入力の要素ごとに適用するスカラーの式だが、カーネルサイズ1×1の畳み込みで表すことができる。

推論のフェーズでは、BatchNormalizationの平均と分散は、学習時の統計情報を使うことで、固定の行列とすることができる。
つまり、推論では畳み込み層のフィルタと行列の合成(レイヤー融合)が可能である。

Chainerではレイヤー融合の機能は提供されていないため、自力で実装する必要がある。
Chainerで実装したサンプルを探したが見つからなかったので、こちらのPyTorchのサンプルを参考に実装してみた。

実装

モデルの定義

以下の通り、畳み込み層とBatchNormalizationの後に全結合層を行う、レイヤー融合の対象とするモデルを定義する。

class MyNet(Chain):

    def __init__(self, filters=64, units=64, n_out=10):
        super(MyNet, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(in_channels=3, out_channels=filters, ksize=3, stride=1)
            self.bn2 = L.BatchNormalization(filters)
            self.l3 = L.Linear(None, units)
            self.l4 = L.Linear(None, n_out)

    def forward(self, x):
        h1 = F.relu(self.bn2(self.conv1(x)))
        h2 = F.relu(self.l3(h1))
        return self.l4(h2)
レイヤー融合関数

畳み込み層とBatchNormalizationを融合して1つの畳み込み層とする関数を以下の通り実装する。

def fuse_conv_and_bn(conv, bn):
    # init
    fusedconv = L.Convolution2D(
        conv.W.shape[1],
        conv.out_channels,
        ksize=conv.ksize,
        stride=conv.stride,
        pad=conv.pad
    )

    # prepare filters
    w_conv = conv.W.data.reshape(conv.out_channels, -1)
    w_bn = np.diag(np.divide(bn.gamma.data, np.sqrt(bn.eps + bn.avg_var)))
    np.copyto(fusedconv.W.data, np.matmul(w_bn, w_conv).reshape(fusedconv.W.data.shape))

    # prepare spatial bias
    if conv.b is not None:
        b_conv = conv.b.data
    else:
        b_conv = np.zeros(conv.W.data.shape[0])
    b_bn = bn.beta.data - np.divide(np.multiply(bn.gamma.data, bn.avg_mean), np.sqrt(bn.avg_var + bn.eps))
    np.copyto(fusedconv.b.data, b_conv + b_bn)

    # we're done
    return fusedconv
レイヤー融合したモデル

上記の関数を使用して、レイヤー融合したモデルを定義する。

class MyFusedNet(Chain):

    def __init__(self, model, units=64, n_out=10):
        super(MyFusedNet, self).__init__()
        with self.init_scope():
            self.conv_fused = fuse_conv_and_bn(model.conv1, model.bn2)
            self.l3 = L.Linear(None, units, initialW=model.l3.W.data, initial_bias=model.l3.b.data)
            self.l4 = L.Linear(None, n_out, initialW=model.l4.W.data, initial_bias=model.l4.b.data)

    def forward(self, x):
        h1 = F.relu(self.conv_fused(x))
        h2 = F.relu(self.l3(h1))
        return self.l4(h2)

モデルの学習

CIFAR-10のデータセットを使用して、レイヤー融合前のモデル学習する。

train, test = chainer.datasets.get_cifar10()

batchsize = 128

train_iter = iterators.SerialIterator(train, batchsize, shuffle=False)
test_iter = iterators.SerialIterator(test, batchsize, False, shuffle=False)

gpu_id = 0

model = MyNet()
model.to_gpu(gpu_id)

max_epoch = 1

model = L.Classifier(model)

optimizer = optimizers.MomentumSGD()

optimizer.setup(model)

updater = training.updaters.StandardUpdater(train_iter, optimizer, device=gpu_id)

trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='result')
trainer.extend(extensions.LogReport())
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))

trainer.run()

学習済みモデルをレイヤー融合

学習済みモデルをレイヤー融合する。レイヤー融合前にモデルを一旦CPUに転送する必要がある。

model.to_cpu()

model_fused = MyFusedNet(model.predictor)

テスト

レイヤー融合前後のモデルで結果が同じになるかテストする。

x, t = test[0]
with chainer.using_config('train', False):
    y = model.predictor(x[None, ...])
    print(y)

with chainer.using_config('train', False):
    y = model_fused(x[None, ...])
    print(y)
結果
variable([[-0.4691602   0.63228214  0.3182091   2.3503518  -0.8040181
            1.8878069   2.7495365  -2.3511798   1.8371497  -0.57447565]])
variable([[-0.46915972  0.63228184  0.3182095   2.3503518  -0.8040179
            1.8878068   2.7495365  -2.3511798   1.8371491  -0.574476  ]])

非常に小さな誤差があるが、ほぼ同じ結果になった。

推論時間の比較

レイヤー融合前後で推論時間の比較を行った。

model.to_gpu(gpu_id)
model_fused.to_gpu(gpu_id)

train_iter.reset()
itr = 0
sum_train_accuracy = 0
start = time.time()
for i in range(0, len(train), batchsize):
    train_batch = train_iter.next()
    x_train, t_train = chainer.dataset.concat_examples(train_batch, gpu_id)
    with chainer.using_config('train', False):
        y_train = model.predictor(x_train)
    sum_train_accuracy += F.accuracy(y_train, t_train).data
    itr += 1
elapsed_time = time.time() - start
print("elapsed_time:{0}".format(elapsed_time) + "[sec]")
print(itr, sum_train_accuracy / itr)

train_iter.reset()
itr = 0
sum_train_accuracy = 0
start = time.time()
for i in range(0, len(train), batchsize):
    train_batch = train_iter.next()
    x_train, t_train = chainer.dataset.concat_examples(train_batch, gpu_id)
    with chainer.using_config('train', False):
        y_train = model_fused(x_train)
    sum_train_accuracy += F.accuracy(y_train, t_train).data
    itr += 1
elapsed_time = time.time() - start
print("elapsed_time:{0}".format(elapsed_time) + "[sec]")
print(itr, sum_train_accuracy / itr)
結果
elapsed_time:1.419318437576294[sec]
391 0.4076087
elapsed_time:1.3278264999389648[sec]
391 0.4076087

レイヤー融合後のモデルでは、推論時間が約93.5%になった。
accuracyは同じであり、精度の低下はない。
畳み込み層1層で試したが、畳み込み層が多い場合はレイヤー融合の効果は相対的に大きくなると思われる。