TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dlshogiモデルの20ブロックから15ブロックへの知識蒸留

dlshogiでは、10ブロックのモデルから始めて、15ブロック、20ブロックとモデルサイズを大きくしている。
ブロックが大きいほど、精度が高くなっており強さにも反映される。
第32回世界コンピュータ将棋選手権のdlshogiでは20ブロックのモデルを採用している。

ブロック数を増やすと、NPSが低下するため、精度向上とNPSはトレードオフになっている。
ブロック数を増やすことによる精度向上分が上回っているため、強くなっていると考える。
しかし、NPSが低くなると探索が浅くなるため、探索が重要になる局面では、ブロック数が小さい方が有利になる可能性がある。

小さなブロック数で精度が高いモデルが理想であるため、より大きなブロック数で質の高い教師を生成して、小さなブロックを学習するということを試みたい。

知識蒸留

単に、より大きなブロック数で生成した教師を小さなブロックで学習するだけでも、小さなブロックの精度を上げることはできそうだが、知識蒸留といういう方法で、より効果的に学習できると考える。

知識蒸留は、大きなモデルから小さなモデルへ知識を転送する手法で、小さなモデルの教師あり学習の損失と、大きなモデルが出力する分布とのKLダイバージェンスの加重平均をとるのが、代表的な方法である。

z_tを大きなモデル(教師モデル)の出力、z_sを小さなモデル(生徒モデル)の出力とすると、損失関数は以下の通りになる。

\displaystyle
L=(1 - \alpha) L_{cross\_entropy} + \alpha D_{KL}(z_t, z_s)
ここで、\alphaは加重平均の重み定数である。

モデルの出力に温度を適用する場合があるが、今回は温度なしとした。

実験

dlshogiの20ブロックを教師モデル、15ブロックを生徒モデルとして、20ブロックで生成した9.4千万局面(同一局面を平均化すると7千万局面)を学習し、知識蒸留ありとなしを比較する。

自己対局では訪問回数の分布を生成して学習しているため、知識蒸留はしなくてもよいのではという気もしているので、実験で確かめたい。

実装

知識蒸留ありの場合、訓練スクリプト(train.py)に、以下の通りKLDを計算して加重平均をとる処理を追加した。

if args.distillation_model:
    with torch.no_grad():
        distillation_y1, distillation_y2 = distillation_model(x1, x2)
    distillation_loss1 = F.kl_div(F.log_softmax(y1, dim=1), F.softmax(distillation_y1, dim=1), reduction='batchmean')
    distillation_p2 = distillation_y2.sigmoid()
    distillation_loss2 = (distillation_p2 * (F.logsigmoid(distillation_y2) - F.logsigmoid(y2)) + (1 - distillation_p2) * (-F.softplus(distillation_y2) + F.softplus(y2))).mean()
    loss = (1 - args.distillation_alpha) * loss + args.distillation_alpha * (distillation_loss1 + distillation_loss2)

価値の損失は、2値分類のため、PyTorchのkl_divが使えず、定義通り計算している。
その際、log(1-p)=-softplus(y)の関係を利用している(yはロジット)。

ただし、教師となる大きなモデルの出力は固定となるため、KLダイバージェンスは、実装上は交差エントロピーに置き換えても問題ないはずである。
今回は定義通り実装した。
交差エントロピーを使って実装した場合も、別途実験したい。

実験条件
  • 加重平均の定数\alpha=0.5
  • バッチサイズ 4096
  • 学習率 0.04、1エポックごと半減
  • エポック数 4
  • 平均化あり、評価値補正あり
実験結果

初期値によって、結果が変わるため、4回測定して平均をとった。
テストデータにfloodgateのR3500以上の棋譜からサンプリングした856,923局面(重複なし)を使用して評価した。

知識蒸留 方策損失 価値損失 方策正解率 価値正解率 方策エントロピー 価値エントロピー
なし 1.6143 0.510782 0.478344 0.729586 1.53176 0.539728
あり 1.56883 0.49999 0.488682 0.738732 1.47312 0.537952

知識蒸留ありの方が、方策、価値ともに精度が高くなった。
正解率は、方策、価値ともに、約1%高くなっている。

学習時間

知識蒸留を行うと、教師モデルの推論の分、学習時間が増える。
学習時間は、以下の通りであった。

知識蒸留 1エポックあたり平均
なし 1:33:19
あり 2:16:59

学習時間は、約47%増加した。

まとめ

知識蒸留により、20ブロックのモデルを教師として、15ブロックのモデルを学習した。
20ブロックで生成した教師を単に教師ありで学習するよりも、知識蒸留を行った方が精度が高くなることが確かめられた。

自己対局で生成した訪問回数の分布を学習するため、知識蒸留は効果はない気もしていたが、訪問回数の分布を学習する場合にも効果があることがわかった。

今回は、加重平均の定数0.5で実験したが、他の条件でも試したい。

また、知識蒸留を行った15ブロックのモデルが、同じデータで学習した20ブロックのモデルよりも強くなるか別途検証したい。