Programming Lab

【徹底解説】PytorchのDataLoaderの使い方

【徹底解説】PyTorchのDataLoaderの使い方

 

 

  • PyTorchのDataLoaderを何となく使ってるけど、実はよくわかってない…
  • DataLoaderとDatasetの仕組みを勉強しようと思ったけど挫折した…

本記事では、その悩みと疑問を解決していきます。

PyTorchのDataLoaderとDatasetを使用することで、簡単にミニバッチ勾配降下法を実装することができます。

最初は、DataLoaderとDatasetが複雑に感じますが一度理解すると、なくてはならない存在になります。

本記事では、特にDataLoaderに焦点を当てて解説していきます。

『DataLoader, Datasetってなんだ?』という方でも理解できるように解説するので安心してください。

 

 

PyTorchのDatasetとDataLoaderについて

 

PyTorchのDataLoaderとDatasetを使用することで、簡単にデータの一部(ミニバッチ)をデータから取り出すことができます。

では、DatasetとDataLoaderについてもう少し詳しく説明します。

 

PyTorchのDatasetとは

 

PyTorchのDatasetとは、PyTorchが扱いやすいように、データを一つにまとめたデータベースのようなものです。

このDatasetをDataLoaderに渡すことで、データベースからミニバッチを簡単に取り出すことができるようになります。

Datasetの詳しい解説は下記を参考にしてください。

 

PytorchのDatasetを徹底解説(自作データセットも作れる)PyTorchのDataset作成方法を徹底的に解説しました。本記事を読むことで、Numpy, PandasからDatasetを作成したり、自作のDatasetを作成しモジュール化する作業を初心者の方でも理解できるように徹底的に解説しました。...

 

 

PyTochのDataLoaderとは

 

DataLoaderは、Datasetを受け取ってデータからミニバッチを取り出すことができます。

DataLoaderは、配列ではなくイテラブルとなり、For文等でデータを取り出すことができます。

主に、DataLoaderを使用するメリットは以下の2つです。

  1. ミニバッチでDataをランダムに取り出すことができる
  2. PyTorchで学習するための型で返してくれる

 

イテラブルとは、『繰り返し可能なオブジェクト』のことです。

基本的には、『for i in 【イテラブル】』のような使い方ができるものだと思っておけば問題はありません。

 

PyTorchのDataLoaderの使い方

 

ここからは、具体的な例を使用し、PyTorchの動作を詳しく解説します。

以下の流れで解説していきます。

  1. Datasetの作成 
  2. DataLoaderの使い方
  3. DataLoaderの動作確認
  4. ミニバッチを確認しながらDataを取り出す方法

 

ここからの話は『Dataset』の基本的な理解が必要です。

詳しい理解は不要ですが、『Dataset』を聞いたことのない読者は下記を一読してから読み進めてください。

 

PytorchのDatasetを徹底解説(自作データセットも作れる)PyTorchのDataset作成方法を徹底的に解説しました。本記事を読むことで、Numpy, PandasからDatasetを作成したり、自作のDatasetを作成しモジュール化する作業を初心者の方でも理解できるように徹底的に解説しました。...

 

Datasetの作成

 

まずは、サンプルのデータをロードし、Datasetを作成します。

今回は、手書き数字のMNISTデータセットを使用します。

MNIST = torchvision.datasets.MNIST(root='./data',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)

 

中身を簡単に確認しておきましょう。

<Input>

print(MNIST)

 

<output>

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

 

では、DataLoaderを次節で紹介します。

 

DataLoaderの使い方

 

DataLoaderは、Datasetを入力して作成することができます。

具体的なコードを見てみましょう。

Loader = torch.utils.data.DataLoader(dataset=MNIST, 
                                         batch_size=10000, 
                                         shuffle=True, 
                                         num_workers=2)

 

これで、DataLoaderを作成することができます。

具体的な引数について説明します。

DataLoader  
第一引数 Dataset
batch_size ミニバッチの個数
shuffle Trueにするとミニバッチをランダムに取り出す
num_workers 複数処理するかどうか

 

では、次節で動作を確認していきましょう。

 

DataLoaderの動作確認

 

では、DataLoaderの確認をしていきます。

DataLoaderはイテラブルなので、For文を使って取り出すことができます。

実際に取り出してみましょう。

for images, label in Loader:
    images, label = images, label

 

images, labeになりが入っている確認してみましょう。

<Input>

# バッチサイズ, チャンネル数, 縦, 横
print(images.size())

# ラベル
print(label.size())

 

<Output>

torch.Size([10000, 1, 28, 28])
torch.Size([10000])

 

imageは、『バッチサイズ, チャネル(今回は白黒だから1), 縦, 横』の順番で取り出せされます。

ミニバッチ内の1番目に対応する数字を具体的に可視化してみてみましょう。

