TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

C#からLibTorchを使ってMNISTデータセットを学習する

C#からディープラーニングフレームワークを使用する方法について、以前にいくつかの方法を検討した。

課題

プロセス間通信を行う方法は、オーバーヘッドによってパフォーマンスが低下する問題がある。
C#のTensorFlowライブラリを使う方法は、推論のみであれば問題ないが学習まで行う場合は、Pythonと同等に扱えるほど完成度が高いライブラリが見つからない。

LibTorch

C#からLibTorchを使えば、これらの問題を解決できるため試してみた。
LibTorchはC++用ライブラリで、PyTorchと同等の機能がC++から利用できる。
TensorFlowにもC言語用ライブラリがあるが、プリミティブな機能しか提供されておらず、Pythonと同等に扱うことはできない。

LibTorchはC#バインディングは提供されていないため、C#から利用するには、学習・推論処理をC++で実装して共有ライブラリにして、呼び出す必要がある。
C#C++の境界は、以下のようにすれば効率がよい。

  • C#側で入力特徴量を配列で作成する。
  • 入力特徴量の配列をunsafeとfixedを使用して、ポインタで渡す。
  • C++側で、from_blobを使ってポインタからtorch::Tensorを作成する。

unsafeが必要になるが、言語間で無駄な変換がないため、高いパフォーマンスが実現できる。

欠点としては、C#から共有ライブラリを使う場合、デバッグが難しくなる。
対策として、共有ライブラリをC++のプログラムからも呼び出すテストコードを作成して、C++のデバッガを使ってデバッグするとよい。

C#からLibTorchを使ってMNISTデータセットを学習するサンプルプログラム

以下、C#からLibTorchを使ってMNISTデータセットを学習するサンプルプログラムについて説明する。
Windowsを想定して説明しているが、Linuxでも実行できる。

完全なソースコードは、GitHubで公開している。
GitHub - TadaoYamaoka/CSLibTorchMNIST

Linux用には、Google ColabでJupyter NoteBookを作成している。
https://colab.research.google.com/drive/1vFzEPjkQ8wLN_OEZODSauePHbzhAfs-s

LibTorchのダウンロード

https://pytorch.orgのQUICK START LOCALLYから、Stable/Windows/LibTorch/C++/10.0(CUDAのバージョンが10.0の場合)を選んでlibtorch-win-shared-with-deps-latest.zipをダウンロードする。
適当な場所に解凍する(以下、C:\に解凍したとして説明)。

C++で学習・推論処理を実装

C++でLibTorchを使って、学習・推論処理を実装する。
C#側で読み込んだMNISTデータセットを、引数でポインタとして受け取るようにする。

ほとんどの処理は、公式のC++のサンプルを流用している。
共有ライブラリの実装については、こちらの記事も参照。

[mnist.cpp]

#include <torch/torch.h>

#include <cstddef>
#include <cstdio>
#include <iostream>
#include <string>
#include <vector>
#include <memory>

#ifdef _MSC_VER
#define DLL_EXPORT __declspec(dllexport)
#else
#define DLL_EXPORT
#endif

extern "C"
{
  DLL_EXPORT void init();
  DLL_EXPORT void train(float *dataset, int64_t *targetset, int dataset_size);
  DLL_EXPORT void test(float* dataset, int64_t* targetset, int dataset_size);
}

// The batch size for training.
const int64_t kTrainBatchSize = 64;

// The batch size for testing.
const int64_t kTestBatchSize = 1000;

// The number of epochs to train.
const int64_t kNumberOfEpochs = 10;

// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;

struct Net : torch::nn::Module
{
  Net()
      : conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
        conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
        fc1(320, 50),
        fc2(50, 10)
  {
    register_module("conv1", conv1);
    register_module("conv2", conv2);
    register_module("conv2_drop", conv2_drop);
    register_module("fc1", fc1);
    register_module("fc2", fc2);
  }

