画像生成モデルは、Stable Diffusionが出てきた頃は、Diffusionモデルが主流であったが、最近のStable Diffusion 3やFlux.1では、Flow Machingのモデルが使用されている。
Flow Machingにもいくつかの訓練方法がある。
Rectified Flowは比較的シンプルでスケール可能な方法であり、画像のような高次元の分布にも適用できる。
Flow Maching
Diffusionモデルは確率微分方程式(SDE)で分布の変換を行うが、Flow Machingでは常微分方程式(ODE)で分布の変換を行う。
ODEを使うことで、確率モデルのサンプリングが不要で、高速に変換が可能になる。
Flow Machingの目的は、分布間を変換するベクトル場を学習することである。
最適輸送で、分布を変換する際のコストを最小化する手法などが提案されているが、Rectified Flowではシンプルに線形に変換を行う。
初期値(ガウスノイズなど)が与えられると、訓練済みのモデルで変換方向を表すベクトルを求めて、少ないステップで変換できる。

公式実装を動かす
動かして確認した方が理解しやすいため、公式実装を動かしてみた。
GitHub - gnobitab/RectifiedFlow: Official Implementation of Rectified Flow (ICLR2023 Spotlight)
公式実装では、DDPM++のUNetモデルを流用して、CIFAR10データセットを学習する。
環境構築
WSL2のUbuntu 22.04で構築しようとしたが、公式リポジトリが公開されたのは2022年で、使用されているCUDAのバージョンが古いため、Ubuntu 20.04でないとインストールできなかった。
そこで、WSL2のUbuntu 20.04で構築した。
Pythonの仮想環境作成
conda create -n rectflow python=3.8 conda activate rectflow
PyTorchインストール
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
ライブラリインストール
pip install tensorflow==2.9.0 tensorflow-probability==0.12.2 tensorflow-gan==2.0.0 tensorflow-datasets==4.6.0 numpy==1.21.6 ninja==1.11.1 matplotlib==3.7.0 ml_collections==0.1.1 jax==0.4.6 jaxlib==0.4.6 scipy==1.10.0
ビルド環境
モデルの実装で、torch.utils.cpp_extensionを使用してCUDAカーネルをビルドしている部分があるため、ビルド環境が必要になる。
sudo apt-get install build-essential ninja-build -y
CUDAインストール
ビルドにCUDAが必要になる。
wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600 wget https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-wsl-ubuntu-11-3-local_11.3.0-1_amd64.deb sudo dpkg -i cuda-repo-wsl-ubuntu-11-3-local_11.3.0-1_amd64.deb sudo apt-key add /var/cuda-repo-wsl-ubuntu-11-3-local/7fa2af80.pub sudo apt-get update sudo apt-get -y install cuda
実行時に環境変数
CUDA_HOME = /usr/local/cuda
の設定が必要になる。
訓練実行
デフォルトの設定だと、
training.n_iters = 1300001 training.snapshot_freq = 50000
となっており、訓練に時間がかかるため、configs/default_cifar10_configs.pyを編集して、
training.n_iters = 20000 training.snapshot_freq = 10000
20000ステップだけ学習するようにした。
以下のコマンドで、訓練を実行する。
python main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/1_rectified_flow
結果
logs/1_rectified_flow/samplesに、スナップショットごとの画像生成サンプルが出力される。
10000ステップ時点

20000ステップ時点

多少ぼんやりしているが、物体の画像が生成できている。
まとめ
Rectified Flowを理解するために公式実装を動かしてみた。
次は、ステップ実行しながら実装の詳細を確認したい。