TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dropless MoE(Mixture of Experts)を試す

dlshogiのモデルはパラメータを増やすほど精度が向上することが60ブロック768フィルタのサイズまで確認できている。
しかし、探索速度が落ちるため、対局した際の強さは、40ブロック512フィルタのモデルには及ばない。

MoE(Mixture of Experts)は、パラメータを増やしながら推論速度を落とさない手法としてLLMで取り入れられている。

DL系の将棋AIを強くするためには、MoEは有力な手法であると考えている。

しかし、MoEの実装は単純ではなく、GPUで効率的に推論するには、PyTorch → ONNX → TensorRTというdlshogiで使用している標準化された手法が使えず、カスタマイズ実装が必要になる。
この記事では、MoEの手法を整理し、GPUで効率的に推論するためのdropless MoEの訓練と推論の実装方法について記述する。

MoE(Mixture of Experts)

MoE(Mixture of Experts)は、入力ごとにゲーティング機構が多数の専門家モデルの一部だけを選んで計算させることで、計算量を抑えながらモデル容量を大きくする手法である。
Transformer系LLMでMoEを使う場合、Transformer block内のFFN(MLP)部分をMoE化する構成が主流である。

Dense MoE

MoEを単純に実装する場合、すべての expert を計算して、重み付き和を取ることで実装できる。
しかし、1つのトークンに対して全expertを計算する必要があり、MoEの計算量を抑えながらモデル容量を大きくできるというメリットは得られない。

Sparse MoE

gate が少数 expert だけを選択し、選択されたexpertのみ計算し、他のexpertの計算はスキップする。

これをGPUで実現するには、capacity-aware routingという手法が使われる。

capacity-aware routing

capacity-aware routingは、MoEで各expertがGPU上で効率よく固定長バッチとして処理できるように、expertごとの受け入れtoken数に上限 capacity を設け、足りない分はpaddingし、超過分はdrop・再割当・別処理するrouting方式である。

これはシンプルに実装できる反面、dropするトークンが発生する可能性があり、capacityを調整する必要がある。
capacityを増やすとdropの発生率は下がるが、paddingにより無駄な計算が増えるため、トレードオフになる。

最近では、dropを回避する、dropless MoEが主流になっている。

dropless MoE

dropless MoE とは、grouped GEMMやblock-sparse kernelなどで全tokenをGPU上で処理する方式である。

expertごとのトークン量が固定長でなくても、処理できるように専用のカーネルとして実装される。

MegaBlocksは、dropless MoEを実用的なGPU学習システムとして成立させた先駆的研究の一つである。

PyTorch 2.10には、torch.nn.functional.grouped_mmが追加され、expertごとの可変長GEMMが、PyTorch標準で実装できるようになった。
grouped_mmは、内部では専用のカーネルとして実装されている。

推論の課題

TensorRTでPyTorchで訓練したモデルを推論する際、ONNXにエクスポートして、TensorRTで読み込むのが標準的な方法である。
しかし、grouped_mmに相当するONNXのオペレータはなく、標準の方法が使えない。

そのため、ONNXのカスタムオペレータとしてエクスポートして、TensorRTのカスタムプラグインとして実装する必要がある。

カスタムプラグインは、grouped_mm相当の処理は、CUTLASSライブラリにあるgrouped GEMMで実装できる。

dropless MoEの実装

実際に、PyTorchでgrouped_mmを使用してdropless MoEを実装して学習できるか検証してみた。

SwinTransformerで、CIFAR-10データセットを学習するシンプルなコードで検証した。

まずは、MoEを使用しない実装は以下の通りである。
PyTorch Lightningを使用して実装している。

train.py
from lightning.pytorch.cli import LightningCLI

from data import CIFAR10DataModule
from model import SwinCIFAR10Classifier


def main() -> None:
    LightningCLI(
        model_class=SwinCIFAR10Classifier,
        datamodule_class=CIFAR10DataModule,
        save_config_kwargs={"overwrite": True},
    )


if __name__ == "__main__":
    main()
data.py
from __future__ import annotations

from pathlib import Path

import lightning.pytorch as pl
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms


