Programming Lab

【PyTorch】学習済みモデルの保存と読み込み|10分で習得!

【PyTorch】学習済みモデルの保存と読み込み|10分で習得!

 

 PyTorchの学習モデルを保存したり、ロードする方法がわからない…

 

本記事では、この悩みを解決していきます。

具体例を使用して、初心者の方でも、動作の詳細も理解することができるように記事を構成しました。

本記事を読むことで、学習済みモデルを確実に保存して、読み込むことができるようになります。

 

 

PyTorchのモデルの保存と読み込み

 

さっそく、PyTorchのモデルの保存と読み込みを行う方法を紹介します。

  1. 必要なライブラリをインポート
  2. 学習モデルを保存
  3. 学習モデルを読み込み

 

必要なライブラリをインポート

 

まずは、必要なライブラリをインポートします。

import torch
import torch.nn as nn
import torch.nn.functional as F

 

使用するデバイスを設定しておきます。

device = 'cuda' if torch.cuda.is_available() else 'cpu'

 

また、今回は具体例として以下のネットワークを使用します。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(3, 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return x

 

モデルを指定したデバイスに送りましょう。

model = Net().to(device)

 

ここからは、このモデルを保存する方法を紹介していきます。

 

学習済みモデルを保存

 

torch.save()を利用することで、学習済みモデルを保存することができます。

具体的には、以下のように実行します。

torch.save(model.state_dict(), PATH)

 

第一引数には、保存する対象を入力してください。

そのためPATHには、保存先のPATH(~~~/~~~.pth)を指定してください(.pthという拡張子のファイルで保存されます)

 

PyTorchの学習済みモデルを保存する場合、モデルの学習可能なパラメータの値を一般的に保存します。

 

実際に、model.state_dict()の中身を見てみると以下のようになります。

model.state_dict()

 

<output>

OrderedDict([('fc1.weight', tensor([[-0.1916,  0.1562, -0.2462],
                      [ 0.5605,  0.1616, -0.2541],
                      [ 0.3653,  0.5372, -0.5730]], device='cuda:0')),
             ('fc1.bias',
              tensor([-0.4214,  0.2875, -0.0175], device='cuda:0'))])

 

このように、デバイスの情報と学習可能なパラメータの値を取得することができます。

これらの値を保存して、学習済みモデルを管理します。

 

実は、torch.save(model, PATH)とすることで、モデル全体をそのまま保存することもできます。

しかし、PyTorchでは非推奨です。

 

学習済みモデルを読み込み

 

torch.load()を利用することで、モデルの学習可能パラメータの情報を読み込むことができます。

具体的には、以下を実行してください。

model_params = torch.load(PATH)
model_params

 

<output>

OrderedDict([('fc1.weight', tensor([[-0.1916,  0.1562, -0.2462],
                      [ 0.5605,  0.1616, -0.2541],
                      [ 0.3653,  0.5372, -0.5730]], device='cuda:0')),
             ('fc1.bias',
              tensor([-0.4214,  0.2875, -0.0175], device='cuda:0'))])

 

先ほど保存したモデルのパラメータを読む込むことができました。

次は、load_state_dict()を利用して、このモデルのパラメータを新たなモデルに読み込みます。

 

# 新しいモデル
model2 = Net().to(device)
print('新しいモデル:\n', model2.state_dict())

# 保存したモデルパラメータの読み込み
model2.load_state_dict(torch.load(PATH))
print('読み込み後のモデル:\n', model2.state_dict())

 

<output>

新しいモデル:
 OrderedDict([('fc1.weight', tensor([[-0.3752,  0.3518,  0.5422],
        [ 0.2593, -0.4580,  0.4354],
        [ 0.2122, -0.2200,  0.4411]], device='cuda:0')), ('fc1.bias', tensor([-0.2742, -0.4467, -0.4960], device='cuda:0'))])
読み込み後のモデル:
 OrderedDict([('fc1.weight', tensor([[-0.1916,  0.1562, -0.2462],
        [ 0.5605,  0.1616, -0.2541],
        [ 0.3653,  0.5372, -0.5730]], device='cuda:0')), ('fc1.bias', tensor([-0.4214,  0.2875, -0.0175], device='cuda:0'))])

 

読み込み後は、保存したパラメータを持つモデルが構築されます。

一般的には、torch.save()load_state_dict()をまとめて以下のように記述することが多いです。

model2 = Net().to(device)
model2.load_state_dict(torch.load(PATH))

 

<output>

<All keys matched successfully>

 

 

学習済みモデルを部分的に読み込む方法

 

学習済みモデルを部分的に読み込むためには、load_state_dict()の引数をstrict=Falseとします。

具体例として、以下のネットワークを構築します。

class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.fc1 = nn.Linear(3, 3, bias=False)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return x

 

バイアス項がないため、先ほど保存したモデルと形状が一致していないことに注意してください!

しかし、strict=Falseとすると適合する部分のみ読み込むことができます。

model3 = Net2().to(device)

print('新しいモデル:\n', model3.state_dict())

# 部分的に読み込む
model3.load_state_dict(torch.load('filename.pth'), strict=False)
print('読み込み後のモデル:\n', model3.state_dict())

 

<output>

新しいモデル:
 OrderedDict([('fc1.weight', tensor([[-0.2354, -0.4788, -0.0337],
        [ 0.0216, -0.2793,  0.4533],
        [ 0.2675,  0.0576, -0.2487]], device='cuda:0'))])
読み込み後のモデル:
 OrderedDict([('fc1.weight', tensor([[-0.1916,  0.1562, -0.2462],
        [ 0.5605,  0.1616, -0.2541],
        [ 0.3653,  0.5372, -0.5730]], device='cuda:0'))])

 

無事に、適合する部分は保存した値になっています。

 

まとめ

 

本記事では、PyTorchの学習済みモデルの保存と読み込み方法を具体例を使用しながら解説しました。

さらにPyTorchについて知りたい方は下記のボタンからアクセスしてください。

 

PyTorchの記事はこちら

ABOUT ME
努力のガリレオ
【運営者】 : 東大で理論物理を研究中(経歴)東京大学, TOEIC950点, NASA留学, カナダ滞在経験有り, 最優秀塾講師賞, オンライン英会話講師試験合格, ブログと独自コンテンツで収益6桁達成 【編集者】: イングリッシュアドバイザーとして勤務中(経歴)中学校教諭一種免許取得[英語],カナダ留学経験あり, TOEIC650点
Python学習を効率化させるサービス

 

Pythonを学習するのに効率的なサービスを紹介していきます。

まず最初におすすめするのは、Udemyです。

Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。

また、セール中は1500円定義で利用することができ、コスパも最強です。

下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。

 

【最新】UdemyでおすすめのPythonコース|東大生が厳選!10万を超える講座があるUdemyの中で、Pythonに関係する講座を厳選しました。また、本記事では、Udemyを使用しながらPythonをどのような順番で勉強するべきかを紹介しました。ぜひ参考にしてください。...

 

他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。

 

【最新】Pythonに強いプログラミングスクール7選|東大生が厳選Pythonの流行と共に、Pythonに強いプログラミングスクールが増えてきました。本記事では、特にPythonを効率的に学ぶことができるプログラミングスクールを経験をもとに厳選して、内容を詳しく解説しています。...

 

自分の学習スタイルに合わせて最適なものを選びましょう。