探索の内部処理
探索の処理は、searchに書かれている。
引数
- params: ルートおよび再帰関数に渡されるパラメータ。
- rng_key: 乱数生成器の状態。
- root: ルートノードの初期状態で、事前確率、価値、埋め込みを含む。
- recurrent_fn: 葉ノードおよび未訪問アクションに対して呼び出される関数。
- root_action_selection_fn: ルートでアクションを選択するための関数。
- interior_action_selection_fn: シミュレーション中にアクションを選択するための関数。
- num_simulations: 実行するシミュレーションの数。
- max_depth: 探索木の最大深度。
- invalid_actions: ルートでの無効なアクションのマスク。
- extra_data: ツリーに渡される追加データ。
- loop_fn: シミュレーションを実行するための関数。
行動選択関数の切り替え
ルートノードと中間ノードで行動選択関数を切り替える。
行動選択関数については、別途解説。
action_selection_fn = action_selection.switching_action_selection_wrapper(
root_action_selection_fn=root_action_selection_fn,
interior_action_selection_fn=interior_action_selection_fn
)
def switching_action_selection_wrapper(
root_action_selection_fn: base.RootActionSelectionFn,
interior_action_selection_fn: base.InteriorActionSelectionFn
) -> base.InteriorActionSelectionFn:
"""Wraps root and interior action selection fns in a conditional statement."""
def switching_action_selection_fn(
rng_key: chex.PRNGKey,
tree: tree_lib.Tree,
node_index: base.NodeIndices,
depth: base.Depth) -> chex.Array:
return jax.lax.cond(
depth == 0,
lambda x: root_action_selection_fn(*x[:3]),
lambda x: interior_action_selection_fn(*x),
(rng_key, tree, node_index, depth))
return switching_action_selection_fn
バッチサイズとバッチインデックス範囲の設定
バッチサイズとバッチインデックス範囲を設定する。
バッチサイズはルートの価値の1番目の次元から取得する。
max_depthとinvalid_actionsが提供されていない場合、デフォルト値を設定する。
batch_size = root.value.shape[0]
batch_range = jnp.arange(batch_size)
if max_depth is None:
max_depth = num_simulations
if invalid_actions is None:
invalid_actions = jnp.zeros_like(root.prior_logits)
ループの本体関数
シミュレーションを実行するループの本体を定義する。
以下のステップを実行する。
- RNGキーを分割。
- シミュレーションを実行し、行動を選択する。
- ノードが未訪問の場合、新しいノードを追加して木を拡張する。
- バックアップを行い木を更新する。
展開とバックアップは別途解説。
def body_fun(sim, loop_state):
rng_key, tree = loop_state
rng_key, simulate_key, expand_key = jax.random.split(rng_key, 3)
simulate_keys = jax.random.split(simulate_key, batch_size)
parent_index, action = simulate(simulate_keys, tree, action_selection_fn, max_depth)
next_node_index = tree.children_index[batch_range, parent_index, action]
next_node_index = jnp.where(next_node_index == Tree.UNVISITED, sim + 1, next_node_index)
tree = expand(params, expand_key, tree, recurrent_fn, parent_index, action, next_node_index)
tree = backward(tree, next_node_index)
loop_state = rng_key, tree
return loop_state
木の初期化
ルートノードから木を初期化する。
tree = instantiate_tree_from_root(root, num_simulations,
root_invalid_actions=invalid_actions,
extra_data=extra_data)
def instantiate_tree_from_root(
root: base.RootFnOutput,
num_simulations: int,
root_invalid_actions: chex.Array,
extra_data: Any) -> Tree:
"""Initializes tree state at search root."""
chex.assert_rank(root.prior_logits, 2)
batch_size, num_actions = root.prior_logits.shape
chex.assert_shape(root.value, [batch_size])
num_nodes = num_simulations + 1
data_dtype = root.value.dtype
batch_node = (batch_size, num_nodes)
batch_node_action = (batch_size, num_nodes, num_actions)
def _zeros(x):
return jnp.zeros(batch_node + x.shape[1:], dtype=x.dtype)
tree = Tree(
node_visits=jnp.zeros(batch_node, dtype=jnp.int32),
raw_values=jnp.zeros(batch_node, dtype=data_dtype),
node_values=jnp.zeros(batch_node, dtype=data_dtype),
parents=jnp.full(batch_node, Tree.NO_PARENT, dtype=jnp.int32),
action_from_parent=jnp.full(
batch_node, Tree.NO_PARENT, dtype=jnp.int32),
children_index=jnp.full(
batch_node_action, Tree.UNVISITED, dtype=jnp.int32),
children_prior_logits=jnp.zeros(
batch_node_action, dtype=root.prior_logits.dtype),
children_values=jnp.zeros(batch_node_action, dtype=data_dtype),
children_visits=jnp.zeros(batch_node_action, dtype=jnp.int32),
children_rewards=jnp.zeros(batch_node_action, dtype=data_dtype),
children_discounts=jnp.zeros(batch_node_action, dtype=data_dtype),
embeddings=jax.tree.map(_zeros, root.embedding),
root_invalid_actions=root_invalid_actions,
extra_data=extra_data)
root_index = jnp.full([batch_size], Tree.ROOT_INDEX)
tree = update_tree_node(
tree, root_index, root.prior_logits, root.value, root.embedding)
return tree
シミュレーションの実行
ループ関数を使用してnum_simulations回シミュレーションを実行する。
ループ関数はデフォルトで、jax.lax.fori_loopを使用する。
結果の木を返す。
tree = instantiate_tree_from_root(root, num_simulations,
root_invalid_actions=invalid_actions,
extra_data=extra_data)
_, tree = loop_fn(
0, num_simulations, body_fun, (rng_key, tree))
return tree