TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

Google ColabでAlphaZero Shogiのモデルを教師あり学習する

Google ColabでAlphaZero Shogiのモデルを論文に通り定義して、テストのために教師ありで学習してみました。
TPUでも学習して学習時間の比較もしてみました。

教師データには、elmoで生成したhcpe形式のデータを使用し、入力特徴量と正解ラベルの加工には、先日作成したPythonの将棋ライブラリ(cshogi)を使用しました。

モデルの定義

AlphaZeroの論文の通り、ResNetで、policyとvalueの2つの出力を持つネットワークを定義します。
ブロック数とフィルタ数、全結合層のユニット数はパラメータにしています。
policyの出力の畳み込みの後の位置ごとのバイアスはカスタムレイヤーを定義しています。

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Input, Dense, Conv2D, BatchNormalization, Activation, Flatten, Add

class Bias(Layer):
    def __init__(self, **kwargs):
        super(Bias, self).__init__(**kwargs)

    def build(self, input_shape):
        self.b = self.add_weight(name='b',
                                 shape=(input_shape[1:]),
                                 initializer='zeros',
                                 trainable=True)
        super(Bias, self).build(input_shape)

    def call(self, x):
        return x + self.b

def conv_layer(inputs,
               filters,
               activation='relu',
               use_bias=True):

    x =  Conv2D(filters,
                kernel_size=3,
                strides=1,
                padding='same',
                data_format='channels_first',
                kernel_initializer='he_normal',
                use_bias=use_bias)(inputs)
    x = BatchNormalization(axis=1)(x)
    if activation is not None:
        x = Activation(activation)(x)
    return x

def ResNet(input_planes=45,
           res_blocks=20,
           filters=256,
           fcl_units=256,
           policy_planes=139):

    inputs = Input(shape=(input_planes, 9, 9))
    x = conv_layer(inputs, filters=filters, use_bias=False)

    for res_block in range(res_blocks):
        # bottleneck residual unit
        y = conv_layer(x, filters=filters, use_bias=False)
        y = conv_layer(y, filters=filters, use_bias=False, activation=None)
        x = Add()([x, y])
        x = Activation('relu')(x)

    # Add policy output
    policy_x = Conv2D(policy_planes,
                      kernel_size=1,
                      strides=1,
                      padding='same',
                      data_format='channels_first',
                      kernel_initializer='he_normal',
                      use_bias=False)(x)
    policy_x = Flatten()(policy_x)
    policy_y = Bias(name='policy')(policy_x)

    # Add value output
    value_x = conv_layer(x, filters=1)
    value_x = Flatten()(value_x)
    value_x = Dense(fcl_units,
                    activation='relu',
                    kernel_initializer='he_normal')(value_x)
    value_y = Dense(1,
                    activation='tanh',
                    kernel_initializer='he_normal',
                    name='value')(value_x)

    # Instantiate model
    return Model(inputs=inputs, outputs=[policy_y, value_y])

教師データ

教師データには、elmoを使用して、hcpe形式で生成したデータを使用します。
hcpe形式のデータには、

hcp ハフマン符号で圧縮した局面
bestMove16 指し手
gameResult 勝敗結果

が含まれます。

訓練データの局面数は、100万局面です。テストデータは10万局面です。

入力特徴量と正解データ

AlphaZeroの論文の通り、局面を入力特徴量に変換します。
ただし、hcpeには、局面の繰り返し数と手数が含まれないため、値を0にしています。
また、局面の繰り返し数は、単に繰り返しがあるかどうかだけ判定し、特徴を1面だけとしました。(今回は常に0)

持ち駒の数(prisoner count)は0から1の範囲の正規化を行わず、そのまま整数の値を設定しています。

policyの正解ラベルは、sparse_categorical_crossentropyを使用するので、one-hotベクトル化を行わず、整数で表しています。
valueの正解となる勝敗は、出力の活性関数にtanhを使うので、-1(負け),0(引き分け),1(勝ち)で表します。

履歴局面はなしです。

cshogiを使うと、以下のように記述できます。

