TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Hugging Face TrainerでMNISTを学習

PyTorchで画像分類モデルを学習するとき、学習ループやチェックポイント管理、TensorBoard対応などを毎回自前で実装するのはやや煩雑である。

Hugging Face Transformers の Trainer を使うことで、NLP用途だけでなくCNNのような画像モデルでも、シンプルかつ統一的な学習基盤を構築できる。

本記事では、MNISTを題材として、

  • CNNモデルの実装
  • Hugging Face Trainer を使った学習
  • 分散学習 (torchrun)
  • Resume training
  • TensorBoard ログ管理
  • Optimizer / Scheduler 設定

について紹介する。

ソース

github.com

特徴

  • jsonargparseで、YAML形式で設定管理
  • PyTorch Lightningと同様のインクリメンタルなバージョン管理
  • マルチGPU対応

学習実行方法

基本的な学習は以下で実行できる。

python train.py --config config.yaml

これだけで、

  • モデル保存
  • checkpoint保存
  • TensorBoardログ
  • evaluation
  • scheduler
  • optimizer

などが Trainer によって自動管理される。

分散学習 (torchrun)

GPUを複数枚使う場合は torchrun を利用する。

torchrun --nproc_per_node 2 train.py --config config.yaml

Trainer は DistributedDataParallel (DDP) に自然対応しているため、追加実装はほぼ不要である。

PyTorch Lightningに近い手軽さで分散学習を扱える。

出力ディレクトリ管理

デフォルト動作

training.output_dir を設定しない場合、デフォルトでは logs/ 以下に version 管理される。

例:

logs/version_0
logs/version_1
logs/version_2

実験ルート変更

CLIから experiment root を変更することも可能である。

python train.py --config config.yaml --experiment_root runs

すると以下のように生成される。

runs/version_0
runs/version_1

output_dir を直接指定

もし training.output_dir を設定した場合は、そのディレクトリを直接使用する。

python train.py \
  --config config.yaml \
  --training.output_dir runs/manual_experiment

この場合は version_* ディレクトリは作成されない。

Resume Training

学習途中から再開する場合は checkpoint を指定する。

python train.py \
  --config config.yaml \
  --training.resume_from_checkpoint logs/version_0/checkpoint-844

TensorBoard

TensorBoardでログを可視化できる。

tensorboard --logdir logs

Optimizer / Scheduler 設定

設定は config.yamltraining セクションで管理する。

例えば以下のように記述する。

training:
  learning_rate: 1e-3
  weight_decay: 0.01
  optim: adamw_torch
  lr_scheduler_type: cosine
  warmup_steps: 100

内部的には Trainer が optimizer / scheduler を生成する。

まとめ

今回紹介した構成では、

  • CNN
  • MNIST
  • Hugging Face Trainer
  • 分散学習
  • TensorBoard
  • Resume
  • Config管理

を非常に少ないコード量で実現できる。

Transformer以外のモデルでも Trainer を積極的に利用する価値は十分にある。