TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

cshogiにGPS将棋のdf-pnを移植(その3 ソース解説)

GPS将棋のdf-pnを1つのソースファイルに移植したことで、LLMにコード全体をコンテキストとして与え、解説させられるようになった。

GPS将棋のdf-pnは、詰将棋探索アルゴリズムの知見が数多く詰め込まれた非常に価値の高いコードである。しかし、前々回の記事で述べたように、その複雑さゆえに、現在では将棋AI開発者からもあまり読まれなくなっている。

そこで、その価値を少しでも伝えるために、GPT-5.5 Proに移植版のコードを読ませ、詳細な解説を作成してもらった。

解説対象のソース(osl_dfpn.cpp):

github.com

このコードは何をするものか

このコードは、将棋の局面 Position に対して、

攻方が相手玉を詰ませられるか?

を調べるための 詰将棋用 df-pn 探索エンジンである。

中心になるのは osl_dfpn.cpp のこの2つである。

ProofDisproof OslDfPn::Impl::attack(...);
ProofDisproof OslDfPn::Impl::defense(...);

attack() は攻方ノード、defense() は受方ノードである。

低レベルの手生成は generateMoves.cpp が担当する。そこには generateMoves<CheckAllOslmate>generateMoves<CheckAllOslmateFixedRaw>generateMoves<Evasion> などがあり、osl_dfpn.cpp から王手生成・逃げ手生成として使われる。

全体を一言でいうと、

df-pn の理論をベースに、
将棋の詰み探索向けに、
手生成・置換表・持ち駒優越・ループ検出・DAG対策・証明流用・PV取得を組み込んだ実装

である。


まず理解すべき理論:AND/OR木

ゲーム木探索では、局面を ORノードANDノード に分けて考えることがある。

詰将棋に当てはめると、こうなる。

ORノード
  攻方の手番。
  どれか1つの王手で詰みに向かえればよい。

ANDノード
  受方の手番。
  受方のすべての逃げ手を潰さなければならない。

ORノードでは、少なくとも1つの子が真なら親も真である。ANDノードでは、すべての子が真でなければ親は真にならない。ORノードは「少なくとも一つの子ノードが真なら真」、ANDノードは「すべての子ノードが真なら真」と定義できる。

将棋の詰み探索では、これを次のように読むと自然である。

攻方:
  王手A、王手B、王手Cのうち、
  どれか1つで詰みが続けばよい。

受方:
  逃げ手A、逃げ手B、逃げ手Cのすべてに対して、
  攻方が詰ませられなければならない。

この対応が、attack()defense() の根本である。


証明数・反証数とは何か

ノードを証明するために展開する必要がある先端ノード数の最小値を 証明数、反証するために必要な先端ノード数の最小値を 反証数 と呼ぶ。証明数が小さいほど証明しやすく、反証数が小さいほど反証しやすい、という意味である。

このコードでは、その組を ProofDisproof が持つ。

proof
  詰みを証明する難しさ。

disproof
  不詰を証明する難しさ。

特別な意味はこうである。

proof == 0
  詰みが証明された。

disproof == 0
  不詰が証明された。

proof > 0 かつ disproof > 0
  まだ未解決。

そして、ORノードとANDノードで集約方法が違う。

ORノード:
  proof    = 子 proof の最小値
  disproof = 子 disproof の合計

ANDノード:
  proof    = 子 proof の合計
  disproof = 子 disproof の最小値

この式が、attack()defense() のコードにそのまま出ている。


attack()defense() の対応

この表が、ソース全体の地図である。

観点 attack() defense()
理論上のノード ORノード ANDノード
手番 攻方 受方
生成する手 王手 王手回避
主な手生成 generate_check_moves() generate_escape_moves()
子で呼ぶ関数 defense() attack()
成功条件 どれか1つの王手が詰みに進む すべての逃げ手を詰ませる
失敗条件 すべての王手が失敗 どれか1つ逃げられる
proof集約 最小値 合計
disproof集約 合計 最小値
次に読む子 proof が小さい王手 disproof が小さい逃げ手

つまり、合言葉はこれである。

attack は OR。
どれか1手でよい。
proof は min、disproof は sum。

defense は AND。
全部潰す必要がある。
proof は sum、disproof は min。

df-pnとは何か

df-pn は、証明数探索を深さ優先探索に変換した手法である。証明数と反証数の両方にしきい値を持ち、そのしきい値を超えるまで深さ優先で最有力ノードを追う。

このコードでは、しきい値は次の構造体で表される。

struct Threshold {
    uint32_t proof;
    uint32_t disproof;
};

理論上の φ / δth_pnum / th_dnum に相当するものである。

attack()defense() は、子を再帰探索するときに、その子専用のしきい値を作る。

oslmate_attack_child_proof_threshold(...)
oslmate_attack_child_disproof_threshold(...)
oslmate_defense_child_proof_threshold(...)

これらは、

この子をどこまで読めば、親の評価が変わるか?

を計算している。

df-pn の重要な特徴は、

常に最有力ノードへ向かって深く読む
ただし証明数・反証数のしきい値で読みすぎを防ぐ

ことである。

このコードでは、attack() なら proof が最小の子、defense() なら disproof が最小の子を next_index として選ぶ。


公開APIから全体を見る

まず読むべき入口は、ファイル下部にある公開関数である。

bool OslDfPn::dfpn(Position& pos);
bool OslDfPn::dfpn_andnode(Position& pos);
Move OslDfPn::dfpn_move(Position& pos);
ProofDisproof OslDfPn::dfpn_probe(Position& pos, Move* best_move);
void OslDfPn::get_pv(Position& pos, std::vector<u32>& pv);

通常の入口は dfpn() である。

流れはこうである。

OslDfPn::dfpn(pos)
  ↓
root_attack_color = pos.turn()
  ↓
探索表をクリア
  ↓
小さいノード上限から探索開始
  ↓
Impl::attack(pos, root threshold)
  ↓
未解決ならノード上限を増やして再探索
  ↓
詰み成功なら true

dfpn_andnode() は、root を受方ノードとして始めたい場合の入口である。

dfpn()
  root は attack()

dfpn_andnode()
  root は defense()

この2つの違いを最初に押さえると、下の実装が読みやすくなる。


ProofDisproof の読み方

このコードでは、探索結果は ProofDisproof で表す。

代表的な状態は次のように読むとよい。

Unknown
  まだ分からない。

Checkmate
  攻方が詰ませられる。

NoCheckmate
  攻方が詰ませられない。

NoEscape
  受方に逃げ手がない。
  攻方から見れば詰み成功。

PawnCheckmate
  打ち歩詰め絡みの特殊扱い。

LoopDetection
  ループ検出。

注意すべきなのは、NoEscape である。

NoEscape = 受方に逃げがない
         = 攻方の詰み成功

名前だけ見ると「逃げなし」だが、詰み探索の結果としては攻方成功側である。


attack() の読み方

attack() は攻方ノードである。

大まかな流れはこうである。

1. 深さ・ノード制限を確認する。
2. 現在局面を経路表に登録する。
3. 置換表から過去の探索結果を読む。
4. すでに詰み・不詰が分かっていれば返す。
5. 浅い詰みショートカットを試す。
6. 王手を生成する。
7. 王手がなければ NoCheckmate。
8. 各王手の子局面を調べる。
9. proof/disproof を集計する。
10. proof が最小の王手を選んで defense() を再帰する。
11. 結果を置換表に保存する。

擬似コードにするとこうである。

ProofDisproof attack(pos, threshold) {
    record = table.probe(pos);
    if (record is final)
        return record;

    if (fixed_attack shortcut succeeds)
        return Checkmate;

    moves = generate_check_moves(pos);
    if (moves.empty())
        return NoCheckmate;

    for move in moves:
        child = table.probe(pos after move);
        if unknown:
            child = estimate_child(...);

    while (not solved) {
        min_proof = min(child.proof);
        sum_disproof = sum(child.disproof);

        if threshold reached:
            save and return;

        next = child with smallest proof;
        do move;
        defense(child_pos, child_threshold);
        undo move;
        update child;
    }
}

attack() の中で最重要なのは、

min_proof
second_proof
sum_disproof
next_index

である。

この next_index が、次に深く読む王手である。


defense() の読み方

defense() は受方ノードである。

大まかな流れはこうである。

1. 深さ・ノード制限を確認する。
2. 現在局面を経路表に登録する。
3. 王手がかかっていなければ NoCheckmate。
4. 置換表から過去の探索結果を読む。
5. 逃げ手を生成する。
6. 逃げ手がなければ NoEscape。
7. 各逃げ手の子局面を調べる。
8. proof/disproof を集計する。
9. disproof が最小の逃げ手を選んで attack() を再帰する。
10. 結果を置換表に保存する。

擬似コードにするとこうである。

ProofDisproof defense(pos, threshold) {
    if (!pos.inCheck())
        return NoCheckmate;

    record = table.probe(pos);
    if (record is final)
        return record;

    moves = generate_escape_moves(pos);
    if (moves.empty())
        return NoEscape;

    for move in moves:
        child = table.probe(pos after move);
        if unknown:
            child = quick_estimate(...);

    while (not solved) {
        sum_proof = sum(child.proof);
        min_disproof = min(child.disproof);

        if threshold reached:
            save and return;

        next = child with smallest disproof;
        do move;
        attack(child_pos, child_threshold);
        undo move;
        update child;
    }
}

defense() で重要なのは、

sum_proof
min_disproof
second_disproof
next_index

である。

この next_index が、受方にとって一番逃げやすそうな手である。


先端ノードの初期値と評価

理論上、未展開の先端ノードは普通 (1,1) とする。

proof = 1
disproof = 1

しかし実際のゲームでは局面によって証明しやすさが違うため、ゲーム固有の評価関数で初期 proof/disproof を変える改良がある。

このコードでそれに対応するのが、次の関数群である。

attack_estimation_zero(...)
fixed_attack_estimation_zero(...)
estimate_attack_pdp(...)
estimate_attack_pdp_with_support(...)
attack_proof_cost(...)

これらは、未探索の子を単純な (1,1) とせず、

相手玉の逃げ道はいくつあるか
王手した駒が取られやすいか
攻方の利きが多いか
受方の利きが多いか
駒打ちで逃げ道を塞げそうか

を見て、初期値を調整する。


King8RuntimeInfo は相手玉8近傍の評価器

詰将棋では、相手玉の周囲8マスが非常に重要である。

このコードでは、

struct King8RuntimeInfo {
    uint8_t drop_candidate;
    uint8_t liberty;
    uint8_t liberty_candidate;
    uint8_t move_candidate2;
    uint8_t spaces;
    uint8_t moves;
    uint8_t liberty_count;
};

が使われる。

意味はこうである。

liberty
  本当の逃げ道になりそうなマス。

drop_candidate
  駒を打つと詰みに絡みそうなマス。

move_candidate2
  駒移動で詰みに絡みそうなマス。

spaces
  空きマス。

moves
  玉移動や応手に関係するマス。

liberty_count
  逃げ道の数。

make_king8_runtime_info() は、相手玉の周囲8マスについて、空きマスか、自駒か、相手駒か、攻方の利きがあるか、受方の十分な利きがあるかを見て、この情報を作る。

これが、1手詰め検出や proof/disproof の見積もりに使われる。


1手詰め検出

本格的な df-pn に入る前に、簡単な詰みは先に見つける。

中心はこれである。

Move immediate_mate_move_in_1_osl(Position& pos, ...);

流れはこうである。

1. 自玉が王手中なら試さない。
2. 攻方色と相手玉を取得する。
3. King8RuntimeInfo で相手玉周辺を評価する。
4. 駒移動候補を見る。
5. 桂候補を見る。
6. 駒打ち候補を見る。

immediate_mate_move_in_1_osl() は、まず移動による詰み候補を探し、次に桂、最後に駒打ち候補を探す。fixed_attack_depth2_osl() などの浅い探索でも、最初にこの1手詰め検出を試す。


固定深さショートカット

本格探索の前に浅い探索を行う仕組みもある。

代表的なのは、

fixed_attack_osl_shortcut(...)
fixed_attack_depth2_osl(...)
fixed_escape_by_move_zero(...)
fixed_has_escape_by_move_zero(...)
fixed_attack_may_unsafe_depth1(...)

である。

役割は、

1手詰め
3手詰めに近い浅い詰み
逃げ手が明らかにない局面
玉周りから見た簡易評価

を本格 df-pn より先に処理することである。