  torch::Tensor forward(torch::Tensor x)
  {
    x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
    x = torch::relu(
        torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
    x = x.view({-1, 320});
    x = torch::relu(fc1->forward(x));
    x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
    x = fc2->forward(x);
    return torch::log_softmax(x, /*dim=*/1);
  }

  torch::nn::Conv2d conv1;
  torch::nn::Conv2d conv2;
  torch::nn::FeatureDropout conv2_drop;
  torch::nn::Linear fc1;
  torch::nn::Linear fc2;
};

std::unique_ptr<Net> model;
std::unique_ptr<torch::Device> device;
std::unique_ptr<torch::optim::Optimizer> optimizer;

DLL_EXPORT
void init()
{
  torch::DeviceType device_type;
  if (torch::cuda::is_available())
  {
    std::cout << "CUDA available! Training on GPU." << std::endl;
    device_type = torch::kCUDA;
  }
  else
  {
    std::cout << "Training on CPU." << std::endl;
    device_type = torch::kCPU;
  }
  device.reset(new torch::Device(device_type));

  model.reset(new Net());
  model->to(*device);

  optimizer.reset(new torch::optim::SGD(
      model->parameters(), torch::optim::SGDOptions(0.01).momentum(0.5)));
}

DLL_EXPORT
void train(
    float *dataset,
    int64_t *targetset,
    int dataset_size) {
  model->train();
  size_t batch_idx = 0;
  for (size_t i = 0; i <= dataset_size - kTrainBatchSize; i += kTrainBatchSize)
  {
    /*std::cout << "batch " << i << std::endl;
    for (int y = 0; y < 28; ++y) {
      for (int x = 0; x < 28; ++x) {
        std::cout << (dataset + i * 28 * 28)[y * 28 + x] << ",";
      }
      std::cout << std::endl;
    }
    std::cout << "label " << targetset[i] << std::endl;*/

    auto data = torch::from_blob(dataset + i * 28 * 28, {kTrainBatchSize, 1, 28, 28}, torch::dtype(torch::kFloat32)).to(*device);
    auto targets = torch::from_blob(targetset + i, {kTrainBatchSize}, torch::dtype(torch::kInt64)).to(*device);
    optimizer->zero_grad();
    auto output = model->forward(data);
    auto loss = torch::nll_loss(output, targets);
    AT_ASSERT(!std::isnan(loss.template item<float>()));
    loss.backward();
    optimizer->step();

    if (batch_idx++ % kLogInterval == 0)
    {
      std::printf(
          "\rTrain [%5ld/%5d] Loss: %.4f",
          batch_idx * kTrainBatchSize,
          dataset_size,
          loss.template item<float>());
    }
  }
}

DLL_EXPORT
void test(
    float *dataset,
    int64_t *targetset,
    int dataset_size) {
  torch::NoGradGuard no_grad;
  model->eval();
  double test_loss = 0;
  int32_t correct = 0;
  for (size_t i = 0; i <= dataset_size - kTestBatchSize; i += kTestBatchSize)
  {
    auto data = torch::from_blob(dataset + i * 28 * 28, {kTestBatchSize, 1, 28, 28}, torch::dtype(torch::kFloat32)).to(*device);
    auto targets = torch::from_blob(targetset + i, {kTestBatchSize}, torch::dtype(torch::kInt64)).to(*device);
    auto output = model->forward(data);
    test_loss += torch::nll_loss(
                     output,
                     targets,
                     /*weight=*/{},
                     Reduction::Sum)
                     .template item<float>();
    auto pred = output.argmax(1);
    correct += pred.eq(targets).sum().template item<int64_t>();
  }

  test_loss /= dataset_size;
  std::printf(
      "\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
      test_loss,
      static_cast<double>(correct) / dataset_size);
}
ビルド

Installing C++ Distributions of PyTorch — PyTorch master documentation
の説明の通り、CMakeを使用してビルドを行う。
共有ライブラリとしてビルドするため、CMakeLists.txtには以下のように記述する。

[CMakeLists.txt]

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(mnist)

find_package(Torch REQUIRED)

add_library(mnist SHARED mnist.cpp)
target_compile_features(mnist PUBLIC cxx_range_for)
target_link_libraries(mnist ${TORCH_LIBRARIES})

add_executable(test test.cpp mnist.cpp)
target_compile_features(test PUBLIC cxx_range_for)
target_link_libraries(test ${TORCH_LIBRARIES})

add_libraryにSHAREDを指定することで、共有ライブラリを生成できる。
最後の3行はC++からテストするためのC++プログラムの設定を記述している。

cmakeでビルド用プロジェクトを生成する。

mkdir build
cd build
cmake -G "Visual Studio 15 2017 Win64" -DCMAKE_PREFIX_PATH=C:\libtorch ..

Visual Studio用のプロジェクトが生成される。
※シンボルあり(RelWithDebInfo)でビルドする場合、.vcxprojのLibTorchのライブラリのパスが、torch-NOTFOUNDになっているので、「C:\libtorch\lib\torch.lib」に置換する。

ビルドする。

msbuild mnist.sln /t:build /p:Configuration=RelWithDebInfo;Platform="x64"
C#側の実装

MNISTデータセットを読み込んで、配列に格納して、引数にポインタを設定して共有ライブラリの関数を呼び出す処理を記述する。

MNISTデータセットは、公式のサンプルと同様に正規化する。
MNISTデータセットはビッグエンディアンなので読み込む際は注意する。
各エポックでランダムにシャッフルする(シャッフルには、MathNet.Numericsを使用した)。
入力特徴量はfloatの配列に格納する。
fixedを使って、配列がガベージコレクションによって再配置されないように固定して、C++で実装した訓練処理を呼び出す。
配列をポインタで渡すため、メソッドにunsafeを指定する。
また、プロジェクトの設定には、

<AllowUnsafeBlocks>true</AllowUnsafeBlocks>

を追加する。

[Program.cs]

using System;
using System.IO;
using System.Runtime.InteropServices;
using MathNet.Numerics;

namespace CSLibTorchMNIST
{
    class Program
    {
        [DllImport("mnist")]
        extern static void init();

