Programming lab

【徹底解説】PytorchのDataLoaderの使い方

【徹底解説】PytorchのDataLoaderの使い方

 

 

  • PytorchのDataLoaderを何となく使ってるけど、実はよくわかってない…
  • DataLoaderとDatasetのせいでPytorchが嫌いになった…

本記事では、その悩みを解決してきます。

PyTorchのDataLoaderとDatasetを使用することで、簡単にミニバッチ勾配降下法を実装することができます。

最初は、DataLoaderとDatasetが複雑に感じますが一度理解してしまうと、なくてはならない存在になります。

本記事では、DataLoaderに焦点を当てて解説していきます。

『DataLoader, Datasetってなんだ?』という方でも理解できるように解説するので安心してください。

 

PytorchのDatasetとDataLoaderについて

 

PytorchのDataLoaderとDatasetを使用することで、簡単にデータの一部(ミニバッチ)をデータから取り出すことができます。

では、DatasetとDataLoaderについてもう少し詳しく説明します。

 

PytorchのDatasetとは

 

PytorchのDatasetとは、データを一つにまとめたデータベースのようなものです。

このDatasetをDataLoaderに渡すことで、データベースからミニバッチを取り出すことができるようになります。

Datasetの解説に関しては、下記を参考にしてください。

 

PytorchのDatasetを徹底解説(自作データセットも作れる)PyTorchのDataset作成方法を徹底的に解説しました。本記事を読むことで、Numpy, PandasからDatasetを作成したり、自作のDatasetを作成しモジュール化する作業を初心者の方でも理解できるように徹底的に解説しました。...

 

PytochのDataLoaderとは

 

DataLoaderは、Datasetを受け取ってデータからミニバッチを取り出すことができます。

DataLoaderは、配列ではなくイテラブルとなり、For文等でデータを取り出すことができます。

DataLoaderのメリットは以下の2つです。

  1. ミニバッチでDataをランダムに取り出すことができる
  2. Pytorchで学習するための型で返してくれる

 

PytorchのDataLoaderの使い方

 

本節では、DataLoaderの使い方を具体的な例を使って解説します。

以下の流れで解説していきます。

  1. Datasetの作成 
  2. DataLoaderの使い方
  3. DataLoaderの動作確認
  4. ミニバッチを確認しながらDataを取り出す方法

 

『Datasetがよくわかない』という方は下記を読んでから読み進めてください。

 

PytorchのDatasetを徹底解説(自作データセットも作れる)PyTorchのDataset作成方法を徹底的に解説しました。本記事を読むことで、Numpy, PandasからDatasetを作成したり、自作のDatasetを作成しモジュール化する作業を初心者の方でも理解できるように徹底的に解説しました。...

 

Datasetの作成

 

まずは、サンプルのデータをロードし、Datasetを作成します。

今回は、手書き数字のMNISTデータセットを使用します。

MNIST = torchvision.datasets.MNIST(root='./data',
                                        train=True,
                                        transform=transforms.ToTensor(),
                                        download=True)

 

中身を簡単に確認しておきましょう。

<Input>

print(MNIST)

 

<output>

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

 

では、DataLoaderを次節で紹介します。

 

DataLoaderの使い方

 

DataLoaderは、Datasetを入力して作成することができます。

具体的なコードを見てみましょう。

Loader = torch.utils.data.DataLoader(dataset=MNIST, 
                                         batch_size=10000, 
                                         shuffle=True, 
                                         num_workers=2)

 

これで、DataLoaderを作成することができます。

具体的な引数について説明します。

DataLoader  
第一引数 Dataset
batch_size ミニバッチの個数
shuffle Trueにするとミニバッチをランダムに取り出す
num_workers 複数処理するかどうか

 

では、次節で動作を確認していきましょう。

 

DataLoaderの動作確認

 

では、DataLoaderの確認をしていきます。

DataLoaderはイテラブルなので、For文を使って取り出すことができます。

実際に取り出してみましょう。

for images, label in Loader:
    images, label = images, label

 

images, labeになりが入っている確認してみましょう。

