TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Gumbel AlphaZeroの論文を読む その3

前回に続き、examples/visualization_demo.py のソースを解説する。

探索

探索の処理は、gumbel_muzero_policyに書かれている。

引数は、以下の通り。

  • params: ルートおよび再帰関数に渡されるパラメータ。
  • rng_key: 乱数生成器の状態。
  • root: (prior_logits, value, embedding)の形式のRootFnOutput。prior_logitsは方策ネットワークからのもので、形状はそれぞれ([B, num_actions], [B], [B, ...])。
  • recurrent_fn: シミュレーションステップで取得された葉ノードおよび未訪問アクションに対して呼び出される関数。引数として(params, rng_key, action, embedding)を取り、RecurrentFnOutputと新しい状態埋め込みを返す。
  • num_simulations: シミュレーションの数。
  • invalid_actions: 無効なアクションのマスク。無効な行動は1、有効な行動は0のマスク。形状は[B, num_actions]。
  • max_depth: シミュレーション中に許可される最大探索木の深さ。
  • loop_fn: シミュレーションを実行するために使用される関数。Haikuモジュール内でこの関数を使用する場合、hk.fori_loopを渡す必要があるかもしれない。
  • qtransform: ノードの完成したQ値を取得するための関数。
  • max_num_considered_actions: ルートノードで展開される最大行動数。有効な行動の数が少ない場合は、より少ない行動が展開される。
  • gumbel_scale: Gumbelノイズのスケール。完全情報ゲームの評価ではgumbel_scale=0.0を使用できる。


処理の流れは以下の通り。

無効な行動をマスクする

無効な行動の事前確率のlogitをfloatの最小値にする。

  # Masking invalid actions.
  root = root.replace(
      prior_logits=_mask_invalid_actions(root.prior_logits, invalid_actions))
def _mask_invalid_actions(logits, invalid_actions):
  """Returns logits with zero mass to invalid actions."""
  if invalid_actions is None:
    return logits
  chex.assert_equal_shape([logits, invalid_actions])
  logits = logits - jnp.max(logits, axis=-1, keepdims=True)
  # At the end of an episode, all actions can be invalid. A softmax would then
  # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
  # a finite `min_logit` for the invalid actions.
  min_logit = jnp.finfo(logits.dtype).min
  return jnp.where(invalid_actions, min_logit, logits)
Gumbelノイズの生成

Jaxでは同じキーからは同じ乱数が生成されるため、jax.random.splitで新しいキーを生成する。
jax.random.gumbelを使用して、Gumbel分布に従う乱数を生成する。
次元は、事前確率のlogitsの次元とする。

  # Generating Gumbel.
  rng_key, gumbel_rng = jax.random.split(rng_key)
  gumbel = gumbel_scale * jax.random.gumbel(
      gumbel_rng, shape=root.prior_logits.shape, dtype=root.prior_logits.dtype)
探索の実行

探索処理を呼び出す。
探索処理の内容は別途解説。

  # Searching.
  extra_data = action_selection.GumbelMuZeroExtraData(root_gumbel=gumbel)
  search_tree = search.search(
      params=params,
      rng_key=rng_key,
      root=root,
      recurrent_fn=recurrent_fn,
      root_action_selection_fn=functools.partial(
          action_selection.gumbel_muzero_root_action_selection,
          num_simulations=num_simulations,
          max_num_considered_actions=max_num_considered_actions,
          qtransform=qtransform,
      ),
      interior_action_selection_fn=functools.partial(
          action_selection.gumbel_muzero_interior_action_selection,
          qtransform=qtransform,
      ),
      num_simulations=num_simulations,
      max_depth=max_depth,
      invalid_actions=invalid_actions,
      extra_data=extra_data,
      loop_fn=loop_fn)
  summary = search_tree.summary()
最適な行動の選択

最も訪問された行動数を計算(considered_visit)する。
qtransformで、未訪問の行動のQ値を補完する。
seq_halving.score_consideredで、行動ごとのスコアを計算する。
スコアが最大の行動を選択する。

  # Acting with the best action from the most visited actions.
  # The "best" action has the highest `gumbel + logits + q`.
  # Inside the minibatch, the considered_visit can be different on states with
  # a smaller number of valid actions.
  considered_visit = jnp.max(summary.visit_counts, axis=-1, keepdims=True)
  # The completed_qvalues include imputed values for unvisited actions.
  completed_qvalues = jax.vmap(qtransform, in_axes=[0, None])(  # pytype: disable=wrong-arg-types  # numpy-scalars  # pylint: disable=line-too-long
      search_tree, search_tree.ROOT_INDEX)
  to_argmax = seq_halving.score_considered(
      considered_visit, gumbel, root.prior_logits, completed_qvalues,
      summary.visit_counts)
  action = action_selection.masked_argmax(to_argmax, invalid_actions)

スコアの計算処理は、以下の通り。
訪問数が最大の行動にペナルティを課して選択されないようにする。
gumbel + logits(=root.prior_logits) + normalized_qvalues(=completed_qvalues)を計算する。
gumbelは乱数なので、確率的に行動を選択することになる。

def score_considered(considered_visit, gumbel, logits, normalized_qvalues,
                     visit_counts):
  """Returns a score usable for an argmax."""
  # We allow to visit a child, if it is the only considered child.
  low_logit = -1e9
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  penalty = jnp.where(
      visit_counts == considered_visit,
      0, -jnp.inf)
  chex.assert_equal_shape([gumbel, logits, normalized_qvalues, penalty])
  return jnp.maximum(low_logit, gumbel + logits + normalized_qvalues) + penalty
行動の重みの生成

方策の学習に使用する行動の重みを計算する。
root.prior_logits + completed_qvaluesに、softmaxを適用して新しい方策を求める。

# Producing action_weights usable to train the policy network.
  completed_search_logits = _mask_invalid_actions(
      root.prior_logits + completed_qvalues, invalid_actions)
  action_weights = jax.nn.softmax(completed_search_logits)

まとめ

探索処理(gumbel_muzero_policy)の流れを解説した。
次回は、探索の内部処理を解説する。