        [DllImport("mnist")]
        unsafe extern static void train(float* dataset, long* targetset, int dataset_size);

        [DllImport("mnist")]
        unsafe extern static void test(float* dataset, long* targetset, int dataset_size);

        unsafe static void Main(string[] args)
        {
            init();

            // load mnist
            var imagesTrain = LoadImages(args[0]);
            var labelsTrain = LoadLabels(args[1]);
            var imagesTest = LoadImages(args[2]);
            var labelsTest = LoadLabels(args[3]);

            var imagesShuffled = new float[imagesTrain.GetLength(0), 28 * 28];
            var labelsShuffled = new long[labelsTrain.GetLength(0)];
            for (int epoch = 0; epoch < 10; ++epoch)
            {
                // shuffle
                var idxs = Combinatorics.GeneratePermutation(imagesTrain.GetLength(0));
                int i = 0;
                foreach (int idx in idxs)
                {
                    Array.Copy(imagesTrain, idx * 28 * 28, imagesShuffled, i * 28 * 28, 28 * 28);
                    labelsShuffled[i] = labelsTrain[idx];
                    ++i;
                }

                fixed(float* pImagesTrain = &imagesShuffled[0, 0])
                fixed(long* pLabelsTrain = &labelsShuffled[0])
                fixed(float* pImagesTest = &imagesTest[0, 0])
                fixed(long* pLabelsTest = &labelsTest[0])
                {
                    train(pImagesTrain, pLabelsTrain, imagesTrain.GetLength(0));
                    test(pImagesTest, pLabelsTest, imagesTest.GetLength(0));
                }
            }
        }

