TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

DCGANを試す

ほぼ個人メモです。

今更ながら生成系のモデルをあまりさわっていなかったので、PyTorchのDCGANのTutorialを試してみた。
DCGAN Tutorial — PyTorch Tutorials 1.12.0+cu102 documentation

Windowsだとチュートリアルのソースがそのままでは動かないので、WSL上で実行した。

WSL上にMiniconda環境を構築

Windows11のWSL2にUbuntu 20.04をインストールし、Miniconda環境を構築した。
Windows11のWSL2だと、WSLgが有効になるので、チュートリアルで使われているmatplotlibによるグラフ表示ができる。
Anacondaだと、Pytorch 1.12のインストールでコンフリクトを起こすのでMinicondaとした。

sudo apt update
sudo apt wget
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
sh Miniconda3-latest-Linux-x86_64.sh

PyTorchインストール

condaでPyTorchをインストールする。
コンフリクトを起こさないように、jupyterやpandas、matplotlibも同時にインストールする。

conda install pytorch torchvision torchaudio cudatoolkit=11.6 jupyter pandas matplotlib -c pytorch -c conda-forge

チュートリアルのソースclone

sudo apt install git
git clone https://github.com/pytorch/tutorials.git

データセットダウンロード

チュートリアルに書かれている通り、Celeb-A Faces datasetを、ホームページにあるGoogleドライブのリンクからダウンロードし、チュートリアルのソースをcloneしたディレクトリのdata/celebaに展開する。

pip install gdown
cd tutorials
mkdir -p data/celeba
cd data/celeba
gdown --id 0B7EVK8r0v71pZjFTYXZWM3FlRnM

sudo apt install unzip
unzip img_align_celeba.zip

実行

チュートリアルのコードを実行する。

cd ~/tutorials
python beginner_source/dcgan_faces_tutorial.py

以下のようにログが出力される。

Random Seed:  999
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.9939  Loss_G: 5.4060  D(x): 0.4964    D(G(z)): 0.6326 / 0.0067
[0/5][50/1583]  Loss_D: 0.3333  Loss_G: 22.2574 D(x): 0.8377    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.3782  Loss_G: 9.1242  D(x): 0.8275    D(G(z)): 0.0089 / 0.0010
[0/5][150/1583] Loss_D: 0.4501  Loss_G: 3.2344  D(x): 0.8090    D(G(z)): 0.1130 / 0.0577
[0/5][200/1583] Loss_D: 0.2085  Loss_G: 4.3650  D(x): 0.8801    D(G(z)): 0.0424 / 0.0236
[0/5][250/1583] Loss_D: 0.6191  Loss_G: 7.4736  D(x): 0.6742    D(G(z)): 0.0015 / 0.0023
[0/5][300/1583] Loss_D: 0.4908  Loss_G: 4.2814  D(x): 0.8352    D(G(z)): 0.2227 / 0.0239
(略)
[4/5][1200/1583]        Loss_D: 0.6867  Loss_G: 1.7331  D(x): 0.5906    D(G(z)): 0.0740 / 0.2289
[4/5][1250/1583]        Loss_D: 0.4646  Loss_G: 2.3542  D(x): 0.7931    D(G(z)): 0.1818 / 0.1189
[4/5][1300/1583]        Loss_D: 0.7274  Loss_G: 3.3515  D(x): 0.9134    D(G(z)): 0.4199 / 0.0480
[4/5][1350/1583]        Loss_D: 0.6512  Loss_G: 3.7683  D(x): 0.8772    D(G(z)): 0.3651 / 0.0332
[4/5][1400/1583]        Loss_D: 0.4803  Loss_G: 2.7035  D(x): 0.7928    D(G(z)): 0.1958 / 0.0844
[4/5][1450/1583]        Loss_D: 0.5692  Loss_G: 1.3964  D(x): 0.7153    D(G(z)): 0.1677 / 0.2903
[4/5][1500/1583]        Loss_D: 0.5118  Loss_G: 2.1602  D(x): 0.7533    D(G(z)): 0.1668 / 0.1532
[4/5][1550/1583]        Loss_D: 0.5780  Loss_G: 2.6719  D(x): 0.8204    D(G(z)): 0.2892 / 0.0883

実行の途中で、グラフが表示される。

データセットのサンプル


損失のグラフ


生成された画像


※実際はアニメーション表示

データセットの画像と生成された画像の比較