TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

EfficientZeroを試す

NeurIPS 2021で提案されたEfficientZeroを試してみた。

EfficientZeroは、MuZeroのようなモデルベースの強化学習の手法で、サンプル効率が非常に高いことが特徴になっている。

DQNでは、5億フレーム(約38日間のリアルゲーム時間)が必要だったが、EfficientZeroでは、40万フレーム(2時間のリアルゲーム時間)で人間のパフォーマンスを上回っている。
サンプル効率は、現実世界に強化学習を適用しようとした場合に課題になるため、強化学習の応用範囲が広がることが期待できる。

オープンソースでコードが公開されているため、実際に試してみた。

環境には、BreakoutNoFrameskip-v4を使用した。

訓練

はじめ、CUDAが使えるようになったWSL2で試そうしたが、GPU1枚だと時間がかかりすぎるので、GPU(V100)8枚の環境で試した。
学習するフレーム数は、2時間のリアルゲーム時間であっても、学習時間は8GPUでも16時間かかった。
Pythonで実装されているため、MCTSのシミュレーションが遅いのが原因かもしれない。

訓練のパラメータは、用意されていたtrain.shを8GPUを使うように修正して、メモリが不足したので、object_store_memoryをデフォルトの半分にした。

$ cat train.sh
set -ex
export CUDA_DEVICE_ORDER='PCI_BUS_ID'
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

python main.py --env BreakoutNoFrameskip-v4 --case atari --opr train --force \
  --num_gpus 8 --num_cpus 40 --cpu_actor 14 --gpu_actor 20 \
  --seed 0 \
  --use_priority \
  --use_max_priority \
  --amp_type 'torch_amp' \
  --info 'EfficientZero-V1' \
  --object_store_memory 68719476736

訓練のログは、resultsの下に出力される。

results/atari/EfficientZero-V1/BreakoutNoFrameskip-v4/seed=0/Mon Dec  6 14:40:26 2021/logs$ cat train.log
[2021-12-06 14:40:26,721][train][INFO][main.py><module>] ==> Path: /work/results/atari/EfficientZero-V1/BreakoutNoFrameskip-v4/seed=0/Mon Dec  6 14:40:26 2021
[2021-12-06 14:40:26,721][train][INFO][main.py><module>] ==> Param: {'action_space_size': 4, 'num_actors': 1, 'do_consistency': True, 'use_value_prefix': True, 'off_correction': True, 'gray_scale': False, 'auto_td_steps_ratio': 0.3, 'episode_life': True, 'change_temperature': True, 'init_zero': True, 'state_norm': False, 'clip_reward': True, 'random_start': True, 'cvt_string': True, 'image_based': True, 'max_moves': 27000, 'test_max_moves': 3000, 'history_length': 400, 'num_simulations': 50, 'discount': 0.988053892081, 'max_grad_norm': 5, 'test_interval': 10000, 'test_episodes': 32, 'value_delta_max': 0.01, 'root_dirichlet_alpha': 0.3, 'root_exploration_fraction': 0.25, 'pb_c_base': 19652, 'pb_c_init': 1.25, 'training_steps': 100000, 'last_steps': 20000, 'checkpoint_interval': 100, 'target_model_interval': 200, 'save_ckpt_interval': 10000, 'log_interval': 1000, 'vis_interval': 1000, 'start_transitions': 2000, 'total_transitions': 100000, 'transition_num': 1, 'batch_size': 256, 'num_unroll_steps': 5, 'td_steps': 5, 'frame_skip': 4, 'stacked_observations': 4, 'lstm_hidden_size': 512, 'lstm_horizon_len': 5, 'reward_loss_coeff': 1, 'value_loss_coeff': 0.25, 'policy_loss_coeff': 1, 'consistency_coeff': 2, 'device': 'cuda', 'debug': False, 'seed': 0, 'value_support': <core.config.DiscreteSupport object at 0x7f9b7e3794f0>, 'reward_support': <core.config.DiscreteSupport object at 0x7f9b7e379580>, 'weight_decay': 0.0001, 'momentum': 0.9, 'lr_warm_up': 0.01, 'lr_warm_step': 1000, 'lr_init': 0.2, 'lr_decay_rate': 0.1, 'lr_decay_steps': 100000, 'mini_infer_size': 64, 'priority_prob_alpha': 0.6, 'priority_prob_beta': 0.4, 'prioritized_replay_eps': 1e-06, 'image_channel': 3, 'proj_hid': 1024, 'proj_out': 1024, 'pred_hid': 512, 'pred_out': 1024, 'bn_mt': 0.1, 'blocks': 1, 'channels': 64, 'reduced_channels_reward': 16, 'reduced_channels_value': 16, 'reduced_channels_policy': 16, 'resnet_fc_reward_layers': [32], 'resnet_fc_value_layers': [32], 'resnet_fc_policy_layers': [32], 'downsample': True, 'env_name': 'BreakoutNoFrameskip-v4', 'obs_shape': (12, 96, 96), 'case': 'atari', 'amp_type': 'torch_amp', 'use_priority': True, 'use_max_priority': True, 'cpu_actor': 14, 'gpu_actor': 20, 'p_mcts_num': 8, 'use_root_value': False, 'auto_td_steps': 30000.0, 'use_augmentation': True, 'augmentation': ['shift', 'intensity'], 'revisit_policy_search_rate': 0.99, 'model_dir': '/work/results/atari/EfficientZero-V1/BreakoutNoFrameskip-v4/seed=0/Mon Dec  6 14:40:26 2021/model'}
[2021-12-06 14:43:36,680][train][INFO][log.py>_log] ==> #0          Total Loss: 49.282   [weighted Loss:49.282   Policy Loss: 7.685    Value Loss: 38.391   Reward Sum Loss: 31.993   Consistency Loss: 0.003    ] Replay Episodes Collected: 61         Buffer Size: 61         Transition Number: 2.012   k Batch Size: 256        Lr: 0.000
[2021-12-06 14:52:16,592][train][INFO][log.py>_log] ==> #1000       Total Loss: 0.142    [weighted Loss:0.142    Policy Loss: 8.076    Value Loss: 0.722    Reward Sum Loss: 0.326    Consistency Loss: -3.840   ] Replay Episodes Collected: 61         Buffer Size: 61         Transition Number: 2.012   k Batch Size: 256        Lr: 0.200
[2021-12-06 15:01:10,856][train][INFO][log.py>_log] ==> #2000       Total Loss: -0.803   [weighted Loss:-0.803   Policy Loss: 7.566    Value Loss: 1.042    Reward Sum Loss: 0.289    Consistency Loss: -4.619   ] Replay Episodes Collected: 61         Buffer Size: 61         Transition Number: 2.012   k Batch Size: 256        Lr: 0.200
(略)
[2021-12-07 06:21:44,635][train][INFO][log.py>_log] ==> #116000     Total Loss: -0.055   [weighted Loss:-0.055   Policy Loss: 6.586    Value Loss: 3.227    Reward Sum Loss: 0.584    Consistency Loss: -4.803   ] Replay Episodes Collected: 608        Buffer Size: 608        Transition Number: 100.193 k Batch Size: 256        Lr: 0.020
[2021-12-07 06:28:38,125][train][INFO][log.py>_log] ==> #117000     Total Loss: -0.144   [weighted Loss:-0.144   Policy Loss: 6.856    Value Loss: 3.196    Reward Sum Loss: 0.477    Consistency Loss: -4.851   ] Replay Episodes Collected: 608        Buffer Size: 608        Transition Number: 100.193 k Batch Size: 256        Lr: 0.020
[2021-12-07 06:35:18,312][train][INFO][log.py>_log] ==> #118000     Total Loss: -0.076   [weighted Loss:-0.076   Policy Loss: 6.861    Value Loss: 3.185    Reward Sum Loss: 0.514    Consistency Loss: -4.852   ] Replay Episodes Collected: 608        Buffer Size: 608        Transition Number: 100.193 k Batch Size: 256        Lr: 0.020
[2021-12-07 06:42:05,937][train][INFO][log.py>_log] ==> #119000     Total Loss: -0.055   [weighted Loss:-0.055   Policy Loss: 6.863    Value Loss: 3.213    Reward Sum Loss: 0.548    Consistency Loss: -4.847   ] Replay Episodes Collected: 608        Buffer Size: 608        Transition Number: 100.193 k Batch Size: 256        Lr: 0.020

