【徹底解説】PyTorchのDataLoaderの使い方
- PyTorchのDataLoaderを何となく使ってるけど、実はよくわかってない…
- DataLoaderとDatasetの仕組みを勉強しようと思ったけど挫折した…
本記事では、その悩みと疑問を解決していきます。
PyTorchのDataLoaderとDatasetを使用することで、簡単にミニバッチ勾配降下法を実装することができます。
最初は、DataLoaderとDatasetが複雑に感じますが一度理解すると、なくてはならない存在になります。
本記事では、特にDataLoaderに焦点を当てて解説していきます。
『DataLoader, Datasetってなんだ?』という方でも理解できるように解説するので安心してください。
PyTorchのDatasetとDataLoaderについて
PyTorchのDataLoaderとDatasetを使用することで、簡単にデータの一部(ミニバッチ)をデータから取り出すことができます。
では、DatasetとDataLoaderについてもう少し詳しく説明します。
PyTorchのDatasetとは
PyTorchのDatasetとは、PyTorchが扱いやすいように、データを一つにまとめたデータベースのようなものです。
このDatasetをDataLoaderに渡すことで、データベースからミニバッチを簡単に取り出すことができるようになります。
Datasetの詳しい解説は下記を参考にしてください。
PyTochのDataLoaderとは
DataLoaderは、Datasetを受け取ってデータからミニバッチを取り出すことができます。
DataLoaderは、配列ではなくイテラブルとなり、For文等でデータを取り出すことができます。
主に、DataLoaderを使用するメリットは以下の2つです。
- ミニバッチでDataをランダムに取り出すことができる
- PyTorchで学習するための型で返してくれる
イテラブルとは、『繰り返し可能なオブジェクト』のことです。
基本的には、『for i in 【イテラブル】』のような使い方ができるものだと思っておけば問題はありません。
PyTorchのDataLoaderの使い方
ここからは、具体的な例を使用し、PyTorchの動作を詳しく解説します。
以下の流れで解説していきます。
- Datasetの作成
- DataLoaderの使い方
- DataLoaderの動作確認
- ミニバッチを確認しながらDataを取り出す方法
ここからの話は『Dataset』の基本的な理解が必要です。
詳しい理解は不要ですが、『Dataset』を聞いたことのない読者は下記を一読してから読み進めてください。
Datasetの作成
まずは、サンプルのデータをロードし、Datasetを作成します。
今回は、手書き数字のMNISTデータセットを使用します。
MNIST = torchvision.datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
中身を簡単に確認しておきましょう。
<Input>
print(MNIST)
<output>
Dataset MNIST
Number of datapoints: 60000
Root location: ./data
Split: Train
StandardTransform
Transform: ToTensor()
では、DataLoaderを次節で紹介します。
DataLoaderの使い方
DataLoaderは、Datasetを入力して作成することができます。
具体的なコードを見てみましょう。
Loader = torch.utils.data.DataLoader(dataset=MNIST,
batch_size=10000,
shuffle=True,
num_workers=2)
これで、DataLoaderを作成することができます。
具体的な引数について説明します。
DataLoader | |
---|---|
第一引数 | Dataset |
batch_size | ミニバッチの個数 |
shuffle | Trueにするとミニバッチをランダムに取り出す |
num_workers | 複数処理するかどうか |
では、次節で動作を確認していきましょう。
DataLoaderの動作確認
では、DataLoaderの確認をしていきます。
DataLoaderはイテラブルなので、For文を使って取り出すことができます。
実際に取り出してみましょう。
for images, label in Loader:
images, label = images, label
images, labeになりが入っている確認してみましょう。
<Input>
# バッチサイズ, チャンネル数, 縦, 横
print(images.size())
# ラベル
print(label.size())
<Output>
torch.Size([10000, 1, 28, 28])
torch.Size([10000])
imageは、『バッチサイズ, チャネル(今回は白黒だから1), 縦, 横』の順番で取り出せされます。
ミニバッチ内の1番目に対応する数字を具体的に可視化してみてみましょう。
<Input>
# 可視化して確認する
fig, ax = plt.subplots()
ax.imshow(images[0][0])
ax.axis('off')
ax.set_title(f'images, label={label[0]}', fontsize=20)
plt.show()
<Output>
ミニバッチを確認しながらDataを取り出す方法
どのミニバッチを使用しているのか確認できるようにするためには、『enumerate関数』を使用します。
『enumerate関数?』という方は下記を参考にしてください。
では、具体的なコードで確認しましょう。
<Input>
# どのMini-Batchを使用しているのかを明らかにする
for i , (images, label) in enumerate(Loader):
images = images
label = label
print(i)
<output>
0
1
2
3
4
5
ここまでが基本的な使い方です。
次章では、ミニバッチが割り切れない時の対処方法を説明します。
DataLoaderのミニバッチが割り切れない場合の対処方法
全データセットに対して、ミニバッチサイズが割り切れないとき、最後のミニバッチデータセットは中途半端な数になってしまいます。
そのようなミニバッチデータセットは、学習に悪い影響を与えることがあるので基本的には取り除きます。
PyTorchのDataLoaderの場合、割り切れなかったミニバッチデータセットを除去するためには、『drop_last』をTrueにすることで除去することができます。
今回は、60000枚の画像なので、ミニバッチデータセットを10000枚にした上述例の場合、割り切れるので6つのミニバッチデータセットができました。
今回は、ミニバッチデータセットを10001枚に設定して『drop_last』の挙動を確認してみます。
<Input>
# 割り切れない場合
Loader = torch.utils.data.DataLoader(dataset=MNIST,
batch_size=10001,
shuffle=True,
num_workers=2,
drop_last = True)
# 最後にデータ数の異なるミニバッチを除去
for i , (images, label) in enumerate(Loader):
images = images
label = label
print(i)
<Output>
0
1
2
3
4
このように4つの10001枚のミニバッチデータセットが作成され、最後の9997枚のミニバッチデータセットは除去されました。
これで、問題なくDataLoaderを使用することができますね。
PyTorchのDataLoaderの注意事項
PyTorchのDataLoaderはイテラブルなので、全てのデータを取り出すまで、一度取り出したデータを再び取り出すことができません。
そのため中途半端に取り出してしまうと、後ほどの学習で全てのデータが使用できないです。
動作を確認するときは注意してください。
まとめ
今回は、DataLoaderの使用方法に焦点を当てて解説しました。
皆様も具体例を使ってDataLoaderを動かしてみてください。
そうすることで、DataLoaderの理解度が格段に上がります。
本記事が、皆様にとって有益であることを願います…
具体的に、ディープラーニングを実装してみたいという方は下記を参考にしてください。
Pythonを学習するのに効率的なサービスを紹介していきます。
まず最初におすすめするのは、Udemyです。
Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。
また、セール中は1500円定義で利用することができ、コスパも最強です。
下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。
他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。
自分の学習スタイルに合わせて最適なものを選びましょう。
また、私がPythonを学ぶ際に使用した本を全て暴露しているので参考にしてください。