<Input>

# バッチサイズ, チャンネル数, 縦, 横
print(images.size())

# ラベル
print(label.size())

 

<Output>

torch.Size([10000, 1, 28, 28])
torch.Size([10000])

 

imageは、『バッチサイズ, チャネル(今回は白黒だから1), 縦, 横』の順番で取り出せされます。

ミニバッチ内の1番目に対応する数字を具体的に可視化してみてみましょう。

<Input>

# 可視化して確認する
fig, ax = plt.subplots()
ax.imshow(images[0][0])
ax.axis('off')
ax.set_title(f'images, label={label[0]}', fontsize=20)
plt.show()

 

<Output>

minibatchのデータサンプル

 

ミニバッチを確認しながらDataを取り出す方法

 

どのミニバッチを使用しているのか確認できるようにするためには、『enumerate関数』を使用します。

『enumerate関数?』という方は下記を参考にしてください。

 

enumerate関数の説明はこちら

では、具体的なコードで確認しましょう。

<Input>

# どのMini-Batchを使用しているのかを明らかにする
for i , (images, label) in enumerate(Loader):
    images = images
    label = label
    print(i)

 

<output>

0
1
2
3
4
5

 

ここまでが基本的な使い方です。

次章では、ミニバッチが割り切れない時の対処方法を説明します。

DataLoaderのミニバッチが割り切れない場合の対処方法

 

全データセットに対して、ミニバッチサイズが割り切れないとき、最後のミニバッチデータセットは中途半端な数になってしまいます。

そのようなミニバッチデータセットは、学習に悪い影響を与えることがあるので基本的には取り除きます。

PytorchのDataLoaderの場合、割り切れなかったミニバッチデータセットを除去するためには、『drop_last』をTrueにすることで除去することができます。

今回は、60000枚の画像なので、ミニバッチデータセットを10000枚にした上述例の場合、割り切れるので6つのミニバッチデータセットができました。

今回は、ミニバッチデータセットを10001枚に設定して『drop_last』の挙動を確認してみます。

<Input>

# 割り切れない場合
Loader = torch.utils.data.DataLoader(dataset=MNIST, 
                                         batch_size=10001, 
                                         shuffle=True, 
                                         num_workers=2,
                                         drop_last = True)

# 最後にデータ数の異なるミニバッチを除去
for i , (images, label) in enumerate(Loader):
    images = images
    label = label
    print(i)

 

<Output>

0
1
2
3
4

 

このように4つの10001枚のミニバッチデータセットが作成され、最後の9997枚のミニバッチデータセットは除去されました。

これで、問題なくDataLoaderを使用することができますね。

 

PytorchのDataLoaderの注意事項

 

PyTorchのDataLoaderはイテラブルなので、全てのデータを取り出すまで、一度取り出したデータを再び取り出すことができません。

そのため中途半端に取り出してしまうと、後ほどの学習で全てのデータが使用できないです。

動作を確認するときは注意してください。

 

まとめ

 

今回は、DataLoaderの使用方法に焦点を当てて解説しました。

皆様も具体例を使ってDataLoaderを動かしてみてください。

そうすることで、DataLoaderの理解度が格段に上がります。

本記事が、皆様にとって有益であることを願います…

具体的に、ディープラーニングを実装してみたいという方は下記を参考にしてください。

 

【PyTorch入門】多層パーセプトロンでMNISTを分類本記事では、PyTrochで多層パーセプトロンを構築し、手書き文字のデータセットのMNISTを分類しました。本記事を理解することでPytorchを使用してディープラーニングを構築する基本が身につきます。...
ABOUT ME
努力のガリレオ
【運営者】 : 東大で理論物理を研究中(経歴)東京大学, TOEIC950点, NASA留学, カナダ滞在経験有り, 最優秀塾講師賞, オンライン英会話講師試験合格, ブログと独自コンテンツで収益6桁達成 【編集者】: イングリッシュアドバイザーとして勤務中(経歴)中学校教諭一種免許取得[英語],カナダ留学経験あり, TOEIC650点