board = Board()
def mini_batch(hcpes):
    features = np.zeros((len(hcpes), 45, 81), dtype=np.float32)
    action_labels = np.empty(len(hcpes), dtype=np.int)
    game_outcomes = np.empty(len(hcpes), dtype=np.float32)

    for i, hcpe in enumerate(hcpes):
        # Input features
        #   P1 piece 14
        #   P2 piece 14
        #   Repetitions 1
        #   P1 prisoner count 7
        #   P2 prisoner count 7
        #   Colour 1
        #   Total move count 1
        board.set_hcp(hcpe['hcp'])
        # piece
        pieces = board.pieces
        for sq in SQUARES:
            piece = pieces[sq]
            if piece != NONE:
                if piece >= WPAWN:
                    piece = piece - 2
                features[i][piece - 1][sq] = 1
        # repetition
        if board.is_draw() == REPETITION_DRAW:
            features[i][28].fill(1)
        # prisoner count
        pieces_in_hand = board.pieces_in_hand
        for c, hands in enumerate(pieces_in_hand):
            for hp, num in enumerate(hands):
                features[i][29 + c * 7 + hp].fill(num)
        # Colour
        if board.turn == WHITE:
            features[i][43].fill(1)
        # Total move count
        # not implement for learning from hcpe

        # Action representation
        #   Queen moves 64
        #   Knight moves 2
        #   Promoting queen moves 64
        #   Promoting knight moves 2
        #   Drop 7
        move = hcpe['bestMove16']
        if not move_is_drop(move):
            from_sq = move_from(move)
            to_sq = move_to(move)
            from_file, from_rank = divmod(from_sq, 9)
            to_file, to_rank = divmod(to_sq, 9)
            diff_file = to_file - from_file
            diff_rank = to_rank - from_rank
            if abs(diff_file) != 1 or abs(diff_rank) != 2:
                # Queen moves
                if diff_file < 0:
                    if diff_rank < 0:
                        move_dd = -diff_file - 1
                    elif diff_rank > 0:
                        move_dd = 8 - diff_file - 1
                    else:
                        move_dd = 16 - diff_file - 1
                elif diff_file > 0:
                    if diff_rank < 0:
                        move_dd = 24 + diff_file - 1
                    elif diff_rank > 0:
                        move_dd = 32 + diff_file - 1
                    else:
                        move_dd = 40 + diff_file - 1
                else:
                    if diff_rank < 0:
                        move_dd = 48 - diff_rank - 1
                    else:
                        move_dd = 56 + diff_rank - 1
            else:
                # Knight moves
                if diff_file < 0:
                    move_dd = 64
                else:
                    move_dd = 65

            promotion = 1 if move_is_promotion(move) else 0

            action_labels[i] = (promotion * 66 + move_dd) * 81 + from_sq
        else:
            # drop
            to_sq = move_to(move)
            hp = move_drop_hand_piece(move)
            action_labels[i] = (132 + hp) * 81 + to_sq

        # game outcome
        #   z: −1 for a loss, 0 for a draw, and +1 for a win
        gameResult = hcpe['gameResult']
        if board.turn == BLACK:
            if gameResult == BLACK_WIN:
                game_outcomes[i] = 1
            if gameResult == WHITE_WIN:
                game_outcomes[i] = -1
            else:
                game_outcomes[i] = 0
        else:
            if gameResult == BLACK_WIN:
                game_outcomes[i] = -1
            if gameResult == WHITE_WIN:
                game_outcomes[i] = 1
            else:
                game_outcomes[i] = 0

    return (features.reshape((len(hcpes), 45, 9, 9)), { 'policy': action_labels, 'value': game_outcomes })

Googleドライブのマウント

学習データは、Googleドライブにアップロードして、Google Colabからマウントしてアクセスします。

from google.colab import drive
drive.mount('/content/drive')

Googleドライブのファイルは、「drive/My Drive」からアクセスできます。

学習

hcpeデータはnumpyのfromfileでデータ形式を指定して読み込みます。
hcpeのデータ形式HuffmanCodedPosAndEvalはcshogiで定義しています。
入力特徴量はデータ量が大きいため、fit_generatorでミニバッチごとに入力特徴量と正解データを作成しながら学習します。