fixed_attack_osl_shortcut()fixed_attack_depth2_osl() を呼ぶ形になっており、固定深さ用の王手生成では generate_fixed_depth_check_moves_into() が使われる。ここでは通常の generate_check_moves() とは違い、OSL の raw な生成順を保つようになっている。


generateMoves.cpp は手生成工場

osl_dfpn.cpp は探索本体であるが、低レベルな手生成は generateMoves.cpp にある。

中心はこのテンプレートである。

template <MoveType MT, Color US, bool ALL = false>
struct GenerateMoves;

MoveType によって生成する手が変わる。

Check
  王手を生成する。

CheckAll
  通常省略される不成も含めた王手を生成する。

CheckAllOslmate
  詰み探索用の王手生成。

CheckAllOslmateFixedRaw
  固定深さ探索用の raw 王手生成。

Evasion
  王手回避手を生成する。

Legal
  合法手を生成する。

PseudoLegal
  疑似合法手を生成する。

generateMoves<MT>(moveList, pos) は、手番が先手なら GenerateMoves<MT, Black>、後手なら GenerateMoves<MT, White> を呼ぶ。


王手生成の中身

王手生成では、大きく次の2種類を考える。

直接王手
  動いた駒そのものが相手玉に利く。

開き王手
  pin されている駒が動き、後ろの飛車・角・香などの利きが通る。

generateMoves.cppGenerateMoves<Check, US, ALL> では、王手を次のように整理している。

1. 成らない移動による直接王手
2. 成る移動による直接王手
3. pinされている駒の移動による間接王手

コード中では、直接王手候補を x、開き王手候補を y として分けて扱っている。ほとんどの局面では開き王手候補が空なので、その前提で最適化されている。

また、歩打ちでは二歩と打ち歩詰めを避ける処理がある。


王手回避生成の中身

GenerateMoves<Evasion, US, ALL> は、王手回避手を生成する。

流れはこうである。

1. 王手している駒を列挙する。
2. その駒の利きで玉が逃げられないマスを bannedKingToBB に入れる。
3. 玉の逃げ手を生成する。
4. 両王手なら玉移動しかないので終了する。
5. 単王手なら、王手駒を取る手や合駒を生成する。

makeBannedKingTo() は、王手している駒の種類に応じて、玉が移動できないマスを作る。GenerateMoves<Evasion> は、玉の移動先に敵の利きがある場合や pin された駒を動かす場合など、非合法手を含む疑似合法手を生成し、あとで合法性を確認する設計である。


generate_check_moves() の役割

osl_dfpn.cppgenerate_check_moves() は、攻方ノード用の王手生成である。

分岐はこうである。

pos.inCheck() == true
  王手を受けながら、さらに王手になる手だけを拾う。

pos.inCheck() == false
  generateMoves<CheckAllOslmate> で王手を生成する。

その後、

append_ignored_unpromote_checks
reorder_osl_numbered_check_moves
sort_moves

を行う。

つまり、

王手を作る
  ↓
必要な不成手を追加する
  ↓
OSL風の順序に整える
  ↓
探索用にソートする

という流れである。generate_check_moves()generateMoves<CheckAllOslmate> を使ったあと、OSL の phase / target / piece-number order に寄せてから sort_moves() をかけている。


generate_escape_moves() の役割

generate_escape_moves() は、受方ノード用の逃げ手生成である。

流れはこうである。

1. delay_node かどうかを見る。
2. delay_node なら特定マスへの cheap king escape を生成する。
3. 通常は cheap escape を生成する。
4. OSL駒番号順に補正する。
5. sort_moves する。
6. need_full_width なら全逃げ手を追加する。
7. 必要な不成逃げ手を追加する。

need_full_width が false のときは、まず軽い逃げ手だけを見る。危険な場合や詰み確認が必要な場合に need_full_width を立て、全幅の逃げ手を追加する。


need_full_width の意味

need_full_width は、受方の逃げ手生成をどこまで広げるかを表す。

need_full_width == 0
  cheap escape だけで探索する。

need_full_width != 0
  全幅の逃げ手を生成する。

詰み探索では、毎回すべての逃げ手を生成すると重い。

そこで最初は、

重要そうな逃げ手だけ見る

という軽量探索をする。

ただし、それだけで詰みと断定すると危険である。そこで必要になったら全逃げ手を追加する。defense() では、詰み成功に見えても need_full_width == 0 の場合、全幅確認へ戻す処理がある。


OSL互換の手順序

このコードには、手順序を整える関数が多くある。

reorder_osl_numbered_check_moves(...)
reorder_osl_numbered_escape_target_moves(...)
reorder_osl_numbered_long_moves_impl(...)
sort_moves(...)
dfpn_check_move_phase(...)
dfpn_check_move_subphase(...)

これは単なるソートではない。

cshogi の汎用王手生成順は OSL の raw order と違うため、OSL の phase / target / piece-number order に近づけてから sort_moves() を行う。

つまり、このコードは、

普通の df-pn 実装

ではなく、

OSL の探索順序に近づけた df-pn 実装

として読む必要がある。

探索アルゴリズムでは、同じ手を最終的に読むとしても、読む順番が速度や保存結果に大きく影響する。詰将棋探索では特に重要である。


OslPieceNumberState は駒番号の再現装置

OSL では、駒に番号があり、その番号順に手を生成する場面がある。

このコードでは、

class OslPieceNumberState

がそれを再現する。

主な中身は、

std::array<PieceInfo, 40> pieces_;
std::array<int, SquareNum> board_number_;

である。

将棋の駒は全部で40枚なので、40個の PieceInfo を持つ。

このクラスは、

この駒は何番か
今どこにあるか
持ち駒か
手を指したらどう更新するか
undo でどう戻すか

を管理する。

OslPieceNumberState には from_position()from_history()apply_move()undo_move()number_of_move_piece()number_of_square() がある。探索中に手順履歴へ追従し、OSL風の駒番号順ソートに使われる。


不成手の扱い

将棋では、成れる駒を成らずに指す手がある。

普通は成った方が強いので省略したくなるが、詰将棋では不成が意味を持つ場合がある。

このコードでは、

has_ignored_unpromote_check(...)
has_ignored_unpromote_escape(...)
is_ignored_unpromote_check_variant(...)
append_ignored_unpromote_checks(...)
unpromote_counterpart(...)

がそのためにある。

generate_check_moves() では、生成後に append_ignored_unpromote_checks() を呼ぶ。generate_escape_moves() でも、full-width 時に必要な不成逃げ手を追加する。


打ち歩詰めと PawnCheckmate

将棋固有のルールとして、打ち歩詰めがある。

generateMoves.cpp の歩打ち生成では、

一段目には打てない
二歩を避ける
打ち歩詰めを避ける

という処理がある。

osl_dfpn.cpp 側にも、

has_pawn_drop_checkmate(...)
ProofDisproof::PawnCheckmate()
normalize_pawn_drop_no_escape(...)
is_pawn_drop_no_escape(...)

がある。

読み方としては、

歩打ちで詰んでいるように見える
  ↓
しかし打ち歩詰めなら通常の詰みとは違う
  ↓
PawnCheckmate として特別扱いする

である。


局面キーとハッシュ

置換表に保存するため、局面にはキーが必要である。

主な関数はこれである。

osl_board64_key(pos)
osl_board32_key(pos)
board_index_key(pos)
secondary_board_key(pos)
board_index_key_after_move(...)
secondary_board_key_after_move(...)
board_keys_after_move(...)

このコードでは、盤面キーを2つに分けている。

board_index
  主キー。手番情報も含む。

board_secondary
  補助キー。

そして、

struct OslmateBoardKey {
    Key board_key;
    uint64_t board_secondary;
};

として扱う。

ただし将棋では、盤面が同じでも持ち駒が違えば別局面である。そのため、DfpnRecordstands[Black]stands[White] を持ち、同じ盤面キーの bucket に持ち駒違いのレコードを複数保存する。


DfpnRecord は局面のカルテ

DfpnRecord は、1局面分の探索結果を保存する構造体である。

主なフィールドは次の通りである。

ProofDisproof proof_disproof;
Move best_move;
Hand stands[ColorNum];
uint64_t board_secondary;
Hand proof_pieces;
Hand proof_pieces_candidate;
uint64_t solved;
uint64_t dag_moves;
uint32_t node_count;
uint32_t tried_oracle;
uint32_t min_pdp;
uint16_t remaining_depth;
Move last_move;
Square last_to;
ProofPiecesType proof_pieces_set;
uint8_t need_full_width;
bool false_branch;
bool dag_terminal;
bool exact;

意味はこうである。

proof_disproof
  この局面の現在の評価。

best_move
  有力手、または証明に使った手。

stands
  この記録が前提にしている持ち駒。

proof_pieces
  詰み証明または不詰証明に必要な持ち駒情報。

solved
  解決済みの子を bit で保持。

dag_moves
  DAG的にまとめられる子を bit で保持。

node_count
  この局面以下の探索量。

tried_oracle
  証明流用をどこまで試したか。

last_move
  DAG探索や証明流用で手順をたどるための手。

last_to
  逃げ手生成の限定に使うマス。

need_full_width
  全幅逃げ手生成が必要か。

exact
  厳密な探索結果かどうか。

つまり、

DfpnRecord = 局面の探索結果・再利用情報・証明情報のカルテ

である。


DfpnTable は置換表

df-pn では、探索済みノードの証明数・反証数を保存する置換表が不可欠である。置換表を使うことで、同じ局面を何度も探索する手間を減らせる。

このコードでは、それが DfpnTable である。

std::unordered_map<
    OslmateBoardKey,
    std::forward_list<DfpnRecord>,
    OslmateBoardKeyHash
> table_;

普通の置換表なら、

局面キー → レコード

しかし、このコードでは、

盤面キー → 複数の DfpnRecord

である。

持ち駒違いの局面を同じ bucket に入れるためである。


DfpnTable::probe() の読み方

DfpnTable::probe() は非常に重要である。

流れはこうである。

1. board key で bucket を探す。
2. bucket がなければ Unknown 相当の DfpnRecord を返す。
3. 同じ持ち駒の record があればそれを使う。
4. 詰み成功 record なら、攻方持ち駒が proofPieces を満たすか見る。
5. 不詰 record なら、受方持ち駒が disproofPieces を満たすか見る。
6. final でない record から proof_hint / disproof_hint を作る。
7. 未知なら初期 proof/disproof を hint で調整する。

この「持ち駒優越」を使うところが、普通の置換表より強力である。probe() は完全一致だけでなく、詰み成功なら攻方持ち駒が証明に必要な駒を満たすか、不詰なら受方持ち駒が反証に必要な駒を満たすかを確認する。


持ち駒優越とは何か

たとえば、

攻方が金1枚を持っていれば詰む

と証明できたとする。

その場合、

攻方が金1枚と銀1枚を持っている

局面でも、その詰み証明を再利用できる可能性が高い。

これが 持ち駒優越である。

コードでは、

osl_stand_is_superior_or_equal(...)
setProofPieces(...)
setDisproofPieces(...)
proofPieces()
disproofPieces()

が関係する。

詰み成功の証明では攻方の持ち駒、詰み失敗の証明では受方の持ち駒を見て、過去の探索結果を再利用する。


SearchContext は探索中の作業机

SearchContext は、探索中に必要な状態をまとめた構造体である。

主なフィールドはこうである。

std::vector<PathEncoding> path_encodings;
std::vector<Move> move_history;
std::vector<Threshold> threshold_history;
std::vector<PathRecord*> path_records;
std::unique_ptr<Position> root_position;
Optional<OslPieceNumberState> piece_numbers;
std::vector<Key> board_index_history;
std::vector<uint64_t> board_secondary_history;
std::vector<std::array<Hand, ColorNum>> stand_history;
std::vector<ActiveNode> active_nodes;
std::vector<std::vector<Move>> move_scratch;
std::vector<std::vector<ChildState>> child_scratch;

役割はこうである。

move_history
  ここまでの手順。

threshold_history
  各深さの proof/disproof しきい値。

path_encodings
  経路ハッシュ。

path_records
  ループ検出用の現在経路。

piece_numbers
  OSL駒番号順を再現する状態。

board_index_history / board_secondary_history
  局面キー履歴。

stand_history
  持ち駒履歴。

active_nodes
  現在展開中の祖先ノード群。

move_scratch / child_scratch
  深さごとの手・子状態の作業領域。

