PyTorchで自作損失関数を作成する方法
本記事では、PyTorchで自作損失関数を作成する方法を紹介します。
扱うタスクによってはPyTorchが対応していない損失関数を使用する場合もあるため一度学んでおくと良いです。
自作損失関数の使用方法
早速ですが、自作損失関数の定義方法を紹介します。
具体例を示した方が理解しやすいと思うので、下記に具体例を示します。
class CustomLoss(nn.Module):
def __init__(self):
super().__init__()
# パラメータを設定
def forward(self, outputs, targets):
'''損失関数を計算
Parameters
------------------
outputs : モデルの出力
targets : 正解ラベル
'''
loss = (lossの計算)
return loss
この記述方法は、PyTorchのデフォルトの損失関数と同じ書き方です。
この方法で記述すると通常の学習方法と同様の記述を使うことができます(下記参照)
criterion = CustomLoss()
outputs = model(input)
loss = criterion(outputs, targets)
loss.backward()
自作損失関数の場合もPyTorchが計算グラフを構築し、自動微分を実行してくれるので、改めて微分を計算する必要はありません!
自作損失関数の具体例
自作損失関数の具体例を紹介します。
今回は、以下のような誤差関数を自作して自動微分を行う過程を説明します。
$$\text{4thLoss} = \frac{1}{N} \sum_{i=1}^{N} (\text{outputs} – \text{labels})^{4}$$
*実際にこのような誤差関数があるかは不明です。
誤差関数の定義
まずは、上述のように誤差関数を以下のように定義します。
class FourthLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, outputs, targets):
loss = ((outputs - targets)**4).mean()
return loss
これで定義は完了です!
自動微分の実行
先ほど定義した4thLossの勾配が自動微分によって求めることができるのかを確認します。
今回は下記のような出力とラベルが与えられる設定を考えます。
targets = torch.randn(100)
targets.requires_grad=True
outputs = torch.randn(100)
outputs.requires_grad=True
損失関数自体の値は以下のようにように求めることができます。
fourthloss = FourthLoss()
loss = fourthloss(outputs, targets)
print(loss)
<output>
tensor(12.8722, grad_fn=<MeanBackward0>)
勾配も以下のように自動微分で求めることができます。
loss.backward()
print(targets.grad)
<output>
tensor([ 9.1533e-02, 1.8771e-03, 1.6530e-01, -2.0632e-03, 1.3365e-01,
-1.7062e-01, -6.6819e-03, 2.7819e-01, 3.4191e-02, 1.9113e-02,
4.4310e-02, 2.4035e-01, 1.1377e-02, 8.5665e-02, -1.2062e-03,
4.0217e-01, 1.6954e-01, -5.8900e-05, -1.5942e+00, -1.0333e-03,
7.6033e-02, -1.7734e-01, 4.2902e-03, 1.9156e-03, 7.9497e-05,
3.7292e-02, 1.1914e-01, -1.7258e-01, 2.9576e-05, -5.8094e-02,
2.0214e-01, 1.8167e-01, -1.4757e-02, 1.2792e+00, -7.5485e-07,
5.9614e-02, 3.5801e-01, 4.1280e-02, 9.7868e-02, -1.9571e-05,
-4.0068e-02, -5.1422e-02, 4.9745e-04, -1.1641e-03, -9.2630e-02,
1.4710e-01, 2.3637e-01, 4.9167e-01, -4.2075e-01, 5.0776e-01,
-7.3926e-02, -4.8032e-02, 1.6428e-02, -2.1037e-03, -2.6576e-02,
-1.0183e+00, 3.1362e-02, -5.0543e-01, 4.5202e-02, -1.9641e-03,
5.6631e-03, -2.3474e-03, 8.3443e-02, -9.3296e-01, 2.7348e-02,
1.7776e+00, 2.0189e-02, 3.8390e-01, 3.2938e-05, -8.2928e-07,
3.2821e-02, -2.0227e-02, -7.3885e-02, 2.2900e-01, 1.0089e-02,
-1.2580e-01, -3.8641e-05, 1.0434e-02, -2.7547e-01, 1.7748e-04,
1.5187e+00, -9.7312e-02, 1.8129e-02, -3.6833e-01, 2.0181e+00,
1.5603e-02, 1.1458e-03, 1.3111e-03, -1.7094e-04, 7.3340e-08,
-2.4228e-02, 2.3902e-01, -5.6095e-01, -1.9228e-02, 2.0077e-02,
-1.3741e-01, 2.8495e-01, -1.0526e-02, 1.1793e-01, 1.3773e-01])
これで自作損失関数の説明は終了です。お疲れ様でした。
まとめ
本記事では、Pytorchによる自作損失関数の使用方法を紹介しました。
実際に自作損失関数が動作するのかを実験したい方は、下記の記事で最も簡単な多層パーセプトロンの実装を行なっているので、そこで使用している損失関数を自作損失関数に変更し動作を確認してくみださい!
その他にも自作でPytorchのモジュールを作成する方法を本ブログでは紹介しています。
Pythonを学習するのに効率的なサービスを紹介していきます。
まず最初におすすめするのは、Udemyです。
Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。
また、セール中は1500円定義で利用することができ、コスパも最強です。
下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。
他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。
自分の学習スタイルに合わせて最適なものを選びましょう。
また、私がPythonを学ぶ際に使用した本を全て暴露しているので参考にしてください。