Unity Barracudaは、Unity用のマルチプラットフォームに対応した推論パッケージである。
ONNXのモデルをロードでき、WindowsでもAndroidでもiOSでも(WeGLでも?)同じコードでディープラーニングモデルの推論ができる。
Barracudaで、dlshogiのモデルが扱えるか使えるか試してみた。
UnityにBarracudaをインストール
UnityでBarracudaを使えるようにするには、Package Managerで、Barracudaを追加する必要がある。
Unity Repositoryから検索できないため、Add Package from git URLに、「com.unity.barracuda」と入力して追加する必要がある。
参考:unity3d - Can't find Barracuda package in Unity Registry - Stack Overflow
モデルをAssetsに追加
AssetsにModelというフォルダを作り(フォルダを作らなくても良い)、dlshogiのモデル(.onnx)をエクスプローラからドラッグ&ドロップして追加する。
追加したモデルを選択してインスペクタで確認すると警告が表示された。
output_policyの次元も間違った値になっている。
原因を調べたところ、torch.flattenを使うと、警告がでるようだ。
torch.flattenをやめて、view(-1, N)にすると、警告がでなくなった。
ただし、Nの部分は、数値リテラルで入力する必要がある。
出力の次元はPythonで計算で求めているので、変数で指定したいが数値リテラルでないとonnxにできなくなる。
モデル学習
torch.flattenをviewに変えただけなので、dlshogiのモデルのパラメータはロードできるが、今回はテスト用にモデルサイズを小さくして、floodgateの棋譜を学習させた。
推論のコード
HierarchyのルートにC#コードを追加して、以下の通りモデルをロードして推論を行うコードを記述した。
入力特徴量作成は、dlshogiのスクリプトで初期局面の入力特徴量を作成してバイナリで出力して、Assets/Resorcesに追加して、ファイルから読み込むようにしている。
注意点として、モデルはNCHW形式でも、入力Tensorは、NHWCにしないといけない点だ。
入力特徴量をファイル出力する際に、transposeで次元の変換を行った。
using System; using System.Collections; using System.Collections.Generic; using Unity.Barracuda; using UnityEngine; public class GameManager : MonoBehaviour { public NNModel modelAsset; // Start is called before the first frame update void Start() { Model runtimeModel = ModelLoader.Load(modelAsset, true); var worker = WorkerFactory.CreateWorker(WorkerFactory.Type.Compute, runtimeModel); TextAsset textAsset1 = Resources.Load<TextAsset>("input1"); int inputSize1 = textAsset1.bytes.Length / 4; float[] inputData1 = new float[inputSize1 * 2]; // batchsize2 for (int i = 0; i < inputSize1; ++i) { inputData1[i] = BitConverter.ToSingle(textAsset1.bytes, i * 4); inputData1[inputSize1 + i] = BitConverter.ToSingle(textAsset1.bytes, i * 4); } Tensor inputTensor1 = new Tensor(2, 9, 9, 62, inputData1); TextAsset textAsset2 = Resources.Load<TextAsset>("input2"); int inputSize2 = textAsset2.bytes.Length / 4; float[] inputData2 = new float[inputSize2 * 2]; for (int i = 0; i < inputSize2; ++i) { inputData2[i] = BitConverter.ToSingle(textAsset2.bytes, i * 4); inputData2[inputSize2 + i] = BitConverter.ToSingle(textAsset2.bytes, i * 4); } Tensor inputTensor2 = new Tensor(2, 9, 9, 57, inputData2); worker.Execute(new Dictionary<string, Tensor> { { "input1", inputTensor1 }, { "input2", inputTensor2 } }); Tensor outputTensorPolicy = worker.PeekOutput("output_policy"); Tensor outputTensorValue = worker.PeekOutput("output_value"); int batchsize = 2; float[] outputPolicy = outputTensorPolicy.data.Download(new TensorShape(2187 * batchsize)); float[] outputValue = outputTensorValue.data.Download(new TensorShape(batchsize)); var legalLabels = new Dictionary<string, int> { {"1g1f", 5}, {"2g2f", 14}, {"3g3f", 23}, {"4g4f", 32}, {"5g5f", 41}, {"6g6f", 50}, {"7g7f", 59}, {"8g8f", 68}, {"9g9f", 77}, {"1i1h", 7}, {"9i9h", 79}, {"3i3h", 25}, {"3i4h", 115}, {"7i6h", 214}, {"7i7h", 61}, {"2h1h", 331}, {"2h3h", 268}, {"2h4h", 277}, {"2h5h", 286}, {"2h6h", 295}, {"2h7h", 304}, {"4i3h", 187}, {"4i4h", 34}, {"4i5h", 124}, {"6i5h", 205}, {"6i6h", 52}, {"6i7h", 142}, {"5i4h", 196}, {"5i5h", 43}, {"5i6h", 133}, }; for (int i = 0; i < batchsize; ++i) { int offset = 2187 * i; float max = 0.0f; foreach (var kv in legalLabels) { float x = outputPolicy[offset + kv.Value]; if (x > max) { max = x; } } // オーバーフローを防止するため最大値で引く float sum = 0.0f; foreach (var kv in legalLabels) { float x = Mathf.Exp(outputPolicy[offset + kv.Value] - max); outputPolicy[offset + kv.Value] = x; sum += x; } // normalize foreach (var kv in legalLabels) { outputPolicy[offset + kv.Value] /= sum; } foreach (var kv in legalLabels) { Debug.Log($"{kv.Key} {outputPolicy[offset + kv.Value]}"); } Debug.Log(outputValue[i]); } inputTensor1.Dispose(); inputTensor2.Dispose(); outputTensorPolicy.Dispose(); outputTensorValue.Dispose(); worker.Dispose(); } // Update is called once per frame void Update() { } }
モデル設定
エディタ側で、modelAssetに、Assetsに追加したモデルを設定する。
結果
実行すると、推論結果は、コンソールログに以下のように表示される。
1g1f 0.03669669 2g2f 0.3327546 3g3f 0.003403991 4g4f 0.001507061 5g5f 0.001972405 6g6f 0.0003284561 7g7f 0.4425911 8g8f 0.000323784 9g9f 0.0374275 1i1h 3.537552E-05 9i9h 1.628024E-05 3i3h 0.01439433 3i4h 0.02409641 7i6h 0.008480824 7i7h 0.009563878 2h1h 0.0001176913 2h3h 0.0001388098 2h4h 0.0004411801 2h5h 0.00565653 2h6h 0.01059533 2h7h 0.004915519 4i3h 4.113196E-05 4i4h 1.68125E-05 4i5h 0.002398505 6i5h 0.0005775181 6i6h 2.587938E-05 6i7h 0.04750725 5i4h 0.0005738597 5i5h 8.232462E-05 5i6h 0.01331916 0.5261418 (2バッチ目は省略)
dlshogiでDebugMessage=true、Softmax_Temperature=100にして、表示した結果とほぼ一致することを確認した。
浮動小数点の精度の違いで少しだけ値がずれるが誤差の範囲である。
警告の出たモデルで推論
警告が表示された元のdlshogiのモデルでは、推論結果は正しく出力されなかった。
モデルの修正は必須のようである。