Position::doMove() だけでは、探索履歴や駒番号状態は更新されない。そのため探索中は、pos.doMove() とあわせて ctx.push_move()、戻すときに ctx.pop_move() を使う。


ループ検出と経路情報

サイクルを含むゲームでは、同じ局面に戻ることがある。

サイクルを含む探索空間では局面の真偽が経路に依存することがあり、単純な置換表利用で誤る可能性がある。

このコードでは、

PathEncoding
PathRecord
DfpnPathTable
VisitLock
child_loop_reason(...)
child_is_loop(...)

がそれに対応する。

PathRecord は、

int distance;
bool visiting;
uint32_t node_count;
std::forward_list<PathEncoding> twin_paths;

を持つ。

意味はこうである。

visiting
  現在の探索経路上にいるか。

twin_paths
  経路依存として記録された path。

distance
  探索経路上での距離情報。

node_count
  その path record の探索量。

VisitLock は RAII で、関数に入ったら visiting = true、出るときに false に戻す。

初心者向けにはこうである。

DfpnTable
  過去に調べた局面の表。

DfpnPathTable
  今たどっている道の表。

PathEncoding
  道筋を小さな値にしたもの。

DAG二重カウント対策

探索空間が木ではなくDAGの場合、別手順で同じ局面に到達することがある。

このとき、証明数・反証数を二重に数えてしまう問題が起こる。

このコードで対応するのが、

dag_moves
dag_terminal
find_dag_source(...)
add_dag(...)

である。

find_dag_source() は、終端レコードの last_move を使って親方向へたどり、現在の active_nodes の祖先と照合する。該当する祖先が見つかると、祖先レコードの dag_moves に bit を立てる。

意味はこうである。

同じ局面に合流する枝を見つける
  ↓
同じ proof/disproof を何度も数えないようにする

attack()defense() の中でも、proof/disproof が一定以上大きい未解決子に対して find_dag_source() を試す。


証明流用とシミュレーション

すでに証明されたノードの証明木を使って、似た別ノードも証明できるかを高速に調べる手法がある。ORノードでは証明木から取り出した指し手だけを使い、ANDノードではすべての応手を確認する。

このコードでは、対応するのが次の関数群である。

try_proof(...)
proof_oracle_attack(...)
proof_oracle_defense(...)
blocking_simulation(...)
grand_parent_simulation(...)

proof_oracle_attack() は、

既存の詰み証明レコードを探す
  ↓
best_move を取り出す
  ↓
現在局面でも使えるように補正する
  ↓
その1手だけを試す
  ↓
proof_oracle_defense() で受方応手を確認する

という流れである。

proof_oracle_attack() では、find_proof_oracle() で証明済みレコードを探し、adjust_oracle_attack_move() で現在局面でも使える手に補正する。


blocking_simulation

blocking_simulation() は、受方の合駒・遮断手のような「似た応手」に証明を流用できるか試す。

流れはこうである。

1. ある子が既に詰み成功している。
2. 同じ target に向かう別の未解決手を探す。
3. その手に対して proof_oracle_attack() を試す。
4. 成功すれば、その子の結果を更新する。

同じマスに合駒する手などは、証明構造が似ている可能性がある。そこで、すでに証明済みの子を利用して別の子も確認する。


grand_parent_simulation

grand_parent_simulation() は、祖父ノードの証明情報を使って現在の子にも証明を流用できるか試す処理である。

条件を満たすと、

祖父ノードの証明済み情報を oracle とする
  ↓
現在の子局面で proof_oracle_attack() を試す
  ↓
結果を child にコピーする

という流れになる。

grand_parent_simulation_suitable() は、直近3手の形を見て、証明流用に向いた形かを判定する。


メモリ管理とGC

df-pn は置換表を大量に使う。

ハッシュ表のメモリが限られるため、小さい部分木の情報を優先的に消す考え方がある。

このコードでは、

DfpnTable::run_gc()
DfpnPathTable::run_gc()
Impl::run_gc_if_needed(...)

が対応する。

DfpnTable::run_gc() は、

final でない
node_count が小さい

レコードを削除対象にする。

つまり、

探索量が小さく、再構築しやすそうな記録を消す

という考え方である。DfpnPathTable::run_gc() も、visiting 中のものや node_count が大きいもの、twin_paths を持つ重要そうなものを残す。


EffectSetCache は利き計算のキャッシュ

詰み探索では、

このマスに攻方の利きがあるか
受方の利きが何枚あるか
複数利きか

を大量に調べる。

このコードでは、

effect_set_at(...)
effect_has_at(...)
effect_count(...)
has_multiple_effect_at(...)

がよく出る。

内部では EffectSetCachepos.attackersTo(...) の結果をキャッシュする。

これにより、

王手候補の評価
玉の逃げ道評価
攻方・受方の利き比較
手のソート
1手詰め判定
固定深さショートカット

が速くなる。


oslmate との速度差の主因

osl_dfpn.cpp は、探索結果の PV や nodes が oslmate と一致するように、探索順序や置換表の扱いを OSL に寄せている。

ただし、速度まで完全に同じにはならない。主因は、

同じ node を読むときの、1 node あたりの局面評価コスト

の違いである。

oslmate は NumEffectState という局面表現を使う。これは、局面を進めたり戻したりするときに、利き情報も一緒に増分更新する仕組みである。

そのため oslmate では、次のような情報を比較的安く取り出せる。

このマスに利きがあるか
利きが何枚あるか
玉が王手されているか
pin / open 情報
玉周辺の利き情報

一方、この移植版は cshogi/Apery 系の PositionBitboard を使っている。EffectSetCache は利き計算をキャッシュするが、OSL の NumEffectState のように、局面の進行に合わせて全体の利き状態を常に増分更新しているわけではない。

そのため、探索中に次のような処理が多く発生する。

pos.attackersTo(...) を使った利きの再計算
effect_has_at(...) や effect_count(...) の問い合わせ
pin / open 判定の再構築
王手生成・王手回避生成の補助判定
玉周辺8マスの安全性評価

nodes が oslmate と一致している場合、探索木の形はおおむね一致している。それでも実行時間に差が出るなら、原因は「余分な node を読んでいること」ではなく、

同じ node を処理するために必要な利き計算や補助判定が重いこと

と考えるのが自然である。

つまり、EffectSetCache はこの実装における速度差を縮めるための重要な部品であるが、NumEffectState そのものの代替ではない。oslmate との残る速度差は、この局面表現の違いに由来する部分が大きい。


sort_moves() の意味

sort_moves() は、生成した王手や逃げ手を探索しやすい順に並べる。

主に見ているのは、

攻方の利きが受方の利きより多いか
移動先の位置
駒種
成りかどうか

である。

ただし、このソートは単なる良さそうな手順序ではない。OSL の生成順序に近づけた上で、その segment 内を並べ替える、という設計である。

generate_check_moves()generate_escape_moves() のコメントにも、OSL の raw order や piece-number order を再現する意図が書かれている。


PV取得:詰み手順を取り出す

探索で詰みが分かったあと、実際の手順を取り出す必要がある。

そのための関数が、

get_pv(...)
retrieve_pv(...)
pv_attack(...)
pv_defense(...)
linked_pv_attack_move(...)
linked_pv_defense_move(...)

である。

流れはこうである。

retrieve_pv()
  ↓
攻方番なら pv_attack()
受方番なら pv_defense()
  ↓
best_move を探す
  ↓
局面を進める
  ↓
次のノードへ

pv_attack() は、1手詰め、固定深さショートカット、置換表の best_move、証明流用などを順に見て、詰み手順を構成する。


理論と実装の対応表

理論上の概念 このコードの対応
ORノード attack()
ANDノード defense()
攻方の詰み成功
攻方の詰み失敗
証明数 ProofDisproof::proof
反証数 ProofDisproof::disproof
先端ノード初期値 ProofDisproof(1, 1)Unknown
先端ノード評価 attack_estimation_zero()estimate_attack_pdp()
df-pn のしきい値 Threshold{proof, disproof}
SelectChild next_index 選択
トランスポジションテーブル DfpnTable
TTLookup DfpnTable::probe()
TTSave store_exact_oslmate() / store_nonexact_oslmate()
経路情報 PathEncoding
経路依存記録 PathRecord::twin_paths
ループ検出 DfpnPathTablechild_loop_reason()
DAG二重カウント対策 find_dag_source()dag_moves
シミュレーション proof_oracle_*blocking_simulation()grand_parent_simulation()
GC DfpnTable::run_gc()DfpnPathTable::run_gc()
証明木の指し手 best_movelast_move
PV取得 get_pv()retrieve_pv()pv_attack()pv_defense()

主要関数の一言辞書

公開API

OslDfPn::dfpn
  現在手番を攻方として詰みを探索する。

OslDfPn::dfpn_andnode
  現在局面を受方ノードとして探索する。

OslDfPn::dfpn_move
  探索結果の best_move を返す。

OslDfPn::dfpn_probe
  現在局面の ProofDisproof を置換表から読む。

OslDfPn::get_pv
  詰み手順を取り出す。

探索本体

Impl::attack
  攻方ノード。王手を選ぶ。

Impl::defense
  受方ノード。逃げ手を選ぶ。

try_proof
  証明流用を試す入口。

proof_oracle_attack
  過去の詰み証明を攻方側から流用する。

proof_oracle_defense
  流用した証明が受方応手すべてに耐えるか確認する。

手生成

generate_check_moves
  df-pn 用の王手生成。

generate_escape_moves
  df-pn 用の逃げ手生成。

generate_fixed_depth_check_moves
  固定深さショートカット用の王手生成。

generateMoves<CheckAllOslmate>
  詰み探索用の王手生成。

generateMoves<Evasion>
  王手回避生成。

保存・再利用

DfpnRecord
  1局面分の探索結果。

DfpnTable
  探索結果の置換表。

DfpnPathTable
  探索経路上の局面表。ループ検出用。

PathEncoding
  手順を hash 的に表す。

OslPieceNumberState
  OSL と同じ駒番号順を再現するための状態。

高速化

immediate_mate_move_in_1_osl
  1手詰め候補の高速判定。

fixed_attack_osl_shortcut
  浅い固定深さ探索。

King8RuntimeInfo
  相手玉の8近傍評価。

EffectSetCache
  利き情報のキャッシュ。

sort_moves
  OSL互換と探索効率のための手順序付け。

読む順番のおすすめ

1周目:入口と主再帰だけ読む

OslDfPn::dfpn()
  ↓
Impl::attack()
  ↓
Impl::defense()

ここでは、細かい補正・証明流用・DAG対策は読まなくて大丈夫である。

見るべきことは、

attack は王手を生成して defense を呼ぶ。
defense は逃げ手を生成して attack を呼ぶ。

だけである。

2周目:proof/disproof の集約を見る

attack() では、

min_proof
second_proof
sum_disproof
next_index

を探す。

defense() では、

sum_proof
min_disproof
second_disproof
next_index

を探す。

ここが理論と実装の一致点である。

3周目:置換表を見る

DfpnRecord
DfpnTable::probe()
DfpnTable::store_exact_oslmate()
DfpnTable::run_gc()

ここで、

探索結果をどう保存するか
同じ局面をどう再利用するか
持ち駒優越をどう使うか

を理解する。

4周目:手生成を見る

generate_check_moves()
generate_escape_moves()
generateMoves<CheckAllOslmate>
GenerateMoves<Check>
GenerateMoves<Evasion>

ここで、探索本体と手生成がつながる。

5周目:高度な再利用を見る

PathEncoding
DfpnPathTable
find_dag_source
proof_oracle_attack
proof_oracle_defense
blocking_simulation
grand_parent_simulation

ここは最後でよい。


全体を4層に分ける

このソースは大きいが、4層に分けると見通しがよくなる。

第1層: 公開API
  OslDfPn::dfpn
  OslDfPn::dfpn_andnode
  OslDfPn::dfpn_probe
  OslDfPn::get_pv

第2層: df-pn探索本体
  Impl::attack
  Impl::defense
  Threshold
  ProofDisproof
  ChildState
  SearchContext

第3層: 保存・再利用・ループ対策
  DfpnRecord
  DfpnTable
  DfpnPathTable
  PathEncoding
  find_dag_source
  proof_oracle_attack
  proof_oracle_defense
  blocking_simulation
  grand_parent_simulation

第4層: 将棋固有の手生成・評価
  generate_check_moves
  generate_escape_moves
  generateMoves<CheckAllOslmate>
  generateMoves<Evasion>
  King8RuntimeInfo
  EffectSetCache
  OslPieceNumberState
  fixed_attack_osl_shortcut

