torch.compile は、PyTorch 2.0 以降で導入された高速化機能で、既存の PyTorch コードをほとんど変更せずに JIT コンパイルして最適化できる仕組みである。
主に GPU 実行時のオーバーヘッド削減やカーネル融合によって性能を向上できる。
dlshogiの train.py と ptl.py を torch.compile に対応して、どれくらい学習が速くなるか測定した。
torch.compile対応の実装
「--use_compile」オプションで、torch.compileの有効無効を切り替えられるようにした。
Pytorchのtorch.compileの引数に対応して、以下のオプションを指定できる。
「--compile_backend」で、backendを指定できる。
指定しない場合はデフォルトで、inductorが使用されるが、Windows環境では動かないため、Windowsではデフォルトを「aot_eager」にした(triton-windowsをインストールすればinductorも使える)。
「--compile_mode」で、default / reduce-overhead / max-autotune などの最適化モードを指定できる。
「max-autotune」を指定するとエラーなった。原因は調べられていない。
「--compile_fullgraph」で、モデル全体を1つのグラフとしてコンパイルすることを要求する。
有効にすると、途中で Python 処理や未対応 op によってグラフが分断される場合、エラーになりやすいです。
一方で、全体がきれいにコンパイルできるモデルでは最適化しやすくなる。
「--compile_dynamic」で、入力 shape が変わるケースに対応しやすくする設定を有効にする。
dlshogiのモデルは特に指定しなくてよい。
測定結果
測定条件
- dlshogiの最新のResNet+Transformerモデル(20ブロック256フィルタ)
- バッチサイズ: 4096
- 学習率: 0.04
- use_amp: 有効
- amp_dtype: bfloat16
- Ubuntu 22.04 + PyTorch 2.3
比較対象
- no compile: torch.compileなし
- compile: torch.compileのデフォルト設定
- fullgraph: --compile_fullgraphを指定
ステップあたりの学習時間は以下の通り。

torch.compileを有効にすることで、学習速度が1.58倍になっている。
さらに、fullgraphを有効になると、学習速度が1.64倍になっている。
まとめ
dlshogiの train.py / ptl.py を torch.compile 対応した。
また、backend、mode、fullgraph、dynamic などの各種 compile オプションも指定可能にした。
Ubuntu 22.04 + PyTorch 2.3 環境で最新ResNet+Transformerモデルの学習時間を測定した結果、torch.compile有効時は約1.58倍、さらにfullgraph有効時は約1.64倍まで高速化した。
AMP対応以来の大きな学習速度向上となった。