class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: str = "data",
        batch_size: int = 128,
        num_workers: int = 4,
        image_size: int = 32,
        val_split: int = 5000,
        seed: int = 42,
        download: bool = True,
    ) -> None:
        super().__init__()
        self.data_dir = Path(data_dir)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.image_size = image_size
        self.val_split = val_split
        self.seed = seed
        self.download = download

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self) -> None:
        datasets.CIFAR10(self.data_dir, train=True, download=self.download)
        datasets.CIFAR10(self.data_dir, train=False, download=self.download)

    def setup(self, stage: str | None = None) -> None:
        train_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.RandomCrop(self.image_size, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
            ]
        )
        eval_transform = transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
            ]
        )

        if stage in (None, "fit"):
            train_base = datasets.CIFAR10(self.data_dir, train=True, transform=train_transform)
            val_base = datasets.CIFAR10(self.data_dir, train=True, transform=eval_transform)
            train_size = len(train_base) - self.val_split
            indices = torch.randperm(len(train_base), generator=torch.Generator().manual_seed(self.seed)).tolist()
            self.train_dataset = Subset(train_base, indices[:train_size])
            self.val_dataset = Subset(val_base, indices[train_size:])

        if stage in (None, "test"):
            self.test_dataset = datasets.CIFAR10(self.data_dir, train=False, transform=eval_transform)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )
model.py
from __future__ import annotations

from typing import Sequence

import lightning.pytorch as pl
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy


def window_partition(x: Tensor, window_size: int) -> Tensor:
    batch_size, height, width, channels = x.shape
    x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return windows.view(-1, window_size, window_size, channels)


def window_reverse(windows: Tensor, window_size: int, height: int, width: int) -> Tensor:
    batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
    x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x.view(batch_size, height, width, -1)


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: Tensor) -> Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class Mlp(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, drop: float = 0.0) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class WindowAttention(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        table_size = (2 * window_size - 1) * (2 * window_size - 1)
        self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads))

        coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size), indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
        batch_windows, tokens, channels = x.shape
        qkv = self.qkv(x).reshape(batch_windows, tokens, 3, self.num_heads, channels // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        query, key, value = qkv[0], qkv[1], qkv[2]

        query = query * self.scale
        attn = query @ key.transpose(-2, -1)
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        relative_position_bias = relative_position_bias.view(tokens, tokens, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            num_windows = mask.shape[0]
            attn = attn.view(batch_windows // num_windows, num_windows, self.num_heads, tokens, tokens)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, tokens, tokens)

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ value).transpose(1, 2).reshape(batch_windows, tokens, channels)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        num_heads: int,
        window_size: int = 4,
        shift_size: int = 0,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = min(window_size, input_resolution[0], input_resolution[1])
        self.shift_size = 0 if min(input_resolution) <= self.window_size else shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, self.window_size, num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = Mlp(dim, int(dim * mlp_ratio), drop=drop)

        self.register_buffer("attn_mask", self._create_mask(), persistent=False)

    def _create_mask(self) -> Tensor | None:
        if self.shift_size == 0:
            return None

        height, width = self.input_resolution
        img_mask = torch.zeros((1, height, width, 1))
        height_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        width_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        count = 0
        for height_slice in height_slices:
            for width_slice in width_slices:
                img_mask[:, height_slice, width_slice, :] = count
                count += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")

        shortcut = x
        x = self.norm1(x).view(batch_size, height, width, channels)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size).view(-1, self.window_size * self.window_size, channels)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        shifted_x = window_reverse(attn_windows.view(-1, self.window_size, self.window_size, channels), self.window_size, height, width)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(batch_size, height * width, channels)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchMerging(nn.Module):
    def __init__(self, input_resolution: tuple[int, int], dim: int) -> None:
        super().__init__()
        self.input_resolution = input_resolution
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")
        if height % 2 != 0 or width % 2 != 0:
            raise ValueError("PatchMerging requires even spatial dimensions.")

        x = x.view(batch_size, height, width, channels)
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], dim=-1).view(batch_size, -1, 4 * channels)
        x = self.norm(x)
        return self.reduction(x)


class BasicLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio: float,
        drop: float,
        attn_drop: float,
        drop_path: Sequence[float],
        downsample: bool,
    ) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock(
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if index % 2 == 0 else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[index],
                )
                for index in range(depth)
            ]
        )
        self.downsample = PatchMerging(input_resolution, dim) if downsample else None

    def forward(self, x: Tensor) -> Tensor:
        for block in self.blocks:
            x = block(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


class PatchEmbed(nn.Module):
    def __init__(self, image_size: int, patch_size: int, in_channels: int, embed_dim: int) -> None:
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.grid_size = image_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        _, _, height, width = x.shape
        if height != self.image_size or width != self.image_size:
            raise ValueError(f"Expected input size {self.image_size}x{self.image_size}, got {height}x{width}.")
        x = self.proj(x).flatten(2).transpose(1, 2)
        return self.norm(x)


class SwinTransformer(nn.Module):
    def __init__(
        self,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
    ) -> None:
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError("image_size must be divisible by patch_size.")
        if len(depths) != len(num_heads):
            raise ValueError("depths and num_heads must have the same length.")

        self.num_layers = len(depths)
        self.num_features = embed_dim * 2 ** (self.num_layers - 1)
        self.patch_embed = PatchEmbed(image_size, patch_size, in_channels, embed_dim)
        self.pos_drop = nn.Dropout(drop_rate)

        total_depth = sum(depths)
        drop_paths = torch.linspace(0, drop_path_rate, total_depth).tolist()

        self.layers = nn.ModuleList()
        resolution = image_size // patch_size
        depth_offset = 0
        for layer_index in range(self.num_layers):
            dim = embed_dim * 2**layer_index
            input_resolution = (resolution, resolution)
            layer = BasicLayer(
                dim=dim,
                input_resolution=input_resolution,
                depth=depths[layer_index],
                num_heads=num_heads[layer_index],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_paths[depth_offset : depth_offset + depths[layer_index]],
                downsample=layer_index < self.num_layers - 1,
            )
            self.layers.append(layer)
            depth_offset += depths[layer_index]
            if layer_index < self.num_layers - 1:
                resolution //= 2

        self.norm = nn.LayerNorm(self.num_features)
        self.head = nn.Linear(self.num_features, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x: Tensor) -> Tensor:
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        return self.head(x)


class SwinCIFAR10Classifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 10,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        learning_rate: float = 0.001,
        weight_decay: float = 0.05,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.model = SwinTransformer(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            num_classes=num_classes,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
        )
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = MulticlassAccuracy(num_classes=num_classes)
        self.val_acc = MulticlassAccuracy(num_classes=num_classes)
        self.test_acc = MulticlassAccuracy(num_classes=num_classes)

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def _shared_step(self, batch: tuple[Tensor, Tensor], stage: str) -> Tensor:
        images, targets = batch
        logits = self(images)
        loss = self.criterion(logits, targets)
        preds = torch.argmax(logits, dim=1)
        metric = getattr(self, f"{stage}_acc")
        metric(preds, targets)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=stage == "train", on_epoch=True)
        self.log(f"{stage}_acc", metric, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "train")

    def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "val")

    def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
config.yaml
seed_everything: 42

trainer:
  accelerator: auto
  devices: auto
  max_epochs: 20
  precision: bf16-mixed
  log_every_n_steps: 50
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val_acc
        mode: max
        save_top_k: 1
        filename: swin-cifar10-{epoch:02d}-{val_acc:.4f}
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: epoch

model:
  num_classes: 10
  image_size: 32
  patch_size: 4
  in_channels: 3
  embed_dim: 64
  depths: [2, 2, 2, 2]
  num_heads: [2, 4, 8, 16]
  window_size: 4
  mlp_ratio: 4.0
  drop_rate: 0.0
  attn_drop_rate: 0.0
  drop_path_rate: 0.1
  learning_rate: 0.001
  weight_decay: 0.05

data:
  data_dir: data
  batch_size: 128
  num_workers: 4
  image_size: 32
  val_split: 5000
  seed: 42
  download: true


続いて、torch.nn.functional.grouped_mmを使用して実装したdropless MoEの実装は以下の通り。

model.py
from __future__ import annotations

from typing import Sequence

import lightning.pytorch as pl
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy


def window_partition(x: Tensor, window_size: int) -> Tensor:
    batch_size, height, width, channels = x.shape
    x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return windows.view(-1, window_size, window_size, channels)