最初から第4層を全部読もうとすると、かなり大変である。

まず第1層と第2層で、

attack / defense の相互再帰
proof/disproof の min/sum

を押さえるのが安全である。


このコードの設計思想

このコードは、単純な df-pn ではない。

設計思想は、おそらく次のようなものである。

1. df-pn の攻方/受方ノード構造を使う。
2. OSL の詰将棋探索に近い手順序を再現する。
3. cshogi/Apery系の Position / Move / Bitboard を使う。
4. 置換表で proof/disproof と best_move を保存する。
5. 将棋の持ち駒優越を使って、完全一致でない局面も再利用する。
6. 経路情報を使ってループや経路依存問題を抑える。
7. DAG合流を検出して二重カウントを減らす。
8. 証明流用で似た局面の探索を省く。
9. 1手詰め・浅い詰みは本格探索前に拾う。
10. メモリが増えすぎたらGCする。
11. 探索後にPVとして詰み手順を取り出す。

このため、attack()defense() だけならもっと短く書けるはずであるが、実戦的な詰将棋探索として多くの補助機構が追加されている。


混乱しやすいポイント

attack_colorpos.turn()

attack_color は、この探索全体で詰ませたい側である。

pos.turn() は現在手番である。

多くの場合、

attack() 中:
  pos.turn() == attack_color

defense() 中:
  pos.turn() == oppositeColor(attack_color)

である。

ただし、証明流用や履歴処理では混乱しやすいので注意である。

NoEscape は詰み成功

NoEscape は受方に逃げがない状態である。

攻方から見ると詰み成功である。

same_stands() が Black の持ち駒だけを見る

DfpnTable::same_stands() は、record.stands[Black] == pos.hand(Black) を見る。

これは OSL 互換の HashKey 的な扱いに由来する。普通の「両者の持ち駒を全部比較するはず」という感覚で読むと戸惑う。

証明流用は通常探索ではない

proof_oracle_* は通常の attack() / defense() とは違い、既存の証明手を利用して似た局面を高速確認する。

失敗しても「不詰」ではない。

証明流用に失敗した
  ↓
まだ分からない
  ↓
通常の df-pn 探索が必要

である。


最短理解用のミニ課題

課題1: dfpn() から attack() へ線を引く

見る場所は、

OslDfPn::dfpn
Impl::SearchContext ctx
ctx.set_root(pos)
impl_->attack(...)

確認することは、

root_attack_color が pos.turn() になる
node_limit を段階的に増やす
探索状態を最初にクリアする

である。

課題2: attack() の王手生成を見る

探す行は、

generate_check_moves(pos, moves, ...)

その前後に、

置換表 probe
固定深さショートカット
GC
王手がなければ NoCheckmate

がある。

課題3: defense() の逃げ手生成を見る

探す行は、

generate_escape_moves(pos, moves, ...)

その後に、

moves.empty() なら NoEscape

になることを確認する。

課題4: 手生成側へ飛ぶ

generate_check_moves() から、

generateMoves<CheckAllOslmate>

へ飛ぶ。

そこで、

GenerateMoves<Check>
直接王手
開き王手
駒打ち王手

の生成を見ると、探索本体と手生成の接続が分かる。

課題5: DfpnTable::probe() を読む

見るポイントは、

同じ持ち駒なら record を返す
詰み成功 record は proofPieces で再利用する
不詰 record は disproofPieces で再利用する
未知なら proof_hint / disproof_hint を使う

である。


最後のまとめ

このコードは、理論的には、

AND/OR木
  ↓
証明数・反証数
  ↓
df-pn
  ↓
置換表
  ↓
経路情報・DAG対策・証明流用

という流れの上にある。

実装としては、

attack()
  攻方ノード。
  王手を生成し、proof 最小の子を defense() で読む。

defense()
  受方ノード。
  逃げ手を生成し、disproof 最小の子を attack() で読む。

DfpnTable
  探索結果を保存・再利用する。

DfpnPathTable
  現在経路を管理してループを検出する。

generateMoves.cpp
  王手・逃げ手・駒打ち・合法手生成の工場。

King8RuntimeInfo
  相手玉の8近傍から詰みやすさを見積もる。

proof_oracle_* / simulation
  過去の証明を似た局面へ流用する。

である。

いちばん大切なのは、やはりこれである。

attack は OR。
どれか1つ詰めばよい。
proof = min、disproof = sum。

defense は AND。
全部の逃げを潰す必要がある。
proof = sum、disproof = min。

この対応を見失わなければ、置換表、持ち駒優越、経路情報、DAG、証明流用、手生成順序などの複雑な仕組みも、すべて「df-pn を将棋の詰み探索として実用化するための部品」として整理して読める。

cshogiにGPS将棋のdf-pnを移植(その2 攻め方に玉がない局面に対応)

前回、cshogiに移植したGPS将棋のdf-pnを、攻め方に玉がない局面に対応した。

対応内容

  • 既存のPositionクラスのcheckerBBを玉がない場合、空にする
  • findCheckers()、hiddenCheckers()を玉なしの場合、空にする
  • isOSのassertを玉がない場合に対応
  • 探索アルゴリズム中のkingSquare_をSquareNum(駒台)扱いにする
  • 玉位置参照している箇所をisInSquare()でガード

PyPI

対応したバージョンをPyPIにv1.0.1として登録した。

github.com

まとめ

cshogiを攻め方に玉がない局面に対応した。
これで、詰将棋ソルバーとして使用可能になった。
探索途中の読み筋は出力していないので、GUIから使う詰将棋エンジンとしては使いにくいが、Pythonスクリプトでバッチ的に処理するのには便利だと思う。

cshogiにGPS将棋のdf-pnを移植

以前に、GPS将棋(OpenShogiLib)のdf-pnをdlshogiのAperyベースに移植することを試みたが、変更が多すぎて途中で挫折したことを書いた。

今回は、Codex(GPT-5.4とGPT-5.5)を使って、移植できるか試した。

移植の難しさ

GPS将棋は、盤面管理、合法手生成が、Aperyベースのdlshogi/cshogiとは大きく異なり、詰み探索がそれらの処理とも密接に関連しているため、ソースの大部分を作り替える必要がある。
また、GPS将棋はC++のテンプレートを駆使して実装されているため、部分的な再利用が行いづらい。

以下のような多岐にわたる処理の移植が必要である。

- df-pn 探索ループ
- OR node / AND node の探索処理
- proof / disproof number の更新規則
- ProofDisproof の特殊値体系
- DfpnRecord 相当の探索レコード管理
- transposition table 相当の DfpnTable
- same-board / same-stand を考慮した record probe / store
- exact / non-exact record の扱い
- root search の反復深化的な node limit 拡張
- path table によるループ検出
- dominance / twin path 判定
- oracle proof / disproof 系の処理
- PV 復元処理
- PV depth table 相当の処理
- OSL の hash key 互換処理
- OSL の proof number 初期値・増加規則
- OSL の手順依存 record / path encoding
- OSL の GC / table growth limit 周辺処理
- 王手生成・王手回避生成以外の df-pn 周辺ロジック

段階的な移植

第1段階

GPS将棋(OpenShogiLib)は、df-pn以外も含む将棋ライブラリになっているため、まずは、df-pnに関係する処理のみを抜き出し、移植を行いやすくした。
以前の詰み探索USI mateエンジン化する作業がそれに相当する。

第2段階

df-pn関連処理のみにしたレポジトリ(oslmate)を元に、Codexを使用して、cshogiに移植する。

移植作業

初回指示

3/5/7/9/11手詰めと不詰みのテストコードを用意し、oslmateのレポジトリを、cshogiと同じディレクトリに配置し、以下の指示を行った。

src\dfpn.cppに実装しているdf-pnを、oslmateをベースに作り替えたい。
現在の実装は流用可能な部分だけ流用して、必要なら完全に破棄してもよい。

olsmateは、oslmateフォルダにある。

oslmateの王手・王手回避生成の王手・王手回避生成は使用しないで、既存のsrc\dfpn.cppで使用しているsrc\generateMoves.cppの王手・王手回避生成をベースにする。ただし、変更が必要な場合は、src\generateMoves.cppを直接修正しないで、dfpn.cppに実装する。

oslmateのソースが多数分かれているが、dfpn.hとdfpn.cppに実装し直す。
モダンなC++17で実装し直す。

テストは、test_cpp\test_board.cppにある。必要なテストを追加してよい。
テストを実行するときはバグで無応答になることを考慮してタイムアウトを設定すること。

ビルドコマンド:
    デバッグビルド: msbuild cshogi.sln /t:Build /p:Configuration=Debug /p:Platform=x64
    リリースビルド: msbuild cshogi.sln /t:Build /p:Configuration=Release /p:Platform=x64


oslmateの実行方法例:
olsmate\x64\Release\oslmate.exeを実行して、標準入力に以下の順に入力

isready
readyokが返ってから
usinewgame
position sfen +B+R5n1/5gk2/p1pps1gp1/4ppnsK/6pP1/1PPSP3L/PR1P1PP2/6S2/L2G1G3 w B2N2LP2p 1
go mate infinite
checkmate で始まる行が返ってから
quit で終了

まずは、現状のdfpnとoslmateのプログラムを理解し、差分を把握して、移植の方針を立ててください。

計画を元に実装を指示した。

この作業は4月下旬に開始したので、まだGPT-5.5がリリースされていなかったため、GPT-5.4で行った。

移植対象が多く、未実装や仮実装がある状態であきらめるため、具体的な処理に踏み込んで実装するように指示する必要があった。

一つのテストをパスしても他が通らなくなることを何度も繰り返すので、処理を確認すると、oslmateにないshortcutを実装してテストを通そうとしていた。
また、一部の処理を正しく実装しても、完全に実装するまでは、テストがNGになる場合があり、それが理由ですぐに元に戻そうとしていた。

指示を変更

テストを通すために実装したり、途中までの実装を戻すことを禁止する必要があるため、以下のように指示を変更した。

oslmateに合わせて他のテストが壊れても、oslmateに合わせることが優先です。
壊れる原因をつぶしてください。

oslmateではテストが通ることは担保されており、oslmateは動く正解です。
テストが通らないのは、oslmateと実装が異なるためです。
oslmateと合わせる実装をしたことでテストが悪化したとしも戻す必要はありません。
継続してoslmateとの違いを調べるべきです。
テストを通すために独自のshortcutや近似で対応することを禁じます。そのための時間は完全に無駄です。
近似実装はそのままにせず正確な実装にしてください。
oslmateとの挙動の差分を詳細に直接観測してoslmateと実装を合わせてください。

このあたりで、1週間くらい経過して、GPT-5.5がリリースされたので、GPT-5.5に切り替えた。
Plusプランだったので、1日5時間くらいで動かしていた。
週の上限もあるため、後半は数日待ちが入ったり、予定より早くリセットされたりで再開したりしていた。

指示の簡略化

上記のように細かく禁止事項を書いても指示を守らない場合があるため、指示をシンプルにした。

oslmateのソースと一致する修正であれば悪化したからといって戻す必要はありません。実装を継続してください。
前後の処理を十分に確認してolsmateと一致するように修正してください。
oslmateのソースと一致しない修正を試すことを禁止します。

これを繰り返していると、3/5/7/9/11手詰めと不詰みのテストコードはパスするようになった。
開始して2週目くらいだったと思う。

テスト追加

将棋図巧第1番をテストに追加して、同様に移植の指示を行った。
これが、いつまでたっても終わらないため、途中で$100のProプランに切り替えた。
2倍キャンペーンが終わった後、$200プランに切り替えて、24時間動かすようにした。

2週間くらい続けてテストはパスするようになったが、oslmateとPVとノード数が一致しない状態だった。
PVとノード数が一致するまでひたすら同じ指示を繰り返して実装を続けた。

最終的にPVとノード数が一致

6月2週目になって、ようやくPVとノード数が一致した。

最後に、大量に仕込まれた観測用コードを削除指示した。

開始から、1か月と3週間ほどかかった。

ノウハウ

  • VS CodexのCodex拡張機能だと、履歴が長くなるとログが壊れるので、Codex CLIを使う
  • 観測コードを追加して移植元と移植先の実際の挙動を比較する指示を行う
  • 指示はシンプルにする(多すぎる指示は守られない)

検証

やねうら王詰将棋データセットがすべて通ることを確認した。