TensorBoardでも確認できる。
f:id:TadaoYamaoka:20211207164218p:plain

スコアが上がっているのが確認できる。
f:id:TadaoYamaoka:20211207164600p:plain

テスト

訓練できたモデルを使用して、実際にゲームをプレイさせてみた。
テストは、WSL2+WSLgの環境でゲーム画面を確認しながら実行した。

WSL2のUbuntu20.04にAnacondaをインストールして、以下の通り環境構築した。

sudo apt update
sudo apt install build-essential cmake zlib1g-dev -y

conda create -n EfficientZero python=3.8
conda activate EfficientZero
pip install numpy==1.19.5 ray==1.0.0 gym==0.15.7 atari-py==0.2.6 cython==0.29.23
pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
pip install tensorboard opencv-python kornia tqdm

test.shを編集して、以下のコマンドを実行した。

python main.py --env BreakoutNoFrameskip-v4 --case atari --opr test --seed 0 --num_gpus 1 --num_cpus 20 --force --test_episodes 8 --load_model --amp_type torch_amp --model_path 'results/atari/EfficientZero-V1/BreakoutNoFrameskip-v4/seed=0/Mon Dec  6 14:40:26 2021/model.p' --info Test --render --save_video

パッケージ不足のエラーがでたので、以下のパッケージをインストールした。

sudo apt-get install ffmpeg
sudo apt-get install python-opengl


8個のウィンドウがわらわらと表示されて、テストの様子が確認できた。
引数--test_episodes 8が同時実行されるエピソードだったようである。
フレームがかくかくして、やはり遅いようである。
f:id:TadaoYamaoka:20211207165124p:plain

動画も保存される。
これは一番うまくプレイできていたエピソード。
youtu.be

まとめ

最新の強化学習手法であるEfficientZeroを試してみた。
たしかに、40万フレームの学習でブロック崩しを人間並みにプレイできるようになることが確認できた。
しかし、学習するゲーム時間は2時間分でも、実行が遅く学習時間はGPU8枚で16時間かかった。

OpenAI Gymのようなstepを操作できる環境であればよいが、リアルタイムで実行するゲームを学習・プレイするにはこの実行速度では厳しそうである。
Pythonで実装されているので、MCTS部分をC++で実装するなどの速度改善は必要そうである。