timmのモデルはソースが複雑なため、ソースからモデルの構造を理解するのは大変である。
そのため、モデル構造を可視化できると理解しやすい。
以下では、PyTorchのTensorBoardで、モデル構造を可視化する方法を紹介する。
TensorBoardでグラフ表示
TensorBoardには、グラフを表示する機能があり、ネットワークのブロックをドリルダウンしながら調べられるので使いやすい。
TensorBoardは、PyTorchに統合されており、以下のようにしてモデルのグラフを書き出すことができる。
from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() writer.add_graph(model, input)
VS Codeに統合されているTensorBoardで、出力したグラフを表示できる。
グラフは、下が入力で、上が出力なので注意が必要である。
timmのモデルのグラフ表示
以下のようなコードで、timmのモデルのグラフを出力できる。
add_graphの引数にバッチサイズ1のデータが必要になるため、値がすべて0のデータを与えている。
timmのモデルは、ImageNetのデータセット向けになっているので入力は244×244のカラー画像になる。
import torch from torch.utils.tensorboard import SummaryWriter import timm # model model = timm.create_model('mobilenetv3_large_100') # graph writer = SummaryWriter() writer.add_graph(model, torch.zeros((1, 3, 244, 244))) writer.close()