        static private float[,] LoadImages(string path)
        {
            using (BinaryReader reader = new BinaryReader(File.OpenRead(path)))
            {
                if (ReadInt32(reader) != 0x00000803)
                {
                    throw new FormatException();
                }

                var count = ReadInt32(reader);
                if (ReadInt32(reader) != 28)
                {
                    throw new FormatException();
                }
                if (ReadInt32(reader) != 28)
                {
                    throw new FormatException();
                }

                float[,] data = new float[count, 28 * 28];
                for (int i = 0; i < count; ++i)
                {
                    for (int j = 0; j < 28 * 28; ++j)
                    {
                        var v = reader.ReadByte();
                        // normalize
                        data[i, j] = ((v / 255f) -0.1307f) / 0.3081f;
                    }
                }

                return data;
            }
        }

        static private Int64[] LoadLabels(string path)
        {
            using (BinaryReader reader = new BinaryReader(File.OpenRead(path)))
            {
                if (ReadInt32(reader) != 0x00000801)
                {
                    throw new FormatException();
                }

                var count = ReadInt32(reader);

                Int64[] labels = new Int64[count];
                for (int i = 0; i < count; ++i)
                {
                    var v = reader.ReadByte();
                    labels[i] = v;
                }

                return labels;
            }
        }

        static private int ReadInt32(BinaryReader reader)
        {
            var v = reader.ReadBytes(4);
            Array.Reverse(v);
            return BitConverter.ToInt32(v);
        }
    }
}
MNISTデータセットのダウンロード

http://yann.lecun.com/exdb/mnist/からtrain-images-idx3-ubyte.gz、train-labels-idx1-ubyte.gz、t10k-images-idx3-ubyte.gz、t10k-labels-idx1-ubyte.gzの4ファイルをダウンロードして、適当な場所に解凍する。

実行

共有ライブラリ(.dll)を実行ファイルのあるフォルダにコピーするか、環境変数PATHに追加する必要がある。
Visual StuidoやVS Codeから実行する場合は、デバッグの設定に環境変数PATHを追加するとよい。

MNISTデータセットのファイルを引数にして実行する。

dotnet run train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
実行結果

成功すれば、以下のように表示される。

CUDA available! Training on GPU.
Train [59584/60000] Loss: 0.6364
Test set: Average loss: 0.1992 | Accuracy: 0.939
Train [59584/60000] Loss: 0.6344
Test set: Average loss: 0.1255 | Accuracy: 0.961
Train [59584/60000] Loss: 0.3608
Test set: Average loss: 0.0965 | Accuracy: 0.970
Train [59584/60000] Loss: 0.1603
Test set: Average loss: 0.0817 | Accuracy: 0.975
Train [59584/60000] Loss: 0.2735
Test set: Average loss: 0.0735 | Accuracy: 0.978
Train [59584/60000] Loss: 0.0666
Test set: Average loss: 0.0664 | Accuracy: 0.980
Train [59584/60000] Loss: 0.1049
Test set: Average loss: 0.0589 | Accuracy: 0.983
Train [59584/60000] Loss: 0.0999
Test set: Average loss: 0.0564 | Accuracy: 0.983
Train [59584/60000] Loss: 0.1263
Test set: Average loss: 0.0543 | Accuracy: 0.984
Train [59584/60000] Loss: 0.0833
Test set: Average loss: 0.0497 | Accuracy: 0.984

はまりポイント

上記のコードを実行できるようになるまでにかなり苦労した。
LibTorchはコードにミスがあると、メモリアクセス違反になるため、Pythonだとすぐにわかるエラーでも原因にたどり着くのに苦労した。

以下の点で、特に解析に時間がかかった。

  • from_blobで、deviceに直接GPUを指定できない。CPUのメモリ上にtorch::Tensorを作成してから、to()でGPUに転送する必要がある。
  • shapeを間違えると、例外が発生する。PyTorchのデフォルトは、channels_firstなので間違わないようにする。
  • nll_lossのtargetのshapeは、1次元。{batch_size, 1}と2次元にすると例外が発生する。