PyTorch Lightningに対応させたdlshogiで、並列(DDP)で学習するといくつか問題が発生したため、対処した。
保存したモデルが壊れる
on_train_endで、モデルを保存していたが、マルチGPUで実行している場合、並列処理用の各プロセスでon_train_endが実行されるため、ファイル書き込みが競合して保存したファイルが壊れる問題があった。
そのため、on_fit_endで保存するように変更した。
EMAのupdate_bn
EMAのupdate_bnを有効にしている場合、update_bnに渡すデータローダを新たに作成しているため、並列処理用の各プロセスで別々のデータローダが作成されて、並列処理されない。
update_bnも並列処理するには、逆伝播を無効にしてfitを再実行する必要がある。
PyTorch Lightningで、fit完了後に条件を変えてfitを再実行するようなことはできないため、update_bn用のスクリプトを別に用意した。
DeepLearningShogi/dlshogi/ptl_update_bn.py at master · TadaoYamaoka/DeepLearningShogi · GitHub
「--ckpt_path」に最終エポックのチェックポイントを指定して、「--trainer.max_epochs」を+1した値にして、「--trainer.sync_batchnorm」をtrueにして実行する。
まとめ
PyTorch Lightningに対応させたdlshogiをDDPで実行した場合に問題が起きたので対応した。
EMAは、PyTorch LightningでDDPに対応して欲しいところだがissueはあるものの対応は進んでいないようだ。