モデルのサイズは、10ブロック、192フィルタとしています。

import tensorflow as tf
from tensorflow.keras.optimizers import SGD
import numpy as np
import os
from cshogi import *

train_hcpe_path = 'drive/My Drive/hcpe/elmo_teacher_depth8_uniq-001-01'
test_hcpe_path = 'drive/My Drive/hcpe/elmo_teacher_depth8_uniq-test-01'
batchsize = 256
epochs = 1
weight_decay = 1e-4
use_tpu = True

model = ResNet(res_blocks=10, filters=192)

train_hcpes = np.fromfile(train_hcpe_path, dtype=HuffmanCodedPosAndEval)
test_hcpes = np.fromfile(test_hcpe_path, dtype=HuffmanCodedPosAndEval)

def datagen(hcpes, batchsize):
    while True:
        np.random.shuffle(hcpes)
        for i in range(0, len(hcpes) - batchsize, batchsize):
            yield mini_batch(hcpes[i:i+batchsize])

def categorical_crossentropy(y_true, y_pred):
    return tf.keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

def categorical_accuracy(y_true, y_pred):
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, tf.nn.softmax(y_pred))

def binary_accuracy(y_true, y_pred):
    return tf.keras.metrics.binary_accuracy(tf.keras.backend.round((y_true + 1) / 2), y_pred, threshold=0)

# add weight decay
for layer in model.layers:
    if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Dense):
        layer.add_loss(tf.keras.regularizers.l2(weight_decay)(layer.kernel))

model.compile(optimizer=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),
              loss={'policy': categorical_crossentropy, 'value': 'mse'},
              metrics={'policy': categorical_accuracy, 'value': binary_accuracy})

# TPU
if use_tpu:
    model = tf.contrib.tpu.keras_to_tpu_model(
        model,
        strategy=tf.contrib.tpu.TPUDistributionStrategy(
            tf.contrib.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
        )
    )

