TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

ChainerでSENetを実装する

ILSVRC 2017で優勝したSqueeze-and-Excitation Networks (SENet)を、こちらのPyTorchの実装を参考にChainerで実装した。

GitHub - TadaoYamaoka/senet.chainer
実装したのは、SE-ResNet20/Cifar10のみ。

結果

通常のResNet
>python cifar.py --batch_size 64 --epochs 10 --baseline
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.54999     1.5513                0.42727        0.460589                  28.2861
2           0.995298    0.95569               0.645407       0.660131                  53.5864
3           0.776467    0.86988               0.728153       0.698746                  79.6711
4           0.651746    0.68348               0.773467       0.768611                  104.735
5           0.5714      0.754005              0.80229        0.745422                  130.665
6           0.517285    0.652727              0.820042       0.777966                  157.214
7           0.472674    0.59009               0.835147       0.801453                  183.322
8           0.437737    0.579935              0.846951       0.801154                  209.395
9           0.411429    0.618022              0.854979       0.785032                  235.985
10          0.384668    0.614629              0.865217       0.797074                  261.89

SENet(reduction 16)

>python cifar.py --batch_size 64 --epochs 10 --reduction 16
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.37516     2.41607               0.490369       0.348129                  49.4437
2           0.907583    0.860231              0.675136       0.697253                  95.2803
3           0.719768    0.791775              0.748319       0.723229                  141.569
4           0.61616     0.781935              0.785531       0.736664                  187.876
5           0.542171    0.70497               0.811881       0.759654                  234.988
6           0.489325    0.624126              0.828185       0.789112                  282.168
7           0.450093    0.674201              0.84221        0.771497                  329.322
8           0.411275    0.760015              0.855954       0.748408                  376.937
9           0.387795    0.6829                0.863091       0.785928                  424.193
10          0.363779    0.609526              0.872639       0.806131                  471.804

SENet(reduction 8)

>python cifar.py --batch_size 64 --epochs 10 --reduction 8
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.3705      1.46585               0.494226       0.510251                  49.7231
2           0.91669     0.876904              0.671655       0.691182                  96.7188
3           0.747184    0.859343              0.737476       0.693969                  143.924
4           0.634395    0.798829              0.777629       0.733678                  191.17
5           0.566925    0.775446              0.803788       0.741043                  239.154
6           0.512038    0.688727              0.823223       0.768113                  287.381
7           0.466351    0.666799              0.838828       0.77707                   334.824
8           0.430302    0.651521              0.849532       0.77926                   381.764
9           0.407922    0.589221              0.858955       0.798268                  429.276
10          0.381586    0.627554              0.866957       0.792596                  475.764

SENet(reduction 4)

>python cifar.py --batch_size 64 --epochs 10 --reduction 4
epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           1.40231     1.74731               0.480858       0.462679                  49.334
2           0.919275    1.12474               0.669354       0.613455                  96.3172
3           0.724938    0.746381              0.745378       0.73756                   142.618
4           0.61876     0.680163              0.784871       0.761445                  189.007
5           0.547908    0.653891              0.808883       0.77916                   236.142
6           0.496254    0.847583              0.828265       0.720143                  282.802
7           0.4581      0.620524              0.839309       0.79578                   331.504
8           0.419443    0.582007              0.855014       0.804936                  377.842
9           0.399142    0.609093              0.861992       0.79369                   425.787
10          0.372682    0.606829              0.868938       0.796477                  473.873

Cifar10の10エポックの学習では、誤差程度の差しか現れなかった。

将棋AIでも効果があるか試す予定。