TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

Kerasでクラス分類モデルの出力をlogitsにする

Google ColabでTPUを使うには、今のところフレームワークにTesorFlow(Keras)を使う必要がある。
Kerasで将棋AI用のモデル定義を行っていて、ChainerではできてKerasでは簡単にできない問題にぶつかった。

Kerasでクラス分類のモデルを定義して学習する際、通常の方法では、出力層のSoftmax活性化関数を含めたモデルしか定義できない。

学習済みモデルを使って、ボルツマン分布で確率を出力したい場合には、出力として、Softmax活性化関数に入力する前のロジットを取得したい。
ChainerのようなDefine-by-Runのフレームワークでは途中の値を取り出すのは容易だが、Kerasはモデルを事前にコンパイルするため途中の値は取得できない。

Kerasでロジットを出力するモデルを定義して、できるだけ標準的な方法で学習できないか調べてみた。
nlp - Keras - how to get unnormalized logits instead of probabilities - Stack Overflow
このQAにヒントが書かれていた。

引数にfrom_logits=Trueを与えたsparse_categorical_crossentropyをラッピングしてカスタム損失関数を定義すればよい。

def sparse_categorical_crossentropy(y_true, y_pred):
    return tf.keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

metricsについては上記のQAには書かれていなかったが、試行錯誤の結果、以下のようにすればよいことがわかった。
metricsには、sparse_categorical_accuracyを使用したしたいが、モデルの出力がロジットになっているため、softmaxを計算した上で、sparse_categorical_accuracyに渡す必要がある。つまり、そのように処理を行うカスタム関数を定義すればよい。

def sparse_categorical_accuracy(y_true, y_pred):
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, tf.nn.softmax(y_pred))

通常のモデルと、ロジットを出力するモデルのそれぞれの、モデル定義と学習方法は以下の通りになる。

通常のモデル

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

ロジットを出力するモデル

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dense(10)
])

def sparse_categorical_crossentropy(y_true, y_pred):
    return tf.keras.backend.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

def sparse_categorical_accuracy(y_true, y_pred):
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, tf.nn.softmax(y_pred))

model.compile(optimizer='adam',
              loss=sparse_categorical_crossentropy,
              metrics=[sparse_categorical_accuracy])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)