TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

dlshogiモデルの20ブロックから15ブロックへの知識蒸留 追試2

前回、知識蒸留を定義通りKLダイバージェンスを計算して実装したが、交差エントロピーを使用しても勾配は変わらないため、交差エントロピーで実装し直した。
交差エントロピーの方がPyTorchで用意されているメソッドが使用でき、実行時間も短くなる。

実装

if args.distillation_model:
    with torch.no_grad():
        distillation_y1, distillation_y2 = distillation_model(x1, x2)
    distillation_loss1 = cross_entropy_loss_with_soft_target(y1, distillation_y1.softmax(dim=1)).mean()
    distillation_loss2 = bce_with_logits_loss(y2, distillation_y2.sigmoid())
    loss = (1 - args.distillation_alpha) * loss + args.distillation_alpha * (distillation_loss1 + distillation_loss2)

cross_entropy_loss_with_soft_targetは、方策の学習でも使用しており、以下の通り実装している。

def cross_entropy_loss_with_soft_target(pred, soft_targets):
    return torch.sum(-soft_targets * F.log_softmax(pred, dim=1), 1)

精度比較

KLダイバージェンスで学習した場合と、精度が変わらないことを確認した。

4回測定した平均

損失 方策損失 価値損失 方策正解率 価値正解率 方策エントロピー 価値エントロピー
KLD 1.56883 0.49999 0.488682 0.738732 1.47312 0.537952
交差エントロピー 1.56787 0.499179 0.489128 0.738752 1.47079 0.536013

方策正解率、価値正解率の差はそれぞれ、0.04%、0.00%である。
ほぼ同じ結果となった。

学習時間

1エポック当たりの学習時間は以下の通り。

KLD 0:43:40
交差エントロピー 0:42:00

学習時間は、96.2%となり、わずかに速くなった。
損失の計算は、全体の計算からするとわずかなので、たいして速くなっていない。

まとめ

知識蒸留の計算を、KLダイバージェンスから交差エントロピーに置き換えても実装上問題ないことを確認した。