処理時間は以下の通り。

  • mate3.sfen

998404/998404 [1:10:25<00:00, 236.31it/s]

  • mate5.sfen

998824/998824 [1:25:42<00:00, 194.23it/s]

  • mate7.sfen

999071/999071 [1:40:00<00:00, 166.50it/s]

  • mate9.sfen

999672/999672 [2:17:09<00:00, 121.48it/s]

  • mate11.sfen

999998/999998 [3:20:41<00:00, 83.05it/s]

探索速度

将棋図巧第1番は、9562 msで解ける。

oslmate(GPT処理)は、5秒くらいなので、それには届いていない。

王手生成後ロジックの違いを吸収するために、王手生成後にソート処理を行っていたりする分遅くなっている。

cshogiに組み込み

cshogiに実装していたdlshogiのDfPnはそのまま残して、cshogi.oslモジュールにoslmateベースのDfPnを組み込んだ。

PYPIに、v1.0.0として登録済み。

使用例
from cshogi import *
from cshogi.osl import DfPn
board = Board()
board.set_sfen("1pG1B4/Gs+P6/pP7/n1ls5/3k5/nL4+r1b/1+p1p+R4/1S7/2N5K b SP2gn2l11p 1")
dfpn = DfPn()
dfpn.search(board)
注意事項
  • DfPnはcshogiモジュールにもあるので、import順に注意(from cshogi import *を後からimportするとDfPnが上書きされる)。
  • 現状、攻め方にも王が必要になっている。詰将棋に使えるように、今後対応予定。

移植したコード

移植したコードは、

cshogi/src/osl_dfpn.cpp at master · TadaoYamaoka/cshogi · GitHub

にある。
10319行ある。

まとめ

GPS将棋(OpenShogiLib)のdf-pnをcshogiに移植した。
まずdf-pn関連部分のみを切り出した「oslmate」を作成し、その後Codex(GPT-5.4→GPT-5.5)を用いて段階的に移植を進めた。
Proプランで24時間動かし続けて、約1か月半かかった。
完成したdf-pnは cshogi.osl.DfPn としてcshogi v1.0.0に組み込んだ。

dropless MoE(Mixture of Experts)を試す その6(推論速度比較)

前回、C++で実装したTensorRTプラグインを使ったMoEの推論処理の推論速度を比較する。

比較対象は、

  • Dense MoE
  • Sparse MoE (自前実装CUDAカーネル)
  • Sparse MoE (CUTLASS Grouped GEMM)

の3パターンとする。

比較条件

  • SwinTransformerのStage 0/1をMoE化
    • Stage 0の解像度 8x8、State 1の解像度 4x4
    • Stage 0のhidden_features 256、State 1のhidden_features 512
  • Expert数4
  • top_k=2
  • バッチサイズ128
  • FP16
  • RTX 4090で測定

前回実装した推論コマンドを呼び出す、以下のベンチマークスクリプトを使用する。

train_cifar10/scripts/benchmark_moe_trt.py at feature/moe · TadaoYamaoka/train_cifar10 · GitHub

測定結果

CUDA Graph 有効

batch=128 warmup=5 iters=10 cuda_graph=True

target avg_latency_ms throughput_images_per_s
Dense MoE 0.582870 219603.000
Sparse Plugin MoE 1.790130 71503.000
Sparse Plugin MoE (CUTLASS) 0.947411 135105.000

CUDA Graph 無効

batch=128 warmup=5 iters=10 cuda_graph=False

target avg_latency_ms throughput_images_per_s
Dense MoE 1.028980 124395.000
Sparse Plugin MoE 2.067490 61910.700
Sparse Plugin MoE (CUTLASS) 1.332480 96061.200

考察

Dense MoEが一番早く、Sparse MoEでは、CUTLASSを使った方が速いという結果になった。

以下に結果を考察する。

1. 概要

主因は、今回の条件では MoE を sparse にするメリットが小さく、sparse 実行のオーバーヘッドが相対的に大きいためであると考えられる。 結果を比較すると以下の通りである。

手法 Latency 備考
Dense MoE 0.583 ms 最速
Sparse Plugin MoE 1.790 ms Dense比 約3.07倍低速
Sparse Plugin MoE (CUTLASS) 0.947 ms Dense比 約1.63倍低速

CUTLASS 版は非 CUTLASS 版より約 1.89 倍高速であり、CUTLASS grouped GEMM 自体は有効に機能している。しかし、依然として Dense に及ばないという点が極めて重要である。

2. 理論メリットの限定性

今回の MoE は E=4, top_k=2 であるため、sparse 化による理論上の計算量削減幅が小さい。

  • Dense MoE: 全 expert (E=4) を計算。
  • Sparse MoE: top-k expert (K=2) のみを計算。

expert MLP の計算量のみに着目すれば、Sparse は Dense の約半分となる。しかし、モデル全体には Patch embedding、LayerNorm、Window attention、TensorRT の入出力処理などが含まれる。 モデル全体のレイテンシに対する MoE expert 計算の支配率が 100% でない限り、全体の短縮幅は 2 倍未満に留まる。この時点で Sparse が勝利するには、routing や packing などの追加コストを極小化する必要がある。

3. Dense MoE における TensorRT の最適化

Dense 版は全 expert を規則的な tensor 演算(torch.einsum 等)として実行する。これは算術量こそ増えるものの、GPU および TensorRT にとっては極めて扱いやすい。

  • contiguous な tensor 構造
  • 静的かつ規則的な shape
  • gather / scatter の抑制
  • TensorRT の標準オペレータとしての最適化
  • cuBLAS / Tensor Core tactic への適合

特に E=4 のような少数 expert 設定では、「計算の無駄」よりも「規則的な dense matmul による高速化」のメリットが上回る。

4. Sparse Plugin におけるオーバーヘッド

Sparse MoE は計算量を削減する代償として、不規則なメモリアクセスや多数のカーネル実行を必要とする。具体的には、routing、prefix sum、token assignment の packing、scatter-add、atomic add といった処理が Dense 版にはほぼ不要な追加コストとして発生する。 E=4, K=2 という設定では、削減できる計算量よりも、これら不規則処理のオーバーヘッドが支配的になりやすい。

5. Plugin 境界による最適化の制限

plugin モードでは、MoE 全体が trt.plugins::CustomMoE という単一の custom node となり、TensorRT から見て内部が不透明(opaque)になる。

  • TensorRT による layer fusion が適用されない。
  • 内部のカーネルに対して TensorRT の tactic selection が効かない。
  • カーネル起動やメモリアクセスの効率が、独自実装の品質に完全に依存する。

対して Dense 版は標準 ONNX オペレータのグラフであるため、TensorRT が広範囲にわたってグラフ最適化を適用できる。この差は極めて大きい。

6. Batch Size 128 における実行効率

今回のバッチサイズ 128 では、特に stage 0 等でトークン数が多くなる(例:128 * 64 = 8192 tokens)。 Dense MoE はこのまとまったトークン数を高 occupancy・高 arithmetic intensity で処理できる。一方、Sparse MoE では expert ごとにトークンが分割されるため、個々の GEMM サイズが小さくなり、metadata のセットアップやスケジューリングのオーバーヘッドによる効率低下の影響を受けやすくなる。

7. CUDA Graph の限界

cuda_graph=True の設定により CPU 側の起動オーバーヘッドは削減されているが、GPU 上での routing 計算、packing copy、atomicAdd といったデバイス側のコストは消失しない。今回の結果は、CPU launch ではなくデバイス側のオーバーヘッドが支配的であることを示唆している。

8. 結論:Dense が勝る理由と今後の展望

GPU 最適化において「無駄な計算をしても規則的な演算の方が速い」という典型的な事例である。

優先度の高い仮説

  1. E=4, K=2 では計算削減の恩恵がオーバーヘッドを相殺できない。
  2. Dense ONNX は TensorRT の標準最適化を最大限に享受している。
  3. Sparse plugin 内の routing や atomic 処理が依然として重い。

改善に向けたアプローチ:

今後 Sparse plugin で Dense を上回るには、少数 expert への特化(kernel fusion の徹底)、atomicAdd の排除、あるいはトークン数や設定に応じて Dense と Sparse を動的に切り替えるハイブリッド戦略の検討が有効である。

まとめ

TensorRTプラグインで実装したMoEの推論速度を比較した。
結果、Dense MoEが一番速いという結果になった。Sparse MoEではCUTLASSを使用することで約1.89 倍速くなることが分かった。

今回測定したSwinTransformerはモデルサイズが小さく、CIFAR-10の画像解像度も小さいため、Sparse化のオーバーヘッドが大きくかえって遅くなるという結果になった。

また、実装したCUDAカーネルはkernel fusionなどの最適化を行っていないため、改良の余地がある。
LLMの推論ライブラリとして人気のvLLMは、デバイスごとの手書きCUDAカーネルで最適化を行っている。
そこで使われている手法を参考すれば、まだまだ最適化できると思っている。

SwinTransformerでの実験はこれまでにして、次は、dlshogiのモデルにMoEを組み込むことを試したい。

dropless MoE(Mixture of Experts)を試す その5(推論処理)

前回、Pythonで実装したエンジンビルドスクリプトで保存した.engineを読み込んで、TensorRTで推論する処理をC++で実装する。

推論処理

TensorRTのライブラリの使用がメインである。 デフォルトで、CUDA Graphを有効にしている。 CUDA Graphは、プラグインでCUDAカーネルを呼び出す際のオーバーヘッドを削減する仕組みである。

infer.cpp

#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>

#include <dlfcn.h>

#include <algorithm>
#include <chrono>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <memory>
#include <numeric>
#include <random>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

using namespace nvinfer1;

class Logger final : public ILogger {
public:
    void log(Severity severity, const char* msg) noexcept override {
        if (severity <= Severity::kINFO) std::cerr << "[TRT] " << msg << '\n';
    }
};

#define CHECK_CUDA(expr) do { \
    cudaError_t _err = (expr); \
    if (_err != cudaSuccess) { \
        throw std::runtime_error(std::string("CUDA error: ") + cudaGetErrorString(_err)); \
    } \
} while (0)

struct Args {
    std::string engine;
    std::string plugin;
    std::string input_file;
    std::string output_file;
    int batch = 1;
    int channels = 3;
    int height = 32;
    int width = 32;
    int warmup = 20;
    int iters = 200;
    bool cuda_graph = true;
};

Args parseArgs(int argc, char** argv) {
    Args a;
    for (int i = 1; i < argc; ++i) {
        std::string k = argv[i];
        auto need = [&](const char* name) -> std::string {
            if (i + 1 >= argc) throw std::runtime_error(std::string("missing value for ") + name);
            return argv[++i];
        };
        if (k == "--engine") a.engine = need("--engine");
        else if (k == "--plugin") a.plugin = need("--plugin");
        else if (k == "--input") a.input_file = need("--input");
        else if (k == "--output") a.output_file = need("--output");
        else if (k == "--batch") a.batch = std::stoi(need("--batch"));
        else if (k == "--channels") a.channels = std::stoi(need("--channels"));
        else if (k == "--height") a.height = std::stoi(need("--height"));
        else if (k == "--width") a.width = std::stoi(need("--width"));
        else if (k == "--warmup") a.warmup = std::stoi(need("--warmup"));
        else if (k == "--iters") a.iters = std::stoi(need("--iters"));
        else if (k == "--no-cuda-graph") a.cuda_graph = false;
        else if (k == "--help" || k == "-h") {
            std::cout << "Usage: moe_trt_infer --engine model.engine [--plugin libcustom_moe_plugin.so] "
                      << "[--batch 1] [--input input.bin] [--output logits.bin] [--iters 200] [--no-cuda-graph]\n";
            std::exit(0);
        } else {
            throw std::runtime_error("unknown argument: " + k);
        }
    }
    if (a.engine.empty()) throw std::runtime_error("--engine is required");
    return a;
}

std::vector<char> readFile(const std::string& path) {
    std::ifstream f(path, std::ios::binary);
    if (!f) throw std::runtime_error("cannot open " + path);
    f.seekg(0, std::ios::end);
    size_t size = static_cast<size_t>(f.tellg());
    f.seekg(0, std::ios::beg);
    std::vector<char> data(size);
    f.read(data.data(), static_cast<std::streamsize>(size));
    return data;
}

