TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Rectified Flowで画像生成する

画像生成モデルは、Stable Diffusionが出てきた頃は、Diffusionモデルが主流であったが、最近のStable Diffusion 3Flux.1では、Flow Machingのモデルが使用されている。

Flow Machingにもいくつかの訓練方法がある。
Rectified Flowは比較的シンプルでスケール可能な方法であり、画像のような高次元の分布にも適用できる。

Flow Maching

Diffusionモデルは確率微分方程式(SDE)で分布の変換を行うが、Flow Machingでは常微分方程式(ODE)で分布の変換を行う。
ODEを使うことで、確率モデルのサンプリングが不要で、高速に変換が可能になる。
Flow Machingの目的は、分布間を変換するベクトル場を学習することである。

最適輸送で、分布を変換する際のコストを最小化する手法などが提案されているが、Rectified Flowではシンプルに線形に変換を行う。
初期値(ガウスノイズなど)が与えられると、訓練済みのモデルで変換方向を表すベクトルを求めて、少ないステップで変換できる。


公式実装を動かす

動かして確認した方が理解しやすいため、公式実装を動かしてみた。

GitHub - gnobitab/RectifiedFlow: Official Implementation of Rectified Flow (ICLR2023 Spotlight)

公式実装では、DDPM++のUNetモデルを流用して、CIFAR10データセットを学習する。

環境構築

WSL2のUbuntu 22.04で構築しようとしたが、公式リポジトリが公開されたのは2022年で、使用されているCUDAのバージョンが古いため、Ubuntu 20.04でないとインストールできなかった。
そこで、WSL2のUbuntu 20.04で構築した。

Pythonの仮想環境作成
conda create -n rectflow python=3.8
conda activate rectflow
PyTorchインストール
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
ライブラリインストール
pip install tensorflow==2.9.0 tensorflow-probability==0.12.2 tensorflow-gan==2.0.0 tensorflow-datasets==4.6.0 numpy==1.21.6 ninja==1.11.1 matplotlib==3.7.0 ml_collections==0.1.1 jax==0.4.6 jaxlib==0.4.6 scipy==1.10.0
ビルド環境

モデルの実装で、torch.utils.cpp_extensionを使用してCUDAカーネルをビルドしている部分があるため、ビルド環境が必要になる。

sudo apt-get install build-essential ninja-build -y
CUDAインストール

ビルドにCUDAが必要になる。

wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-wsl-ubuntu-11-3-local_11.3.0-1_amd64.deb
sudo dpkg -i cuda-repo-wsl-ubuntu-11-3-local_11.3.0-1_amd64.deb
sudo apt-key add /var/cuda-repo-wsl-ubuntu-11-3-local/7fa2af80.pub
sudo apt-get update
sudo apt-get -y install cuda


実行時に環境変数

CUDA_HOME = /usr/local/cuda

の設定が必要になる。

訓練実行

デフォルトの設定だと、

training.n_iters = 1300001
training.snapshot_freq = 50000

となっており、訓練に時間がかかるため、configs/default_cifar10_configs.pyを編集して、

training.n_iters = 20000
training.snapshot_freq = 10000

20000ステップだけ学習するようにした。

以下のコマンドで、訓練を実行する。

python main.py --config ./configs/rectified_flow/cifar10_rf_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/1_rectified_flow

結果

logs/1_rectified_flow/samplesに、スナップショットごとの画像生成サンプルが出力される。

10000ステップ時点

20000ステップ時点

多少ぼんやりしているが、物体の画像が生成できている。

まとめ

Rectified Flowを理解するために公式実装を動かしてみた。
次は、ステップ実行しながら実装の詳細を確認したい。