TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dropless MoE(Mixture of Experts)を試す その2(ONNXエクスポート)

前回、PyTorchで実装したMoE対応のSwinTransformerモデルをTensorRTで推論できるように、ONNXにエクスポートする。

最低限、grouped_mmのみをカスタムオペレータとすればよいが、router + dispatch + expert + combineを分離するとTensorRTのグラフとpluginの入出力が複雑になるため、MoE化したMLPをまとめてプラグイン化する。
ソースでは、DroplessMoEMlpをカスタムノードにする。

カスタムノード

以下のようなforwardが空のノードを定義して、symbolicにONNXのオペレータのメタ情報を定義する。

class CustomMoEFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: Any,
        x: Tensor,
        router_w: Tensor,
        w1: Tensor,
        b1: Tensor,
        w2: Tensor,
        b2: Tensor,
        top_k: int,
        num_experts: int,
        in_features: int,
        hidden_features: int,
    ) -> Tensor:
        # During ONNX tracing the symbolic() method below emits the actual
        # CustomMoE node. The eager value is needed only for shape propagation,
        # so avoid tracing the expensive dynamic routing reference here.
        return torch.empty_like(x)

    @staticmethod
    def symbolic(
        g: Any,
        x: Any,
        router_w: Any,
        w1: Any,
        b1: Any,
        w2: Any,
        b2: Any,
        top_k: int,
        num_experts: int,
        in_features: int,
        hidden_features: int,
    ) -> Any:
        out = g.op(
            "trt.plugins::CustomMoE",
            x,
            router_w,
            w1,
            b1,
            w2,
            b2,
            top_k_i=int(top_k),
            num_experts_i=int(num_experts),
            in_features_i=int(in_features),
            hidden_features_i=int(hidden_features),
            activation_s="gelu",
            plugin_version_s="1",
            plugin_namespace_s="",
            outputs=1,
        )
        out.setType(x.type())
        return out

元の MoE weight を custom node に接続する wrapperを用意する。
ONNXにエクスポートする際に、パラメータを書きこむために必要である。

class PluginMoEExportWrapper(nn.Module):
    def __init__(self, src: nn.Module) -> None:
        super().__init__()
        self.in_features = int(src.in_features)
        self.hidden_features = int(src.hidden_features)
        self.num_experts = int(src.num_experts)
        self.top_k = int(src.top_k)
        # Keep these as registered parameters so torch.onnx.export writes them
        # as ONNX initializers connected to the custom node.
        self.router_weight = src.router.weight
        self.w1 = src.w1
        self.b1 = src.b1
        self.w2 = src.w2
        self.b2 = src.b2

    def forward(self, x: Tensor) -> Tensor:
        return CustomMoEFunction.apply(
            x,
            self.router_weight,
            self.w1,
            self.b1,
            self.w2,
            self.b2,
            self.top_k,
            self.num_experts,
            self.in_features,
            self.hidden_features,
        )

元のモデルのノードを置換する。

def replace_moe(model: nn.Module, model_module: Any, mode: str, fp16_experts: bool) -> int:
    replaced = 0
    for name, child in list(model.named_modules()):
        if isinstance(child, model_module.DroplessMoEMlp):
            if fp16_experts:
                child.w1.data = child.w1.data.half()
                child.b1.data = child.b1.data.half()
                child.w2.data = child.w2.data.half()
                child.b2.data = child.b2.data.half()
                child.router.float()
            wrapper: nn.Module
            if mode == "plugin":
                wrapper = PluginMoEExportWrapper(child)
            elif mode == "dense":
                wrapper = DenseMoEExportWrapper(child)
            else:
                raise ValueError(f"Unsupported mode: {mode}")
            set_child(model, name, wrapper)
            replaced += 1
    return replaced

コード全体は、以下の場所にある。
grouped_mmを使わずDense MoEとして出力するdenseモードも実装している。

github.com

ONNXエクスポート結果

ONNXにエクスポートした結果を、Netronで確認すると、CustomMoEノードが出力されていることが確認できた。


なお、denseモードでエクスポートすると複雑なグラフがエキスパートの数分だけ出力される。

まとめ

dropless MoEを実装したSwinTransformerをカスタムノードを使用してONNXにエクスポートする処理を実装した。
エクスポート結果で、MoEを実装したMLPがカスタムノードとして出力されていることが確認できた。
次は、TensorRTのプラグインを実装したい。

dropless MoE(Mixture of Experts)を試す

dlshogiのモデルはパラメータを増やすほど精度が向上することが60ブロック768フィルタのサイズまで確認できている。
しかし、探索速度が落ちるため、対局した際の強さは、40ブロック512フィルタのモデルには及ばない。

MoE(Mixture of Experts)は、パラメータを増やしながら推論速度を落とさない手法としてLLMで取り入れられている。

DL系の将棋AIを強くするためには、MoEは有力な手法であると考えている。

しかし、MoEの実装は単純ではなく、GPUで効率的に推論するには、PyTorch → ONNX → TensorRTというdlshogiで使用している標準化された手法が使えず、カスタマイズ実装が必要になる。
この記事では、MoEの手法を整理し、GPUで効率的に推論するためのdropless MoEの訓練と推論の実装方法について記述する。

MoE(Mixture of Experts)

MoE(Mixture of Experts)は、入力ごとにゲーティング機構が多数の専門家モデルの一部だけを選んで計算させることで、計算量を抑えながらモデル容量を大きくする手法である。
Transformer系LLMでMoEを使う場合、Transformer block内のFFN(MLP)部分をMoE化する構成が主流である。

Dense MoE

MoEを単純に実装する場合、すべての expert を計算して、重み付き和を取ることで実装できる。
しかし、1つのトークンに対して全expertを計算する必要があり、MoEの計算量を抑えながらモデル容量を大きくできるというメリットは得られない。

Sparse MoE

gate が少数 expert だけを選択し、選択されたexpertのみ計算し、他のexpertの計算はスキップする。

これをGPUで実現するには、capacity-aware routingという手法が使われる。