<Input>

# 可視化して確認する
fig, ax = plt.subplots()
ax.imshow(images[0][0])
ax.axis('off')
ax.set_title(f'images, label={label[0]}', fontsize=20)
plt.show()

 

<Output>

minibatchのデータサンプル

 

ミニバッチを確認しながらDataを取り出す方法

 

どのミニバッチを使用しているのか確認できるようにするためには、『enumerate関数』を使用します。

『enumerate関数?』という方は下記を参考にしてください。

 

enumerate関数の説明はこちら

では、具体的なコードで確認しましょう。

<Input>

# どのMini-Batchを使用しているのかを明らかにする
for i , (images, label) in enumerate(Loader):
    images = images
    label = label
    print(i)

 

<output>

0
1
2
3
4
5

 

ここまでが基本的な使い方です。

次章では、ミニバッチが割り切れない時の対処方法を説明します。

 

DataLoaderのミニバッチが割り切れない場合の対処方法

 

全データセットに対して、ミニバッチサイズが割り切れないとき、最後のミニバッチデータセットは中途半端な数になってしまいます。

そのようなミニバッチデータセットは、学習に悪い影響を与えることがあるので基本的には取り除きます。

PyTorchのDataLoaderの場合、割り切れなかったミニバッチデータセットを除去するためには、『drop_last』をTrueにすることで除去することができます。

今回は、60000枚の画像なので、ミニバッチデータセットを10000枚にした上述例の場合、割り切れるので6つのミニバッチデータセットができました。

今回は、ミニバッチデータセットを10001枚に設定して『drop_last』の挙動を確認してみます。

<Input>

# 割り切れない場合
Loader = torch.utils.data.DataLoader(dataset=MNIST, 
                                         batch_size=10001, 
                                         shuffle=True, 
                                         num_workers=2,
                                         drop_last = True)

# 最後にデータ数の異なるミニバッチを除去
for i , (images, label) in enumerate(Loader):
    images = images
    label = label
    print(i)

 

<Output>

0
1
2
3
4

 

このように4つの10001枚のミニバッチデータセットが作成され、最後の9997枚のミニバッチデータセットは除去されました。

これで、問題なくDataLoaderを使用することができますね。

 

PyTorchのDataLoaderの注意事項

 

PyTorchのDataLoaderはイテラブルなので、全てのデータを取り出すまで、一度取り出したデータを再び取り出すことができません。

そのため中途半端に取り出してしまうと、後ほどの学習で全てのデータが使用できないです。

動作を確認するときは注意してください。

 

まとめ

 

今回は、DataLoaderの使用方法に焦点を当てて解説しました。

皆様も具体例を使ってDataLoaderを動かしてみてください。

そうすることで、DataLoaderの理解度が格段に上がります。

本記事が、皆様にとって有益であることを願います…

具体的に、ディープラーニングを実装してみたいという方は下記を参考にしてください。

 

【入門】PyTorchの使い方をMNISTデータセットで学ぶ(15分)本記事では、MNISTデータセットを利用して、PyTorchの実装の流れを紹介しました。本記事を通して読めば、PyTorchを実務で使用するイメージが確実に身につくと思います。...
ABOUT ME
努力のガリレオ
【運営者】 : 東大で理論物理を研究中(経歴)東京大学, TOEIC950点, NASA留学, カナダ滞在経験有り, 最優秀塾講師賞, オンライン英会話講師試験合格, ブログと独自コンテンツで収益6桁達成 【編集者】: イングリッシュアドバイザーとして勤務中(経歴)中学校教諭一種免許取得[英語],カナダ留学経験あり, TOEIC650点
Python学習を効率化させるサービス

 

Pythonを学習するのに効率的なサービスを紹介していきます。

まず最初におすすめするのは、Udemyです。

Udemyは、Pythonに特化した授業がたくさんあり、どの授業も良質です。

また、セール中は1500円定義で利用することができ、コスパも最強です。

下記の記事では、実際に私が15個以上の講義を受講して特におすすめだった講義を紹介しています。

 

【最新】UdemyでおすすめのPythonコース|東大生が厳選!10万を超える講座があるUdemyの中で、Pythonに関係する講座を厳選しました。また、本記事では、Udemyを使用しながらPythonをどのような順番で勉強するべきかを紹介しました。ぜひ参考にしてください。...

 

他のPythonに特化したオンライン・オフラインスクールも下記の記事でまとめています。

 

【最新】Pythonに強いプログラミングスクール7選|東大生が厳選Pythonの流行と共に、Pythonに強いプログラミングスクールが増えてきました。本記事では、特にPythonを効率的に学ぶことができるプログラミングスクールを経験をもとに厳選して、内容を詳しく解説しています。...

 

自分の学習スタイルに合わせて最適なものを選びましょう。