def window_reverse(windows: Tensor, window_size: int, height: int, width: int) -> Tensor:
    batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
    x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x.view(batch_size, height, width, -1)


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0) -> None:
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x: Tensor) -> Tensor:
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class Mlp(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, drop: float = 0.0) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class DroplessMoEMlp(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        num_experts: int,
        top_k: int = 2,
        drop: float = 0.0,
    ) -> None:
        super().__init__()
        if num_experts <= 0:
            raise ValueError("num_experts must be greater than 0.")
        if top_k <= 0 or top_k > num_experts:
            raise ValueError("top_k must be in the range [1, num_experts].")

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.num_experts = num_experts
        self.top_k = top_k

        self.router = nn.Linear(in_features, num_experts, bias=False)
        # grouped_mm operand layout avoids per-forward transpose: x @ w1 maps D -> H, h @ w2 maps H -> D.
        self.w1 = nn.Parameter(torch.empty(num_experts, in_features, hidden_features))
        self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_features))
        self.w2 = nn.Parameter(torch.empty(num_experts, hidden_features, in_features))
        self.b2 = nn.Parameter(torch.zeros(num_experts, in_features))
        self.act = nn.GELU()
        self.drop = nn.Dropout(drop)
        self.aux_loss: Tensor | None = None
        self.z_loss: Tensor | None = None

        nn.init.trunc_normal_(self.w1, std=0.02)
        nn.init.trunc_normal_(self.w2, std=0.02)

    def _apply(self, fn):
        module = super()._apply(fn)
        self.router.float()
        return module

    def forward(self, x: Tensor) -> Tensor:
        if not x.is_cuda:
            raise RuntimeError("DroplessMoEMlp requires CUDA tensors because it uses torch.nn.functional.grouped_mm.")
        if not hasattr(F, "grouped_mm"):
            raise RuntimeError("torch.nn.functional.grouped_mm is not available in this PyTorch build.")
        if self.router.weight.dtype != torch.float32:
            raise RuntimeError("Router weights must remain FP32 for stable routing.")

        orig_shape = x.shape
        x = x.reshape(-1, self.in_features)
        tokens = x.shape[0]

        # Keep routing numerics in FP32; expert grouped_mm below still uses the activation dtype.
        with torch.autocast(device_type="cuda", enabled=False):
            logits = self.router(x.float())
            probs = F.softmax(logits, dim=-1)
            topk_prob, topk_expert = torch.topk(probs, k=self.top_k, dim=-1)
            topk_gate = topk_prob / topk_prob.sum(dim=-1, keepdim=True)

        flat_expert = topk_expert.reshape(-1)
        flat_gate = topk_gate.reshape(-1)
        flat_token = torch.arange(tokens, device=x.device).repeat_interleave(self.top_k)

        order = torch.argsort(flat_expert)
        expert_sorted = flat_expert[order]
        token_sorted = flat_token[order]
        gate_sorted = flat_gate[order]
        x_sorted = x.index_select(0, token_sorted)

        counts = torch.bincount(expert_sorted, minlength=self.num_experts)
        offsets = torch.cumsum(counts, dim=0).to(torch.int32)
        tokens_per_expert = counts.to(probs.dtype) / counts.sum().to(probs.dtype)
        prob_per_expert = probs.mean(dim=0)
        self.aux_loss = self.num_experts * torch.sum(tokens_per_expert * prob_per_expert)
        self.z_loss = torch.mean(torch.logsumexp(logits, dim=-1).square())

        w1 = self.w1 if self.w1.dtype == x.dtype else self.w1.to(dtype=x.dtype)

        # grouped_mm requires offs[-1] < mat_a.shape[0], so append one ignored row.
        padding = torch.zeros(1, x_sorted.shape[-1], dtype=x_sorted.dtype, device=x_sorted.device)
        x_grouped = torch.cat((x_sorted, padding), dim=0)
        h = F.grouped_mm(x_grouped, w1, offs=offsets)[: x_sorted.shape[0]]
        h = h + self.b1.index_select(0, expert_sorted).to(dtype=h.dtype)
        h = self.act(h)
        h = self.drop(h)

        w2 = self.w2 if self.w2.dtype == h.dtype else self.w2.to(dtype=h.dtype)

        # grouped_mm requires offs[-1] < mat_a.shape[0], so append one ignored row.
        hidden_padding = torch.zeros(1, h.shape[-1], dtype=h.dtype, device=h.device)
        h_grouped = torch.cat((h, hidden_padding), dim=0)
        y_sorted = F.grouped_mm(h_grouped, w2, offs=offsets)[: h.shape[0]]
        y_sorted = y_sorted + self.b2.index_select(0, expert_sorted).to(dtype=y_sorted.dtype)
        y_sorted = self.drop(y_sorted)
        y_sorted = y_sorted * gate_sorted.unsqueeze(-1).to(dtype=y_sorted.dtype)

        y = torch.zeros_like(x)
        y.index_add_(0, token_sorted, y_sorted.to(dtype=y.dtype))
        return y.reshape(orig_shape)


