dlshogiのモデルはパラメータを増やすほど精度が向上することが60ブロック768フィルタのサイズまで確認できている。
しかし、探索速度が落ちるため、対局した際の強さは、40ブロック512フィルタのモデルには及ばない。
MoE(Mixture of Experts)は、パラメータを増やしながら推論速度を落とさない手法としてLLMで取り入れられている。
DL系の将棋AIを強くするためには、MoEは有力な手法であると考えている。
しかし、MoEの実装は単純ではなく、GPUで効率的に推論するには、PyTorch → ONNX → TensorRTというdlshogiで使用している標準化された手法が使えず、カスタマイズ実装が必要になる。
この記事では、MoEの手法を整理し、GPUで効率的に推論するためのdropless MoEの訓練と推論の実装方法について記述する。
MoE(Mixture of Experts)
MoE(Mixture of Experts)は、入力ごとにゲーティング機構が多数の専門家モデルの一部だけを選んで計算させることで、計算量を抑えながらモデル容量を大きくする手法である。
Transformer系LLMでMoEを使う場合、Transformer block内のFFN(MLP)部分をMoE化する構成が主流である。
Dense MoE
MoEを単純に実装する場合、すべての expert を計算して、重み付き和を取ることで実装できる。
しかし、1つのトークンに対して全expertを計算する必要があり、MoEの計算量を抑えながらモデル容量を大きくできるというメリットは得られない。
Sparse MoE
gate が少数 expert だけを選択し、選択されたexpertのみ計算し、他のexpertの計算はスキップする。
これをGPUで実現するには、capacity-aware routingという手法が使われる。
capacity-aware routing
capacity-aware routingは、MoEで各expertがGPU上で効率よく固定長バッチとして処理できるように、expertごとの受け入れtoken数に上限 capacity を設け、足りない分はpaddingし、超過分はdrop・再割当・別処理するrouting方式である。
これはシンプルに実装できる反面、dropするトークンが発生する可能性があり、capacityを調整する必要がある。
capacityを増やすとdropの発生率は下がるが、paddingにより無駄な計算が増えるため、トレードオフになる。
最近では、dropを回避する、dropless MoEが主流になっている。
dropless MoE
dropless MoE とは、grouped GEMMやblock-sparse kernelなどで全tokenをGPU上で処理する方式である。
expertごとのトークン量が固定長でなくても、処理できるように専用のカーネルとして実装される。
MegaBlocksは、dropless MoEを実用的なGPU学習システムとして成立させた先駆的研究の一つである。
PyTorch 2.10には、torch.nn.functional.grouped_mmが追加され、expertごとの可変長GEMMが、PyTorch標準で実装できるようになった。
grouped_mmは、内部では専用のカーネルとして実装されている。
推論の課題
TensorRTでPyTorchで訓練したモデルを推論する際、ONNXにエクスポートして、TensorRTで読み込むのが標準的な方法である。
しかし、grouped_mmに相当するONNXのオペレータはなく、標準の方法が使えない。
そのため、ONNXのカスタムオペレータとしてエクスポートして、TensorRTのカスタムプラグインとして実装する必要がある。
カスタムプラグインは、grouped_mm相当の処理は、CUTLASSライブラリにあるgrouped GEMMで実装できる。
dropless MoEの実装
実際に、PyTorchでgrouped_mmを使用してdropless MoEを実装して学習できるか検証してみた。
SwinTransformerで、CIFAR-10データセットを学習するシンプルなコードで検証した。
まずは、MoEを使用しない実装は以下の通りである。
PyTorch Lightningを使用して実装している。
train.py
from lightning.pytorch.cli import LightningCLI from data import CIFAR10DataModule from model import SwinCIFAR10Classifier def main() -> None: LightningCLI( model_class=SwinCIFAR10Classifier, datamodule_class=CIFAR10DataModule, save_config_kwargs={"overwrite": True}, ) if __name__ == "__main__": main()
data.py
from __future__ import annotations from pathlib import Path import lightning.pytorch as pl import torch from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms class CIFAR10DataModule(pl.LightningDataModule): def __init__( self, data_dir: str = "data", batch_size: int = 128, num_workers: int = 4, image_size: int = 32, val_split: int = 5000, seed: int = 42, download: bool = True, ) -> None: super().__init__() self.data_dir = Path(data_dir) self.batch_size = batch_size self.num_workers = num_workers self.image_size = image_size self.val_split = val_split self.seed = seed self.download = download self.train_dataset = None self.val_dataset = None self.test_dataset = None def prepare_data(self) -> None: datasets.CIFAR10(self.data_dir, train=True, download=self.download) datasets.CIFAR10(self.data_dir, train=False, download=self.download) def setup(self, stage: str | None = None) -> None: train_transform = transforms.Compose( [ transforms.Resize((self.image_size, self.image_size)), transforms.RandomCrop(self.image_size, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ] ) eval_transform = transforms.Compose( [ transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ] ) if stage in (None, "fit"): train_base = datasets.CIFAR10(self.data_dir, train=True, transform=train_transform) val_base = datasets.CIFAR10(self.data_dir, train=True, transform=eval_transform) train_size = len(train_base) - self.val_split indices = torch.randperm(len(train_base), generator=torch.Generator().manual_seed(self.seed)).tolist() self.train_dataset = Subset(train_base, indices[:train_size]) self.val_dataset = Subset(val_base, indices[train_size:]) if stage in (None, "test"): self.test_dataset = datasets.CIFAR10(self.data_dir, train=False, transform=eval_transform) def train_dataloader(self) -> DataLoader: return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, ) def val_dataloader(self) -> DataLoader: return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, ) def test_dataloader(self) -> DataLoader: return DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, )
model.py
from __future__ import annotations from typing import Sequence import lightning.pytorch as pl import torch from torch import Tensor, nn from torch.nn import functional as F from torchmetrics.classification import MulticlassAccuracy def window_partition(x: Tensor, window_size: int) -> Tensor: batch_size, height, width, channels = x.shape x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() return windows.view(-1, window_size, window_size, channels) def window_reverse(windows: Tensor, window_size: int, height: int, width: int) -> Tensor: batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() return x.view(batch_size, height, width, -1) class DropPath(nn.Module): def __init__(self, drop_prob: float = 0.0) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: Tensor) -> Tensor: if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() return x.div(keep_prob) * random_tensor class Mlp(nn.Module): def __init__(self, in_features: int, hidden_features: int, drop: float = 0.0) -> None: super().__init__() self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, in_features) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class WindowAttention(nn.Module): def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None: super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 table_size = (2 * window_size - 1) * (2 * window_size - 1) self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads)) coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size), indexing="ij")) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += window_size - 1 relative_coords[:, :, 1] += window_size - 1 relative_coords[:, :, 0] *= 2 * window_size - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index, persistent=False) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor: batch_windows, tokens, channels = x.shape qkv = self.qkv(x).reshape(batch_windows, tokens, 3, self.num_heads, channels // self.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) query, key, value = qkv[0], qkv[1], qkv[2] query = query * self.scale attn = query @ key.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] relative_position_bias = relative_position_bias.view(tokens, tokens, -1).permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: num_windows = mask.shape[0] attn = attn.view(batch_windows // num_windows, num_windows, self.num_heads, tokens, tokens) attn = attn + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, tokens, tokens) attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn) x = (attn @ value).transpose(1, 2).reshape(batch_windows, tokens, channels) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): def __init__( self, dim: int, input_resolution: tuple[int, int], num_heads: int, window_size: int = 4, shift_size: int = 0, mlp_ratio: float = 4.0, drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, ) -> None: super().__init__() self.dim = dim self.input_resolution = input_resolution self.window_size = min(window_size, input_resolution[0], input_resolution[1]) self.shift_size = 0 if min(input_resolution) <= self.window_size else shift_size self.norm1 = nn.LayerNorm(dim) self.attn = WindowAttention(dim, self.window_size, num_heads, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = nn.LayerNorm(dim) self.mlp = Mlp(dim, int(dim * mlp_ratio), drop=drop) self.register_buffer("attn_mask", self._create_mask(), persistent=False) def _create_mask(self) -> Tensor | None: if self.shift_size == 0: return None height, width = self.input_resolution img_mask = torch.zeros((1, height, width, 1)) height_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) width_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) count = 0 for height_slice in height_slices: for width_slice in width_slices: img_mask[:, height_slice, width_slice, :] = count count += 1 mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def forward(self, x: Tensor) -> Tensor: height, width = self.input_resolution batch_size, length, channels = x.shape if length != height * width: raise ValueError(f"Expected token length {height * width}, got {length}.") shortcut = x x = self.norm1(x).view(batch_size, height, width, channels) if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x x_windows = window_partition(shifted_x, self.window_size).view(-1, self.window_size * self.window_size, channels) attn_windows = self.attn(x_windows, mask=self.attn_mask) shifted_x = window_reverse(attn_windows.view(-1, self.window_size, self.window_size, channels), self.window_size, height, width) if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(batch_size, height * width, channels) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): def __init__(self, input_resolution: tuple[int, int], dim: int) -> None: super().__init__() self.input_resolution = input_resolution self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = nn.LayerNorm(4 * dim) def forward(self, x: Tensor) -> Tensor: height, width = self.input_resolution batch_size, length, channels = x.shape if length != height * width: raise ValueError(f"Expected token length {height * width}, got {length}.") if height % 2 != 0 or width % 2 != 0: raise ValueError("PatchMerging requires even spatial dimensions.") x = x.view(batch_size, height, width, channels) x0 = x[:, 0::2, 0::2, :] x1 = x[:, 1::2, 0::2, :] x2 = x[:, 0::2, 1::2, :] x3 = x[:, 1::2, 1::2, :] x = torch.cat([x0, x1, x2, x3], dim=-1).view(batch_size, -1, 4 * channels) x = self.norm(x) return self.reduction(x) class BasicLayer(nn.Module): def __init__( self, dim: int, input_resolution: tuple[int, int], depth: int, num_heads: int, window_size: int, mlp_ratio: float, drop: float, attn_drop: float, drop_path: Sequence[float], downsample: bool, ) -> None: super().__init__() self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if index % 2 == 0 else window_size // 2, mlp_ratio=mlp_ratio, drop=drop, attn_drop=attn_drop, drop_path=drop_path[index], ) for index in range(depth) ] ) self.downsample = PatchMerging(input_resolution, dim) if downsample else None def forward(self, x: Tensor) -> Tensor: for block in self.blocks: x = block(x) if self.downsample is not None: x = self.downsample(x) return x class PatchEmbed(nn.Module): def __init__(self, image_size: int, patch_size: int, in_channels: int, embed_dim: int) -> None: super().__init__() self.image_size = image_size self.patch_size = patch_size self.grid_size = image_size // patch_size self.num_patches = self.grid_size * self.grid_size self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: Tensor) -> Tensor: _, _, height, width = x.shape if height != self.image_size or width != self.image_size: raise ValueError(f"Expected input size {self.image_size}x{self.image_size}, got {height}x{width}.") x = self.proj(x).flatten(2).transpose(1, 2) return self.norm(x) class SwinTransformer(nn.Module): def __init__( self, image_size: int = 32, patch_size: int = 4, in_channels: int = 3, num_classes: int = 10, embed_dim: int = 64, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (2, 4, 8, 16), window_size: int = 4, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.1, ) -> None: super().__init__() if image_size % patch_size != 0: raise ValueError("image_size must be divisible by patch_size.") if len(depths) != len(num_heads): raise ValueError("depths and num_heads must have the same length.") self.num_layers = len(depths) self.num_features = embed_dim * 2 ** (self.num_layers - 1) self.patch_embed = PatchEmbed(image_size, patch_size, in_channels, embed_dim) self.pos_drop = nn.Dropout(drop_rate) total_depth = sum(depths) drop_paths = torch.linspace(0, drop_path_rate, total_depth).tolist() self.layers = nn.ModuleList() resolution = image_size // patch_size depth_offset = 0 for layer_index in range(self.num_layers): dim = embed_dim * 2**layer_index input_resolution = (resolution, resolution) layer = BasicLayer( dim=dim, input_resolution=input_resolution, depth=depths[layer_index], num_heads=num_heads[layer_index], window_size=window_size, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_paths[depth_offset : depth_offset + depths[layer_index]], downsample=layer_index < self.num_layers - 1, ) self.layers.append(layer) depth_offset += depths[layer_index] if layer_index < self.num_layers - 1: resolution //= 2 self.norm = nn.LayerNorm(self.num_features) self.head = nn.Linear(self.num_features, num_classes) self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, x: Tensor) -> Tensor: x = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) x = x.mean(dim=1) return self.head(x) class SwinCIFAR10Classifier(pl.LightningModule): def __init__( self, num_classes: int = 10, image_size: int = 32, patch_size: int = 4, in_channels: int = 3, embed_dim: int = 64, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (2, 4, 8, 16), window_size: int = 4, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.1, learning_rate: float = 0.001, weight_decay: float = 0.05, ) -> None: super().__init__() self.save_hyperparameters() self.model = SwinTransformer( image_size=image_size, patch_size=patch_size, in_channels=in_channels, num_classes=num_classes, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, ) self.criterion = nn.CrossEntropyLoss() self.train_acc = MulticlassAccuracy(num_classes=num_classes) self.val_acc = MulticlassAccuracy(num_classes=num_classes) self.test_acc = MulticlassAccuracy(num_classes=num_classes) def forward(self, x: Tensor) -> Tensor: return self.model(x) def _shared_step(self, batch: tuple[Tensor, Tensor], stage: str) -> Tensor: images, targets = batch logits = self(images) loss = self.criterion(logits, targets) preds = torch.argmax(logits, dim=1) metric = getattr(self, f"{stage}_acc") metric(preds, targets) self.log(f"{stage}_loss", loss, prog_bar=True, on_step=stage == "train", on_epoch=True) self.log(f"{stage}_acc", metric, prog_bar=True, on_step=False, on_epoch=True) return loss def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: return self._shared_step(batch, "train") def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: return self._shared_step(batch, "val") def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: return self._shared_step(batch, "test") def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs) return {"optimizer": optimizer, "lr_scheduler": scheduler}
config.yaml
seed_everything: 42 trainer: accelerator: auto devices: auto max_epochs: 20 precision: bf16-mixed log_every_n_steps: 50 callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: monitor: val_acc mode: max save_top_k: 1 filename: swin-cifar10-{epoch:02d}-{val_acc:.4f} - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: epoch model: num_classes: 10 image_size: 32 patch_size: 4 in_channels: 3 embed_dim: 64 depths: [2, 2, 2, 2] num_heads: [2, 4, 8, 16] window_size: 4 mlp_ratio: 4.0 drop_rate: 0.0 attn_drop_rate: 0.0 drop_path_rate: 0.1 learning_rate: 0.001 weight_decay: 0.05 data: data_dir: data batch_size: 128 num_workers: 4 image_size: 32 val_split: 5000 seed: 42 download: true
続いて、torch.nn.functional.grouped_mmを使用して実装したdropless MoEの実装は以下の通り。
model.py
from __future__ import annotations from typing import Sequence import lightning.pytorch as pl import torch from torch import Tensor, nn from torch.nn import functional as F from torchmetrics.classification import MulticlassAccuracy def window_partition(x: Tensor, window_size: int) -> Tensor: batch_size, height, width, channels = x.shape x = x.view(batch_size, height // window_size, window_size, width // window_size, window_size, channels) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous() return windows.view(-1, window_size, window_size, channels) def window_reverse(windows: Tensor, window_size: int, height: int, width: int) -> Tensor: batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) x = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous() return x.view(batch_size, height, width, -1) class DropPath(nn.Module): def __init__(self, drop_prob: float = 0.0) -> None: super().__init__() self.drop_prob = drop_prob def forward(self, x: Tensor) -> Tensor: if self.drop_prob == 0.0 or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() return x.div(keep_prob) * random_tensor class Mlp(nn.Module): def __init__(self, in_features: int, hidden_features: int, drop: float = 0.0) -> None: super().__init__() self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, in_features) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DroplessMoEMlp(nn.Module): def __init__( self, in_features: int, hidden_features: int, num_experts: int, top_k: int = 2, drop: float = 0.0, ) -> None: super().__init__() if num_experts <= 0: raise ValueError("num_experts must be greater than 0.") if top_k <= 0 or top_k > num_experts: raise ValueError("top_k must be in the range [1, num_experts].") self.in_features = in_features self.hidden_features = hidden_features self.num_experts = num_experts self.top_k = top_k self.router = nn.Linear(in_features, num_experts, bias=False) # grouped_mm operand layout avoids per-forward transpose: x @ w1 maps D -> H, h @ w2 maps H -> D. self.w1 = nn.Parameter(torch.empty(num_experts, in_features, hidden_features)) self.b1 = nn.Parameter(torch.zeros(num_experts, hidden_features)) self.w2 = nn.Parameter(torch.empty(num_experts, hidden_features, in_features)) self.b2 = nn.Parameter(torch.zeros(num_experts, in_features)) self.act = nn.GELU() self.drop = nn.Dropout(drop) self.aux_loss: Tensor | None = None self.z_loss: Tensor | None = None nn.init.trunc_normal_(self.w1, std=0.02) nn.init.trunc_normal_(self.w2, std=0.02) def _apply(self, fn): module = super()._apply(fn) self.router.float() return module def forward(self, x: Tensor) -> Tensor: if not x.is_cuda: raise RuntimeError("DroplessMoEMlp requires CUDA tensors because it uses torch.nn.functional.grouped_mm.") if not hasattr(F, "grouped_mm"): raise RuntimeError("torch.nn.functional.grouped_mm is not available in this PyTorch build.") if self.router.weight.dtype != torch.float32: raise RuntimeError("Router weights must remain FP32 for stable routing.") orig_shape = x.shape x = x.reshape(-1, self.in_features) tokens = x.shape[0] # Keep routing numerics in FP32; expert grouped_mm below still uses the activation dtype. with torch.autocast(device_type="cuda", enabled=False): logits = self.router(x.float()) probs = F.softmax(logits, dim=-1) topk_prob, topk_expert = torch.topk(probs, k=self.top_k, dim=-1) topk_gate = topk_prob / topk_prob.sum(dim=-1, keepdim=True) flat_expert = topk_expert.reshape(-1) flat_gate = topk_gate.reshape(-1) flat_token = torch.arange(tokens, device=x.device).repeat_interleave(self.top_k) order = torch.argsort(flat_expert) expert_sorted = flat_expert[order] token_sorted = flat_token[order] gate_sorted = flat_gate[order] x_sorted = x.index_select(0, token_sorted) counts = torch.bincount(expert_sorted, minlength=self.num_experts) offsets = torch.cumsum(counts, dim=0).to(torch.int32) tokens_per_expert = counts.to(probs.dtype) / counts.sum().to(probs.dtype) prob_per_expert = probs.mean(dim=0) self.aux_loss = self.num_experts * torch.sum(tokens_per_expert * prob_per_expert) self.z_loss = torch.mean(torch.logsumexp(logits, dim=-1).square()) w1 = self.w1 if self.w1.dtype == x.dtype else self.w1.to(dtype=x.dtype) # grouped_mm requires offs[-1] < mat_a.shape[0], so append one ignored row. padding = torch.zeros(1, x_sorted.shape[-1], dtype=x_sorted.dtype, device=x_sorted.device) x_grouped = torch.cat((x_sorted, padding), dim=0) h = F.grouped_mm(x_grouped, w1, offs=offsets)[: x_sorted.shape[0]] h = h + self.b1.index_select(0, expert_sorted).to(dtype=h.dtype) h = self.act(h) h = self.drop(h) w2 = self.w2 if self.w2.dtype == h.dtype else self.w2.to(dtype=h.dtype) # grouped_mm requires offs[-1] < mat_a.shape[0], so append one ignored row. hidden_padding = torch.zeros(1, h.shape[-1], dtype=h.dtype, device=h.device) h_grouped = torch.cat((h, hidden_padding), dim=0) y_sorted = F.grouped_mm(h_grouped, w2, offs=offsets)[: h.shape[0]] y_sorted = y_sorted + self.b2.index_select(0, expert_sorted).to(dtype=y_sorted.dtype) y_sorted = self.drop(y_sorted) y_sorted = y_sorted * gate_sorted.unsqueeze(-1).to(dtype=y_sorted.dtype) y = torch.zeros_like(x) y.index_add_(0, token_sorted, y_sorted.to(dtype=y.dtype)) return y.reshape(orig_shape) class WindowAttention(nn.Module): def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0) -> None: super().__init__() if num_heads <= 0: raise ValueError("num_heads must be greater than 0.") if dim % num_heads != 0: raise ValueError("dim must be divisible by num_heads.") self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 table_size = (2 * window_size - 1) * (2 * window_size - 1) self.relative_position_bias_table = nn.Parameter(torch.zeros(table_size, num_heads)) coords = torch.stack(torch.meshgrid(torch.arange(window_size), torch.arange(window_size), indexing="ij")) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += window_size - 1 relative_coords[:, :, 1] += window_size - 1 relative_coords[:, :, 0] *= 2 * window_size - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index, persistent=False) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor: batch_windows, tokens, channels = x.shape qkv = self.qkv(x).reshape(batch_windows, tokens, 3, self.num_heads, channels // self.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) query, key, value = qkv[0], qkv[1], qkv[2] query = query * self.scale attn = query @ key.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] relative_position_bias = relative_position_bias.view(tokens, tokens, -1).permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: num_windows = mask.shape[0] attn = attn.view(batch_windows // num_windows, num_windows, self.num_heads, tokens, tokens) attn = attn + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, tokens, tokens) attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn) x = (attn @ value).transpose(1, 2).reshape(batch_windows, tokens, channels) x = self.proj(x) x = self.proj_drop(x) return x class SwinTransformerBlock(nn.Module): def __init__( self, dim: int, input_resolution: tuple[int, int], num_heads: int, window_size: int = 4, shift_size: int = 0, mlp_ratio: float = 4.0, drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, use_moe_mlp: bool = False, moe_num_experts: int = 4, moe_top_k: int = 2, ) -> None: super().__init__() self.dim = dim self.input_resolution = input_resolution self.window_size = min(window_size, input_resolution[0], input_resolution[1]) if input_resolution[0] % self.window_size != 0 or input_resolution[1] % self.window_size != 0: raise ValueError("input_resolution must be divisible by window_size, or padding is required.") self.shift_size = 0 if min(input_resolution) <= self.window_size else shift_size self.norm1 = nn.LayerNorm(dim) self.attn = WindowAttention(dim, self.window_size, num_heads, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = nn.LayerNorm(dim) hidden_features = int(dim * mlp_ratio) if use_moe_mlp: self.mlp = DroplessMoEMlp(dim, hidden_features, num_experts=moe_num_experts, top_k=moe_top_k, drop=drop) else: self.mlp = Mlp(dim, hidden_features, drop=drop) self.register_buffer("attn_mask", self._create_mask(), persistent=False) def _create_mask(self) -> Tensor | None: if self.shift_size == 0: return None height, width = self.input_resolution img_mask = torch.zeros((1, height, width, 1)) height_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) width_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) count = 0 for height_slice in height_slices: for width_slice in width_slices: img_mask[:, height_slice, width_slice, :] = count count += 1 mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) return attn_mask def forward(self, x: Tensor) -> Tensor: height, width = self.input_resolution batch_size, length, channels = x.shape if length != height * width: raise ValueError(f"Expected token length {height * width}, got {length}.") shortcut = x x = self.norm1(x).view(batch_size, height, width, channels) if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x x_windows = window_partition(shifted_x, self.window_size).view(-1, self.window_size * self.window_size, channels) attn_windows = self.attn(x_windows, mask=self.attn_mask) shifted_x = window_reverse(attn_windows.view(-1, self.window_size, self.window_size, channels), self.window_size, height, width) if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(batch_size, height * width, channels) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class PatchMerging(nn.Module): def __init__(self, input_resolution: tuple[int, int], dim: int) -> None: super().__init__() self.input_resolution = input_resolution self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = nn.LayerNorm(4 * dim) def forward(self, x: Tensor) -> Tensor: height, width = self.input_resolution batch_size, length, channels = x.shape if length != height * width: raise ValueError(f"Expected token length {height * width}, got {length}.") if height % 2 != 0 or width % 2 != 0: raise ValueError("PatchMerging requires even spatial dimensions.") x = x.view(batch_size, height, width, channels) x0 = x[:, 0::2, 0::2, :] x1 = x[:, 1::2, 0::2, :] x2 = x[:, 0::2, 1::2, :] x3 = x[:, 1::2, 1::2, :] x = torch.cat([x0, x1, x2, x3], dim=-1).view(batch_size, -1, 4 * channels) x = self.norm(x) return self.reduction(x) class BasicLayer(nn.Module): def __init__( self, dim: int, input_resolution: tuple[int, int], depth: int, num_heads: int, window_size: int, mlp_ratio: float, drop: float, attn_drop: float, drop_path: Sequence[float], downsample: bool, use_moe_mlp: bool, moe_num_experts: int, moe_top_k: int, ) -> None: super().__init__() self.blocks = nn.ModuleList( [ SwinTransformerBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if index % 2 == 0 else window_size // 2, mlp_ratio=mlp_ratio, drop=drop, attn_drop=attn_drop, drop_path=drop_path[index], use_moe_mlp=use_moe_mlp, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, ) for index in range(depth) ] ) self.downsample = PatchMerging(input_resolution, dim) if downsample else None def moe_loss(self) -> Tensor | None: losses = [block.mlp.aux_loss for block in self.blocks if isinstance(block.mlp, DroplessMoEMlp) and block.mlp.aux_loss is not None] if not losses: return None return torch.stack(losses).mean() def moe_z_loss(self) -> Tensor | None: losses = [block.mlp.z_loss for block in self.blocks if isinstance(block.mlp, DroplessMoEMlp) and block.mlp.z_loss is not None] if not losses: return None return torch.stack(losses).mean() def forward(self, x: Tensor) -> Tensor: for block in self.blocks: x = block(x) if self.downsample is not None: x = self.downsample(x) return x class PatchEmbed(nn.Module): def __init__(self, image_size: int, patch_size: int, in_channels: int, embed_dim: int) -> None: super().__init__() self.image_size = image_size self.patch_size = patch_size self.grid_size = image_size // patch_size self.num_patches = self.grid_size * self.grid_size self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: Tensor) -> Tensor: _, _, height, width = x.shape if height != self.image_size or width != self.image_size: raise ValueError(f"Expected input size {self.image_size}x{self.image_size}, got {height}x{width}.") x = self.proj(x).flatten(2).transpose(1, 2) return self.norm(x) class SwinTransformer(nn.Module): def __init__( self, image_size: int = 32, patch_size: int = 4, in_channels: int = 3, num_classes: int = 10, embed_dim: int = 64, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (2, 4, 8, 16), window_size: int = 4, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.1, moe_stages: Sequence[int] = (0, 1), moe_num_experts: int = 4, moe_top_k: int = 2, ) -> None: super().__init__() if image_size % patch_size != 0: raise ValueError("image_size must be divisible by patch_size.") if len(depths) != len(num_heads): raise ValueError("depths and num_heads must have the same length.") moe_stage_set = set(moe_stages) invalid_moe_stages = moe_stage_set.difference(range(len(depths))) if invalid_moe_stages: raise ValueError(f"moe_stages contains invalid stage indices: {sorted(invalid_moe_stages)}.") self.num_layers = len(depths) self.num_features = embed_dim * 2 ** (self.num_layers - 1) self.patch_embed = PatchEmbed(image_size, patch_size, in_channels, embed_dim) self.pos_drop = nn.Dropout(drop_rate) total_depth = sum(depths) drop_paths = torch.linspace(0, drop_path_rate, total_depth).tolist() self.layers = nn.ModuleList() resolution = image_size // patch_size depth_offset = 0 for layer_index in range(self.num_layers): dim = embed_dim * 2**layer_index input_resolution = (resolution, resolution) layer = BasicLayer( dim=dim, input_resolution=input_resolution, depth=depths[layer_index], num_heads=num_heads[layer_index], window_size=window_size, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_paths[depth_offset : depth_offset + depths[layer_index]], downsample=layer_index < self.num_layers - 1, use_moe_mlp=layer_index in moe_stage_set, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, ) self.layers.append(layer) depth_offset += depths[layer_index] if layer_index < self.num_layers - 1: resolution //= 2 self.norm = nn.LayerNorm(self.num_features) self.head = nn.Linear(self.num_features, num_classes) self.apply(self._init_weights) def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward(self, x: Tensor) -> Tensor: x = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) x = x.mean(dim=1) return self.head(x) def moe_loss(self) -> Tensor | None: losses = [layer.moe_loss() for layer in self.layers] losses = [loss for loss in losses if loss is not None] if not losses: return None return torch.stack(losses).mean() def moe_z_loss(self) -> Tensor | None: losses = [layer.moe_z_loss() for layer in self.layers] losses = [loss for loss in losses if loss is not None] if not losses: return None return torch.stack(losses).mean() class SwinCIFAR10Classifier(pl.LightningModule): def __init__( self, num_classes: int = 10, image_size: int = 32, patch_size: int = 4, in_channels: int = 3, embed_dim: int = 64, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (2, 4, 8, 16), window_size: int = 4, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.1, moe_stages: Sequence[int] = (0, 1), moe_num_experts: int = 4, moe_top_k: int = 2, moe_aux_loss_weight: float = 0.01, moe_z_loss_weight: float = 0.001, learning_rate: float = 0.001, weight_decay: float = 0.05, ) -> None: super().__init__() self.save_hyperparameters() self.model = SwinTransformer( image_size=image_size, patch_size=patch_size, in_channels=in_channels, num_classes=num_classes, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, moe_stages=moe_stages, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, ) self.criterion = nn.CrossEntropyLoss() self.train_acc = MulticlassAccuracy(num_classes=num_classes) self.val_acc = MulticlassAccuracy(num_classes=num_classes) self.test_acc = MulticlassAccuracy(num_classes=num_classes) def forward(self, x: Tensor) -> Tensor: return self.model(x) def _shared_step(self, batch: tuple[Tensor, Tensor], stage: str) -> Tensor: images, targets = batch logits = self(images) loss = self.criterion(logits, targets) moe_aux_loss = self.model.moe_loss() if moe_aux_loss is not None: self.log(f"{stage}_moe_aux_loss", moe_aux_loss, prog_bar=False, on_step=stage == "train", on_epoch=True) if stage == "train" and self.hparams.moe_aux_loss_weight > 0.0: loss = loss + self.hparams.moe_aux_loss_weight * moe_aux_loss moe_z_loss = self.model.moe_z_loss() if moe_z_loss is not None: self.log(f"{stage}_moe_z_loss", moe_z_loss, prog_bar=False, on_step=stage == "train", on_epoch=True) if stage == "train" and self.hparams.moe_z_loss_weight > 0.0: loss = loss + self.hparams.moe_z_loss_weight * moe_z_loss preds = torch.argmax(logits, dim=1) metric = getattr(self, f"{stage}_acc") metric(preds, targets) self.log(f"{stage}_loss", loss, prog_bar=True, on_step=stage == "train", on_epoch=True) self.log(f"{stage}_acc", metric, prog_bar=True, on_step=False, on_epoch=True) return loss def training_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: return self._shared_step(batch, "train") def validation_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: return self._shared_step(batch, "val") def test_step(self, batch: tuple[Tensor, Tensor], batch_idx: int) -> Tensor: return self._shared_step(batch, "test") def configure_optimizers(self): optimizer = torch.optim.AdamW( self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs) return {"optimizer": optimizer, "lr_scheduler": scheduler}
config.yaml
seed_everything: 42 trainer: accelerator: auto devices: auto max_epochs: 20 precision: bf16-mixed callbacks: - class_path: lightning.pytorch.callbacks.ModelCheckpoint init_args: monitor: val_acc mode: max save_top_k: 1 filename: swin-cifar10-{epoch:02d}-{val_acc:.4f} - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: logging_interval: epoch model: num_classes: 10 image_size: 32 patch_size: 4 in_channels: 3 embed_dim: 64 depths: [2, 2, 2, 2] num_heads: [2, 4, 8, 16] window_size: 4 mlp_ratio: 4.0 drop_rate: 0.0 attn_drop_rate: 0.0 drop_path_rate: 0.1 moe_stages: [0, 1] moe_num_experts: 4 moe_top_k: 2 moe_aux_loss_weight: 0.01 moe_z_loss_weight: 0.001 learning_rate: 0.001 weight_decay: 0.05 data: data_dir: data batch_size: 128 num_workers: 4 image_size: 32 val_split: 5000 seed: 42 download: true
比較結果
パラメータ数
| MoEなし | 9.1 M |
|---|---|
| MoEあり | 10.1 M |
評価損失

※オレンジ:MoEなし、青:MoEあり
評価正解率

※オレンジ:MoEなし、青:MoEあり
訓練時間


※オレンジ:MoEなし、青:MoEあり
考察
MoEありの方が、評価精度は高くなっている。
パラメータ数は、MoEありの方が、1.1倍である。
学習時間は、MoEありの方が、1.4倍になった。
Expertは、Top-2に割り当てているため、その分計算量が増えた可能性がある。
追加で、Top-1でも測定した。


※赤:Top-1
Top-1でも学習時間はほぼ変わらなかった。
MoE化に伴うゲーティング処理のオーバーヘッドが原因と考える。
また、Top-1の場合、正解率はMoEなしとほとんど変わらなかった。

まとめ
PyTorchのgrouped_mmを使用して、dropless MoEを実装し、精度や学習時間を評価した。
結果、Top-2のMoEで精度が向上することが確認できたが、学習時間はパラメータ数の比率以上に遅くなった。
小さなSwinTransformerで実験したため、ゲーティングのオーバーヘッドが大きかったと考える。
次は、TensorRTのカスタムプラグインで推論処理を実装したい。