TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Unity Barracudaでdlshogiのモデルを推論する

Unity Barracudaは、Unity用のマルチプラットフォームに対応した推論パッケージである。

ONNXのモデルをロードでき、WindowsでもAndroidでもiOSでも(WeGLでも?)同じコードでディープラーニングモデルの推論ができる。

Barracudaで、dlshogiのモデルが扱えるか使えるか試してみた。

UnityにBarracudaをインストール

UnityでBarracudaを使えるようにするには、Package Managerで、Barracudaを追加する必要がある。
Unity Repositoryから検索できないため、Add Package from git URLに、「com.unity.barracuda」と入力して追加する必要がある。
参考:unity3d - Can't find Barracuda package in Unity Registry - Stack Overflow

モデルをAssetsに追加

AssetsにModelというフォルダを作り(フォルダを作らなくても良い)、dlshogiのモデル(.onnx)をエクスプローラからドラッグ&ドロップして追加する。

追加したモデルを選択してインスペクタで確認すると警告が表示された。
f:id:TadaoYamaoka:20220216003243p:plain

output_policyの次元も間違った値になっている。

原因を調べたところ、torch.flattenを使うと、警告がでるようだ。
torch.flattenをやめて、view(-1, N)にすると、警告がでなくなった。
ただし、Nの部分は、数値リテラルで入力する必要がある。
出力の次元はPythonで計算で求めているので、変数で指定したいが数値リテラルでないとonnxにできなくなる。

モデル学習

torch.flattenをviewに変えただけなので、dlshogiのモデルのパラメータはロードできるが、今回はテスト用にモデルサイズを小さくして、floodgateの棋譜を学習させた。

推論のコード

HierarchyのルートにC#コードを追加して、以下の通りモデルをロードして推論を行うコードを記述した。
入力特徴量作成は、dlshogiのスクリプトで初期局面の入力特徴量を作成してバイナリで出力して、Assets/Resorcesに追加して、ファイルから読み込むようにしている。

注意点として、モデルはNCHW形式でも、入力Tensorは、NHWCにしないといけない点だ。
入力特徴量をファイル出力する際に、transposeで次元の変換を行った。

using System;
using System.Collections;
using System.Collections.Generic;
using Unity.Barracuda;
using UnityEngine;

public class GameManager : MonoBehaviour
{
    public NNModel modelAsset;

