TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Gumbel AlphaZeroの論文を読む その2

前回、Gumbel AlphaZeroの論文の概要と、公式実装の環境構築について記載した。

今回は、公式実装のサンプルプログラム examples/visualization_demo.py のソースを調べながらGumbel AlphaZeroのアルゴリズムを理解する。

visualization_demo.pyの概要

visualization_demo.py は、環境の状態遷移と報酬を表形式で定義して、Gumbel AlphaZeroのアルゴリズムで探索を行い、選択した行動と行動価値を出力する。
また、探索結果を木構造で可視化する。

出力例:

Selected action: 1
Selected action Q-value: 10.666667


環境

状態遷移と報酬は以下のように定義されている。

  # We will define a deterministic toy environment.
  # The deterministic `transition_matrix` has shape `[num_states, num_actions]`.
  # The `transition_matrix[s, a]` holds the next state.
  transition_matrix = jnp.array([
      [1, 2, 3, 4],
      [0, 5, 0, 0],
      [0, 0, 0, 6],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
  ], dtype=jnp.int32)
  # The `rewards` have shape `[num_states, num_actions]`. The `rewards[s, a]`
  # holds the reward for that (s, a) pair.
  rewards = jnp.array([
      [1, -1, 0, 0],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
      [0, 0, 0, 0],
      [10, 0, 20, 0],
  ], dtype=jnp.float32)

7つの状態があり、transition_matrixで遷移先の状態、rewardsで報酬が定義されている。

また、収益の割引率も定義されている。

  # The discount for each (s, a) pair.
  discounts = jnp.where(transition_matrix > 0, 1.0, 0.0)

状態0に遷移する場合は、割引率が0になっている。

価値の初期値

各状態の価値の初期値は15に設定されている。

  # Using optimistic initial values to encourage exploration.
  values = jnp.full([num_states], 15.0)

方策の事前確率

各状態における方策の事前確率のlogitsは0に設定されている。つまり、どの行動も均等に選択される。

  # The prior policies for each state.
  all_prior_logits = jnp.zeros_like(rewards)

ルートノードと状態遷移関数

_make_batched_env_modelで、ルートノードと状態遷移関数を作成する。

  root, recurrent_fn = _make_batched_env_model(
      # Using batch_size=2 to test the batched search.
      batch_size=2,
      transition_matrix=transition_matrix,
      rewards=rewards,
      discounts=discounts,
      values=values,
      prior_logits=all_prior_logits)
def _make_batched_env_model(
    batch_size: int,
    *,
    transition_matrix: chex.Array,
    rewards: chex.Array,
    discounts: chex.Array,
    values: chex.Array,
    prior_logits: chex.Array):
  """Returns a batched `(root, recurrent_fn)`."""
  chex.assert_equal_shape([transition_matrix, rewards, discounts,
                           prior_logits])
  num_states, num_actions = transition_matrix.shape
  chex.assert_shape(values, [num_states])
  # We will start the search at state zero.
  root_state = 0
  root = mctx.RootFnOutput(
      prior_logits=jnp.full([batch_size, num_actions],
                            prior_logits[root_state]),
      value=jnp.full([batch_size], values[root_state]),
      # The embedding will hold the state index.
      embedding=jnp.zeros([batch_size], dtype=jnp.int32),
  )

  def recurrent_fn(params, rng_key, action, embedding):
    del params, rng_key
    chex.assert_shape(action, [batch_size])
    chex.assert_shape(embedding, [batch_size])
    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=rewards[embedding, action],
        discount=discounts[embedding, action],
        prior_logits=prior_logits[embedding],
        value=values[embedding])
    next_embedding = transition_matrix[embedding, action]
    return recurrent_fn_output, next_embedding

  return root, recurrent_fn

ルートノードは、RootFnOutput型で、(事前確率、価値、状態埋め込み)を保持する。
状態埋め込みは、MuZeroの場合は状態を表現するベクトルになるが、visualization_demo.pyでは状態のインデックスである。

状態遷移関数は、入力に(行動、状態埋め込み)を受け取り、RecurrentFnOutput型(報酬、割引率、事前確率、価値)と次の状態埋め込みを返す。

バッチで探索可能なように、それぞれ最初の次元がバッチの次元になっている。

探索

以上の準備が整ったら、探索処理を呼び出す。

  # Running the search.
  policy_output = mctx.gumbel_muzero_policy(
      params=(),
      rng_key=rng_key,
      root=root,
      recurrent_fn=recurrent_fn,
      num_simulations=FLAGS.num_simulations,
      max_depth=FLAGS.max_depth,
      max_num_considered_actions=FLAGS.max_num_considered_actions,
  )

引数は、rng_key、ルートノード、ルートノードの状態遷移関数、シミュレーション回数、探索の深さ、ルートノードの最大行動数である。
rng_keyは、乱数を生成するために使用される。

シミュレーション回数は、32に設定されている。

flags.DEFINE_integer("num_simulations", 32, "Number of simulations.")

探索の深さは、制限なしに設定されている。

flags.DEFINE_integer("max_depth", None, "The maximum search depth.")

ルートノードの最大行動数は、16に設定されている。

flags.DEFINE_integer("max_num_considered_actions", 16,
                     "The maximum number of actions expanded at the root.")

まとめ

環境定義と探索の呼び出しまで解説した。
次回は、探索処理を調べる。