前回、Gumbel AlphaZeroの論文の概要と、公式実装の環境構築について記載した。
今回は、公式実装のサンプルプログラム examples/visualization_demo.py のソースを調べながらGumbel AlphaZeroのアルゴリズムを理解する。
visualization_demo.pyの概要
visualization_demo.py は、環境の状態遷移と報酬を表形式で定義して、Gumbel AlphaZeroのアルゴリズムで探索を行い、選択した行動と行動価値を出力する。
また、探索結果を木構造で可視化する。
出力例:
Selected action: 1 Selected action Q-value: 10.666667
環境
状態遷移と報酬は以下のように定義されている。
# We will define a deterministic toy environment. # The deterministic `transition_matrix` has shape `[num_states, num_actions]`. # The `transition_matrix[s, a]` holds the next state. transition_matrix = jnp.array([ [1, 2, 3, 4], [0, 5, 0, 0], [0, 0, 0, 6], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], ], dtype=jnp.int32) # The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]` # holds the reward for that (s, a) pair. rewards = jnp.array([ [1, -1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [10, 0, 20, 0], ], dtype=jnp.float32)
7つの状態があり、transition_matrixで遷移先の状態、rewardsで報酬が定義されている。
また、収益の割引率も定義されている。
# The discount for each (s, a) pair. discounts = jnp.where(transition_matrix > 0, 1.0, 0.0)
状態0に遷移する場合は、割引率が0になっている。
価値の初期値
各状態の価値の初期値は15に設定されている。
# Using optimistic initial values to encourage exploration. values = jnp.full([num_states], 15.0)
方策の事前確率
各状態における方策の事前確率のlogitsは0に設定されている。つまり、どの行動も均等に選択される。
# The prior policies for each state.
all_prior_logits = jnp.zeros_like(rewards)
ルートノードと状態遷移関数
_make_batched_env_modelで、ルートノードと状態遷移関数を作成する。
root, recurrent_fn = _make_batched_env_model( # Using batch_size=2 to test the batched search. batch_size=2, transition_matrix=transition_matrix, rewards=rewards, discounts=discounts, values=values, prior_logits=all_prior_logits)
def _make_batched_env_model( batch_size: int, *, transition_matrix: chex.Array, rewards: chex.Array, discounts: chex.Array, values: chex.Array, prior_logits: chex.Array): """Returns a batched `(root, recurrent_fn)`.""" chex.assert_equal_shape([transition_matrix, rewards, discounts, prior_logits]) num_states, num_actions = transition_matrix.shape chex.assert_shape(values, [num_states]) # We will start the search at state zero. root_state = 0 root = mctx.RootFnOutput( prior_logits=jnp.full([batch_size, num_actions], prior_logits[root_state]), value=jnp.full([batch_size], values[root_state]), # The embedding will hold the state index. embedding=jnp.zeros([batch_size], dtype=jnp.int32), ) def recurrent_fn(params, rng_key, action, embedding): del params, rng_key chex.assert_shape(action, [batch_size]) chex.assert_shape(embedding, [batch_size]) recurrent_fn_output = mctx.RecurrentFnOutput( reward=rewards[embedding, action], discount=discounts[embedding, action], prior_logits=prior_logits[embedding], value=values[embedding]) next_embedding = transition_matrix[embedding, action] return recurrent_fn_output, next_embedding return root, recurrent_fn
ルートノードは、RootFnOutput型で、(事前確率、価値、状態埋め込み)を保持する。
状態埋め込みは、MuZeroの場合は状態を表現するベクトルになるが、visualization_demo.pyでは状態のインデックスである。
状態遷移関数は、入力に(行動、状態埋め込み)を受け取り、RecurrentFnOutput型(報酬、割引率、事前確率、価値)と次の状態埋め込みを返す。
バッチで探索可能なように、それぞれ最初の次元がバッチの次元になっている。
探索
以上の準備が整ったら、探索処理を呼び出す。
# Running the search.
policy_output = mctx.gumbel_muzero_policy(
params=(),
rng_key=rng_key,
root=root,
recurrent_fn=recurrent_fn,
num_simulations=FLAGS.num_simulations,
max_depth=FLAGS.max_depth,
max_num_considered_actions=FLAGS.max_num_considered_actions,
)
引数は、rng_key、ルートノード、ルートノードの状態遷移関数、シミュレーション回数、探索の深さ、ルートノードの最大行動数である。
rng_keyは、乱数を生成するために使用される。
シミュレーション回数は、32に設定されている。
flags.DEFINE_integer("num_simulations", 32, "Number of simulations.")
探索の深さは、制限なしに設定されている。
flags.DEFINE_integer("max_depth", None, "The maximum search depth.")
ルートノードの最大行動数は、16に設定されている。
flags.DEFINE_integer("max_num_considered_actions", 16, "The maximum number of actions expanded at the root.")
まとめ
環境定義と探索の呼び出しまで解説した。
次回は、探索処理を調べる。