TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

ControlNetの学習を試す

ControlNetは、Stable Diffusionの出力を、スケッチや深度、ポーズなどで制御する手法で、従来のテキストによる条件付けやImage2Imageでは難しかった制御が可能になる。

ControlNetを独自のデータセットで学習をしたいと考えており、まずは公式で用意されているチュートリアルを試した。

以下の内容は、公式のチュートリアルのままなので、手順は公式を参照した方が良い。

チュートリアル

公式のチュートリアルでは、色の付いた円を描いただけの画像を使って、スケッチに着色するControlNetモデルを学習する。


"prompt": "pale golden rod circle with old lace background"

同様の画像が5万セットある。
データセットHuggingfaceにある。

環境構築

NVIDIAのDockerコンテナを使用した。
docker run --gpus all --shm-size=32g -it nvcr.io/nvidia/pytorch:22.05-py3

environment.yamlに記載されたパッケージをインストールする。
バージョンの不整合を起こすパッケージを先にアンインストールした。

pip uninstall torch torchvision torchaudio torchtext Pillow
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torch==1.12.1 gradio==3.16.2 albumentations==1.3.0 opencv-contrib-python==4.3.0.36 imageio==2.9.0 imageio-ffmpeg==0.4.2 pytorch-lightning==1.5.0 omegaconf==2.1.1 test-tube streamlit==1.12.1 einops==0.3.0 transformers==4.19.2 webdataset==0.2.5 kornia==0.6 open_clip_torch==2.0.2 invisible-watermark streamlit-drawable-canvas==0.8.0 torchmetrics==0.6.0 timm==0.6.12 addict==2.4.0 yapf==0.32.0 prettytable==3.6.0 safetensors==0.2.7 basicsr==1.4.2 Pillow
pip uninstall tensorboard tb-nightly
pip install tensorboard

準備

Stable Diffusion 1.5のチェックポイントを元に初期チェックポイントを作成する。

python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt

4GPUで学習するため、pl.Trainerの引数をgpus=4に変更した。

trainer = pl.Trainer(gpus=4, precision=32, callbacks=[logger])

また、「if __name__ == '__main__':」がないと、データローダがエラーになるため、処理をmain関数に移動した。

訓練

訓練を開始する。

python tutorial_train.py

結果

訓練損失

2.4エポックほど学習した際の訓練損失は以下の通りとなった。

評価画像

200ステップごとに評価画像がimage_log/train/に生成される。
評価では、プロンプトとスケッチから画像を生成する。
訓練の過程で生成された画像は、以下のようになった。

訓練開始時

1エポック

2エポック

2.4エポック

訓練開始時は、Stable Diffusionの出力そのままになっている。
1エポック学習した時点で、ほぼ円を単色で塗った画像になっている。
2エポック学習した画像では、模様のようなものが出力されており、まだ十分に学習されていない。
2.4エポック学習した画像では、円と背景がほぼ単色で塗れているが、エッジがぼやけた個所が残っている。

学習時間

V100 4GPUで、バッチサイズ4で学習した場合、1エポック(3125step)に約1時間50分かかった。

まとめ

ControlNetの学習を公式のチュートリアルで試した。
円に色を塗るだけの単純なデータセットで、数エポックである程度学習できることが確認できた。
エッジがぼやけた個所が残ったため、どれくらいで鮮明になるか別途確認したい。

ControlNetの学習方法がわかったので、次は独自のデータセットで学習を試したい。