TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

書籍のpython-dlshogi2のバグについて 続き

先日の記事で報告した、書籍「強い将棋ソフトの創りかた」の第5章のソースコードのmake_move_labelの移動方向の判定に誤りがあった点について、GitHubソースコードを修正して、再学習したモデルファイルをプッシュしました。

make_move_labelで移動方向の判定が誤っている · Issue #5 · TadaoYamaoka/python-dlshogi2 · GitHub

checkpoints/checkpoint.pthについて

python-dlshogi2のGitHubリポジトリにあるcheckpoints/checkpoint.pthは、書籍の第7章のデータをすべて使用して、3epoch学習したチェックポイントです。

Colabではなくローカルで学習しています。
参考までに、学習に使用したシェルスクリプトは以下の通りです。Colabの場合は分割して学習しないと、データを読み込むためのメモリが不足します。

#!/bin/sh
python -m pydlshogi2.train floodgate_2019-2021_r3500-*.hcpe suisho3kai-*.hcpe dlshogi_with_gct-*.hcpe floodgate_test_2017-2018_r3500_eval5000.hcpe -b 4096 --lr 0.04 --eval_interval 1000 --log log_train_pydlshogi2-001.txt
python -m pydlshogi2.train floodgate_2019-2021_r3500-*.hcpe suisho3kai-*.hcpe dlshogi_with_gct-*.hcpe floodgate_test_2017-2018_r3500_eval5000.hcpe -b 4096 --lr 0.004 --eval_interval 1000 -r checkpoints/checkpoint-001.pth --log log_train_pydlshogi2-002.txt
python -m pydlshogi2.train floodgate_2019-2021_r3500-*.hcpe suisho3kai-*.hcpe dlshogi_with_gct-*.hcpe floodgate_test_2017-2018_r3500_eval5000.hcpe -b 4096 --lr 0.0004 --eval_interval 1000 -r checkpoints/checkpoint-002.pth --log log_train_pydlshogi2-003.txt

GPUがV100で、1epochに約18時間かかります。

48先生からプルリクを頂いているAMP対応を入れると、時間は半分くらいになると思います。
AMP対応 by bleu48 · Pull Request #4 · TadaoYamaoka/python-dlshogi2 · GitHub

3epoch学習した際のテスト精度は以下の通りです。

2021/12/28 15:15:11     INFO    epoch = 3, steps = 195555, train loss avr = 1.6887461, 0.4733937, 2.1621399, test loss = 1.6887461, 0.4733937, 2.1621399, test accuracy = 0.4764219, 0.7450078

floodgateに放流

checkpoint.pthを使用したpython-dlshogi2を、しばらくfloodgateに放流させておきます。
http://wdoor.c.u-tokyo.ac.jp/shogi/view/show-player.cgi?event=LATEST&filter=floodgate&show_self_play=1&user=python-dlshogi2
技巧2 1コアには勝てているようです。

まとめ

書籍の第5章のmake_move_labelの移動方向の判定にバグがあったため修正し、再学習したモデルをGitHubにプッシュしました。
Colabのノートブックについては、特にpython-dlshogi2のバージョン指定はしていないので、これから実行する場合は修正後のコードが反映されます。

すでにローカルでpython-dlshogi2を試されていた場合は、申し訳ありませんがgit pullを行ってください。
また、修正前のコードで学習したモデルと互換性がないため、再学習が必要になりますm(__)m。