size_t dtypeSize(DataType t) {
    switch (t) {
        case DataType::kFLOAT: return 4;
        case DataType::kHALF: return 2;
        case DataType::kINT8: return 1;
        case DataType::kINT32: return 4;
        case DataType::kBOOL: return 1;
#if NV_TENSORRT_MAJOR >= 9
        case DataType::kBF16: return 2;
#endif
        default: throw std::runtime_error("unsupported tensor dtype");
    }
}

int64_t volume(const Dims& d) {
    int64_t v = 1;
    for (int i = 0; i < d.nbDims; ++i) {
        if (d.d[i] < 0) throw std::runtime_error("dynamic dimension was not resolved");
        v *= d.d[i];
    }
    return v;
}

void fillRandomFloat(float* p, size_t n) {
    std::mt19937 gen(1234);
    std::normal_distribution<float> dist(0.0f, 1.0f);
    for (size_t i = 0; i < n; ++i) p[i] = dist(gen);
}

void loadInputFloat(const std::string& file, float* host, size_t count) {
    if (file.empty()) {
        fillRandomFloat(host, count);
        return;
    }
    std::ifstream f(file, std::ios::binary);
    if (!f) throw std::runtime_error("cannot open input file: " + file);
    f.read(reinterpret_cast<char*>(host), static_cast<std::streamsize>(count * sizeof(float)));
    if (static_cast<size_t>(f.gcount()) != count * sizeof(float)) {
        throw std::runtime_error("input file size does not match expected FP32 tensor size");
    }
}

int main(int argc, char** argv) {
    try {
        Args args = parseArgs(argc, argv);
        Logger logger;
        initLibNvInferPlugins(&logger, "");

        void* pluginHandle = nullptr;
        if (!args.plugin.empty()) {
            pluginHandle = dlopen(args.plugin.c_str(), RTLD_NOW | RTLD_GLOBAL);
            if (!pluginHandle) throw std::runtime_error(std::string("dlopen failed: ") + dlerror());
            std::cerr << "[INFO] loaded plugin " << args.plugin << '\n';
        }

        std::vector<char> engineData = readFile(args.engine);
        std::unique_ptr<IRuntime> runtime(createInferRuntime(logger));
        if (!runtime) throw std::runtime_error("createInferRuntime failed");
        std::unique_ptr<ICudaEngine> engine(runtime->deserializeCudaEngine(engineData.data(), engineData.size()));
        if (!engine) throw std::runtime_error("deserializeCudaEngine failed");
        std::unique_ptr<IExecutionContext> context(engine->createExecutionContext());
        if (!context) throw std::runtime_error("createExecutionContext failed");

        int nb = engine->getNbIOTensors();
        std::vector<std::string> names;
        names.reserve(nb);
        std::string inputName;
        std::vector<std::string> outputNames;
        for (int i = 0; i < nb; ++i) {
            const char* name = engine->getIOTensorName(i);
            names.emplace_back(name);
            if (engine->getTensorIOMode(name) == TensorIOMode::kINPUT) inputName = name;
            else outputNames.emplace_back(name);
        }
        if (inputName.empty()) throw std::runtime_error("no input tensor found");

        Dims inShape = engine->getTensorShape(inputName.c_str());
        if (inShape.nbDims != 4) throw std::runtime_error("expected NCHW input rank 4");
        inShape.d[0] = args.batch;
        inShape.d[1] = args.channels;
        inShape.d[2] = args.height;
        inShape.d[3] = args.width;
        if (!context->setInputShape(inputName.c_str(), inShape)) {
            throw std::runtime_error("setInputShape failed");
        }

        cudaStream_t stream{};
        CHECK_CUDA(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));

        struct Buf { void* dev = nullptr; void* host = nullptr; size_t bytes = 0; DataType dtype{}; Dims shape{}; bool isInput = false; };
        std::unordered_map<std::string, Buf> bufs;
        for (const auto& name : names) {
            Buf b;
            b.dtype = engine->getTensorDataType(name.c_str());
            b.shape = context->getTensorShape(name.c_str());
            b.bytes = static_cast<size_t>(volume(b.shape)) * dtypeSize(b.dtype);
            b.isInput = engine->getTensorIOMode(name.c_str()) == TensorIOMode::kINPUT;
            CHECK_CUDA(cudaMalloc(&b.dev, b.bytes));
            CHECK_CUDA(cudaHostAlloc(&b.host, b.bytes, cudaHostAllocDefault));
            if (!context->setTensorAddress(name.c_str(), b.dev)) {
                throw std::runtime_error("setTensorAddress failed for " + name);
            }
            bufs.emplace(name, b);
            std::cerr << "[INFO] tensor " << name << " bytes=" << b.bytes
                      << " dtype=" << static_cast<int>(b.dtype) << " dims=[";
            for (int j = 0; j < b.shape.nbDims; ++j) std::cerr << (j ? "," : "") << b.shape.d[j];
            std::cerr << "]\n";
        }

        Buf& input = bufs.at(inputName);
        if (input.dtype != DataType::kFLOAT) {
            throw std::runtime_error("runner currently expects FP32 network input; rebuild ONNX with FP32 input");
        }
        loadInputFloat(args.input_file, static_cast<float*>(input.host), input.bytes / sizeof(float));

        auto launchOnce = [&]() {
            CHECK_CUDA(cudaMemcpyAsync(input.dev, input.host, input.bytes, cudaMemcpyHostToDevice, stream));
            if (!context->enqueueV3(stream)) throw std::runtime_error("enqueueV3 failed");
            for (const auto& outName : outputNames) {
                Buf& out = bufs.at(outName);
                CHECK_CUDA(cudaMemcpyAsync(out.host, out.dev, out.bytes, cudaMemcpyDeviceToHost, stream));
            }
        };

        for (int i = 0; i < args.warmup; ++i) launchOnce();
        CHECK_CUDA(cudaStreamSynchronize(stream));

        cudaEvent_t start{}, stop{};
        CHECK_CUDA(cudaEventCreate(&start));
        CHECK_CUDA(cudaEventCreate(&stop));

        float elapsedMs = 0.0f;
        if (args.cuda_graph) {
            cudaGraph_t graph{};
            cudaGraphExec_t graphExec{};
            CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
            launchOnce();
            CHECK_CUDA(cudaStreamEndCapture(stream, &graph));
#if CUDART_VERSION >= 13000
            CHECK_CUDA(cudaGraphInstantiate(&graphExec, graph, 0));
#else
            CHECK_CUDA(cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0));
#endif
            CHECK_CUDA(cudaEventRecord(start, stream));
            for (int i = 0; i < args.iters; ++i) CHECK_CUDA(cudaGraphLaunch(graphExec, stream));
            CHECK_CUDA(cudaEventRecord(stop, stream));
            CHECK_CUDA(cudaEventSynchronize(stop));
            CHECK_CUDA(cudaEventElapsedTime(&elapsedMs, start, stop));
            CHECK_CUDA(cudaGraphExecDestroy(graphExec));
            CHECK_CUDA(cudaGraphDestroy(graph));
        } else {
            CHECK_CUDA(cudaEventRecord(start, stream));
            for (int i = 0; i < args.iters; ++i) launchOnce();
            CHECK_CUDA(cudaEventRecord(stop, stream));
            CHECK_CUDA(cudaEventSynchronize(stop));
            CHECK_CUDA(cudaEventElapsedTime(&elapsedMs, start, stop));
        }

        std::cout << "avg_latency_ms=" << (elapsedMs / std::max(1, args.iters))
                  << " throughput_images_per_s=" << (1000.0 * args.batch * args.iters / elapsedMs) << '\n';

        if (!args.output_file.empty() && !outputNames.empty()) {
            const Buf& out = bufs.at(outputNames[0]);
            std::ofstream f(args.output_file, std::ios::binary);
            f.write(static_cast<const char*>(out.host), static_cast<std::streamsize>(out.bytes));
            std::cerr << "[OK] wrote " << args.output_file << '\n';
        }

        for (auto& kv : bufs) {
            if (kv.second.dev) cudaFree(kv.second.dev);
            if (kv.second.host) cudaFreeHost(kv.second.host);
        }
        cudaEventDestroy(start);
        cudaEventDestroy(stop);
        cudaStreamDestroy(stream);
        context.reset();
        engine.reset();
        runtime.reset();
        if (pluginHandle) dlclose(pluginHandle);
        return 0;
    } catch (const std::exception& e) {
        std::cerr << "[ERROR] " << e.what() << '\n';
        return 1;
    }
}

解説

infer.cpp の概要

infer.cpp は、TensorRT の .engine ファイルを C++ で読み込み、必要に応じて MoE plugin の .so をロードし、CUDA stream 上で推論・ベンチマーク・出力保存を行うランナーである。主な処理の流れは以下の通りである。

  1. 引数の解析
  2. TensorRT plugin の初期化
  3. 必要に応じて libcustom_moe_plugin.sodlopen
  4. シリアライズされた engine の読み込み
  5. TensorRT runtime / engine / execution context の生成
  6. 入出力 tensor の列挙
  7. dynamic input shape の設定
  8. device / pinned host buffer の確保
  9. 入力データの読み込み
  10. ウォームアップ実行
  11. CUDA Graph の有無によるベンチマーク測定
  12. レイテンシおよびスループットの算出
  13. 必要に応じた logits の保存
  14. リソースの解放

1. ヘッダと namespace

冒頭では TensorRT と CUDA runtime を使用するため、以下のヘッダをインクルードしている。

#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <cuda_runtime_api.h>

また、plugin の .so ファイルを動的にロードするために <dlfcn.h> を使用している。これは dlopen() および dlclose() のためである。 using namespace nvinfer1; により、IRuntime, ICudaEngine, IExecutionContext, Dims, DataType などを nvinfer1:: のプレフィックスなしで記述可能にしている。

2. Logger

Logger クラスは ILogger を継承した TensorRT 用のロガーである。runtime や parser、engine のデシリアライズ時のログを受け取る。本実装では kINFO 以下(INFO / WARNING / ERROR など)を標準エラー出力に表示する。

3. CHECK_CUDA

CUDA API 呼び出しのエラーチェックを行うマクロである。失敗時には cudaGetErrorString() の内容を含む例外を投げる。本ランナーは main() 全体を try-catch で囲んでいるため、エラー発生時には詳細を表示して終了する。

4. Args と parseArgs

Args はコマンドライン引数を保持する構造体である。主要な引数は以下の通りである。

引数 意味
--engine TensorRT .engine ファイル(必須)
--plugin libcustom_moe_plugin.so へのパス
--input 生の FP32 入力 tensor ファイル
--output 出力 tensor を保存するバイナリパス
--batch 推論バッチサイズ
--warmup ベンチマーク前のウォームアップ回数
--iters ベンチマークのイテレーション数
--no-cuda-graph CUDA Graph を無効化する

5. readFile

.engine ファイルを std::vector<char> に読み込むヘルパー関数である。TensorRT engine はバイナリデータであるため、バイナリモードで読み込む必要がある。

6. dtypeSize と volume

  • dtypeSize: DataType からバイトサイズ(kFLOAT なら 4 など)を返す。
  • volume: Dims から要素数を計算する。shape に -1(動的次元)が残っている場合は例外を投げる。

7. 入力データの生成・読み込み

  • fillRandomFloat: 入力が指定されない場合、固定シードの正規分布乱数で入力バッファを埋める。
  • loadInputFloat: 指定されたファイルを FP32 の NCHW 形式として読み込む。

8. TensorRT plugin の初期化とロード

initLibNvInferPlugins() で標準 plugin を初期化した後、--plugin が指定されていれば dlopen() でカスタム plugin をロードする。 カスタム MoE plugin を含む engine の場合、デシリアライズ前に creator が登録されている必要があるため、この順序は極めて重要である。

9. デシリアライズとコンテキスト生成

IRuntime を介して engine ファイルをデシリアライズし、ICudaEngine および IExecutionContext を生成する。

10. I/O tensor の列挙と Shape 設定

getNbIOTensors() を用いて入出力 tensor を列挙する。本コードは「入力 1 個、出力 1 個以上」の構成を想定している。 また、setInputShape() を呼び出すことで、実行時の具体的なバッチサイズ等を確定させる。

11. CUDA stream とバッファ確保

非同期実行用の CUDA stream を作成する。各 tensor に対して cudaMalloc(Device 用)と cudaHostAlloc(Pinned Host 用)を実行し、setTensorAddress() で TensorRT にアドレスを通知する。

12. launchOnce

推論 1 回分の処理をラムダ式として定義している。

  1. H2D copy: 入力をホストからデバイスへ転送。
  2. enqueueV3: 推論の実行(非同期)。
  3. D2H copy: 出力をデバイスからホストへ転送。