model.fit_generator(datagen(train_hcpes, batchsize), len(train_hcpes) // batchsize,
                    epochs=epochs,
                    validation_data=datagen(test_hcpes, batchsize), validation_steps=len(test_hcpes) // batchsize)

学習結果

TPUでの学習結果
INFO:tensorflow:Querying Tensorflow master (grpc://10.26.203.146:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 2190986923747320621)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 12416730924266005929)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 13813975489927709994)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 8536274068396247395)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 3398334431816235909)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 590228271416283825)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 15798692355657786920)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 1954696542070558698)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 898663154261538184)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 17179869184, 4503259854332690989)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 17179869184, 10060441942931197921)
WARNING:tensorflow:tpu_model (from tensorflow.contrib.tpu.python.tpu.keras_support) is experimental and may change or be removed at any time, and without warning.
INFO:tensorflow:New input shapes; (re-)compiling: mode=train (# of cores 8), [TensorSpec(shape=(32,), dtype=tf.int32, name='core_id_40'), TensorSpec(shape=(32, 9, 9, 45), dtype=tf.float32, name='input_4_10'), TensorSpec(shape=(32, 1), dtype=tf.float32, name='policy_target_120'), TensorSpec(shape=(32, 1), dtype=tf.float32, name='value_target_120')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for input_4
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 38.74543237686157 secs
INFO:tensorflow:Setting weights on TPU model.
3905/3906 [============================>.] - ETA: 0s - loss: 4.4481 - policy_loss: 1.7086 - value_loss: 1.7085 - policy_categorical_accuracy: 1.7085 - value_binary_accuracy: 1.7084INFO:tensorflow:New input shapes; (re-)compiling: mode=eval (# of cores 8), [TensorSpec(shape=(32,), dtype=tf.int32, name='core_id_50'), TensorSpec(shape=(32, 9, 9, 45), dtype=tf.float32, name='input_4_10'), TensorSpec(shape=(32, 1), dtype=tf.float32, name='policy_target_120'), TensorSpec(shape=(32, 1), dtype=tf.float32, name='value_target_120')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for input_4
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 17.765841245651245 secs
390/390 [==============================] - 43s 110ms/step - loss: 3.7269 - policy_loss: 1.4889 - value_loss: 1.4889 - policy_categorical_accuracy: 1.4878 - value_binary_accuracy: 1.4872
3906/3906 [==============================] - 644s 165ms/step - loss: 4.4478 - policy_loss: 1.7085 - value_loss: 1.7084 - policy_categorical_accuracy: 1.7084 - value_binary_accuracy: 1.7083 - val_loss: 3.7269 - val_policy_loss: 1.4889 - val_value_loss: 1.4889 - val_policy_categorical_accuracy: 1.4878 - val_value_binary_accuracy: 1.4872
<tensorflow.python.keras.callbacks.History at 0x7f841a0954a8>

TPUでは、lossは正しく表示されますが、accuracyが正しく表示されませんでした。
ネットワークの出力が2つの場合に発生するようです。
まだKerasのTPUサポートは正式版ではないので、バグの可能性があります。

Google ColabのGPU(Tesla K80)での学習結果
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
390/390 [==============================] - 62s 159ms/step - loss: 4.4728 - policy_loss: 3.4411 - value_loss: 0.2205 - policy_categorical_accuracy: 0.2172 - value_binary_accuracy: 0.7430
3906/3906 [==============================] - 1803s 462ms/step - loss: 5.2603 - policy_loss: 4.1786 - value_loss: 0.2296 - policy_categorical_accuracy: 0.1737 - value_binary_accuracy: 0.7433 - val_loss: 4.4728 - val_policy_loss: 3.4411 - val_value_loss: 0.2205 - val_policy_categorical_accuracy: 0.2172 - val_value_binary_accuracy: 0.7430
<tensorflow.python.keras.callbacks.History at 0x7f61b7ee5780>
ローカルPCのGPU(1080 Ti)での学習結果
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\keras\utils\losses_utils.py:170: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From C:\Anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
2019-02-17 16:51:40.039612: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2019-02-17 16:51:40.216090: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties:
name: GeForce GTX 1080 Ti major: 6 minor: 1 memoryClockRate(GHz): 1.6705
pciBusID: 0000:01:00.0
totalMemory: 11.00GiB freeMemory: 9.10GiB
2019-02-17 16:51:40.224327: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0
2019-02-17 16:51:40.638740: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-02-17 16:51:40.642624: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990]      0
2019-02-17 16:51:40.645263: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0:   N
2019-02-17 16:51:40.648587: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 8780 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1)
2019-02-17 16:51:45.858150: I tensorflow/stream_executor/dso_loader.cc:152] successfully opened CUDA library cublas64_100.dll locally
390/390 [==============================] - 16s 42ms/step - loss: 4.4274 - policy_loss: 3.3871 - value_loss: 0.2285 - policy_categorical_accuracy: 0.2272 - value_binary_accuracy: 0.7431
3906/3906 [==============================] - 451s 116ms/step - loss: 5.2415 - policy_loss: 4.1575 - value_loss: 0.2314 - policy_categorical_accuracy: 0.1760 - value_binary_accuracy: 0.7424 - val_loss: 4.4274 - val_policy_loss: 3.3871 - val_value_loss: 0.2285 - val_policy_categorical_accuracy: 0.2272 - val_value_binary_accuracy: 0.7431

GPUでは、accuracyも正しく表示されています。

比較

TPUとGPUで学習時間を比較した結果は以下の通りでした。

TPU 644s
Google ColabのGPU(Tesla K80) 1803s
ローカルPCのGPU(1080Ti) 451s

バッチサイズを2048に増やすと以下の通りになりました。(Google ColabのGPUは遅いので測定から外しています。)

TPU 806s
ローカルPCのGPU(1080Ti) 516s

20ブロック、256フィルタ、バッチサイズ256では以下の通りでした。

TPU 1448s
ローカルPCのGPU(1080Ti) 1267s

いずれの条件でも、ローカルPCのGPU(1080Ti)が早いという結果になりました。
条件を変えると変わってくるかもしれません。

Google Colabを使う場合は、TPUの方がGPUより圧倒的に速いです。
ただし、TPUはまだKerasでの動作が安定していないため、安心して使えるのはTensorFlow 2.0で正式サポートされてからになりそうです。