前回実装した、Chainerで学習したモデルを使用してcuDNNで推論するコードを、Residual Network(ResNet)構成にした。
推論時には、テンソルの加算を行うだけで特に難しいことはない。
ネットワーク定義(Chainer)
ResNetは1ブロックのみで、ブロック内の畳み込み層は1層のみとした。
nn.py
from chainer import Chain
import chainer.functions as F
import chainer.links as L
k = 16
fcl = 256
class NN(Chain):
def __init__(self):
super(NN, self).__init__()
with self.init_scope():
self.conv1 = L.Convolution2D(in_channels = 1, out_channels = k, ksize = 3, pad = 1)
self.conv2 = L.Convolution2D(in_channels = k, out_channels = k, ksize = 3, pad = 1)
self.conv3 = L.Convolution2D(in_channels = k, out_channels = k, ksize = 3, pad = 1)
self.l4 = L.Linear(7*7*k, fcl)
self.l5 = L.Linear(fcl, 10)
self.bn1 = L.BatchNormalization(k)
self.bn2 = L.BatchNormalization(k)
def __call__(self, x):
h = self.conv1(F.reshape(x, (len(x), 1, 28, 28)))
h = self.bn1(h)
h1 = F.relu(h)
h = self.conv2(h1)
h = self.bn2(h)
h = h + h1
h = F.max_pooling_2d(F.relu(h), 2)
h = self.conv3(h)
h = F.max_pooling_2d(F.relu(h), 2)
h = F.relu(self.l4(h))
h = F.dropout(h, ratio=0.4)
return self.l5(h)
cuDNNを使用してテンソルの加算を行うクラスを追加した。
layers.h
class Add {
public:
void operator() (cudnnHandle_t handle, cudnnTensorDescriptor_t xDesc, float* x, float* y) {
const float alpha = 1.0f;
const float beta = 1.0f;
checkCUDNN(cudnnAddTensor(handle, &alpha, xDesc, x, &beta, xDesc, y));
}
};
ネットワーク定義(C++)
ResNet構成のネットワーク定義を以下の通り実装した。
nn.h
class NN {
public:
typedef float x_t[batch_size][1][IMAGE_H][IMAGE_W];
typedef float y_t[batch_size][10];
NN();
~NN();
void load_model(const char* filename);
void foward(x_t x, y_t y);
private:
static CudnnHandle cudnnHandle;
static CublasHandle cublasHandle;
static const int k = 16;
static const int fcl = 256;
ConvLayer<k, 1, 3, 1> conv1;
Bias<k, 1, 1> bias1;
ConvLayer<k, k, 3, 1> conv2;
Bias<k, 1, 1> bias2;
ConvLayer<k, k, 3, 1> conv3;
Bias<k, 1, 1> bias3;
Linear<7 * 7 * k, fcl> l4;
Bias<fcl, 1, 1> bias4;
Linear<fcl, 10> l5;
Bias<10, 1, 1> bias5;
BatchNormalization<k> bn1;
BatchNormalization<k> bn2;
ReLU relu;
MaxPooling2D<2> max_pooling_2d;
Add add;
CudnnTensorDescriptor xDesc;
CudnnTensorDescriptor h1Desc;
CudnnTensorDescriptor h2Desc;
CudnnTensorDescriptor h3Desc;
CudnnTensorDescriptor h4Desc;
CudnnTensorDescriptor h5Desc;
CudnnTensorDescriptor h6Desc;
CudnnTensorDescriptor yDesc;
float* x_dev;
float* h1_dev;
float* h1_bn_dev;
float* h2_dev;
float* h2_bn_dev;
float* h3_dev;
float* h4_dev;
float* h5_dev;
float* h6_dev;
float* y_dev;
};
nn.cpp
NN::NN()
{
conv1.get_xdesc(xDesc, batch_size, IMAGE_H, IMAGE_W);
const int h1_h = conv1.get_yh(IMAGE_H);
const int h1_w = conv1.get_yw(IMAGE_W);
conv1.get_ydesc(h1Desc, batch_size, h1_h, h1_w);
const int h3_h = max_pooling_2d.get_yh(h1_h);
const int h3_w = max_pooling_2d.get_yw(h1_w);
conv3.get_xdesc(h3Desc, batch_size, h3_h, h3_w);
const int h4_h = conv3.get_yh(h3_h);
const int h4_w = conv3.get_yw(h3_w);
conv3.get_ydesc(h4Desc, batch_size, h4_h, h4_w);
const int h5_h = max_pooling_2d.get_yh(h4_h);
const int h5_w = max_pooling_2d.get_yw(h4_w);
max_pooling_2d.get_desc(h5Desc, batch_size, k, h5_h, h5_w);
l4.get_ydesc(h6Desc, batch_size);
l5.get_ydesc(yDesc, batch_size);
conv1.init(cudnnHandle, xDesc, h1Desc);
conv2.init(cudnnHandle, h1Desc, h1Desc);
conv3.init(cudnnHandle, h3Desc, h4Desc);
checkCudaErrors(cudaMalloc((void**)&x_dev, conv1.get_xsize(batch_size, IMAGE_H, IMAGE_W)));
checkCudaErrors(cudaMalloc((void**)&h1_dev, conv1.get_ysize(batch_size, h1_h, h1_w)));
checkCudaErrors(cudaMalloc((void**)&h1_bn_dev, conv1.get_ysize(batch_size, h1_h, h1_w)));
checkCudaErrors(cudaMalloc((void**)&h2_dev, conv2.get_ysize(batch_size, h1_h, h1_w)));
checkCudaErrors(cudaMalloc((void**)&h2_bn_dev, conv2.get_ysize(batch_size, h1_h, h1_w)));
checkCudaErrors(cudaMalloc((void**)&h3_dev, conv3.get_xsize(batch_size, h3_h, h3_w)));
checkCudaErrors(cudaMalloc((void**)&h4_dev, conv3.get_ysize(batch_size, h4_h, h4_w)));
checkCudaErrors(cudaMalloc((void**)&h5_dev, batch_size * k * h5_h * h5_w * sizeof(float)));
checkCudaErrors(cudaMalloc((void**)&h6_dev, batch_size * fcl * sizeof(float)));
checkCudaErrors(cudaMalloc((void**)&y_dev, batch_size * 10 * sizeof(float)));
}
NN::~NN() {
checkCudaErrors(cudaFree(x_dev));
checkCudaErrors(cudaFree(h1_dev));
checkCudaErrors(cudaFree(h1_bn_dev));
checkCudaErrors(cudaFree(h2_dev));
checkCudaErrors(cudaFree(h2_bn_dev));
checkCudaErrors(cudaFree(h3_dev));
checkCudaErrors(cudaFree(h4_dev));
checkCudaErrors(cudaFree(h5_dev));
checkCudaErrors(cudaFree(h6_dev));
checkCudaErrors(cudaFree(y_dev));
}
Chainerで学習したモデルの読み込み(C++)
前回までの畳み込みの読み込みと同じでResNet構成によって特別な処理はない。
void NN::load_model(const char* filepath)
{
ParamMap params;
load_npz(filepath, params);
conv1.set_param(params["conv1/W.npy"].data);
bias1.set_bias(params["conv1/b.npy"].data);
conv2.set_param(params["conv2/W.npy"].data);
bias2.set_bias(params["conv2/b.npy"].data);
conv3.set_param(params["conv3/W.npy"].data);
bias3.set_bias(params["conv3/b.npy"].data);
l4.set_param(params["l4/W.npy"].data);
bias4.set_bias(params["l4/b.npy"].data);
l5.set_param(params["l5/W.npy"].data);
bias5.set_bias(params["l5/b.npy"].data);
bn1.set_param(params["bn1/gamma.npy"].data, params["bn1/beta.npy"].data, params["bn1/avg_mean.npy"].data, params["bn1/avg_var.npy"].data);
bn2.set_param(params["bn2/gamma.npy"].data, params["bn2/beta.npy"].data, params["bn2/avg_mean.npy"].data, params["bn2/avg_var.npy"].data);
}
ResNetブロックの処理とテンソルの加算処理を追加した。
void NN::foward(x_t x, y_t y)
{
checkCudaErrors(cudaMemcpy(x_dev, x, sizeof(x_t), cudaMemcpyHostToDevice));
conv1(cudnnHandle, xDesc, x_dev, h1Desc, h1_dev);
bias1(cudnnHandle, h1Desc, h1_dev);
bn1(cudnnHandle, h1Desc, h1_dev, h1_bn_dev);
relu(cudnnHandle, h1Desc, h1_bn_dev);
conv2(cudnnHandle, h1Desc, h1_bn_dev, h1Desc, h2_dev);
bias2(cudnnHandle, h1Desc, h2_dev);
bn2(cudnnHandle, h1Desc, h2_dev, h2_bn_dev);
add(cudnnHandle, h1Desc, h1_bn_dev, h2_bn_dev);
relu(cudnnHandle, h1Desc, h2_bn_dev);
max_pooling_2d(cudnnHandle, h1Desc, h2_bn_dev, h3Desc, h3_dev);
conv3(cudnnHandle, h3Desc, h3_dev, h4Desc, h4_dev);
bias3(cudnnHandle, h4Desc, h4_dev);
relu(cudnnHandle, h4Desc, h4_dev);
max_pooling_2d(cudnnHandle, h4Desc, h4_dev, h5Desc, h5_dev);
l4(cublasHandle, batch_size, h5_dev, h6_dev);
bias4(cudnnHandle, h6Desc, h6_dev);
relu(cudnnHandle, h6Desc, h6_dev);
l5(cublasHandle, batch_size, h6_dev, y_dev);
bias5(cudnnHandle, yDesc, y_dev);
checkCudaErrors(cudaMemcpy(y, y_dev, sizeof(y_t), cudaMemcpyDeviceToHost));
}
テンソル加算の実際の処理は、layers.hのAddクラスのoperator()に実装している。
1層目の畳み込み層の活性化関数の出力h1_bn_devをResNetブロックの入力にして、ResNetブロックの出力h2_bn_devをh1_bn_devに加算している。
ResNetブロックの入力と出力のテンソルのサイズは変わらないため、テンソルディスクリプタは1層目の畳み込み層と共通のh1Descを使用している。