TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Gumbel AlphaZeroの論文を読む その12(方策の学習の実装)

その9で解説した、方策の学習について、公式の実装を確認する。

完成されたQ値(Completed Q-values)による方策学習の実装

Gumbel AlphaZeroでは、探索で得られたQ値を最大限活用するため、「完成されたQ値(Completed Q-values)」という概念を導入し、これを用いて方策ネットワークを学習する。 この仕組みは、qtransforms.pyqtransform_completed_by_mix_value関数を中心に、tree.pyTree.qvalues_compute_mixed_valueなど複数の関数で実装されている。

1. Q値変換関数(qtransform_completed_by_mix_value

実装箇所
  • qtransforms.py → qtransform_completed_by_mix_value
def qtransform_completed_by_mix_value(
    tree: tree_lib.Tree,
    node_index: chex.Numeric,
    *,
    value_scale: chex.Numeric = 0.1,
    maxvisit_init: chex.Numeric = 50.0,
    rescale_values: bool = True,
    use_mixed_value: bool = True,
    epsilon: chex.Numeric = 1e-8,
) -> chex.Array:
  chex.assert_shape(node_index, ())
  qvalues = tree.qvalues(node_index)
  visit_counts = tree.children_visits[node_index]

  # Computing the mixed value and producing completed_qvalues.
  raw_value = tree.raw_values[node_index]
  prior_probs = jax.nn.softmax(
      tree.children_prior_logits[node_index])
  if use_mixed_value:
    value = _compute_mixed_value(
        raw_value,
        qvalues=qvalues,
        visit_counts=visit_counts,
        prior_probs=prior_probs)
  else:
    value = raw_value
  completed_qvalues = _complete_qvalues(
      qvalues, visit_counts=visit_counts, value=value)

  # Scaling the Q-values.
  if rescale_values:
    completed_qvalues = _rescale_qvalues(completed_qvalues, epsilon)
  maxvisit = jnp.max(visit_counts, axis=-1)
  visit_scale = maxvisit_init + maxvisit
  return visit_scale * value_scale * completed_qvalues


def _rescale_qvalues(qvalues, epsilon):
  """Rescales the given completed Q-values to be from the [0, 1] interval."""
  min_value = jnp.min(qvalues, axis=-1, keepdims=True)
  max_value = jnp.max(qvalues, axis=-1, keepdims=True)
  return (qvalues - min_value) / jnp.maximum(max_value - min_value, epsilon)


def _complete_qvalues(qvalues, *, visit_counts, value):
  """Returns completed Q-values, with the `value` for unvisited actions."""
  chex.assert_equal_shape([qvalues, visit_counts])
  chex.assert_shape(value, [])

  # The missing qvalues are replaced by the value.
  completed_qvalues = jnp.where(
      visit_counts > 0,
      qvalues,
      value)
  chex.assert_equal_shape([completed_qvalues, qvalues])
  return completed_qvalues


def _compute_mixed_value(raw_value, qvalues, visit_counts, prior_probs):
  """Interpolates the raw_value and weighted qvalues.

  Args:
    raw_value: an approximate value of the state. Shape `[]`.
    qvalues: Q-values for all actions. Shape `[num_actions]`. The unvisited
      actions have undefined Q-value.
    visit_counts: the visit counts for all actions. Shape `[num_actions]`.
    prior_probs: the action probabilities, produced by the policy network for
      each action. Shape `[num_actions]`.

  Returns:
    An estimator of the state value. Shape `[]`.
  """
  sum_visit_counts = jnp.sum(visit_counts, axis=-1)
  # Ensuring non-nan weighted_q, even if the visited actions have zero
  # prior probability.
  prior_probs = jnp.maximum(jnp.finfo(prior_probs.dtype).tiny, prior_probs)
  # Summing the probabilities of the visited actions.
  sum_probs = jnp.sum(jnp.where(visit_counts > 0, prior_probs, 0.0),
                      axis=-1)
  weighted_q = jnp.sum(jnp.where(
      visit_counts > 0,
      prior_probs * qvalues / jnp.where(visit_counts > 0, sum_probs, 1.0),
      0.0), axis=-1)
  return (raw_value + sum_visit_counts * weighted_q) / (sum_visit_counts + 1)
主な処理の流れ
  1. ノード情報の取得 指定ノード(node_index)の子アクションのQ値(qvalues)、訪問回数(visit_counts)、生の価値(raw_value)、事前確率(prior_probs)を取得する。

  2. Mixed Valueの計算

    • use_mixed_value=Trueの場合、_compute_mixed_valueを呼び出し、探索済みアクションのQ値の加重平均と生の価値を補間した「混合価値近似(Mixed Value)」を計算する。
    • そうでなければ、生の価値(raw_value)をそのまま使う。
  3. Q値ベクトルの補完 _complete_qvaluesで、未訪問アクションのQ値を上記で計算した値で埋める(補完する)。

  4. Q値の正規化 rescale_values=Trueなら、_rescale_qvaluesでQ値を[0,1]に正規化する。

  5. 訪問回数によるスケーリング 最も訪問されたアクションの回数(maxvisit)を取得し、maxvisit_initvalue_scaleでスケーリングする。

要約

この関数は、探索で得られたQ値と価値ネットワークの予測値を組み合わせ、未訪問アクションも含めて全アクションの「完成されたQ値」を生成する。 このベクトルは、方策改善やKLダイバージェンス損失の計算に直接使われる。


2. 子アクションのQ値計算(Tree.qvalues

実装箇所
  • tree.py → Tree.qvalues_unbatched_qvalues
def _unbatched_qvalues(tree: Tree, index: int) -> int:
  chex.assert_rank(tree.children_discounts, 2)
  return (  # pytype: disable=bad-return-type  # numpy-scalars
      tree.children_rewards[index]
      + tree.children_discounts[index] * tree.children_values[index]
  )
主な処理の流れ
  • 指定ノードの各アクションについて、 Q(s, a) = R(s, a) + γ * V(s') (即時報酬+割引後の次状態価値)を計算する。
  • バッチ入力にも対応し、jax.vmapで複数ノードのQ値を一括計算可能である。
要約

MCTS木の各ノードにおける全アクションのQ値を、強化学習のベルマン方程式に基づき計算する。 このQ値が、後述の「完成されたQ値」の基礎となる。


3. Mixed Valueの計算(_compute_mixed_value

実装箇所
  • qtransforms.py → _compute_mixed_value
def _compute_mixed_value(raw_value, qvalues, visit_counts, prior_probs):
  sum_visit_counts = jnp.sum(visit_counts, axis=-1)
  # Ensuring non-nan weighted_q, even if the visited actions have zero
  # prior probability.
  prior_probs = jnp.maximum(jnp.finfo(prior_probs.dtype).tiny, prior_probs)
  # Summing the probabilities of the visited actions.
  sum_probs = jnp.sum(jnp.where(visit_counts > 0, prior_probs, 0.0),
                      axis=-1)
  weighted_q = jnp.sum(jnp.where(
      visit_counts > 0,
      prior_probs * qvalues / jnp.where(visit_counts > 0, sum_probs, 1.0),
      0.0), axis=-1)
  return (raw_value + sum_visit_counts * weighted_q) / (sum_visit_counts + 1)
主な処理の流れ
  1. 総訪問回数の計算 visit_countsを合計し、総訪問回数を取得する。

  2. 訪問済みアクションのQ値加重平均 方策ネットワークの事前確率(prior_probs)で重み付けし、訪問済みアクションのQ値の期待値(weighted_q)を計算する。

  3. 補間(ブレンド raw_value(価値ネットワークの予測値)とweighted_qを、総訪問回数に応じて補間する。 訪問回数が0ならraw_value、多いほどweighted_qに近づく。

要約

探索初期は価値ネットワークの予測値を、探索が進むほど探索で得られたQ値の加重平均を重視する、動的な価値推定を実現する。 これが未訪問アクションのQ値補完に使われる。

4. KLダイバージェンスによる学習

実装箇所
  • policies.py → gumbel_muzero_policyの内部
    • qtransformで得た「完成されたQ値」をsoftmaxに通し、改善方策π'を構築

KLダイバージェンス損失を計算して学習を行う部分は、mctxライブラリの利用者側が実装する部分であり、mctxライブラリ自体には含まれない。

mctxライブラリの役割は、探索を実行し、その結果として学習のターゲット(教師ラベル)となる改善方策π' を計算して出力することである。

具体的には、mctx.gumbel_muzero_policyなどの関数は、最終的にPolicyOutputというデータ構造を返す。この中のaction_weightsが、論文で述べられている改善方策π'に相当する。

@chex.dataclass(frozen=True)
class PolicyOutput(Generic[T]):
  """The output of a policy.

  action: `[B]` the proposed action.
  action_weights: `[B, num_actions]` the targets used to train a policy network.
    The action weights sum to one. Usually, the policy network is trained by
    cross-entropy:
    `cross_entropy(labels=stop_gradient(action_weights), logits=prior_logits)`.
  search_tree: `[B, ...]` the search tree of the finished search.
  """
  action: chex.Array
  action_weights: chex.Array
  search_tree: tree.Tree[T]

mctxライブラリの利用者は、このPolicyOutput.action_weightsを受け取り、それを教師ラベルとして、自身の方策ネットワークの出力(prior_logits)とのKLダイバージェンスを計算する損失関数を定義し、オプティマイザ(例: optax)を使ってパラメータを更新する。

主な処理の流れ
  • 探索後、qtransformで得た「完成されたQ値」と事前ロジットを合成し、softmaxで改善方策π'を計算する。
  • 元の方策ネットワークπが、このπ'に近づくようにKLダイバージェンス損失を最小化する。
要約

最善手だけでなく、探索で得られた全アクションの価値情報を活用し、方策ネットワーク全体を効率的に改善する。

まとめ

Gumbel AlphaZeroの方策学習は、 - 探索で得られたQ値と価値ネットワークの予測値を組み合わせて「完成されたQ値」を生成し、 - これを用いて改善方策π'を構築、 - 元の方策ネットワークπがπ'に近づくようKLダイバージェンス損失で学習する という流れで実装されている。

この仕組みにより、単一の最善手だけでなく、探索で得られた価値情報を最大限活用した方策改善が可能となっている。

これで、一通り論文の主要部分の解説と実装の確認ができた。 次は、Gumbel AlphaZeroで強化学習するミニマムなプログラムをPyTorchで実装したい。