前回の続き
探索の内部処理
探索の処理は、searchに書かれている。
引数
- params: ルートおよび再帰関数に渡されるパラメータ。
- rng_key: 乱数生成器の状態。
- root: ルートノードの初期状態で、事前確率、価値、埋め込みを含む。
- recurrent_fn: 葉ノードおよび未訪問アクションに対して呼び出される関数。
- root_action_selection_fn: ルートでアクションを選択するための関数。
- interior_action_selection_fn: シミュレーション中にアクションを選択するための関数。
- num_simulations: 実行するシミュレーションの数。
- max_depth: 探索木の最大深度。
- invalid_actions: ルートでの無効なアクションのマスク。
- extra_data: ツリーに渡される追加データ。
- loop_fn: シミュレーションを実行するための関数。
行動選択関数の切り替え
ルートノードと中間ノードで行動選択関数を切り替える。
行動選択関数については、別途解説。
action_selection_fn = action_selection.switching_action_selection_wrapper( root_action_selection_fn=root_action_selection_fn, interior_action_selection_fn=interior_action_selection_fn )
def switching_action_selection_wrapper( root_action_selection_fn: base.RootActionSelectionFn, interior_action_selection_fn: base.InteriorActionSelectionFn ) -> base.InteriorActionSelectionFn: """Wraps root and interior action selection fns in a conditional statement.""" def switching_action_selection_fn( rng_key: chex.PRNGKey, tree: tree_lib.Tree, node_index: base.NodeIndices, depth: base.Depth) -> chex.Array: return jax.lax.cond( depth == 0, lambda x: root_action_selection_fn(*x[:3]), lambda x: interior_action_selection_fn(*x), (rng_key, tree, node_index, depth)) return switching_action_selection_fn
バッチサイズとバッチインデックス範囲の設定
バッチサイズとバッチインデックス範囲を設定する。
バッチサイズはルートの価値の1番目の次元から取得する。
max_depthとinvalid_actionsが提供されていない場合、デフォルト値を設定する。
# Do simulation, expansion, and backward steps. batch_size = root.value.shape[0] batch_range = jnp.arange(batch_size) if max_depth is None: max_depth = num_simulations if invalid_actions is None: invalid_actions = jnp.zeros_like(root.prior_logits)
ループの本体関数
シミュレーションを実行するループの本体を定義する。
以下のステップを実行する。
- RNGキーを分割。
- シミュレーションを実行し、行動を選択する。
- ノードが未訪問の場合、新しいノードを追加して木を拡張する。
- バックアップを行い木を更新する。
展開とバックアップは別途解説。
def body_fun(sim, loop_state): rng_key, tree = loop_state rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3) simulate_keys = jax.random.split(simulate_key, batch_size) parent_index, action = simulate(simulate_keys, tree, action_selection_fn, max_depth) next_node_index = tree.children_index[batch_range, parent_index, action] next_node_index = jnp.where(next_node_index == Tree.UNVISITED, sim + 1, next_node_index) tree = expand(params, expand_key, tree, recurrent_fn, parent_index, action, next_node_index) tree = backward(tree, next_node_index) loop_state = rng_key, tree return loop_state
木の初期化
ルートノードから木を初期化する。
# Allocate all necessary storage.
tree = instantiate_tree_from_root(root, num_simulations,
root_invalid_actions=invalid_actions,
extra_data=extra_data)
def instantiate_tree_from_root( root: base.RootFnOutput, num_simulations: int, root_invalid_actions: chex.Array, extra_data: Any) -> Tree: """Initializes tree state at search root.""" chex.assert_rank(root.prior_logits, 2) batch_size, num_actions = root.prior_logits.shape chex.assert_shape(root.value, [batch_size]) num_nodes = num_simulations + 1 data_dtype = root.value.dtype batch_node = (batch_size, num_nodes) batch_node_action = (batch_size, num_nodes, num_actions) def _zeros(x): return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype) # Create a new empty tree state and fill its root. tree = Tree( node_visits=jnp.zeros(batch_node, dtype=jnp.int32), raw_values=jnp.zeros(batch_node, dtype=data_dtype), node_values=jnp.zeros(batch_node, dtype=data_dtype), parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32), action_from_parent=jnp.full( batch_node, Tree.NO_PARENT, dtype=jnp.int32), children_index=jnp.full( batch_node_action, Tree.UNVISITED, dtype=jnp.int32), children_prior_logits=jnp.zeros( batch_node_action, dtype=root.prior_logits.dtype), children_values=jnp.zeros(batch_node_action, dtype=data_dtype), children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32), children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype), children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype), embeddings=jax.tree.map(_zeros, root.embedding), root_invalid_actions=root_invalid_actions, extra_data=extra_data) root_index = jnp.full([batch_size], Tree.ROOT_INDEX) tree = update_tree_node( tree, root_index, root.prior_logits, root.value, root.embedding) return tree
シミュレーションの実行
ループ関数を使用してnum_simulations回シミュレーションを実行する。
ループ関数はデフォルトで、jax.lax.fori_loopを使用する。
結果の木を返す。
# Allocate all necessary storage. tree = instantiate_tree_from_root(root, num_simulations, root_invalid_actions=invalid_actions, extra_data=extra_data) _, tree = loop_fn( 0, num_simulations, body_fun, (rng_key, tree)) return tree
まとめ
探索の内部処理を解説した。
次回は、探索の内部処理で呼ばれている各関数の詳細を解説する。