capacity-aware routing

capacity-aware routingは、MoEで各expertがGPU上で効率よく固定長バッチとして処理できるように、expertごとの受け入れtoken数に上限 capacity を設け、足りない分はpaddingし、超過分はdrop・再割当・別処理するrouting方式である。

これはシンプルに実装できる反面、dropするトークンが発生する可能性があり、capacityを調整する必要がある。
capacityを増やすとdropの発生率は下がるが、paddingにより無駄な計算が増えるため、トレードオフになる。

最近では、dropを回避する、dropless MoEが主流になっている。

dropless MoE

dropless MoE とは、grouped GEMMやblock-sparse kernelなどで全tokenをGPU上で処理する方式である。

expertごとのトークン量が固定長でなくても、処理できるように専用のカーネルとして実装される。

MegaBlocksは、dropless MoEを実用的なGPU学習システムとして成立させた先駆的研究の一つである。

PyTorch 2.10には、torch.nn.functional.grouped_mmが追加され、expertごとの可変長GEMMが、PyTorch標準で実装できるようになった。
grouped_mmは、内部では専用のカーネルとして実装されている。

推論の課題

TensorRTでPyTorchで訓練したモデルを推論する際、ONNXにエクスポートして、TensorRTで読み込むのが標準的な方法である。
しかし、grouped_mmに相当するONNXのオペレータはなく、標準の方法が使えない。

そのため、ONNXのカスタムオペレータとしてエクスポートして、TensorRTのカスタムプラグインとして実装する必要がある。

カスタムプラグインは、grouped_mm相当の処理は、CUTLASSライブラリにあるgrouped GEMMで実装できる。

dropless MoEの実装

実際に、PyTorchでgrouped_mmを使用してdropless MoEを実装して学習できるか検証してみた。

SwinTransformerで、CIFAR-10データセットを学習するシンプルなコードで検証した。

まずは、MoEを使用しない実装は以下の通りである。
PyTorch Lightningを使用して実装している。

train.py
from lightning.pytorch.cli import LightningCLI

from data import CIFAR10DataModule
from model import SwinCIFAR10Classifier


def main() -> None:
    LightningCLI(
        model_class=SwinCIFAR10Classifier,
        datamodule_class=CIFAR10DataModule,
        save_config_kwargs={"overwrite": True},
    )


if __name__ == "__main__":
    main()
data.py
from __future__ import annotations

from pathlib import Path

import lightning.pytorch as pl
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "data",
        batch_size: int = 128,
        num_workers: int = 4,
        image_size: int = 32,
        val_split: int = 5000,
        seed: int = 42,
        download: bool = True,
    ) -> None:
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = image_size
        self.val_split = val_split
        self.seed = seed
        self.download = download

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self) -> None:
        datasets.CIFAR10(self.data_dir, train=True, download=self.download)
        datasets.CIFAR10(self.data_dir, train=False, download=self.download)

    def setup(self, stage: str | None = None) -> None:
        train_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.RandomCrop(self.image_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
            ]
        )
        eval_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
            ]
        )

        if stage in (None, "fit"):
            train_base = datasets.CIFAR10(self.data_dir, train=True, transform=train_transform)
            val_base = datasets.CIFAR10(self.data_dir, train=True, transform=eval_transform)
            train_size = len(train_base) - self.val_split
            indices = torch.randperm(len(train_base), generator=torch.Generator().manual_seed(self.seed)).tolist()
            self.train_dataset = Subset(train_base, indices[:train_size])
            self.val_dataset = Subset(val_base, indices[train_size:])

        if stage in (None, "test"):
            self.test_dataset = datasets.CIFAR10(self.data_dir, train=False, transform=eval_transform)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )
model.py
from __future__ import annotations

from typing import Sequence

import lightning.pytorch as pl
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy


def window_partition(x: Tensor, window_size: int) -> Tensor:
    batch_size, height, width, channels = x.shape
    x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return windows.view(-1, window_size, window_size, channels)


def window_reverse(windows: Tensor, window_size: int, height: int, width: int) -> Tensor:
    batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
    x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x.view(batch_size, height, width, -1)


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: Tensor) -> Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class Mlp(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, drop: float = 0.0) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class WindowAttention(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        table_size = (2 * window_size - 1) * (2 * window_size - 1)
        self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads))

        coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size), indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
        batch_windows, tokens, channels = x.shape
        qkv = self.qkv(x).reshape(batch_windows, tokens, 3, self.num_heads, channels // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        query, key, value = qkv[0], qkv[1], qkv[2]

        query = query * self.scale
        attn = query @ key.transpose(-2, -1)
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        relative_position_bias = relative_position_bias.view(tokens, tokens, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            num_windows = mask.shape[0]
            attn = attn.view(batch_windows // num_windows, num_windows, self.num_heads, tokens, tokens)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, tokens, tokens)

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ value).transpose(1, 2).reshape(batch_windows, tokens, channels)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        num_heads: int,
        window_size: int = 4,
        shift_size: int = 0,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = min(window_size, input_resolution[0], input_resolution[1])
        self.shift_size = 0 if min(input_resolution) <= self.window_size else shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, self.window_size, num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = Mlp(dim, int(dim * mlp_ratio), drop=drop)

        self.register_buffer("attn_mask", self._create_mask(), persistent=False)

    def _create_mask(self) -> Tensor | None:
        if self.shift_size == 0:
            return None

        height, width = self.input_resolution
        img_mask = torch.zeros((1, height, width, 1))
        height_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        width_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        count = 0
        for height_slice in height_slices:
            for width_slice in width_slices:
                img_mask[:, height_slice, width_slice, :] = count
                count += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")

        shortcut = x
        x = self.norm1(x).view(batch_size, height, width, channels)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size).view(-1, self.window_size * self.window_size, channels)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        shifted_x = window_reverse(attn_windows.view(-1, self.window_size, self.window_size, channels), self.window_size, height, width)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(batch_size, height * width, channels)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchMerging(nn.Module):
    def __init__(self, input_resolution: tuple[int, int], dim: int) -> None:
        super().__init__()
        self.input_resolution = input_resolution
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")
        if height % 2 != 0 or width % 2 != 0:
            raise ValueError("PatchMerging requires even spatial dimensions.")

        x = x.view(batch_size, height, width, channels)
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], dim=-1).view(batch_size, -1, 4 * channels)
        x = self.norm(x)
        return self.reduction(x)


class BasicLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio: float,
        drop: float,
        attn_drop: float,
        drop_path: Sequence[float],
        downsample: bool,
    ) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock(
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if index % 2 == 0 else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[index],
                )
                for index in range(depth)
            ]
        )
        self.downsample = PatchMerging(input_resolution, dim) if downsample else None

    def forward(self, x: Tensor) -> Tensor:
        for block in self.blocks:
            x = block(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


class PatchEmbed(nn.Module):
    def __init__(self, image_size: int, patch_size: int, in_channels: int, embed_dim: int) -> None:
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.grid_size = image_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        _, _, height, width = x.shape
        if height != self.image_size or width != self.image_size:
            raise ValueError(f"Expected input size {self.image_size}x{self.image_size}, got {height}x{width}.")
        x = self.proj(x).flatten(2).transpose(1, 2)
        return self.norm(x)


class SwinTransformer(nn.Module):
    def __init__(
        self,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
    ) -> None:
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError("image_size must be divisible by patch_size.")
        if len(depths) != len(num_heads):
            raise ValueError("depths and num_heads must have the same length.")

        self.num_layers = len(depths)
        self.num_features = embed_dim * 2 ** (self.num_layers - 1)
        self.patch_embed = PatchEmbed(image_size, patch_size, in_channels, embed_dim)
        self.pos_drop = nn.Dropout(drop_rate)

        total_depth = sum(depths)
        drop_paths = torch.linspace(0, drop_path_rate, total_depth).tolist()

        self.layers = nn.ModuleList()
        resolution = image_size // patch_size
        depth_offset = 0
        for layer_index in range(self.num_layers):
            dim = embed_dim * 2**layer_index
            input_resolution = (resolution, resolution)
            layer = BasicLayer(
                dim=dim,
                input_resolution=input_resolution,
                depth=depths[layer_index],
                num_heads=num_heads[layer_index],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_paths[depth_offset : depth_offset + depths[layer_index]],
                downsample=layer_index < self.num_layers - 1,
            )
            self.layers.append(layer)
            depth_offset += depths[layer_index]
            if layer_index < self.num_layers - 1:
                resolution //= 2

        self.norm = nn.LayerNorm(self.num_features)
        self.head = nn.Linear(self.num_features, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x: Tensor) -> Tensor:
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        return self.head(x)


class SwinCIFAR10Classifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 10,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        learning_rate: float = 0.001,
        weight_decay: float = 0.05,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.model = SwinTransformer(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            num_classes=num_classes,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
        )
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = MulticlassAccuracy(num_classes=num_classes)
        self.val_acc = MulticlassAccuracy(num_classes=num_classes)
        self.test_acc = MulticlassAccuracy(num_classes=num_classes)

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def _shared_step(self, batch: tuple[Tensor, Tensor], stage: str) -> Tensor:
        images, targets = batch
        logits = self(images)
        loss = self.criterion(logits, targets)
        preds = torch.argmax(logits, dim=1)
        metric = getattr(self, f"{stage}_acc")
        metric(preds, targets)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=stage == "train", on_epoch=True)
        self.log(f"{stage}_acc", metric, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "train")

    def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "val")

    def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
config.yaml
seed_everything: 42

trainer:
  accelerator: auto
  devices: auto
  max_epochs: 20
  precision: bf16-mixed
  log_every_n_steps: 50
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val_acc
        mode: max
        save_top_k: 1
        filename: swin-cifar10-{epoch:02d}-{val_acc:.4f}
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: epoch

model:
  num_classes: 10
  image_size: 32
  patch_size: 4
  in_channels: 3
  embed_dim: 64
  depths: [2, 2, 2, 2]
  num_heads: [2, 4, 8, 16]
  window_size: 4
  mlp_ratio: 4.0
  drop_rate: 0.0
  attn_drop_rate: 0.0
  drop_path_rate: 0.1
  learning_rate: 0.001
  weight_decay: 0.05

data:
  data_dir: data
  batch_size: 128
  num_workers: 4
  image_size: 32
  val_split: 5000
  seed: 42
  download: true


続いて、torch.nn.functional.grouped_mmを使用して実装したdropless MoEの実装は以下の通り。

model.py
from __future__ import annotations

from typing import Sequence

import lightning.pytorch as pl
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy


def window_partition(x: Tensor, window_size: int) -> Tensor:
    batch_size, height, width, channels = x.shape
    x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return windows.view(-1, window_size, window_size, channels)


