TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dlshogiモデルの枝刈りを試す

前回、深層強化学習において、モデルの枝刈りによりスケーリングが可能であることを示した論文を紹介した。

dlshogiの強化学習でもモデルの枝刈りが効果があるか試したいと考えているが、まずはモデルの枝刈りのみを行って、精度と探索速度にどう影響するかを調べてみる。

モデルの枝刈り

モデルの枝刈り(pruning)は、休止状態のニューロンを削除することで、モデルサイズを削減し、推論速度を向上させる手法である。

PyTorchには枝刈りの方法として、unstructuredとstructuredの2種類の方法が用意されている。

unstructuredは、層のパラメータ全体から一定の割合でパラメータを削除する。

structuredは、一定の割合で畳み込みのチャンネル全体を削除する。

パラメータの削除は、実際には削除は行っておらずパラメータを0にすることで実装されている。

値が0のパラメータも演算は必要なため、推論速度を向上するには、ハードウェア的な支援が必要になる。

NVIDIAGPUでは、Ampere以降で、structuredの枝刈りに対応している。
NVIDIA Ampere アーキテクチャと TensorRT を使用してスパース性で推論を高速化する - NVIDIA 技術ブログ

ただし、連続する4つのチャンネルのうち2つが0である必要があるため、枝刈り率50%以上でないと効果がない。

先述の論文では、枝刈りは5%でパフォーマンスが上がることが報告されているため、枝刈りしても推論速度の向上は期待できない。

PyTorchでの実装

チェックポイントを読み込んで、モデルの各層のモジュールをnamed_modules()で取得して、畳み込み層の場合、weightに対して、prune.ln_structuredを適用する。

コードは、以下の通り簡単に記述できる。

import argparse

import torch

from dlshogi.common import *
from dlshogi.network.policy_value_network import policy_value_network
from torch.nn.utils import prune

parser = argparse.ArgumentParser()
parser.add_argument("checkpoint")
parser.add_argument("output_checkpoint")
parser.add_argument("--network", default="resnet10_swish", help="network type")
parser.add_argument("--amount", type=float, default=0.05)
args = parser.parse_args()

device = torch.device("cpu")
model = policy_value_network(args.network)

checkpoint = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(checkpoint["model"])

# prune
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.ln_structured(module, name="weight", n=1, dim=0, amount=args.amount)
        prune.remove(module, "weight")

checkpoint["model"] = model.state_dict()
torch.save(checkpoint, args.output_checkpoint)

実験条件

dlshogiの30ブロック384フィルタのモデルを使用する。
枝刈り率は、論文と同じ5%の他に、推論速度の違いを見るため50%、90%でも試す。
テストデータには、floodgateのR3800以上の棋譜を使用する。
探索速度は、初期局面で10秒思考したときのNPSを測定する。
GPUは4090 1枚を使用する。

精度

通常は枝刈り後にファインチューニングを行うが、ファインチューニング前の精度を確かめた。

枝刈り 方策損失 価値損失 方策正解率 価値正解率
なし 1.3057 0.4402 0.5585 0.7765
5% 4.484 0.644 0.3173 0.6775
50% 7.3963 0.8144 0.0090 0.5878
90% 9.6545 1.5117 0.0001 0.4989


枝刈り5%でも、方策の正解率は24%、価値の正解率は9.9%下がっており、精度はかなり落ちる。
枝刈り50%、90%は、方策の正解率が1%未満であり、使い物にならない精度である。

推論速度

ソース修正

TensorRTで、枝刈りしたモデルの推論速度を向上させるには、オプションの設定が必要になる。

公式ドキュメントのSparsityの項目に説明がある。
Developer Guide :: NVIDIA Deep Learning TensorRT Documentation

ソースコードに以下の行を追加した。

config->setFlag(BuilderFlag::kSPARSE_WEIGHTS);
比較結果
枝刈り NPS
なし 20255
5% 19836
50% 23805
90% 26020

枝刈り5%の場合、枝刈りなしとほぼ同じの探索速度である。
枝刈り50%の場合17.5%、枝刈り90%の場合28.5%向上する。

ハードウェア支援が効果があることが確認できた。

kSPARSE_WEIGHTSオプションなしの場合

kSPARSE_WEIGHTSオプションなしの場合でも確認した。

枝刈り NPS
なし 20255
5% 19698
50% 20898
90% 20470

どの条件もほぼ同じ探索速度である。

まとめ

dlshogiモデルで枝刈りを試した。
5%の枝刈りでも精度が大きく落ちることが分かった。
50%の枝刈りでは方策正解率が1%未満になり使い物にならないほど精度が下がった。

また、TensorRTでオプションを有効にすると、枝刈り50%以上では探索速度が向上することが確認できた。
しかし、枝刈り5%では探索速度は変わらなかった。
論文で報告されている枝刈り率は5%のため、探索速度の向上は期待できない。

ファインチューニングを行った後に精度が回復するかは別途確認したい。