TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

PaintsChainerをPyTorchで学習

Stable Diffusionが公開されてから、拡散モデルで自動着色したいと思っていて、自動着色についていろいろ調べていた。

最近、Style2Paints V5 Previewが発表されて、拡散モデルによる自動着色で高いクオリティの自動着色が実現できることが示された。

ControlNet

Style2Paints V5は、ControlNetという方法が使われており、つい先日、論文ソースが公開された。
線画の着色の学習済みモデルも公開されているが、アニメイラストの自動着色のモデルは、権利の関係で公開されていない。
また、公開されたモデルでは、元のイラストの線が少し書き換えられることがあり、生成画像の質は高いが着色用途には問題がある。
イラストの着色には、元の線を保持する特化したモデルが必要である。

従来の着色技術

拡散モデル以前は、GANによる手法が主流であったが、GANによる手法は以下のような問題があった。

  • 水彩画のようなにじみのある色になる
  • 色の鮮やかさが足りない
  • 色が線からはみ出す
  • 色の間違いがある
  • アーティファクトが表れる

PaintsChainerが最初期のモデルで、それ以降も改良手法が提案されている。

PaintsChainerの学習を試す

拡散モデルを試す前に、GANによる自動着色を試してみた。
PaintsChainer V1/V2は学習のソースコードも公開されており、試すことができる。
PaintsChainerはChainerで実装されているが、Chainerは既に開発が終了しているため、PyTorchにポーティングを行った。
GitHub - TadaoYamaoka/PaintsPyTorch: PaintsChainer training with PyTorch

PaintsChainerの学習

PaintsChainerの学習は、3段階に分かれている。

  1. カラー画像からスケッチ画像を予測するモデル(lnet)の学習
  2. GANによる128x128の画像の着色モデル
  3. UNetによる128x128の画像から512x512の画像への超解像度モデル
データセット

danbooru2017の512x512の画像を使用した。
グレースケールの画像も含まれるため、以下のような処理で事前に除外した。

    img = cv2.imread(path)
    b, g, r = cv2.split(img)
    if np.abs(b - g).mean() + np.abs(g - r).mean() + np.abs(r - b).mean() >= 100:
        files.append(path[len(dir):] + '\n')
スケッチ画像

スケッチ画像は、PaintsChainerでは、dilate(膨張)した画像と元画像の差分をとることで作成しているが、sketchKerasを使用して作成した。

縮小画像

lnetとGANモデルの学習には、128x128の画像が必要になるため、512x512画像をリサイズした画像を作成した。
その際、cv2.resizeのデフォルトパラメータで変換すると、スケッチの線が途切れるため、cv2.INTER_AREAを使用した。

lnetの学習

lnetは、GANのGeneratorが生成した画像をスケッチに変換して元のスケッチを再現できるかの損失関数のために使用される。
これがあることによって、生成画像に元のスケッチの線を残すことができる。
PaintsChainerでは、線の一致には強いペナルティが与えられている。

lnetは、以前は公式サイトで、学習済みモデルが配布されていたが、現在ダウンロードできなくなっているため、自分で学習した。
lnetは、学習用のコードがなかったため、自分で実装した。
PaintsPyTorch/train_lnet.py at main · TadaoYamaoka/PaintsPyTorch · GitHub

カラー画像とスケッチ画像のペアの平均絶対誤差(MAE)が最小になるように学習する。

GANによる着色モデル学習

128x128のカラー画像とスケッチ画像のペアから、GANにより着色モデルを学習する。
上記のlnetの箇所で説明した通り、通常のGANの損失に加えて、生成画像から元のスケッチが生成できるかも損失に加えられている。

PaintsChainerでは、訓練中の処理で画像のチャンネルを編集する処理があったが、実行が遅くなるためデータローダ側で処理するようにした。
PaintsPyTorch/train_128.py at main · TadaoYamaoka/PaintsPyTorch · GitHub

40エポックほど学習した時点の損失は以下のようになった。

着色画像の例

パーツの塗り分けは学習できているが、にじみが多く質は高くない。
これは比較的うまく塗れている例だが、書き込みが細かいとまったく着色ができていない。

生成例を確認していて気づいたが、元のデータセットには、四コマ漫画のようなイラスト以外の画像も含まれていたため、学習の妨げになっていそうである。

学習の途中で生成例を確認していると、色の間違いは多いがひとまず塗れていたものが、白っぽく塗るようになり、再び色が着くようになるという変化が見られて、GeneratorとDiscriminatorが競っている様子が観察できた。

超解像度モデル

超解像度モデルでは、512x512のスケッチ画像に、着色モデルで生成した128x128の画像を拡大して512x512にした画像をチャネル方向に連結したものを入力として、元のカラー画像を予測する。
GANは使用しておらず、単純なUNetのモデルになっている。

PaintsChainerでは、訓練中の処理で、画像の拡大を一旦CPUに転送してOpenCV2で拡大した後にGPUに戻す処理があったが、実行が遅くなるためTorchVisionを使用してGPU側で処理するようにした。
PaintsPyTorch/train_x2.py at main · TadaoYamaoka/PaintsPyTorch · GitHub

着色の質があまり良くない状態なので、超解像度モデルの学習はまだ試していない。

まとめ

GANによる自動着色を、PaintsChainerのソースをPyTorchにポーティングして試してみた。
PaintsChainerが元の線を残すために、通常のGANにない損失を加えていることなどがわかり、理解を深めることができた。
また、danbooruデータセットをそのまま使うと、イラスト以外の画像が含まれており、学習の妨げになりそうなことが分かった。
WaifuDiffusionでも行われているようにCLIP Aesthetic scoreによるフィルタリングは必要そうである。
データセットをクリーニングしたもので質がどう変わるか試してみたい。

続き
PaintsChainerをPyTorchで学習 その2 - TadaoYamaokaの開発日記