【超簡単】PyTorchで計算グラフを表示する方法|ネットワークを可視化しよう!
本記事では、PyTorchのネットワーク(計算グラフ)を可視化する方法を初心者の方でも理解できるように説明します。
計算グラフを表示する準備
まずは、計算グラフを表示する準備をしましょう。
torchvisの導入
Notebookで以下を入力して実行してください。
!pip install torchviz
ライブラリをインポート
今回使用するライブラリをインポートしてください。
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchviz import make_dot
from IPython.display import display
これで準備は完了です。
torchvisで計算グラフを表示
これで準備が整ったので実際にtorchvis
を使用して計算グラフを可視化してみましょう。
可視化するネットワークを定義
当ブログの『【入門】PyTorchの使い方をMNISTデータセットで学ぶ(15分)』で使用したネットワークを可視化します。
具体的には以下のようなネットワークを使用します。
num_classes = 10
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(28*28, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, num_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
計算グラフを表示
計算グラフを表示するためには、一度適当な値を入力して出力を得てから、その出力とネットワークのパラメータを辞書形式にしてmake_dot
関数に入力します。
具体的な例を以下に示します。
model = Net()
# 適当な入力
x = torch.randn(1, 28*28)
# 出力
y = model(x)
# 計算グラフを表示
img = make_dot(y, params=dict(model.named_parameters()))
display(img)
<出力>
こんな感じに綺麗な計算グラフを表示できます。
参考資料|おすすめ教材
参考文献|おすすめ参考書
本記事を作成する際に利用した参考文献を下記にまとめました。
参考講座
UdemyのPyTorchコースがマジでおすすめです!
私は受講した中で特におすすめな講座を厳選しています。
PyTorch以外も講座もぜひ受講してみてください!
まとめ
本記事では、PyTorchを用いて計算グラフを可視化する方法を紹介しました。
PyTorchに関する他の記事は下記を参考にしてください。
Pythonを学習するのに効率的なサービスを紹介していきます。
まず最初におすすめするのは、Udemyです。
Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。
また、セール中は1500円定義で利用することができ、コスパも最強です。
下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。
他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。
自分の学習スタイルに合わせて最適なものを選びましょう。
また、私がPythonを学ぶ際に使用した本を全て暴露しているので参考にしてください。