def window_reverse(windows: Tensor, window_size: int, height: int, width: int) -> Tensor:
    batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
    x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x.view(batch_size, height, width, -1)


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: Tensor) -> Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class Mlp(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, drop: float = 0.0) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class DroplessMoEMlp(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        num_experts: int,
        top_k: int = 2,
        drop: float = 0.0,
    ) -> None:
        super().__init__()
        if num_experts <= 0:
            raise ValueError("num_experts must be greater than 0.")
        if top_k <= 0 or top_k > num_experts:
            raise ValueError("top_k must be in the range [1, num_experts].")

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.num_experts = num_experts
        self.top_k = top_k

        self.router = nn.Linear(in_features, num_experts, bias=False)
        # grouped_mm operand layout avoids per-forward transpose: x @ w1 maps D -> H, h @ w2 maps H -> D.
        self.w1 = nn.Parameter(torch.empty(num_experts, in_features, hidden_features))
        self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_features))
        self.w2 = nn.Parameter(torch.empty(num_experts, hidden_features, in_features))
        self.b2 = nn.Parameter(torch.zeros(num_experts, in_features))
        self.act = nn.GELU()
        self.drop = nn.Dropout(drop)
        self.aux_loss: Tensor | None = None
        self.z_loss: Tensor | None = None

        nn.init.trunc_normal_(self.w1, std=0.02)
        nn.init.trunc_normal_(self.w2, std=0.02)

    def _apply(self, fn):
        module = super()._apply(fn)
        self.router.float()
        return module

    def forward(self, x: Tensor) -> Tensor:
        if not x.is_cuda:
            raise RuntimeError("DroplessMoEMlp requires CUDA tensors because it uses torch.nn.functional.grouped_mm.")
        if not hasattr(F, "grouped_mm"):
            raise RuntimeError("torch.nn.functional.grouped_mm is not available in this PyTorch build.")
        if self.router.weight.dtype != torch.float32:
            raise RuntimeError("Router weights must remain FP32 for stable routing.")

        orig_shape = x.shape
        x = x.reshape(-1, self.in_features)
        tokens = x.shape[0]

        # Keep routing numerics in FP32; expert grouped_mm below still uses the activation dtype.
        with torch.autocast(device_type="cuda", enabled=False):
            logits = self.router(x.float())
            probs = F.softmax(logits, dim=-1)
            topk_prob, topk_expert = torch.topk(probs, k=self.top_k, dim=-1)
            topk_gate = topk_prob / topk_prob.sum(dim=-1, keepdim=True)

        flat_expert = topk_expert.reshape(-1)
        flat_gate = topk_gate.reshape(-1)
        flat_token = torch.arange(tokens, device=x.device).repeat_interleave(self.top_k)

        order = torch.argsort(flat_expert)
        expert_sorted = flat_expert[order]
        token_sorted = flat_token[order]
        gate_sorted = flat_gate[order]
        x_sorted = x.index_select(0, token_sorted)

        counts = torch.bincount(expert_sorted, minlength=self.num_experts)
        offsets = torch.cumsum(counts, dim=0).to(torch.int32)
        tokens_per_expert = counts.to(probs.dtype) / counts.sum().to(probs.dtype)
        prob_per_expert = probs.mean(dim=0)
        self.aux_loss = self.num_experts * torch.sum(tokens_per_expert * prob_per_expert)
        self.z_loss = torch.mean(torch.logsumexp(logits, dim=-1).square())

        w1 = self.w1 if self.w1.dtype == x.dtype else self.w1.to(dtype=x.dtype)

        # grouped_mm requires offs[-1] < mat_a.shape[0], so append one ignored row.
        padding = torch.zeros(1, x_sorted.shape[-1], dtype=x_sorted.dtype, device=x_sorted.device)
        x_grouped = torch.cat((x_sorted, padding), dim=0)
        h = F.grouped_mm(x_grouped, w1, offs=offsets)[: x_sorted.shape[0]]
        h = h + self.b1.index_select(0, expert_sorted).to(dtype=h.dtype)
        h = self.act(h)
        h = self.drop(h)

        w2 = self.w2 if self.w2.dtype == h.dtype else self.w2.to(dtype=h.dtype)

        # grouped_mm requires offs[-1] < mat_a.shape[0], so append one ignored row.
        hidden_padding = torch.zeros(1, h.shape[-1], dtype=h.dtype, device=h.device)
        h_grouped = torch.cat((h, hidden_padding), dim=0)
        y_sorted = F.grouped_mm(h_grouped, w2, offs=offsets)[: h.shape[0]]
        y_sorted = y_sorted + self.b2.index_select(0, expert_sorted).to(dtype=y_sorted.dtype)
        y_sorted = self.drop(y_sorted)
        y_sorted = y_sorted * gate_sorted.unsqueeze(-1).to(dtype=y_sorted.dtype)

        y = torch.zeros_like(x)
        y.index_add_(0, token_sorted, y_sorted.to(dtype=y.dtype))
        return y.reshape(orig_shape)


class WindowAttention(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
        super().__init__()
        if num_heads <= 0:
            raise ValueError("num_heads must be greater than 0.")
        if dim % num_heads != 0:
            raise ValueError("dim must be divisible by num_heads.")
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        table_size = (2 * window_size - 1) * (2 * window_size - 1)
        self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads))

        coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size), indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
        batch_windows, tokens, channels = x.shape
        qkv = self.qkv(x).reshape(batch_windows, tokens, 3, self.num_heads, channels // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        query, key, value = qkv[0], qkv[1], qkv[2]

        query = query * self.scale
        attn = query @ key.transpose(-2, -1)
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        relative_position_bias = relative_position_bias.view(tokens, tokens, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            num_windows = mask.shape[0]
            attn = attn.view(batch_windows // num_windows, num_windows, self.num_heads, tokens, tokens)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, tokens, tokens)

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ value).transpose(1, 2).reshape(batch_windows, tokens, channels)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        num_heads: int,
        window_size: int = 4,
        shift_size: int = 0,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        use_moe_mlp: bool = False,
        moe_num_experts: int = 4,
        moe_top_k: int = 2,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = min(window_size, input_resolution[0], input_resolution[1])
        if input_resolution[0] % self.window_size != 0 or input_resolution[1] % self.window_size != 0:
            raise ValueError("input_resolution must be divisible by window_size, or padding is required.")
        self.shift_size = 0 if min(input_resolution) <= self.window_size else shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, self.window_size, num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        hidden_features = int(dim * mlp_ratio)
        if use_moe_mlp:
            self.mlp = DroplessMoEMlp(dim, hidden_features, num_experts=moe_num_experts, top_k=moe_top_k, drop=drop)
        else:
            self.mlp = Mlp(dim, hidden_features, drop=drop)

        self.register_buffer("attn_mask", self._create_mask(), persistent=False)

    def _create_mask(self) -> Tensor | None:
        if self.shift_size == 0:
            return None

        height, width = self.input_resolution
        img_mask = torch.zeros((1, height, width, 1))
        height_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        width_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        count = 0
        for height_slice in height_slices:
            for width_slice in width_slices:
                img_mask[:, height_slice, width_slice, :] = count
                count += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")

        shortcut = x
        x = self.norm1(x).view(batch_size, height, width, channels)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size).view(-1, self.window_size * self.window_size, channels)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        shifted_x = window_reverse(attn_windows.view(-1, self.window_size, self.window_size, channels), self.window_size, height, width)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(batch_size, height * width, channels)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchMerging(nn.Module):
    def __init__(self, input_resolution: tuple[int, int], dim: int) -> None:
        super().__init__()
        self.input_resolution = input_resolution
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")
        if height % 2 != 0 or width % 2 != 0:
            raise ValueError("PatchMerging requires even spatial dimensions.")

        x = x.view(batch_size, height, width, channels)
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], dim=-1).view(batch_size, -1, 4 * channels)
        x = self.norm(x)
        return self.reduction(x)


class BasicLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio: float,
        drop: float,
        attn_drop: float,
        drop_path: Sequence[float],
        downsample: bool,
        use_moe_mlp: bool,
        moe_num_experts: int,
        moe_top_k: int,
    ) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock(
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if index % 2 == 0 else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[index],
                    use_moe_mlp=use_moe_mlp,
                    moe_num_experts=moe_num_experts,
                    moe_top_k=moe_top_k,
                )
                for index in range(depth)
            ]
        )
        self.downsample = PatchMerging(input_resolution, dim) if downsample else None

    def moe_loss(self) -> Tensor | None:
        losses = [block.mlp.aux_loss for block in self.blocks if isinstance(block.mlp, DroplessMoEMlp) and block.mlp.aux_loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()

    def moe_z_loss(self) -> Tensor | None:
        losses = [block.mlp.z_loss for block in self.blocks if isinstance(block.mlp, DroplessMoEMlp) and block.mlp.z_loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()

    def forward(self, x: Tensor) -> Tensor:
        for block in self.blocks:
            x = block(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


class PatchEmbed(nn.Module):
    def __init__(self, image_size: int, patch_size: int, in_channels: int, embed_dim: int) -> None:
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.grid_size = image_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        _, _, height, width = x.shape
        if height != self.image_size or width != self.image_size:
            raise ValueError(f"Expected input size {self.image_size}x{self.image_size}, got {height}x{width}.")
        x = self.proj(x).flatten(2).transpose(1, 2)
        return self.norm(x)


class SwinTransformer(nn.Module):
    def __init__(
        self,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        moe_stages: Sequence[int] = (0, 1),
        moe_num_experts: int = 4,
        moe_top_k: int = 2,
    ) -> None:
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError("image_size must be divisible by patch_size.")
        if len(depths) != len(num_heads):
            raise ValueError("depths and num_heads must have the same length.")
        moe_stage_set = set(moe_stages)
        invalid_moe_stages = moe_stage_set.difference(range(len(depths)))
        if invalid_moe_stages:
            raise ValueError(f"moe_stages contains invalid stage indices: {sorted(invalid_moe_stages)}.")

        self.num_layers = len(depths)
        self.num_features = embed_dim * 2 ** (self.num_layers - 1)
        self.patch_embed = PatchEmbed(image_size, patch_size, in_channels, embed_dim)
        self.pos_drop = nn.Dropout(drop_rate)

        total_depth = sum(depths)
        drop_paths = torch.linspace(0, drop_path_rate, total_depth).tolist()

        self.layers = nn.ModuleList()
        resolution = image_size // patch_size
        depth_offset = 0
        for layer_index in range(self.num_layers):
            dim = embed_dim * 2**layer_index
            input_resolution = (resolution, resolution)
            layer = BasicLayer(
                dim=dim,
                input_resolution=input_resolution,
                depth=depths[layer_index],
                num_heads=num_heads[layer_index],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_paths[depth_offset : depth_offset + depths[layer_index]],
                downsample=layer_index < self.num_layers - 1,
                use_moe_mlp=layer_index in moe_stage_set,
                moe_num_experts=moe_num_experts,
                moe_top_k=moe_top_k,
            )
            self.layers.append(layer)
            depth_offset += depths[layer_index]
            if layer_index < self.num_layers - 1:
                resolution //= 2

        self.norm = nn.LayerNorm(self.num_features)
        self.head = nn.Linear(self.num_features, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x: Tensor) -> Tensor:
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        return self.head(x)

    def moe_loss(self) -> Tensor | None:
        losses = [layer.moe_loss() for layer in self.layers]
        losses = [loss for loss in losses if loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()

    def moe_z_loss(self) -> Tensor | None:
        losses = [layer.moe_z_loss() for layer in self.layers]
        losses = [loss for loss in losses if loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()


class SwinCIFAR10Classifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 10,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        moe_stages: Sequence[int] = (0, 1),
        moe_num_experts: int = 4,
        moe_top_k: int = 2,
        moe_aux_loss_weight: float = 0.01,
        moe_z_loss_weight: float = 0.001,
        learning_rate: float = 0.001,
        weight_decay: float = 0.05,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.model = SwinTransformer(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            num_classes=num_classes,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            moe_stages=moe_stages,
            moe_num_experts=moe_num_experts,
            moe_top_k=moe_top_k,
        )
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = MulticlassAccuracy(num_classes=num_classes)
        self.val_acc = MulticlassAccuracy(num_classes=num_classes)
        self.test_acc = MulticlassAccuracy(num_classes=num_classes)

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def _shared_step(self, batch: tuple[Tensor, Tensor], stage: str) -> Tensor:
        images, targets = batch
        logits = self(images)
        loss = self.criterion(logits, targets)
        moe_aux_loss = self.model.moe_loss()
        if moe_aux_loss is not None:
            self.log(f"{stage}_moe_aux_loss", moe_aux_loss, prog_bar=False, on_step=stage == "train", on_epoch=True)
            if stage == "train" and self.hparams.moe_aux_loss_weight > 0.0:
                loss = loss + self.hparams.moe_aux_loss_weight * moe_aux_loss
        moe_z_loss = self.model.moe_z_loss()
        if moe_z_loss is not None:
            self.log(f"{stage}_moe_z_loss", moe_z_loss, prog_bar=False, on_step=stage == "train", on_epoch=True)
            if stage == "train" and self.hparams.moe_z_loss_weight > 0.0:
                loss = loss + self.hparams.moe_z_loss_weight * moe_z_loss
        preds = torch.argmax(logits, dim=1)
        metric = getattr(self, f"{stage}_acc")
        metric(preds, targets)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=stage == "train", on_epoch=True)
        self.log(f"{stage}_acc", metric, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "train")

    def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "val")

    def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
config.yaml
seed_everything: 42

trainer:
  accelerator: auto
  devices: auto
  max_epochs: 20
  precision: bf16-mixed
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val_acc
        mode: max
        save_top_k: 1
        filename: swin-cifar10-{epoch:02d}-{val_acc:.4f}
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: epoch

model:
  num_classes: 10
  image_size: 32
  patch_size: 4
  in_channels: 3
  embed_dim: 64
  depths: [2, 2, 2, 2]
  num_heads: [2, 4, 8, 16]
  window_size: 4
  mlp_ratio: 4.0
  drop_rate: 0.0
  attn_drop_rate: 0.0
  drop_path_rate: 0.1
  moe_stages: [0, 1]
  moe_num_experts: 4
  moe_top_k: 2
  moe_aux_loss_weight: 0.01
  moe_z_loss_weight: 0.001
  learning_rate: 0.001
  weight_decay: 0.05

data:
  data_dir: data
  batch_size: 128
  num_workers: 4
  image_size: 32
  val_split: 5000
  seed: 42
  download: true

比較結果

パラメータ数
MoEなし 9.1 M
MoEあり 10.1 M
評価損失


※オレンジ:MoEなし、青:MoEあり

評価正解率


※オレンジ:MoEなし、青:MoEあり

訓練時間



※オレンジ:MoEなし、青:MoEあり

考察

MoEありの方が、評価精度は高くなっている。
パラメータ数は、MoEありの方が、1.1倍である。
学習時間は、MoEありの方が、1.4倍になった。

Expertは、Top-2に割り当てているため、その分計算量が増えた可能性がある。

追加で、Top-1でも測定した。



※赤:Top-1

Top-1でも学習時間はほぼ変わらなかった。
MoE化に伴うゲーティング処理のオーバーヘッドが原因と考える。

また、Top-1の場合、正解率はMoEなしとほとんど変わらなかった。

まとめ

PyTorchのgrouped_mmを使用して、dropless MoEを実装し、精度や学習時間を評価した。
結果、Top-2のMoEで精度が向上することが確認できたが、学習時間はパラメータ数の比率以上に遅くなった。
小さなSwinTransformerで実験したため、ゲーティングのオーバーヘッドが大きかったと考える。

次は、TensorRTのカスタムプラグインで推論処理を実装したい。

【dlshogi】QKNormを試す

最近のLLMでは、AttentionにQKNormが使われている。
特に、RMSNormを使う実装が主流になっている。

世界コンピュータ将棋の会場でnshogiの開発者と話した際に、SwiGLUとQKNormが効果があったということだった。
SwiGLUは、dlshogiでも採用して効果が高いことを確認していたが、QKNormは試していなかったので試してみた。

QKNorm

Attentionの Query と Key を内積前に正規化して、attention logits のスケール暴走や softmax saturation を防ぎ、学習を安定化する手法である。
元の論文では、正規化にL2 normalizationが使用されている。

また、内積の結果に通常の1 / \sqrt{d}の代わりに learnable scalar \alpha を掛けている。

RMSNorm

入力ベクトルを平均ではなく RMS(root mean square)だけで正規化することで、スケール変化に対する不変性を保ちながら LayerNorm を軽量化した正規化手法である。
QKNormに、RMSNormを使うのが主流であり、LLaMAGemmaでもRMSNormが使用されている。

Gemma4の実装では、ValueにもRMSNormが使用されている。

比較パターン

  • QKNormにRMSNormを使用
  • QKNormにRMSNormを使用し、Q、Kそれぞれにlearnable weightを掛ける
  • QKNormに加えて、ValueにもRMSNormを使用
  • QKNormに加えて、ValueにもRMSNormを使用し、Q、Kそれぞれにlearnable weightを掛ける
  • QKNormにLayerNormを使用
  • QKNormにBatchNorm2Dを使用

訓練条件

  • WCSC36のdlshogiのResNet+Transformerモデルの20ブロック256フィルタ
  • 訓練データ約3.9億局面
  • バッチサイズ4096
  • Momentum SGD
  • 学習率0.04からエポックごとに半減
  • 8エポック

評価データは、2017年~2018年6月のfloodgateのR3500以上の棋譜からサンプリングした856,923局面を使用。
シードを変えて、2回測定して平均をとる。

実験結果

方策損失 価値損失 方策正解率 価値正解率
WCSC36版 1.41530 0.46225 0.52857 0.76188
QK Norm(RMS Norm) 1.40991 0.46068 0.53044 0.76306
QK Norm(RMS Norm + weight) 1.41369 0.46050 0.52923 0.76342
QK Norm(RMS Norm)+V RMSNorm 1.41012 0.46052 0.52989 0.76334
QK Norm(RMS Norm + weight)+V RMSNorm 1.41564 0.46104 0.52966 0.76279
QK Norm(Layer Norm) 1.41171 0.46335 0.53010 0.76078
QK Norm(Batch Norm) 1.41429 0.46268 0.52927 0.76157


考察

QKNormにRMSNormを使用したパターンが最も方策損失が低く、方策正解率が高い。
価値損失は、QKNormにlearnable weightを使用した方が低いが僅差である。

learnable weight

方策損失はlearnable weightはない方が良い。
価値損失はlearnable weightがある方が、わずかに良くなっているが誤差程度である。
ValueにRMSNormを使うパターンでは、learnable weightがない方が方策、価値ともによい。

learnable weightはない方が良いと言える。

ValueのRMSNorm

ValueのRMSNormを適用した場合、方策、価値ともに条件により少し良くなったり、悪くなったりしており、大きな改善は見られない。
効果はなさそうである。

LayerNormとBatchNorm2D

方策、価値ともに、RMSNormの方が明らかに良い。

まとめ

dlshogiのResNet+Transformerモデルに、QKNormが効果があるか試してみた。
結果、QKNormにRMSNormを使用した場合、方策、価値ともに精度が改善することが確かめられた。

Hugging Face TrainerでMNISTを学習

PyTorchで画像分類モデルを学習するとき、学習ループやチェックポイント管理、TensorBoard対応などを毎回自前で実装するのはやや煩雑である。

Hugging Face Transformers の Trainer を使うことで、NLP用途だけでなくCNNのような画像モデルでも、シンプルかつ統一的な学習基盤を構築できる。

本記事では、MNISTを題材として、

  • CNNモデルの実装
  • Hugging Face Trainer を使った学習
  • 分散学習 (torchrun)
  • Resume training
  • TensorBoard ログ管理
  • Optimizer / Scheduler 設定

について紹介する。

ソース

github.com

特徴

  • jsonargparseで、YAML形式で設定管理
  • PyTorch Lightningと同様のインクリメンタルなバージョン管理
  • マルチGPU対応

学習実行方法

基本的な学習は以下で実行できる。

python train.py --config config.yaml

これだけで、

  • モデル保存
  • checkpoint保存
  • TensorBoardログ
  • evaluation
  • scheduler
  • optimizer

などが Trainer によって自動管理される。

分散学習 (torchrun)

GPUを複数枚使う場合は torchrun を利用する。

torchrun --nproc_per_node 2 train.py --config config.yaml

Trainer は DistributedDataParallel (DDP) に自然対応しているため、追加実装はほぼ不要である。

PyTorch Lightningに近い手軽さで分散学習を扱える。

出力ディレクトリ管理

デフォルト動作

training.output_dir を設定しない場合、デフォルトでは logs/ 以下に version 管理される。

例:

logs/version_0
logs/version_1
logs/version_2

実験ルート変更

CLIから experiment root を変更することも可能である。

python train.py --config config.yaml --experiment_root runs

すると以下のように生成される。

runs/version_0
runs/version_1

output_dir を直接指定

もし training.output_dir を設定した場合は、そのディレクトリを直接使用する。

python train.py \
  --config config.yaml \
  --training.output_dir runs/manual_experiment

この場合は version_* ディレクトリは作成されない。

Resume Training

学習途中から再開する場合は checkpoint を指定する。

python train.py \
  --config config.yaml \
  --training.resume_from_checkpoint logs/version_0/checkpoint-844

TensorBoard

TensorBoardでログを可視化できる。

tensorboard --logdir logs

Optimizer / Scheduler 設定

設定は config.yamltraining セクションで管理する。

例えば以下のように記述する。

training:
  learning_rate: 1e-3
  weight_decay: 0.01
  optim: adamw_torch
  lr_scheduler_type: cosine
  warmup_steps: 100

内部的には Trainer が optimizer / scheduler を生成する。

まとめ

今回紹介した構成では、

  • CNN
  • MNIST
  • Hugging Face Trainer
  • 分散学習
  • TensorBoard
  • Resume
  • Config管理

を非常に少ないコード量で実現できる。

Transformer以外のモデルでも Trainer を積極的に利用する価値は十分にある。

【dlshogi】torch.compileに対応したら学習が1.6倍速くなった件

torch.compile は、PyTorch 2.0 以降で導入された高速化機能で、既存の PyTorch コードをほとんど変更せずに JIT コンパイルして最適化できる仕組みである。
主に GPU 実行時のオーバーヘッド削減やカーネル融合によって性能を向上できる。

dlshogiの train.py と ptl.py を torch.compile に対応して、どれくらい学習が速くなるか測定した。

torch.compile対応の実装

「--use_compile」オプションで、torch.compileの有効無効を切り替えられるようにした。

Pytorchのtorch.compileの引数に対応して、以下のオプションを指定できる。

「--compile_backend」で、backendを指定できる。
指定しない場合はデフォルトで、inductorが使用されるが、Windows環境では動かないため、Windowsではデフォルトを「aot_eager」にした(triton-windowsをインストールすればinductorも使える)。

「--compile_mode」で、default / reduce-overhead / max-autotune などの最適化モードを指定できる。
「max-autotune」を指定するとエラーなった。原因は調べられていない。

「--compile_fullgraph」で、モデル全体を1つのグラフとしてコンパイルすることを要求する。
有効にすると、途中で Python 処理や未対応 op によってグラフが分断される場合、エラーになりやすいです。
一方で、全体がきれいにコンパイルできるモデルでは最適化しやすくなる。

「--compile_dynamic」で、入力 shape が変わるケースに対応しやすくする設定を有効にする。
dlshogiのモデルは特に指定しなくてよい。

測定結果

測定条件
  • dlshogiの最新のResNet+Transformerモデル(20ブロック256フィルタ)
  • バッチサイズ: 4096
  • 学習率: 0.04
  • use_amp: 有効
  • amp_dtype: bfloat16
  • Ubuntu 22.04 + PyTorch 2.3
比較対象
  • no compile: torch.compileなし
  • compile: torch.compileのデフォルト設定
  • fullgraph: --compile_fullgraphを指定

ステップあたりの学習時間は以下の通り。


torch.compileを有効にすることで、学習速度が1.58倍になっている。
さらに、fullgraphを有効になると、学習速度が1.64倍になっている。

まとめ

dlshogiの train.py / ptl.py を torch.compile 対応した。
また、backend、mode、fullgraph、dynamic などの各種 compile オプションも指定可能にした。
Ubuntu 22.04 + PyTorch 2.3 環境で最新ResNet+Transformerモデルの学習時間を測定した結果、torch.compile有効時は約1.58倍、さらにfullgraph有効時は約1.64倍まで高速化した。

AMP対応以来の大きな学習速度向上となった。

【dlshogi】TensorRTの推論処理の最適化でNPSが1.3倍になった件

先日の世界コンピュータ将棋選手権の会場で「あすとら将棋」さんから、GPUへのデータ転送を推論と並列化すると1割くらい速くなるという話を伺って、さっそく実験してみた。

データ転送の並列化

これまでは、一つのGPUを複数スレッドで共有して、1つのスレッドが推論中はロックして排他的に利用していた。
しかし、推論中にもデータ転送は行うことができるため、次のスレッドは先にデータ転送を行って、純粋に推論部分だけを排他すればよい。

github.com

また、GPT-5.5で実装したところ、圧縮した特徴量をfloat32にするunpack処理をCUDAで実装しているが、その部分も並列化して問題ないことに気づいた。
初期のCUDAのプログラミングモデルでは、同時に実行できるカーネルは一つだけだったが、Compute Capability 2.0以上では、ストリームが分かれていれば、複数カーネルをSMに分配して実行できるようになっている。

測定結果

benchmark.pyを使用して、floodgateから抽出した100局面で、1秒探索した際のNPSは以下の通り。

  • 5回測定した平均値の100局面の統計
  • H100 PCI 1枚、3スレッド
変更前 変更後
平均 10591 11751
中央値 10625 11818
最大値 12468 14168
最小値 9076 10678

NPSが平均で、1.11倍になった。

実行コンテキストによる非同期化

上記を実装する際、GPT-5.5が実行コンテキスト(IExecutionContext)を使うと、推論も排他不要である指摘をしてくれた。

NVIDIAのドキュメントにも、1 つの重みセットを複数の重複する推論タスクに使用できると記載がある。
Developer Guide :: NVIDIA Deep Learning TensorRT Documentation

IExecutionContextを使うことで、各スレッドは非同期で実行できる。
それにより、ロック処理を外すことができた。

github.com

IExecutionContextを使う場合、推論プロファイルの作成方法が変わるため、シリアライズファイルもスレッド数に応じて別になっている。

測定結果
変更前 変更後
平均 10591 13953
中央値 10625 13855
最大値 12468 16660
最小値 9076 13227

NPSが平均で、1.32倍になった。

強さ

NPSが強さに反映されるか連続対局で測定した。

  • 中終盤互角局面集を使用
  • 基準として氷彗8スレッド(hayabusa-8th)を加えている
  • H100 PCI x 2、3スレッド
   # PLAYER            :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
   1 hayabusa-8th      :    34.3   31.7   125.5     221    57      85  120   11   90     5
   2 pre68_40x512_e    :     4.8   32.5   114.0     223    51      94  108   12  103     5
   3 pre68_40x512      :   -39.1   31.9    95.5     226    42     ---   91    9  126     4

pre68_40x512_eが実行コンテキストに対応した版で、pre68_40x512は変更前の版である。
測定数は少ないが速報値としては、R+33.9になっている。

2GPU、3スレッドで測定している。
変更前はスレッド数 3が最適だったが、実行コンテキスを使う場合、3よりも増やすことで、さらに強くできる可能性がある。

ただし、他のスレッドの結果を待たずに実行するため、ツリーの成長の仕方が変わりNPSが上がるだけで強くならない可能性もあるので、強さを計測してチューニングする必要がある。

まとめ

GPUへのデータ転送と特徴量unpack処理を推論と並列化し、H100環境でNPSが平均1.11倍向上した。
また、TensorRTのIExecutionContextを利用して推論自体も並列化し、ロック不要化によってNPSは平均1.32倍まで向上した。
連続対局の速報値では実行コンテキスト対応版が約+34 Elo強くなった。

この改良はすでにmasterブランチにマージしている。

【dlshogi】TransformerモデルのPython実装

将棋AIの大会で、DL系の開発者が減少傾向にあるため、dlshogiの成果物を少し共有したいと思います。

TransformerのPython実装

第5回電竜戦第35回世界コンピュータ将棋選手権で使用したResNet+TransformerモデルのPython実装をGitHubのmasterブランチに追加しました。

github.com


先日開催された第36回世界コンピュータ将棋選手権で使用したモデルは別ですが、ResNet+Transformerモデルのベースとして使用してください。

ResNetとハイブリッド構成にするため、通常の言語モデルのMulti Head Attentionとはテンソルの次元の扱いが異なっており、Linerは1x1のConv2Dで代替しています。
BatchNorm2dを使用しているのと、活性化の位置も通常のMulti Head Attentionから変更している部分があります。
変更の余地は多々あると思います。

モデルの訓練は、「--amp_dtype bfloat16」を指定しないと損失がnanになりやすいです。

入玉特徴量

第35回世界コンピュータ将棋選手権以降に使用している入玉特徴量をmasterブランチのVisual Studioプロジェクトで有効化しました。
使用しない場合は、プロジェクトのC++->プリプロセッサの定義から「NYUGYOKU_FEATURES」を削除してください。
また、unpack.cuのプロパティのカスタムビルドのコマンドラインから、「-DNYUGYOKU_FEATURES」を削除してください。

Linuxでビルドする場合は、デフォルトはOFFになっています。
有効にする場合は、makeの引数に「NYUGYOKU_FEATURES=1」を追加してください。

make NYUGYOKU_FEATURES=1

有効になっているかは、「usi」コマンドで、id nameが、「dlshogi NYUGYOKU_FEATURES」になっているかで確認できます。



また、モデルの訓練時に入玉特徴量を有効にするには、Pythonモジュールのインストール時に有効化が必要です。
環境変数に「NYUGYOKU_FEATURES=1」を追加してインストールしてください。

NYUGYOKU_FEATURES=1 pip install -e .

有効になっているかは、

from dlshogi.common import MAX_FEATURES2_NYUGYOKU_NUM
MAX_FEATURES2_NYUGYOKU_NUM

が31になっているかで確認できます。

まとめ

DL系の開発者の助けになるようにdlshogiの成果物を少し共有しました。
アピール文書で要点は書かれても実装が公開されていない状況だったので、実装を共有することでDL系の開発が活性化することを願います。