TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

大規模言語モデルで将棋AIを作る その4(ResNetの特徴マップ)

前回までは、ネットワーク全体をTransformerで構成したところ、ResNetと比較して精度が上がらないという結果になった。

今回は、ResNetとTransformerを組み合わせて、初めにResNetで特徴マップを作成した後、その特徴マップを座標ごとに分割しトークンとして、Transformerに入力することを試す。

ネットワーク構成

ResNetは12ブロック256フィルタとし、Transformerはヘッド数8、feedforward256、8層とする。
Transformerの埋め込みの次元はResNetのフィルタ数となるため256、トークン数は座標の数になるため81になる。

ResNet20ブロックの後半8ブロックをTransformer8層に置き換えた形になる。

実装

ResNetの特徴マップを入力とする場合、次元は(batchsize, channels, tokens)となるため、PyTorchの標準のTransformerを使う場合、channelsとtokensの次元を交換する必要がある。
Transformerを自作する場合は、次元を交換しないで、Q,K,VのLinearを1x1の畳み込み層で実装して効率化できそうだが、今回はPyTorchの標準のTransformerで実装した。

入力層と出力層は、ResNetから変更しない。

class PolicyValueNetwork(nn.Module):
    def __init__(self, blocks, channels, activation=nn.ReLU(), fcl=256):
        ...
        # Resnet blocks
        self.blocks = nn.Sequential(*[ResNetBlock(channels, activation) for _ in range(blocks)])
        
        # Transformer
        self.pos_encoder = PositionalEncoding(channels, 81)
        transformer_layer = nn.TransformerEncoderLayer(channels, nhead=8, dim_feedforward=1024, dropout=0.1, activation="gelu", batch_first=True)
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=8)
        ...

    def forward(self, x1, x2):
        ...
        # resnet blocks
        h = self.blocks(u1)
        
        # Transformer
        h = h.flatten(2)
        h = h.transpose(1, 2)
        h = self.pos_encoder(h)
        h = self.transformer(h)
        h = h.transpose(1, 2)
        h = h.view(h.size(0), -1, 9, 9)
        ...

結果

ResNet20ブロック256フィルタのモデルと比較した。
また、ヘッド数、feedforwardのユニット数、活性化関数は条件を変えて比較した。

データは前回と同じで、4エポック学習した。

nhead feed forward 活性化関数 val loss policy acc. value acc.
ResNet 2.5547 0.4136 0.6714
Transformer 8 256 gelu 2.5627 0.4124 0.6769
Transformer 4 512 gelu 2.5560 0.4130 0.6805
Transformer 8 256 gelu 2.5933 0.4132 0.6497
Transformer 8 256 relu 2.5529 0.4129 0.6799
Transformer 8 1024 relu 2.5538 0.4156 0.6803
Transformer 8 1024 gelu 2.5524 0.4130 0.6798
Transformer 16 1024 gelu 2.5584 0.4125 0.6752

ResNet20ブロック256フィルタより若干悪いくらいの精度になった。
feedforwardのユニット数を増やすと、ResNet20ブロック256フィルタより少し良くなった。

ヘッド数は少なくても多すぎても良くない。
活性化関数は、条件によって変わるためreluとgeluがどちらが良いとも言えない。

学習時間

4エポックの学習時間は以下の通り。
ただし、Windowsで実行しているため、Flash Attentionは有効になっていない。

nhead feed forward 活性化関数 学習時間
ResNet 1:37
Transformer 8 256 gelu 1:58
Transformer 8 512 gelu 2:04
Transformer 4 256 gelu 2:00
Transformer 8 256 relu 2:01
Transformer 8 1024 relu 2:14
Transformer 8 1024 gelu 2:14
Transformer 16 1024 gelu 2:27

ResNetよりもTransformerの学習時間は長くなっている。

まとめ

ResNetの特徴マップをTransformerの入力にすることを試した。
ResNetの後半8ブロックを、Transformer8層に置き換えたところ、同じくらいの精度になった。
feed forwardのユニット数を増やすことでResNetよりも少し精度がよくなった。
しかし、学習時間は増えるため、計算量に見合う精度にはなっていない。

次は、位置エンコーダを工夫することで精度が上げられないか試したい。