Programming Lab

PyTorchで自作損失関数を作成する方法

PyTorchで自作損失関数を作成する方法

 

本記事では、PyTorchで自作損失関数を作成する方法を紹介します。

扱うタスクによってはPyTorchが対応していない損失関数を使用する場合もあるため一度学んでおくと良いです。

自作損失関数の使用方法

早速ですが、自作損失関数の定義方法を紹介します。

具体例を示した方が理解しやすいと思うので、下記に具体例を示します。

class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # パラメータを設定 

    def forward(self, outputs, targets):
        '''損失関数を計算

        Parameters
        ------------------
        outputs : モデルの出力
        targets : 正解ラベル
        '''
        loss = (lossの計算)
        return loss

 

この記述方法は、PyTorchのデフォルトの損失関数と同じ書き方です。

この方法で記述すると通常の学習方法と同様の記述を使うことができます(下記参照)

criterion = CustomLoss()
outputs = model(input)
loss = criterion(outputs, targets)
loss.backward()

 

自作損失関数の場合もPyTorchが計算グラフを構築し、自動微分を実行してくれるので、改めて微分を計算する必要はありません!

 

自作損失関数の具体例

自作損失関数の具体例を紹介します。

今回は、以下のような誤差関数を自作して自動微分を行う過程を説明します。

$$\text{4thLoss} = \frac{1}{N} \sum_{i=1}^{N} (\text{outputs} – \text{labels})^{4}$$

*実際にこのような誤差関数があるかは不明です。

誤差関数の定義

 

まずは、上述のように誤差関数を以下のように定義します。

class FourthLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, outputs, targets):
        loss = ((outputs - targets)**4).mean()
        return loss        

 

これで定義は完了です!

自動微分の実行

 

先ほど定義した4thLossの勾配が自動微分によって求めることができるのかを確認します。

今回は下記のような出力とラベルが与えられる設定を考えます。

targets = torch.randn(100)
targets.requires_grad=True
outputs = torch.randn(100)
outputs.requires_grad=True

 

損失関数自体の値は以下のようにように求めることができます。

fourthloss = FourthLoss()
loss = fourthloss(outputs, targets)
print(loss)

 

<output>

tensor(12.8722, grad_fn=<MeanBackward0>)

 

勾配も以下のように自動微分で求めることができます。

loss.backward()
print(targets.grad)

 

<output>

tensor([ 9.1533e-02,  1.8771e-03,  1.6530e-01, -2.0632e-03,  1.3365e-01,
        -1.7062e-01, -6.6819e-03,  2.7819e-01,  3.4191e-02,  1.9113e-02,
         4.4310e-02,  2.4035e-01,  1.1377e-02,  8.5665e-02, -1.2062e-03,
         4.0217e-01,  1.6954e-01, -5.8900e-05, -1.5942e+00, -1.0333e-03,
         7.6033e-02, -1.7734e-01,  4.2902e-03,  1.9156e-03,  7.9497e-05,
         3.7292e-02,  1.1914e-01, -1.7258e-01,  2.9576e-05, -5.8094e-02,
         2.0214e-01,  1.8167e-01, -1.4757e-02,  1.2792e+00, -7.5485e-07,
         5.9614e-02,  3.5801e-01,  4.1280e-02,  9.7868e-02, -1.9571e-05,
        -4.0068e-02, -5.1422e-02,  4.9745e-04, -1.1641e-03, -9.2630e-02,
         1.4710e-01,  2.3637e-01,  4.9167e-01, -4.2075e-01,  5.0776e-01,
        -7.3926e-02, -4.8032e-02,  1.6428e-02, -2.1037e-03, -2.6576e-02,
        -1.0183e+00,  3.1362e-02, -5.0543e-01,  4.5202e-02, -1.9641e-03,
         5.6631e-03, -2.3474e-03,  8.3443e-02, -9.3296e-01,  2.7348e-02,
         1.7776e+00,  2.0189e-02,  3.8390e-01,  3.2938e-05, -8.2928e-07,
         3.2821e-02, -2.0227e-02, -7.3885e-02,  2.2900e-01,  1.0089e-02,
        -1.2580e-01, -3.8641e-05,  1.0434e-02, -2.7547e-01,  1.7748e-04,
         1.5187e+00, -9.7312e-02,  1.8129e-02, -3.6833e-01,  2.0181e+00,
         1.5603e-02,  1.1458e-03,  1.3111e-03, -1.7094e-04,  7.3340e-08,
        -2.4228e-02,  2.3902e-01, -5.6095e-01, -1.9228e-02,  2.0077e-02,
        -1.3741e-01,  2.8495e-01, -1.0526e-02,  1.1793e-01,  1.3773e-01])

 

これで自作損失関数の説明は終了です。お疲れ様でした。

 

まとめ

 

本記事では、Pytorchによる自作損失関数の使用方法を紹介しました。

実際に自作損失関数が動作するのかを実験したい方は、下記の記事で最も簡単な多層パーセプトロンの実装を行なっているので、そこで使用している損失関数を自作損失関数に変更し動作を確認してくみださい!

 

【入門】PyTorchの使い方をMNISTデータセットで学ぶ(15分)本記事では、MNISTデータセットを利用して、PyTorchの実装の流れを紹介しました。本記事を通して読めば、PyTorchを実務で使用するイメージが確実に身につくと思います。...

 

その他にも自作でPytorchのモジュールを作成する方法を本ブログでは紹介しています。

ABOUT ME
努力のガリレオ
【運営者】 : 東大で理論物理を研究中(経歴)東京大学, TOEIC950点, NASA留学, カナダ滞在経験有り, 最優秀塾講師賞, オンライン英会話講師試験合格, ブログと独自コンテンツで収益6桁達成 【編集者】: イングリッシュアドバイザーとして勤務中(経歴)中学校教諭一種免許取得[英語],カナダ留学経験あり, TOEIC650点
Python学習を効率化させるサービス

 

Pythonを学習するのに効率的なサービスを紹介していきます。

まず最初におすすめするのは、Udemyです。

Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。

また、セール中は1500円定義で利用することができ、コスパも最強です。

下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。

 

【最新】UdemyでおすすめのPythonコース|東大生が厳選!10万を超える講座があるUdemyの中で、Pythonに関係する講座を厳選しました。また、本記事では、Udemyを使用しながらPythonをどのような順番で勉強するべきかを紹介しました。ぜひ参考にしてください。...

 

他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。

 

【最新】Pythonに強いプログラミングスクール7選|東大生が厳選Pythonの流行と共に、Pythonに強いプログラミングスクールが増えてきました。本記事では、特にPythonを効率的に学ぶことができるプログラミングスクールを経験をもとに厳選して、内容を詳しく解説しています。...

 

自分の学習スタイルに合わせて最適なものを選びましょう。