13. ウォームアップとベンチマーク

本測定の前にウォームアップを行い、初回実行コストや GPU クロックの影響を排除する。測定には cudaEvent を使用し、GPU 上の経過時間を正確に取得する。

14. CUDA Graph

--cuda-graph が有効な場合、launchOnce の一連の処理をキャプチャしてグラフ化する。これにより、毎回のカーネル起動に伴う CPU オーバーヘッドを削減できる。MoE のようにカーネル数が多いモデルでは特に有効である。

15. レイテンシとスループットの計算

  • 平均レイテンシ: 全実行時間 / イテレーション数
  • スループット: (1000 * batch * iters) / 全実行時間

16. 出力保存と後片付け

推論終了後、必要に応じて結果をバイナリ保存し、確保したすべての CUDA リソースおよび TensorRT オブジェクトを解放する。

検証

Dense MoE版の推論処理も実装して、出力のlogitsが誤差範囲で一致することを確認した。

まとめ

C++で、TensorRTのプラグインを使用したMoEの推論処理を実装した。 次は、Dense MoEとSparse MoEの推論速度を比較したい。

dropless MoE(Mixture of Experts)を試す その4(エンジンビルド)

前回、実装したプラグインを使用して、カスタムノードを含むONNXをTensorRT の serialized engineに変換する。

dlshogiでは、C++でエンジンビルドを実装しているが、今回はPythonのtensorrtライブラリを使用して実装する。 生成したserialized engineを後でC++の推論プログラムからロードする。

エンジンビルド

TensorRT の serialized engine とは、TensorRT が ONNX などのモデルから生成したバイナリで、以下を含む。

  • TensorRT が最適化した network graph
  • 選択済みの CUDA kernel / tactic
  • layer fusion の結果
  • dynamic shape 用 optimization profile
  • weight
  • plugin layer の情報

実行環境に最適化するため、TensorRT version、GPU architecture、plugin ABI、CUDA 環境に依存する。

変換スクリプト

build_engine.py

"""Build a TensorRT engine from ONNX.

For plugin-mode ONNX, load libcustom_moe_plugin.so before parsing.
"""
from __future__ import annotations

import argparse
import ctypes
from pathlib import Path

import tensorrt as trt


def main() -> None:
    p = argparse.ArgumentParser()
    p.add_argument("--onnx", type=Path, required=True)
    p.add_argument("--engine", type=Path, required=True)
    p.add_argument("--plugin", type=Path, default=None, help="Path to libcustom_moe_plugin.so")
    p.add_argument("--min-batch", type=int, default=1)
    p.add_argument("--opt-batch", type=int, default=8)
    p.add_argument("--max-batch", type=int, default=32)
    p.add_argument("--channels", type=int, default=3)
    p.add_argument("--height", type=int, default=32)
    p.add_argument("--width", type=int, default=32)
    p.add_argument("--fp16", action="store_true")
    p.add_argument("--workspace-gb", type=float, default=4.0)
    p.add_argument("--version-compatible", action="store_true")
    args = p.parse_args()

    logger = trt.Logger(trt.Logger.INFO)
    trt.init_libnvinfer_plugins(logger, "")
    if args.plugin is not None:
        ctypes.CDLL(args.plugin.as_posix(), mode=ctypes.RTLD_GLOBAL)
        print(f"[INFO] loaded plugin: {args.plugin}")

    explicit_batch = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    builder = trt.Builder(logger)
    network = builder.create_network(explicit_batch)
    parser = trt.OnnxParser(network, logger)

    data = args.onnx.read_bytes()
    if not parser.parse(data):
        print("[ERROR] ONNX parse failed")
        for i in range(parser.num_errors):
            print(parser.get_error(i))
        raise SystemExit(1)

    config = builder.create_builder_config()
    if hasattr(config, "set_memory_pool_limit"):
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, int(args.workspace_gb * (1 << 30)))
    else:
        config.max_workspace_size = int(args.workspace_gb * (1 << 30))

    if args.fp16 and builder.platform_has_fast_fp16:
        config.set_flag(trt.BuilderFlag.FP16)
        print("[INFO] enabled FP16")
    elif args.fp16:
        print("[WARN] requested FP16, but platform_has_fast_fp16 is false")

    if args.version_compatible and hasattr(trt.BuilderFlag, "VERSION_COMPATIBLE"):
        config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
        if hasattr(parser, "get_used_vc_plugin_libraries") and hasattr(config, "set_plugins_to_serialize"):
            libs = parser.get_used_vc_plugin_libraries()
            if libs:
                config.set_plugins_to_serialize(libs)
                print(f"[INFO] serializing plugin libraries into engine: {libs}")

    inp = network.get_input(0)
    profile = builder.create_optimization_profile()
    min_shape = (args.min_batch, args.channels, args.height, args.width)
    opt_shape = (args.opt_batch, args.channels, args.height, args.width)
    max_shape = (args.max_batch, args.channels, args.height, args.width)
    profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
    config.add_optimization_profile(profile)
    print(f"[INFO] optimization profile for {inp.name}: min={min_shape}, opt={opt_shape}, max={max_shape}")

    serialized = builder.build_serialized_network(network, config)
    if serialized is None:
        raise RuntimeError("TensorRT engine build failed")
    args.engine.parent.mkdir(parents=True, exist_ok=True)
    args.engine.write_bytes(bytes(serialized))
    print(f"[OK] wrote {args.engine}")


if __name__ == "__main__":
    main()

解説

build_engine.py は、export_moe_onnx.py で出力した ONNX を TensorRT の serialized engine.engine ファイル)に変換するスクリプトである。特に plugin モードの ONNX には CustomMoE カスタムノードが含まれるため、ONNX のパース前に libcustom_moe_plugin.so をロードする役割も担っている。

1. 役割

このスクリプトの処理フローは、概念的に以下の通りである。

ONNX
  ↓
TensorRT OnnxParser
  ↓
TensorRT network
  ↓
Builder + BuilderConfig + OptimizationProfile
  ↓
Serialized TensorRT engine
  ↓
.engine file

plugin モードの場合は、さらに前処理として以下の流れが加わる。

libcustom_moe_plugin.so をロード
  ↓
CustomMoEPluginCreator が TensorRT plugin registry に登録
  ↓
ONNX parser が trt.plugins::CustomMoE を plugin layer として解決

2. CLI 引数

冒頭で argparse によって定義されている引数は以下の通りである。

p.add_argument("--onnx", type=Path, required=True)
p.add_argument("--engine", type=Path, required=True)
p.add_argument("--plugin", type=Path, default=None)
p.add_argument("--min-batch", type=int, default=1)
p.add_argument("--opt-batch", type=int, default=8)
p.add_argument("--max-batch", type=int, default=32)
p.add_argument("--channels", type=int, default=3)
p.add_argument("--height", type=int, default=32)
p.add_argument("--width", type=int, default=32)
p.add_argument("--fp16", action="store_true")
p.add_argument("--workspace-gb", type=float, default=4.0)
p.add_argument("--version-compatible", action="store_true")

主な引数の意味を以下にまとめる。

引数 意味
--onnx 入力 ONNX ファイル
--engine 出力 TensorRT engine ファイル
--plugin libcustom_moe_plugin.so のパス
--min-batch dynamic batch の最小値
--opt-batch TensorRT が最適化の基準とする batch サイズ
--max-batch dynamic batch の最大値
--channels 入力チャンネル数(CIFAR-10 の場合は通常 3)
--height 入力画像の高さ
--width 入力画像の幅
--fp16 TensorRT の FP16 ビルドフラグを有効化
--workspace-gb 戦略選択やビルドに使用するワークスペースの上限
--version-compatible TensorRT のバージョン互換エンジン機能を試行

3. logger と plugin の初期化

まず、TensorRT logger を作成する。

logger = trt.Logger(trt.Logger.INFO)

その後、標準の TensorRT プラグインを初期化する。

trt.init_libnvinfer_plugins(logger, "")

さらに --plugin が指定されている場合は、ctypes.CDLL を用いて .so ファイルをロードする。

if args.plugin is not None:
    ctypes.CDLL(args.plugin.as_posix(), mode=ctypes.RTLD_GLOBAL)
    print(f"[INFO] loaded plugin: {args.plugin}")

この工程は plugin モードにおいて極めて重要である。export_moe_onnx.py の plugin モードで作成された ONNX には trt.plugins::CustomMoE ノードが含まれている。そのため、パース前に libcustom_moe_plugin.so をロードしておかなければ、TensorRT parser は CustomMoE を解決できなくなる。

RTLD_GLOBAL を指定する理由は、ロードした共有ライブラリのシンボルやプラグインクリエイターをプロセス全体から参照可能にするためである。

4. explicit batch network の作成

次に、TensorRT の builder、network、ONNX parser を作成する。

explicit_batch = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
builder = trt.Builder(logger)
network = builder.create_network(explicit_batch)
parser = trt.OnnxParser(network, logger)

このコードでは explicit batch mode を採用している。これは現代の TensorRT における標準的な形式であり、入力シェイプを [N, C, H, W] のようにバッチ次元を含めて明示的に扱う。

export_moe_onnx.py ではバッチ軸を dynamic に設定して ONNX をエクスポートしているため、TensorRT 側でも explicit batch と optimization profile の組み合わせが必要となる。


5. ONNX のパース

ONNX ファイルをバイト列として読み込み、parser に渡す。

data = args.onnx.read_bytes()
if not parser.parse(data):
    print("[ERROR] ONNX parse failed")
    for i in range(parser.num_errors):
        print(parser.get_error(i))
    raise SystemExit(1)

パースに失敗した場合は、エラー内容をすべて表示して処理を終了する。plugin モードで失敗する典型的な原因は以下の通りである。

  • --plugin が指定されていない
  • libcustom_moe_plugin.so のロードに失敗している
  • ONNX のカスタムオペレーション名とプラグイン名が一致していない
  • プラグインのバージョンやネームスペースが一致していない
  • TensorRT がカスタムオペレーションの属性を解釈できない

6. BuilderConfig とワークスペース

パース完了後、エンジンのビルド設定(config)を作成する。

config = builder.create_builder_config()

ワークスペースの上限設定については、TensorRT のバージョンによる API の差異を吸収するため、2 通りの方法に対応させている。

if hasattr(config, "set_memory_pool_limit"):
    config.set_memory_pool_limit(
        trt.MemoryPoolType.WORKSPACE,
        int(args.workspace_gb * (1 << 30))
    )
else:
    config.max_workspace_size = int(args.workspace_gb * (1 << 30))

ワークスペースは、ビルド時の戦略選択や実行時の一時バッファとして利用されるメモリ領域である。値を大きくすれば必ずしも高速化するわけではないが、不足すると最適な戦略が選ばれず、ビルドの失敗や性能低下を招く。

7. FP16 フラグ

--fp16 が指定され、かつ GPU が高速な FP16 演算に対応している場合に限り、FP16 フラグを有効化する。

if args.fp16 and builder.platform_has_fast_fp16:
    config.set_flag(trt.BuilderFlag.FP16)
    print("[INFO] enabled FP16")
elif args.fp16:
    print("[WARN] requested FP16, but platform_has_fast_fp16 is false")

このフラグは、TensorRT に対して「可能なレイヤーや戦略において FP16 実行を許可する」ことを伝えるものである。

なお、MoE プラグインについては、x/outputexpert weights/biases は FP16 または FP32、router_w は FP32 である必要がある。FP16 での実行を最適化するには、ONNX エクスポート時に --fp16-experts を使用し、あらかじめ重みを FP16 に変換しておく設計が推奨される。

8. バージョン互換エンジン

--version-compatible オプションが指定された場合の処理は以下の通りである。

if args.version_compatible and hasattr(trt.BuilderFlag, "VERSION_COMPATIBLE"):
    config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
    if hasattr(parser, "get_used_vc_plugin_libraries") and hasattr(config, "set_plugins_to_serialize"):
        libs = parser.get_used_vc_plugin_libraries()
        if libs:
            config.set_plugins_to_serialize(libs)
            print(f"[INFO] serializing plugin libraries into engine: {libs}")

これはエンジンの互換性を高め、必要に応じてプラグインライブラリの情報をエンジン内にシリアライズする機能である。ただし、動作は TensorRT のバージョンやプラグインの実装に依存するため、通常の開発フェーズではこのフラグなしでビルドを行う方がトラブルシューティングは容易である。

