TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

timmのモデルを可視化する

timmのモデルはソースが複雑なため、ソースからモデルの構造を理解するのは大変である。
そのため、モデル構造を可視化できると理解しやすい。

以下では、PyTorchのTensorBoardで、モデル構造を可視化する方法を紹介する。

TensorBoardでグラフ表示

TensorBoardには、グラフを表示する機能があり、ネットワークのブロックをドリルダウンしながら調べられるので使いやすい。

TensorBoardは、PyTorchに統合されており、以下のようにしてモデルのグラフを書き出すことができる。

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()
writer.add_graph(model, input)

VS Codeに統合されているTensorBoardで、出力したグラフを表示できる。

グラフは、下が入力で、上が出力なので注意が必要である。

timmのモデルのグラフ表示

以下のようなコードで、timmのモデルのグラフを出力できる。
add_graphの引数にバッチサイズ1のデータが必要になるため、値がすべて0のデータを与えている。
timmのモデルは、ImageNetのデータセット向けになっているので入力は244×244のカラー画像になる。

import torch
from torch.utils.tensorboard import SummaryWriter
import timm

# model
model = timm.create_model('mobilenetv3_large_100')

# graph
writer = SummaryWriter()
writer.add_graph(model, torch.zeros((1, 3, 244, 244)))
writer.close()