【PyTorch】学習済みモデルの保存と読み込み|10分で習得!
本記事では、この悩みを解決していきます。
具体例を使用して、初心者の方でも、動作の詳細も理解することができるように記事を構成しました。
本記事を読むことで、学習済みモデルを確実に保存して、読み込むことができるようになります。
PyTorchのモデルの保存と読み込み
さっそく、PyTorchのモデルの保存と読み込みを行う方法を紹介します。
- 必要なライブラリをインポート
- 学習モデルを保存
- 学習モデルを読み込み
必要なライブラリをインポート
まずは、必要なライブラリをインポートします。
import torch
import torch.nn as nn
import torch.nn.functional as F
使用するデバイスを設定しておきます。
device = 'cuda' if torch.cuda.is_available() else 'cpu'
また、今回は具体例として以下のネットワークを使用します。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(3, 3)
def forward(self, x):
x = F.relu(self.fc1(x))
return x
モデルを指定したデバイスに送りましょう。
model = Net().to(device)
ここからは、このモデルを保存する方法を紹介していきます。
学習済みモデルを保存
torch.save()
を利用することで、学習済みモデルを保存することができます。
具体的には、以下のように実行します。
torch.save(model.state_dict(), PATH)
第一引数には、保存する対象を入力してください。
そのためPATH
には、保存先のPATH(~~~/~~~.pth
)を指定してください(.pth
という拡張子のファイルで保存されます)
PyTorchの学習済みモデルを保存する場合、モデルの学習可能なパラメータの値を一般的に保存します。
実際に、model.state_dict()
の中身を見てみると以下のようになります。
model.state_dict()
<output>
OrderedDict([('fc1.weight', tensor([[-0.1916, 0.1562, -0.2462],
[ 0.5605, 0.1616, -0.2541],
[ 0.3653, 0.5372, -0.5730]], device='cuda:0')),
('fc1.bias',
tensor([-0.4214, 0.2875, -0.0175], device='cuda:0'))])
このように、デバイスの情報と学習可能なパラメータの値を取得することができます。
これらの値を保存して、学習済みモデルを管理します。
実は、torch.save(model, PATH)
とすることで、モデル全体をそのまま保存することもできます。
しかし、PyTorchでは非推奨です。
学習済みモデルを読み込み
torch.load()
を利用することで、モデルの学習可能パラメータの情報を読み込むことができます。
具体的には、以下を実行してください。
model_params = torch.load(PATH)
model_params
<output>
OrderedDict([('fc1.weight', tensor([[-0.1916, 0.1562, -0.2462],
[ 0.5605, 0.1616, -0.2541],
[ 0.3653, 0.5372, -0.5730]], device='cuda:0')),
('fc1.bias',
tensor([-0.4214, 0.2875, -0.0175], device='cuda:0'))])
先ほど保存したモデルのパラメータを読む込むことができました。
次は、load_state_dict()
を利用して、このモデルのパラメータを新たなモデルに読み込みます。
# 新しいモデル
model2 = Net().to(device)
print('新しいモデル:\n', model2.state_dict())
# 保存したモデルパラメータの読み込み
model2.load_state_dict(torch.load(PATH))
print('読み込み後のモデル:\n', model2.state_dict())
<output>
新しいモデル:
OrderedDict([('fc1.weight', tensor([[-0.3752, 0.3518, 0.5422],
[ 0.2593, -0.4580, 0.4354],
[ 0.2122, -0.2200, 0.4411]], device='cuda:0')), ('fc1.bias', tensor([-0.2742, -0.4467, -0.4960], device='cuda:0'))])
読み込み後のモデル:
OrderedDict([('fc1.weight', tensor([[-0.1916, 0.1562, -0.2462],
[ 0.5605, 0.1616, -0.2541],
[ 0.3653, 0.5372, -0.5730]], device='cuda:0')), ('fc1.bias', tensor([-0.4214, 0.2875, -0.0175], device='cuda:0'))])
読み込み後は、保存したパラメータを持つモデルが構築されます。
一般的には、torch.save()
とload_state_dict()
をまとめて以下のように記述することが多いです。
model2 = Net().to(device)
model2.load_state_dict(torch.load(PATH))
<output>
<All keys matched successfully>
学習済みモデルを部分的に読み込む方法
学習済みモデルを部分的に読み込むためには、load_state_dict()
の引数をstrict=False
とします。
具体例として、以下のネットワークを構築します。
class Net2(nn.Module):
def __init__(self):
super(Net2, self).__init__()
self.fc1 = nn.Linear(3, 3, bias=False)
def forward(self, x):
x = F.relu(self.fc1(x))
return x
バイアス項がないため、先ほど保存したモデルと形状が一致していないことに注意してください!
しかし、strict=False
とすると適合する部分のみ読み込むことができます。
model3 = Net2().to(device)
print('新しいモデル:\n', model3.state_dict())
# 部分的に読み込む
model3.load_state_dict(torch.load('filename.pth'), strict=False)
print('読み込み後のモデル:\n', model3.state_dict())
<output>
新しいモデル:
OrderedDict([('fc1.weight', tensor([[-0.2354, -0.4788, -0.0337],
[ 0.0216, -0.2793, 0.4533],
[ 0.2675, 0.0576, -0.2487]], device='cuda:0'))])
読み込み後のモデル:
OrderedDict([('fc1.weight', tensor([[-0.1916, 0.1562, -0.2462],
[ 0.5605, 0.1616, -0.2541],
[ 0.3653, 0.5372, -0.5730]], device='cuda:0'))])
無事に、適合する部分は保存した値になっています。
まとめ
本記事では、PyTorchの学習済みモデルの保存と読み込み方法を具体例を使用しながら解説しました。
さらにPyTorchについて知りたい方は下記のボタンからアクセスしてください。
Pythonを学習するのに効率的なサービスを紹介していきます。
まず最初におすすめするのは、Udemyです。
Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。
また、セール中は1500円定義で利用することができ、コスパも最強です。
下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。
他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。
自分の学習スタイルに合わせて最適なものを選びましょう。
また、私がPythonを学ぶ際に使用した本を全て暴露しているので参考にしてください。