9. Optimization Profile

ONNX 側でバッチ次元が dynamic になっているため、ビルドには optimization profile の設定が不可欠である。

inp = network.get_input(0)
profile = builder.create_optimization_profile()

min_shape = (args.min_batch, args.channels, args.height, args.width)
opt_shape = (args.opt_batch, args.channels, args.height, args.width)
max_shape = (args.max_batch, args.channels, args.height, args.width)

profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)

opt_shape は TensorRT が最適化の際に最も重視する形状である。実運用でバッチサイズ 1 の推論がメインであれば --opt-batch 1 とし、複数のバッチを処理するならデフォルトの 8 程度が妥当な設定となる。

10. エンジンのビルドと保存

最後に、エンジンをビルドする。

serialized = builder.build_serialized_network(network, config)
if serialized is None:
    raise RuntimeError("TensorRT engine build failed")

ビルドが成功すれば、指定されたパスにバイト列として .engine ファイルを保存する。

args.engine.parent.mkdir(parents=True, exist_ok=True)
args.engine.write_bytes(bytes(serialized))
print(f"[OK] wrote {args.engine}")

このエンジンファイルは、C++ 側の infer.cpp などで deserializeCudaEngine() を用いてロードされる。推論時にも、デシリアライズの前にビルド時と同じプラグイン .so をロードしておく必要がある点に注意したい。

11. 実行例

plugin モードの ONNX からエンジンを作成する場合

python build_engine.py \
  --onnx model_moe_plugin.onnx \
  --engine model_moe_plugin.engine \
  --plugin ./libcustom_moe_plugin.so \
  --min-batch 1 \
  --opt-batch 8 \
  --max-batch 32 \
  --workspace-gb 4

FP16 エンジンを作成する場合

python build_engine.py \
  --onnx model_moe_plugin_fp16.onnx \
  --engine model_moe_plugin_fp16.engine \
  --plugin ./libcustom_moe_plugin.so \
  --fp16 \
  --min-batch 1 \
  --opt-batch 8 \
  --max-batch 32

### dense モードの ONNX の場合

`CustomMoE` プラグインノードを含まないため、`--plugin` の指定は不要である。

python build_engine.py \ --onnx model_moe_dense.onnx \ --engine model_moe_dense.engine \ --fp16

## 12. 制限事項と注意点

* **入力数は 1 個を想定:** コード上、0 番目の入力テンソルのみを取得しているため、複数入力モデルには拡張が必要である。
* **バッチ以外の次元は固定:** `C/H/W` は引数で固定されており、同一エンジンで解像度を動的に切り替える設計にはなっていない。
* **プロファイルは 1 つ:** 複数のバッチ範囲や解像度を 1 つのエンジンに含める場合は、複数のプロファイルを追加する実装が必要である。
* **プラグインのロード順:** `ctypes.CDLL()` は、必ず `parser.parse()` よりも前に実行される必要がある。
* **環境の整合性:** `libcustom_moe_plugin.so` は、ビルド環境と実行環境の間で TensorRT、CUDA、コンパイラの ABI が一致していなければならない。


## 13. 処理順のまとめ

1. CLI 引数の読み込み
2. TensorRT logger の作成
3. TensorRT 標準プラグインの初期化
4. 必要に応じた `libcustom_moe_plugin.so` のロード
5. explicit batch network の作成
6. ONNX parser によるネットワークの構築
7. BuilderConfig の作成
8. ワークスペース上限の設定
9. FP16 フラグの設定(任意)
10. バージョン互換フラグの設定(任意)
11. Optimization Profile の設定
12. エンジンのビルド実行
13. `.engine` ファイルとしての保存

# まとめ
カスタムノードを含むONNXからTensorRT の serialized engineを生成するエンジンビルドスクリプトについて解説した。
次は、C++でTensorRTを使用した推論処理を実装したい。

dropless MoE(Mixture of Experts)を試す その3(プラグイン)

前回、ONNXにエクスポートしたカスタムノードに対応する処理をTensorRTのプラグインで実装する。

TensorRTのプラグイン

公式ドキュメントに記載されている通り、プラグインクラスとプラグイン クリエーターを実装する。 プラグインがREGISTER_TENSORRT_PLUGINで登録されていると、ONNXのパーサが自動でカスタムノードをプラグインに置換する。

今回は、互換性を優先して、TensorRT 8ベースで実装する。 TensorRT 8ベースだと、nvinfer1::IPluginV2DynamicExtを継承したプラグインクラスと、nvinfer1::IPluginCreatorを継承したプラグイン クリエーターを実装する。

ここでは、プラグインクラスのenqueue()から呼び出すCUDAカーネルの実装を中心に解説する。 詳細な実装は、GitHubのソースを参照して欲しい。

CUDAカーネルは何パターンか実装しているが、CUTLASS grouped GEMMを使うenqueueTypedCutlassHalfに絞って解説する。

cudaError_t enqueueTypedCutlassHalf(
    const CustomMoeConfig& cfg,
    int64_t M,
    const void* x_void,
    const float* router_w,
    const void* w1_void,
    const void* b1_void,
    const void* w2_void,
    const void* b2_void,
    void* y_void,
    void* workspace,
    cudaStream_t stream) {
    const __half* x = static_cast<const __half*>(x_void);
    const __half* w1 = static_cast<const __half*>(w1_void);
    const __half* b1 = static_cast<const __half*>(b1_void);
    const __half* w2 = static_cast<const __half*>(w2_void);
    const __half* b2 = static_cast<const __half*>(b2_void);
    __half* y = static_cast<__half*>(y_void);

    WorkspaceParts ws = splitWorkspace(workspace, cfg, M, nvinfer1::DataType::kHALF);
    const int64_t N = M * cfg.top_k;
    const int threads = 256;

    cudaMemsetAsync(ws.counts, 0, sizeof(int) * cfg.num_experts, stream);
    cudaMemsetAsync(y, 0, sizeof(__half) * M * cfg.in_features, stream);

    routeTopKKernel<__half><<<static_cast<unsigned>((M + threads - 1) / threads), threads, 0, stream>>>(
        x, router_w, ws.topk_experts, ws.topk_gates, M, cfg.in_features, cfg.num_experts, cfg.top_k);
    countExpertsKernel<<<static_cast<unsigned>((N + threads - 1) / threads), threads, 0, stream>>>(
        ws.topk_experts, ws.counts, N);
    prefixAndResetKernel<<<1, 1, 0, stream>>>(ws.counts, ws.offsets, ws.write_ptr, cfg.num_experts);
    packAssignmentsKernel<<<static_cast<unsigned>((N + threads - 1) / threads), threads, 0, stream>>>(
        ws.topk_experts, ws.topk_gates, ws.write_ptr, ws.assignment_dst, ws.token_sorted,
        ws.expert_sorted, ws.gate_sorted, M, cfg.top_k);
    copyPackedXKernel<__half><<<static_cast<unsigned>((N * cfg.in_features + threads - 1) / threads), threads, 0, stream>>>(
        x, ws.assignment_dst, static_cast<__half*>(ws.x_sorted), M, cfg.in_features, cfg.top_k);

    cudaError_t err = runCutlassGroupedGemm(
        ws,
        cfg.num_experts,
        cfg.in_features,
        cfg.hidden_features,
        static_cast<const __half*>(ws.x_sorted),
        w1,
        static_cast<__half*>(ws.h_sorted),
        static_cast<__half*>(ws.h_sorted),
        stream);
    if (err != cudaSuccess) return err;
    addBiasGeluKernel<__half><<<static_cast<unsigned>((N * cfg.hidden_features + threads - 1) / threads), threads, 0, stream>>>(
        static_cast<__half*>(ws.h_sorted), b1, ws.expert_sorted, N, cfg.hidden_features);

    err = runCutlassGroupedGemm(
        ws,
        cfg.num_experts,
        cfg.hidden_features,
        cfg.in_features,
        static_cast<const __half*>(ws.h_sorted),
        w2,
        static_cast<__half*>(ws.y_sorted),
        static_cast<__half*>(ws.y_sorted),
        stream);
    if (err != cudaSuccess) return err;
    addBiasGateCombineKernel<__half><<<static_cast<unsigned>((N * cfg.in_features + threads - 1) / threads), threads, 0, stream>>>(
        static_cast<__half*>(ws.y_sorted), b2, ws.token_sorted, ws.expert_sorted, ws.gate_sorted,
        y, N, cfg.in_features);
    return cudaGetLastError();
}

PyTorchで実装したDroplessMoEMlpの流れをそのままCUDAで実装している。

  routeTopKKernel
  countExpertsKernel
  prefixAndResetKernel
  packAssignmentsKernel
  copyPackedXKernel
  CUTLASS Grouped GEMM FC1
  addBiasGeluKernel
  CUTLASS Grouped GEMM FC2
  addBiasGateCombineKernel

関数冒頭:void ポインタのキャスト

enqueueTypedCutlassHalf は、TensorRT plugin から void* で渡されたテンソルを __half* にキャストする。ただし、router_w だけは const float* のまま保持する。これは routing 計算を FP32 で行い、数値的安定性を確保する設計に基づいている。

Workspace の分割

splitWorkspace により、TensorRT が提供する一時領域を各内部バッファへ割り当てる。 重要な点は、x_sorted, h_sorted, y_sorted が assignment 単位(N = M * K 行)を持つことである。1 token が K 個の expert に送られるため、expert MLP の入力行数は M ではなく N となる。

初期化:counts と y のゼロクリア

counts(expert ごとの割当数)と出力 ycudaMemsetAsync で 0 初期化する。y の初期化が必要な理由は、最終的な結合処理において y[token, d] += gate * expert_output という atomic add を行うためである。

routeTopKKernel:Router と Top-k Gate の計算

1 thread が 1 token を担当し、以下の処理を行う。

  1. Logits 計算: x (FP16) と router_w (FP32) の内積を float で計算。
  2. Top-k 選択と正規化: E 個の logits から top-k を選び、その中だけで再正規化(Softmax)を行い gate 値を算出する。

count / prefixAndResetKernel:配置の決定

  • countExpertsKernel: 各 assignment がどの expert に割り当てられたかを走査し、atomicAddcounts[e] を算出する。
  • prefixAndResetKernel: counts から prefix sum を計算し、offsets(各 expert の開始位置)と write_ptr を作成する。これにより、各 expert が処理するパックドバッファ上の範囲が確定する。

pack / copyPackedXKernel:データの並べ替え

packAssignmentsKernel で各 assignment の書き込み先インデックスを確定し、copyPackedXKernel で実データを x_sorted へコピーする。これにより、x_sorted は同一の expert に送られる token が連続して並ぶレイアウト(expert-major)となる。

FC1:CUTLASS Grouped GEMM

本関数の主目的である。runCutlassGroupedGemm を呼び出し、全 expert の行列積を一括実行する。 各 expert e において、A_e = [M_e, D]、B_e = [D, H] となり、結果 D_e = [M_e, H] を得る。

setupGroupedGemmMetaKernel では、CUTLASS が要求する各 expert の問題サイズ(Problem Size)やポインタの配列を構築する。

  • Arch: Sm80 (Ampere)
  • 精度: 入出力 FP16、Accumulator FP32
  • Epilogue: LinearCombination (beta=0)

addBiasGeluKernel:FC1 の Bias と GELU

CUTLASS の Grouped GEMM は純粋な行列積のみを担当するため、Bias 加算と GELU 活性化関数は別個の kernel で実行する。この際、expert_sorted を参照して各行に対応する expert の bias を特定する。

FC2:CUTLASS Grouped GEMM

FC1 と同様に runCutlassGroupedGemm を実行する。

  • 入力 (A): h_sorted [N, H]
  • 重み (B): w2 [E, H, D]
  • 出力 (D): y_sorted [N, D]

addBiasGateCombineKernel:FC2 Bias と Scatter-Add

最後に FC2 の出力に対して bias を加算し、gate 値を乗じた上で、元の token インデックスに基づき y へ足し込む。 atomicAdd を用いることで、1 token に対する K 個の expert 出力を正しく集約する。

まとめ

MoEのカスタムノードに対応するTensorRTプラグインを実装し、CUTLASS Grouped GEMMを使うCUDAカーネルを中心に解説した。 この実装の最適化は不十分で、Biasや活性化関数は一つのカーネルに融合する余地がある。

次回は、プラグインをロードしてTensorRTエンジンをビルドする処理を実装したい。