class WindowAttention(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None:
        super().__init__()
        if num_heads <= 0:
            raise ValueError("num_heads must be greater than 0.")
        if dim % num_heads != 0:
            raise ValueError("dim must be divisible by num_heads.")
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        table_size = (2 * window_size - 1) * (2 * window_size - 1)
        self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads))

        coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size), indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
        batch_windows, tokens, channels = x.shape
        qkv = self.qkv(x).reshape(batch_windows, tokens, 3, self.num_heads, channels // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        query, key, value = qkv[0], qkv[1], qkv[2]

        query = query * self.scale
        attn = query @ key.transpose(-2, -1)
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        relative_position_bias = relative_position_bias.view(tokens, tokens, -1).permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            num_windows = mask.shape[0]
            attn = attn.view(batch_windows // num_windows, num_windows, self.num_heads, tokens, tokens)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, tokens, tokens)

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ value).transpose(1, 2).reshape(batch_windows, tokens, channels)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        num_heads: int,
        window_size: int = 4,
        shift_size: int = 0,
        mlp_ratio: float = 4.0,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        use_moe_mlp: bool = False,
        moe_num_experts: int = 4,
        moe_top_k: int = 2,
    ) -> None:
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = min(window_size, input_resolution[0], input_resolution[1])
        if input_resolution[0] % self.window_size != 0 or input_resolution[1] % self.window_size != 0:
            raise ValueError("input_resolution must be divisible by window_size, or padding is required.")
        self.shift_size = 0 if min(input_resolution) <= self.window_size else shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, self.window_size, num_heads, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        hidden_features = int(dim * mlp_ratio)
        if use_moe_mlp:
            self.mlp = DroplessMoEMlp(dim, hidden_features, num_experts=moe_num_experts, top_k=moe_top_k, drop=drop)
        else:
            self.mlp = Mlp(dim, hidden_features, drop=drop)

        self.register_buffer("attn_mask", self._create_mask(), persistent=False)

    def _create_mask(self) -> Tensor | None:
        if self.shift_size == 0:
            return None

        height, width = self.input_resolution
        img_mask = torch.zeros((1, height, width, 1))
        height_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        width_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
        count = 0
        for height_slice in height_slices:
            for width_slice in width_slices:
                img_mask[:, height_slice, width_slice, :] = count
                count += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")

        shortcut = x
        x = self.norm1(x).view(batch_size, height, width, channels)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size).view(-1, self.window_size * self.window_size, channels)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)
        shifted_x = window_reverse(attn_windows.view(-1, self.window_size, self.window_size, channels), self.window_size, height, width)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(batch_size, height * width, channels)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PatchMerging(nn.Module):
    def __init__(self, input_resolution: tuple[int, int], dim: int) -> None:
        super().__init__()
        self.input_resolution = input_resolution
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x: Tensor) -> Tensor:
        height, width = self.input_resolution
        batch_size, length, channels = x.shape
        if length != height * width:
            raise ValueError(f"Expected token length {height * width}, got {length}.")
        if height % 2 != 0 or width % 2 != 0:
            raise ValueError("PatchMerging requires even spatial dimensions.")

        x = x.view(batch_size, height, width, channels)
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], dim=-1).view(batch_size, -1, 4 * channels)
        x = self.norm(x)
        return self.reduction(x)


class BasicLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        input_resolution: tuple[int, int],
        depth: int,
        num_heads: int,
        window_size: int,
        mlp_ratio: float,
        drop: float,
        attn_drop: float,
        drop_path: Sequence[float],
        downsample: bool,
        use_moe_mlp: bool,
        moe_num_experts: int,
        moe_top_k: int,
    ) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                SwinTransformerBlock(
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=0 if index % 2 == 0 else window_size // 2,
                    mlp_ratio=mlp_ratio,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path[index],
                    use_moe_mlp=use_moe_mlp,
                    moe_num_experts=moe_num_experts,
                    moe_top_k=moe_top_k,
                )
                for index in range(depth)
            ]
        )
        self.downsample = PatchMerging(input_resolution, dim) if downsample else None

    def moe_loss(self) -> Tensor | None:
        losses = [block.mlp.aux_loss for block in self.blocks if isinstance(block.mlp, DroplessMoEMlp) and block.mlp.aux_loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()

    def moe_z_loss(self) -> Tensor | None:
        losses = [block.mlp.z_loss for block in self.blocks if isinstance(block.mlp, DroplessMoEMlp) and block.mlp.z_loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()

    def forward(self, x: Tensor) -> Tensor:
        for block in self.blocks:
            x = block(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


class PatchEmbed(nn.Module):
    def __init__(self, image_size: int, patch_size: int, in_channels: int, embed_dim: int) -> None:
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.grid_size = image_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        _, _, height, width = x.shape
        if height != self.image_size or width != self.image_size:
            raise ValueError(f"Expected input size {self.image_size}x{self.image_size}, got {height}x{width}.")
        x = self.proj(x).flatten(2).transpose(1, 2)
        return self.norm(x)


class SwinTransformer(nn.Module):
    def __init__(
        self,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        num_classes: int = 10,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        moe_stages: Sequence[int] = (0, 1),
        moe_num_experts: int = 4,
        moe_top_k: int = 2,
    ) -> None:
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError("image_size must be divisible by patch_size.")
        if len(depths) != len(num_heads):
            raise ValueError("depths and num_heads must have the same length.")
        moe_stage_set = set(moe_stages)
        invalid_moe_stages = moe_stage_set.difference(range(len(depths)))
        if invalid_moe_stages:
            raise ValueError(f"moe_stages contains invalid stage indices: {sorted(invalid_moe_stages)}.")

        self.num_layers = len(depths)
        self.num_features = embed_dim * 2 ** (self.num_layers - 1)
        self.patch_embed = PatchEmbed(image_size, patch_size, in_channels, embed_dim)
        self.pos_drop = nn.Dropout(drop_rate)

        total_depth = sum(depths)
        drop_paths = torch.linspace(0, drop_path_rate, total_depth).tolist()

        self.layers = nn.ModuleList()
        resolution = image_size // patch_size
        depth_offset = 0
        for layer_index in range(self.num_layers):
            dim = embed_dim * 2**layer_index
            input_resolution = (resolution, resolution)
            layer = BasicLayer(
                dim=dim,
                input_resolution=input_resolution,
                depth=depths[layer_index],
                num_heads=num_heads[layer_index],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_paths[depth_offset : depth_offset + depths[layer_index]],
                downsample=layer_index < self.num_layers - 1,
                use_moe_mlp=layer_index in moe_stage_set,
                moe_num_experts=moe_num_experts,
                moe_top_k=moe_top_k,
            )
            self.layers.append(layer)
            depth_offset += depths[layer_index]
            if layer_index < self.num_layers - 1:
                resolution //= 2

        self.norm = nn.LayerNorm(self.num_features)
        self.head = nn.Linear(self.num_features, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x: Tensor) -> Tensor:
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        return self.head(x)

    def moe_loss(self) -> Tensor | None:
        losses = [layer.moe_loss() for layer in self.layers]
        losses = [loss for loss in losses if loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()

    def moe_z_loss(self) -> Tensor | None:
        losses = [layer.moe_z_loss() for layer in self.layers]
        losses = [loss for loss in losses if loss is not None]
        if not losses:
            return None
        return torch.stack(losses).mean()


class SwinCIFAR10Classifier(pl.LightningModule):
    def __init__(
        self,
        num_classes: int = 10,
        image_size: int = 32,
        patch_size: int = 4,
        in_channels: int = 3,
        embed_dim: int = 64,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (2, 4, 8, 16),
        window_size: int = 4,
        mlp_ratio: float = 4.0,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.1,
        moe_stages: Sequence[int] = (0, 1),
        moe_num_experts: int = 4,
        moe_top_k: int = 2,
        moe_aux_loss_weight: float = 0.01,
        moe_z_loss_weight: float = 0.001,
        learning_rate: float = 0.001,
        weight_decay: float = 0.05,
    ) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.model = SwinTransformer(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            num_classes=num_classes,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads,
            window_size=window_size,
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=drop_path_rate,
            moe_stages=moe_stages,
            moe_num_experts=moe_num_experts,
            moe_top_k=moe_top_k,
        )
        self.criterion = nn.CrossEntropyLoss()
        self.train_acc = MulticlassAccuracy(num_classes=num_classes)
        self.val_acc = MulticlassAccuracy(num_classes=num_classes)
        self.test_acc = MulticlassAccuracy(num_classes=num_classes)

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def _shared_step(self, batch: tuple[Tensor, Tensor], stage: str) -> Tensor:
        images, targets = batch
        logits = self(images)
        loss = self.criterion(logits, targets)
        moe_aux_loss = self.model.moe_loss()
        if moe_aux_loss is not None:
            self.log(f"{stage}_moe_aux_loss", moe_aux_loss, prog_bar=False, on_step=stage == "train", on_epoch=True)
            if stage == "train" and self.hparams.moe_aux_loss_weight > 0.0:
                loss = loss + self.hparams.moe_aux_loss_weight * moe_aux_loss
        moe_z_loss = self.model.moe_z_loss()
        if moe_z_loss is not None:
            self.log(f"{stage}_moe_z_loss", moe_z_loss, prog_bar=False, on_step=stage == "train", on_epoch=True)
            if stage == "train" and self.hparams.moe_z_loss_weight > 0.0:
                loss = loss + self.hparams.moe_z_loss_weight * moe_z_loss
        preds = torch.argmax(logits, dim=1)
        metric = getattr(self, f"{stage}_acc")
        metric(preds, targets)
        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=stage == "train", on_epoch=True)
        self.log(f"{stage}_acc", metric, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "train")

    def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "val")

    def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        return self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}
config.yaml
seed_everything: 42

trainer:
  accelerator: auto
  devices: auto
  max_epochs: 20
  precision: bf16-mixed
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: val_acc
        mode: max
        save_top_k: 1
        filename: swin-cifar10-{epoch:02d}-{val_acc:.4f}
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: epoch

model:
  num_classes: 10
  image_size: 32
  patch_size: 4
  in_channels: 3
  embed_dim: 64
  depths: [2, 2, 2, 2]
  num_heads: [2, 4, 8, 16]
  window_size: 4
  mlp_ratio: 4.0
  drop_rate: 0.0
  attn_drop_rate: 0.0
  drop_path_rate: 0.1
  moe_stages: [0, 1]
  moe_num_experts: 4
  moe_top_k: 2
  moe_aux_loss_weight: 0.01
  moe_z_loss_weight: 0.001
  learning_rate: 0.001
  weight_decay: 0.05

data:
  data_dir: data
  batch_size: 128
  num_workers: 4
  image_size: 32
  val_split: 5000
  seed: 42
  download: true

比較結果

パラメータ数
MoEなし 9.1 M
MoEあり 10.1 M
評価損失


※オレンジ:MoEなし、青:MoEあり

評価正解率


※オレンジ:MoEなし、青:MoEあり

訓練時間



※オレンジ:MoEなし、青:MoEあり

考察

MoEありの方が、評価精度は高くなっている。
パラメータ数は、MoEありの方が、1.1倍である。
学習時間は、MoEありの方が、1.4倍になった。

Expertは、Top-2に割り当てているため、その分計算量が増えた可能性がある。

追加で、Top-1でも測定した。



※赤:Top-1

Top-1でも学習時間はほぼ変わらなかった。
MoE化に伴うゲーティング処理のオーバーヘッドが原因と考える。

また、Top-1の場合、正解率はMoEなしとほとんど変わらなかった。

まとめ

PyTorchのgrouped_mmを使用して、dropless MoEを実装し、精度や学習時間を評価した。
結果、Top-2のMoEで精度が向上することが確認できたが、学習時間はパラメータ数の比率以上に遅くなった。
小さなSwinTransformerで実験したため、ゲーティングのオーバーヘッドが大きかったと考える。

次は、TensorRTのカスタムプラグインで推論処理を実装したい。