前回までは、ネットワーク全体をTransformerで構成したところ、ResNetと比較して精度が上がらないという結果になった。
今回は、ResNetとTransformerを組み合わせて、初めにResNetで特徴マップを作成した後、その特徴マップを座標ごとに分割しトークンとして、Transformerに入力することを試す。
ネットワーク構成
ResNetは12ブロック256フィルタとし、Transformerはヘッド数8、feedforward256、8層とする。
Transformerの埋め込みの次元はResNetのフィルタ数となるため256、トークン数は座標の数になるため81になる。
ResNet20ブロックの後半8ブロックをTransformer8層に置き換えた形になる。
実装
ResNetの特徴マップを入力とする場合、次元は(batchsize, channels, tokens)となるため、PyTorchの標準のTransformerを使う場合、channelsとtokensの次元を交換する必要がある。
Transformerを自作する場合は、次元を交換しないで、Q,K,VのLinearを1x1の畳み込み層で実装して効率化できそうだが、今回はPyTorchの標準のTransformerで実装した。
入力層と出力層は、ResNetから変更しない。
class PolicyValueNetwork(nn.Module): def __init__(self, blocks, channels, activation=nn.ReLU(), fcl=256): ... # Resnet blocks self.blocks = nn.Sequential(*[ResNetBlock(channels, activation) for _ in range(blocks)]) # Transformer self.pos_encoder = PositionalEncoding(channels, 81) transformer_layer = nn.TransformerEncoderLayer(channels, nhead=8, dim_feedforward=1024, dropout=0.1, activation="gelu", batch_first=True) self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=8) ... def forward(self, x1, x2): ... # resnet blocks h = self.blocks(u1) # Transformer h = h.flatten(2) h = h.transpose(1, 2) h = self.pos_encoder(h) h = self.transformer(h) h = h.transpose(1, 2) h = h.view(h.size(0), -1, 9, 9) ...
結果
ResNet20ブロック256フィルタのモデルと比較した。
また、ヘッド数、feedforwardのユニット数、活性化関数は条件を変えて比較した。
データは前回と同じで、4エポック学習した。
nhead | feed forward | 活性化関数 | val loss | policy acc. | value acc. | |
---|---|---|---|---|---|---|
ResNet | 2.5547 | 0.4136 | 0.6714 | |||
Transformer | 8 | 256 | gelu | 2.5627 | 0.4124 | 0.6769 |
Transformer | 4 | 512 | gelu | 2.5560 | 0.4130 | 0.6805 |
Transformer | 8 | 256 | gelu | 2.5933 | 0.4132 | 0.6497 |
Transformer | 8 | 256 | relu | 2.5529 | 0.4129 | 0.6799 |
Transformer | 8 | 1024 | relu | 2.5538 | 0.4156 | 0.6803 |
Transformer | 8 | 1024 | gelu | 2.5524 | 0.4130 | 0.6798 |
Transformer | 16 | 1024 | gelu | 2.5584 | 0.4125 | 0.6752 |
ResNet20ブロック256フィルタより若干悪いくらいの精度になった。
feedforwardのユニット数を増やすと、ResNet20ブロック256フィルタより少し良くなった。
ヘッド数は少なくても多すぎても良くない。
活性化関数は、条件によって変わるためreluとgeluがどちらが良いとも言えない。
学習時間
4エポックの学習時間は以下の通り。
ただし、Windowsで実行しているため、Flash Attentionは有効になっていない。
nhead | feed forward | 活性化関数 | 学習時間 | |
---|---|---|---|---|
ResNet | 1:37 | |||
Transformer | 8 | 256 | gelu | 1:58 |
Transformer | 8 | 512 | gelu | 2:04 |
Transformer | 4 | 256 | gelu | 2:00 |
Transformer | 8 | 256 | relu | 2:01 |
Transformer | 8 | 1024 | relu | 2:14 |
Transformer | 8 | 1024 | gelu | 2:14 |
Transformer | 16 | 1024 | gelu | 2:27 |
ResNetよりもTransformerの学習時間は長くなっている。
まとめ
ResNetの特徴マップをTransformerの入力にすることを試した。
ResNetの後半8ブロックを、Transformer8層に置き換えたところ、同じくらいの精度になった。
feed forwardのユニット数を増やすことでResNetよりも少し精度がよくなった。
しかし、学習時間は増えるため、計算量に見合う精度にはなっていない。
次は、位置エンコーダを工夫することで精度が上げられないか試したい。