    // Start is called before the first frame update
    void Start()
    {
        Model runtimeModel = ModelLoader.Load(modelAsset, true);

        var worker = WorkerFactory.CreateWorker(WorkerFactory.Type.Compute, runtimeModel);

        TextAsset textAsset1 = Resources.Load<TextAsset>("input1");
        int inputSize1 = textAsset1.bytes.Length / 4;
        float[] inputData1 = new float[inputSize1 * 2]; // batchsize2
        for (int i = 0; i < inputSize1; ++i)
        {
            inputData1[i] = BitConverter.ToSingle(textAsset1.bytes, i * 4);
            inputData1[inputSize1 + i] = BitConverter.ToSingle(textAsset1.bytes, i * 4);
        }
        Tensor inputTensor1 = new Tensor(2, 9, 9, 62, inputData1);

        TextAsset textAsset2 = Resources.Load<TextAsset>("input2");
        int inputSize2 = textAsset2.bytes.Length / 4;
        float[] inputData2 = new float[inputSize2 * 2];
        for (int i = 0; i < inputSize2; ++i)
        {
            inputData2[i] = BitConverter.ToSingle(textAsset2.bytes, i * 4);
            inputData2[inputSize2 + i] = BitConverter.ToSingle(textAsset2.bytes, i * 4);
        }
        Tensor inputTensor2 = new Tensor(2, 9, 9, 57, inputData2);

        worker.Execute(new Dictionary<string, Tensor> { { "input1", inputTensor1 }, { "input2", inputTensor2 } });

        Tensor outputTensorPolicy = worker.PeekOutput("output_policy");
        Tensor outputTensorValue = worker.PeekOutput("output_value");

        int batchsize = 2;
        float[] outputPolicy = outputTensorPolicy.data.Download(new TensorShape(2187 * batchsize));
        float[] outputValue = outputTensorValue.data.Download(new TensorShape(batchsize));

        var legalLabels = new Dictionary<string, int>
        {
            {"1g1f", 5},
            {"2g2f", 14},
            {"3g3f", 23},
            {"4g4f", 32},
            {"5g5f", 41},
            {"6g6f", 50},
            {"7g7f", 59},
            {"8g8f", 68},
            {"9g9f", 77},
            {"1i1h", 7},
            {"9i9h", 79},
            {"3i3h", 25},
            {"3i4h", 115},
            {"7i6h", 214},
            {"7i7h", 61},
            {"2h1h", 331},
            {"2h3h", 268},
            {"2h4h", 277},
            {"2h5h", 286},
            {"2h6h", 295},
            {"2h7h", 304},
            {"4i3h", 187},
            {"4i4h", 34},
            {"4i5h", 124},
            {"6i5h", 205},
            {"6i6h", 52},
            {"6i7h", 142},
            {"5i4h", 196},
            {"5i5h", 43},
            {"5i6h", 133},
        };
        for (int i = 0; i < batchsize; ++i)
        {
            int offset = 2187 * i;
            float max = 0.0f;
            foreach (var kv in legalLabels)
            {
                float x = outputPolicy[offset + kv.Value];
                if (x > max)
                {
                    max = x;
                }
            }
            // オーバーフローを防止するため最大値で引く
            float sum = 0.0f;
            foreach (var kv in legalLabels)
            {
                float x = Mathf.Exp(outputPolicy[offset + kv.Value] - max);
                outputPolicy[offset + kv.Value] = x;
                sum += x;
            }
            // normalize
            foreach (var kv in legalLabels)
            {
                outputPolicy[offset + kv.Value] /= sum;
            }
            foreach (var kv in legalLabels)
            {
                Debug.Log($"{kv.Key} {outputPolicy[offset + kv.Value]}");
            }

            Debug.Log(outputValue[i]);
        }

        inputTensor1.Dispose();
        inputTensor2.Dispose();
        outputTensorPolicy.Dispose();
        outputTensorValue.Dispose();
        worker.Dispose();
    }

    // Update is called once per frame
    void Update()
    {
    }
}

モデル設定

エディタ側で、modelAssetに、Assetsに追加したモデルを設定する。

結果

実行すると、推論結果は、コンソールログに以下のように表示される。

1g1f 0.03669669
2g2f 0.3327546
3g3f 0.003403991
4g4f 0.001507061
5g5f 0.001972405
6g6f 0.0003284561
7g7f 0.4425911
8g8f 0.000323784
9g9f 0.0374275
1i1h 3.537552E-05
9i9h 1.628024E-05
3i3h 0.01439433
3i4h 0.02409641
7i6h 0.008480824
7i7h 0.009563878
2h1h 0.0001176913
2h3h 0.0001388098
2h4h 0.0004411801
2h5h 0.00565653
2h6h 0.01059533
2h7h 0.004915519
4i3h 4.113196E-05
4i4h 1.68125E-05
4i5h 0.002398505
6i5h 0.0005775181
6i6h 2.587938E-05
6i7h 0.04750725
5i4h 0.0005738597
5i5h 8.232462E-05
5i6h 0.01331916
0.5261418
(2バッチ目は省略)

dlshogiでDebugMessage=true、Softmax_Temperature=100にして、表示した結果とほぼ一致することを確認した。
浮動小数点の精度の違いで少しだけ値がずれるが誤差の範囲である。

警告の出たモデルで推論

警告が表示された元のdlshogiのモデルでは、推論結果は正しく出力されなかった。
モデルの修正は必須のようである。

まとめ

Unity Barracudaでdlshogiのモデルが推論できることを確認した。
確認したのはWindow上だが、AndroidiOSでも同じコードで実行できるはずである。

少し躓いたのは、torch.flattenが使えなかったことであるが、torch.flattenをviewに変えるだけで解決できた。

推論速度は別途測定したい。