TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Gumbel AlphaZeroの論文を読む その4

前回の続き

探索の内部処理

探索の処理は、searchに書かれている。

引数
  • params: ルートおよび再帰関数に渡されるパラメータ。
  • rng_key: 乱数生成器の状態。
  • root: ルートノードの初期状態で、事前確率、価値、埋め込みを含む。
  • recurrent_fn: 葉ノードおよび未訪問アクションに対して呼び出される関数。
  • root_action_selection_fn: ルートでアクションを選択するための関数。
  • interior_action_selection_fn: シミュレーション中にアクションを選択するための関数。
  • num_simulations: 実行するシミュレーションの数。
  • max_depth: 探索木の最大深度。
  • invalid_actions: ルートでの無効なアクションのマスク。
  • extra_data: ツリーに渡される追加データ。
  • loop_fn: シミュレーションを実行するための関数。
行動選択関数の切り替え

ルートノードと中間ノードで行動選択関数を切り替える。
行動選択関数については、別途解説。

  action_selection_fn = action_selection.switching_action_selection_wrapper(
      root_action_selection_fn=root_action_selection_fn,
      interior_action_selection_fn=interior_action_selection_fn
  )
def switching_action_selection_wrapper(
    root_action_selection_fn: base.RootActionSelectionFn,
    interior_action_selection_fn: base.InteriorActionSelectionFn
) -> base.InteriorActionSelectionFn:
  """Wraps root and interior action selection fns in a conditional statement."""

  def switching_action_selection_fn(
      rng_key: chex.PRNGKey,
      tree: tree_lib.Tree,
      node_index: base.NodeIndices,
      depth: base.Depth) -> chex.Array:
    return jax.lax.cond(
        depth == 0,
        lambda x: root_action_selection_fn(*x[:3]),
        lambda x: interior_action_selection_fn(*x),
        (rng_key, tree, node_index, depth))

  return switching_action_selection_fn
バッチサイズとバッチインデックス範囲の設定

バッチサイズとバッチインデックス範囲を設定する。
バッチサイズはルートの価値の1番目の次元から取得する。
max_depthとinvalid_actionsが提供されていない場合、デフォルト値を設定する。

  # Do simulation, expansion, and backward steps.
  batch_size = root.value.shape[0]
  batch_range = jnp.arange(batch_size)
  if max_depth is None:
    max_depth = num_simulations
  if invalid_actions is None:
    invalid_actions = jnp.zeros_like(root.prior_logits)
ループの本体関数

シミュレーションを実行するループの本体を定義する。
以下のステップを実行する。

  • RNGキーを分割。
  • シミュレーションを実行し、行動を選択する。
  • ノードが未訪問の場合、新しいノードを追加して木を拡張する。
  • バックアップを行い木を更新する。

展開とバックアップは別途解説。

def body_fun(sim, loop_state):
  rng_key, tree = loop_state
  rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)
  simulate_keys = jax.random.split(simulate_key, batch_size)
  parent_index, action = simulate(simulate_keys, tree, action_selection_fn, max_depth)
  next_node_index = tree.children_index[batch_range, parent_index, action]
  next_node_index = jnp.where(next_node_index == Tree.UNVISITED, sim + 1, next_node_index)
  tree = expand(params, expand_key, tree, recurrent_fn, parent_index, action, next_node_index)
  tree = backward(tree, next_node_index)
  loop_state = rng_key, tree
  return loop_state
木の初期化

ルートノードから木を初期化する。

  # Allocate all necessary storage.
  tree = instantiate_tree_from_root(root, num_simulations,
                                    root_invalid_actions=invalid_actions,
                                    extra_data=extra_data)
def instantiate_tree_from_root(
    root: base.RootFnOutput,
    num_simulations: int,
    root_invalid_actions: chex.Array,
    extra_data: Any) -> Tree:
  """Initializes tree state at search root."""
  chex.assert_rank(root.prior_logits, 2)
  batch_size, num_actions = root.prior_logits.shape
  chex.assert_shape(root.value, [batch_size])
  num_nodes = num_simulations + 1
  data_dtype = root.value.dtype
  batch_node = (batch_size, num_nodes)
  batch_node_action = (batch_size, num_nodes, num_actions)

  def _zeros(x):
    return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype)

  # Create a new empty tree state and fill its root.
  tree = Tree(
      node_visits=jnp.zeros(batch_node, dtype=jnp.int32),
      raw_values=jnp.zeros(batch_node, dtype=data_dtype),
      node_values=jnp.zeros(batch_node, dtype=data_dtype),
      parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32),
      action_from_parent=jnp.full(
          batch_node, Tree.NO_PARENT, dtype=jnp.int32),
      children_index=jnp.full(
          batch_node_action, Tree.UNVISITED, dtype=jnp.int32),
      children_prior_logits=jnp.zeros(
          batch_node_action, dtype=root.prior_logits.dtype),
      children_values=jnp.zeros(batch_node_action, dtype=data_dtype),
      children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
      children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
      children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
      embeddings=jax.tree.map(_zeros, root.embedding),
      root_invalid_actions=root_invalid_actions,
      extra_data=extra_data)

  root_index = jnp.full([batch_size], Tree.ROOT_INDEX)
  tree = update_tree_node(
      tree, root_index, root.prior_logits, root.value, root.embedding)
  return tree
シミュレーションの実行

ループ関数を使用してnum_simulations回シミュレーションを実行する。
ループ関数はデフォルトで、jax.lax.fori_loopを使用する。
結果の木を返す。

  # Allocate all necessary storage.
  tree = instantiate_tree_from_root(root, num_simulations,
                                    root_invalid_actions=invalid_actions,
                                    extra_data=extra_data)
  _, tree = loop_fn(
      0, num_simulations, body_fun, (rng_key, tree))

  return tree

まとめ

探索の内部処理を解説した。
次回は、探索の内部処理で呼ばれている各関数の詳細を解説する。