前回、知識蒸留を定義通りKLダイバージェンスを計算して実装したが、交差エントロピーを使用しても勾配は変わらないため、交差エントロピーで実装し直した。
交差エントロピーの方がPyTorchで用意されているメソッドが使用でき、実行時間も短くなる。
実装
if args.distillation_model: with torch.no_grad(): distillation_y1, distillation_y2 = distillation_model(x1, x2) distillation_loss1 = cross_entropy_loss_with_soft_target(y1, distillation_y1.softmax(dim=1)).mean() distillation_loss2 = bce_with_logits_loss(y2, distillation_y2.sigmoid()) loss = (1 - args.distillation_alpha) * loss + args.distillation_alpha * (distillation_loss1 + distillation_loss2)
cross_entropy_loss_with_soft_targetは、方策の学習でも使用しており、以下の通り実装している。
def cross_entropy_loss_with_soft_target(pred, soft_targets): return torch.sum(-soft_targets * F.log_softmax(pred, dim=1), 1)
精度比較
KLダイバージェンスで学習した場合と、精度が変わらないことを確認した。
4回測定した平均
損失 | 方策損失 | 価値損失 | 方策正解率 | 価値正解率 | 方策エントロピー | 価値エントロピー |
---|---|---|---|---|---|---|
KLD | 1.56883 | 0.49999 | 0.488682 | 0.738732 | 1.47312 | 0.537952 |
交差エントロピー | 1.56787 | 0.499179 | 0.489128 | 0.738752 | 1.47079 | 0.536013 |
方策正解率、価値正解率の差はそれぞれ、0.04%、0.00%である。
ほぼ同じ結果となった。
学習時間
1エポック当たりの学習時間は以下の通り。
KLD | 0:43:40 |
交差エントロピー | 0:42:00 |
学習時間は、96.2%となり、わずかに速くなった。
損失の計算は、全体の計算からするとわずかなので、たいして速くなっていない。