TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

【dlshogi】torch.compileに対応したら学習が1.6倍速くなった件

torch.compile は、PyTorch 2.0 以降で導入された高速化機能で、既存の PyTorch コードをほとんど変更せずに JIT コンパイルして最適化できる仕組みである。
主に GPU 実行時のオーバーヘッド削減やカーネル融合によって性能を向上できる。

dlshogiの train.py と ptl.py を torch.compile に対応して、どれくらい学習が速くなるか測定した。

torch.compile対応の実装

「--use_compile」オプションで、torch.compileの有効無効を切り替えられるようにした。

Pytorchのtorch.compileの引数に対応して、以下のオプションを指定できる。

「--compile_backend」で、backendを指定できる。
指定しない場合はデフォルトで、inductorが使用されるが、Windows環境では動かないため、Windowsではデフォルトを「aot_eager」にした(triton-windowsをインストールすればinductorも使える)。

「--compile_mode」で、default / reduce-overhead / max-autotune などの最適化モードを指定できる。
「max-autotune」を指定するとエラーなった。原因は調べられていない。

「--compile_fullgraph」で、モデル全体を1つのグラフとしてコンパイルすることを要求する。
有効にすると、途中で Python 処理や未対応 op によってグラフが分断される場合、エラーになりやすいです。
一方で、全体がきれいにコンパイルできるモデルでは最適化しやすくなる。

「--compile_dynamic」で、入力 shape が変わるケースに対応しやすくする設定を有効にする。
dlshogiのモデルは特に指定しなくてよい。

測定結果

測定条件
  • dlshogiの最新のResNet+Transformerモデル(20ブロック256フィルタ)
  • バッチサイズ: 4096
  • 学習率: 0.04
  • use_amp: 有効
  • amp_dtype: bfloat16
  • Ubuntu 22.04 + PyTorch 2.3
比較対象
  • no compile: torch.compileなし
  • compile: torch.compileのデフォルト設定
  • fullgraph: --compile_fullgraphを指定

ステップあたりの学習時間は以下の通り。


torch.compileを有効にすることで、学習速度が1.58倍になっている。
さらに、fullgraphを有効になると、学習速度が1.64倍になっている。

まとめ

dlshogiの train.py / ptl.py を torch.compile 対応した。
また、backend、mode、fullgraph、dynamic などの各種 compile オプションも指定可能にした。
Ubuntu 22.04 + PyTorch 2.3 環境で最新ResNet+Transformerモデルの学習時間を測定した結果、torch.compile有効時は約1.58倍、さらにfullgraph有効時は約1.64倍まで高速化した。

AMP対応以来の大きな学習速度向上となった。