コア(Core)
- 入力:prev_state, embedded_entity, embedded_spatial, embedded_scalar
- 出力:
- next_state : 次のステップのLSTM状態
- lstm_output : LSTMの出力
- コアは、「embedded_entity」、「embedded_spatial」、および「embedded_scalar」を単一の1Dテンソルに連結し、そのテンソルを「prev_state」と一緒に、サイズ384の3つの隠れ層を持つLSTMに送る。投影(projection)は使用されない。 ゲートにレイヤー正規化を適用する。 LSTMの出力が、このモジュールの出力となる。
アクションタイプヘッド(Action Type Head)
- 入力:lstm_output, scalar_context
- 出力:
- action_type_logits : 各アクションを実行する確率に対応するロジット
- action_type : action_type_logitsからサンプリングされたaction_type
- autoregressive_embedding : 「lstm_output」と以前にサンプリングされたすべての引数からの情報を結合する埋め込み。 引数がサンプリングされる順序は、ネットワーク図(拡張データ図3)を参照
- アクションタイプヘッドは、 「lstm_output」をサイズ256の1Dテンソルに埋め込み、それぞれサイズ256のレイヤー正規化で16個のResBlocksに渡し、ReLUを適用する。 出力は、 「scalar_context」によってゲーティングされる「GLU」を介して、可能なアクションタイプごとに1つのロジットを持つテンソルに変換される。「action_type」は、温度0.8の多項ロジットからサンプリングされます。 ここで、教師あり学習中は、「action_type」が人間のアクションタイプの正解データとなり、温度は1.0である(他のすべての引数についても同様)。
- 次に、最初にReLUとサイズ256の全結合層を「action_type」のワンホットに適用し、「scalar_context」によってゲーティングされた「GLU」を介してサイズ1024の1Dテンソルに投影することにより、「autoregressive_embedding」が生成される 。 その投影は、「lstm_output」の別の投影に追加され、「scalar_context」によってゲーティングされたサイズ1024の1Dテンソルに「autoregressive_embedding」が生成される。
遅延ヘッド(Delay Head)
- 入力:autoregressive_embedding
- 出力:
- delay_logits : 各遅延の確率に対応するロジット
- delay : サンプリングされた遅延
- autoregressive_embedding : 「lstm_output」と以前にサンプリングされたすべての引数からの情報を結合する埋め込み。 引数がサンプリングされる順序は、ネットワーク図(拡張データ図3)を参照
- 「autoregressive_embedding」は、サイズ128(ゲームステップで要求される遅延ごとに1つ)を持つ「delay_logits」に埋め込まれる前に、ReLUで2層(それぞれサイズ256)の線形ネットワークを使用してデコードされる。 「delay」は多項ロジットを使用して「delay_logits」からサンプリングされるが、他のすべての引数とは異なり、サンプリングの前に「delay_logits」に温度は適用されない。 「action_type」と同様に、「delay」はReLUを含む2層(それぞれサイズ256)の線形ネットワークを介してサイズ1024の1Dテンソルに投影され、「autoregressive_embedding」に追加される。
キューヘッド(Queued Head)
- 入力:autoregressive_embedding, action_type, embedded_entity
- 出力:
- queued_logits : キューイングする確率とキューイングしない確率に対応するロジット
- autoregressive_embedding : 「lstm_output」と以前にサンプリングされたすべての引数からの情報を結合する埋め込み。 引数がサンプリングされる順序は、ネットワーク図(拡張データ図3)を参照
- キューヘッドは遅延ヘッドと似ているが、サンプリング前に0.8の温度がロジットに適用され、「queued_logits」のサイズは2(キューイングとキューイングなし)である。選択された「action_type」に対してキューイングが不可能な場合、投影された「queued」は「autoregressive_embedding」に追加されない。
ユニット選択ヘッド(Selected Units Head)
- 入力:autoregressive_embedding, action_type, entity_embeddings
- 出力:
- units_logits : 各ユニットを選択する確率に対応するロジット。可能な64ユニットの選択ごとに繰り返される。
- units : このアクションのために選択されたユニット。
- autoregressive_embedding : 「lstm_output」と以前にサンプリングされたすべての引数からの情報を結合する埋め込み。 引数がサンプリングされる順序は、ネットワーク図(拡張データ図3)を参照
- 該当する場合、ユニット選択ヘッドはまず、 「action_type」を受け入れることができるエンティティタイプを決定し、最大でユニットタイプの数に等しいそのタイプのワンホットを作成し、サイズ256の全結合層とReLUに渡す。 これは、このヘッドでは「func_embed」と呼ばれる。
- また、ユニットを選択できるマスクを計算し、存在するすべてのエンティティ(敵ユニットを含む)を選択できるように初期化する。
- 次に、32チャンネルとカーネルサイズ1の1D畳み込みを介して「entity_embeddings」を供給することにより、各エンティティに対応するキーを計算し、ユニットの選択終了に対応する新しい変数を作成する。
- 次に、最大64ユニットを選択するために繰り返され、ネットワークは「autoregressive_embedding」をサイズ256の全結合層に渡し、「func_embed」を追加し、ReLUとサイズ32の全結合層に組み合わせを渡す。結果は、クエリを取得するために、サイズが32で初期状態がゼロのLSTMに送られる。 エンティティキーはクエリで乗算され、マスクと温度0.8を使用してサンプリングされ、選択するエンティティが決定される。 そのエンティティは、将来の反復で選択できないようにマスクされる。 選択されたエンティティのワンホットの位置にキーが乗算され、エンティティ全体の平均が減らされ、サイズ1024の全結合層を通過し、後続の反復のために「autoregressive_embedding」に追加される。 最後の「autoregressive_embedding」が返される。 「action_type」にユニットの選択が含まれない場合、このヘッドは無視される。
ターゲットユニットヘッド(Target Unit Head)
- 入力:autoregressive_embedding, action_type, entity_embeddings
- 出力:
- target_unit_logits : ユニットをターゲットとする確率に対応するロジット
- target_unit : サンプリングされたターゲットユニット
- 「func_embed」は、ユニット選択ヘッドと同じように計算され、同じ方法でクエリに使用される(サイズ256の全結合層を介して渡される「 autoregressive_embedding」の出力に追加される)。 次に、クエリはReLUとサイズ32の全結合層を介して渡され、クエリはユニット選択ヘッドと同じ方法で作成されたキーに適用され、「target_unit_logits」を取得する。 「target_unit」は、温度0.8の多項ロジットを使用して「target_unit_logits」からサンプリングされます。 ここで、これは2つの終端引数の1つであるため(もう一つはロケーションヘッド。ターゲットユニットとターゲットロケーションはどちらもアクションがないため)、「autoregressive_embedding」を返さないことに注意する。
ロケーションヘッド(Location Head)
- 入力:autoregressive_embedding, action_type, map_skip
- 出力:
- target_location_logits : 各場所をターゲットとする確率に対応するロジット
- target_location : サンプリングされたターゲットの場所
- 「autoregressive_embedding」は、4つのチャネルで「map_skip」(マップ情報が1D埋め込みに変更される直前)の最後のスキップと同じ高さ/幅になるように変更され、2つはチャネル次元に沿って連結され、ReLUを介して渡され、128チャネルとカーネルサイズ1の2D畳み込みを通過してから、別のReLUを通過する。次に、3Dテンソル(高さ、幅、チャネル)は、「autoregressive_embedding」でゲートされ、「map_skip」の要素を使用して、128チャネル、カーネルサイズ3、FiLMの一連のゲート付きResBlocksを通過します。順番は最後のResBlockスキップが最初になる。 その後、カーネルサイズ4およびチャンネルサイズ128、64、16、1の一連の転置2D畳み込みのそれぞれによって2倍にアップサンプリングされる(128 x 128入力から256 x 256ターゲットロケーション選択にアップサンプリングされる)。これらの最終的なロジットは、実際のターゲット位置を取得するために、温度0.8でフラット化およびサンプリング(建築アクションのためのカメラの外側など、「action_type」を使用して無効な場所をマスク)する。
感想
初期の論文では、アクションの引数はそれぞれ別に出力してchain ruleを適用していましたが、今回はアクションタイプの引数は、ネットワークで表現されています。
選択したアクションタイプが埋め込みにされて、次の引数選択の入力となり、その出力も埋め込みに追加されて、次の引数に渡されています。
アクションの引数のネットワーク構成には、MLP、ポインターネットワーク、アテンション、Deconv ResNetが引数に応じて使われています。
StarCraftくらいに操作が複雑になると、それに応じた複雑な方策の表現が必要になっています。
アーキテクチャ詳細の方策に関する部分は以上です。残りはバリューネットワークに関する部分です。
AlphaStarの論文を読む - TadaoYamaokaの開発日記
AlphaStarの論文を読む その2 - TadaoYamaokaの開発日記
AlphaStarの論文を読む その3 - TadaoYamaokaの開発日記
AlphaStarの論文を読む その4 - TadaoYamaokaの開発日記
AlphaStarの論文を読む その5(アーキテクチャ) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その6(アーキテクチャその2) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その7(アーキテクチャその3) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その8(教師あり学習、強化学習) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その9(マルチエージェント学習) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その10(リーグ構成) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その11(インフラ) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その12(評価) - TadaoYamaokaの開発日記
AlphaStarの論文を読む その13(分析、AlphaStarの一般性) - TadaoYamaokaの開発日記