pytorch Dataloader

物体分類においてミニバッチ学習を行うとき、torch/torchvision の ImageFolder 関数で画像を読み込みを行い、DataLoader 関数でミニバッチの制御を行っている。ImageFolder 関数を使用するとき、画像がカテゴリごとに整理されている必要がある。例えば、train というフォルダを作成し、train フォルダの中に apple、orange、cherry などのサブフォルダを作成し、各サブフォルダの下に該当する画像を配置する。こうして整理した train フォルダへのパスを ImageFolder 関数に代入することで、ImageFolder 関数はフォルダ構造を解釈して、画像を学習に使えるように準備する。

画像データが整理されていない場合、あるいは独自のフォルダ構造で整理されている場合は ImageFolder 関数が使えなくなる。このとき、ImageFolder を独自に定義する必要がある。この定義は Dataset クラスを継承して定義する必要がある。話をまとめると、物体分類では ImageFolderDataLoader を使用してミニバッチを制御しているが、実際には DatasetDataLoader の機能を利用しているということになる。

PyTorch を利用したミニバッチ学習では、次のような流れとなっている。

  1. Dataset を利用してデータ(特徴量および教師ラベル)を読み込む。Dataset は、i 番目のデータがリクエストされたときに、i 番目のデータ(特徴量および教師ラベル)を返す。ただし、データに対する前処理 transforms を行いたい場合は、前処理を行ってから返す。
  2. DatasetDataLoader に代入する。DataLoader がミニバッチ学習できるように、データの枚数などを自動的に調整する。

このうち、前処理を行う transforms およびデータの読み込みを行う Dataset を、学習データの形式に合わせて独自に定義して使用できる。

torch.utils.data.Dataset

Dataset はデータを読み込んで、前処理などを行うためのクラス(高度な関数)となっている。Dataset を独自に定義する場合は、次のようなことに注意して定義する。

  • Dataset クラスには全データを参照するためのリストを保持してある。このリストは、全画像へのパスからなるリストでも、全画像を読み込んだ後の数値データでもよい。Dataset に対して len 関数を実行したときに、全データの件数を出力するような機能を備えている。
  • i 番目のデータへの問い合わせがある場合、i 番目のデータを返す機能を備えている。この場合、特徴量だけ返してもよいし、特徴量と教師ラベルを同時に返してもいい。

画像パスと教師ラベルのテキストファイル読み込み

ここで Dataset の作り方として、画像パスと教師ラベルが記載されているテキストファイルを読み込んで、データを整理する Dataset を定義する例を示す。このテキストファイルは、タブ区切りで、以下のように画像パスが 1 列目、教師ラベルが 2 列目に記載されているものとする。教師ラベルはゼロから始まる整数で記述されているものとする。例えば、apple ならば 0、orange ならば 1 などのように、この対応関係を自分で把握しておく必要がある。また、ほとんどの機械学習パッケージでは、教師ラベルを one-hot 表現で表す必要があるが、pytorch の場合は one-hot 表現に変更する必要はない。

/home/user/train/00/a894mca.jpg	0
/home/user/train/00/cdnkj48.jpg	1
/home/user/train/00/k959skk.jpg	1
/home/user/train/01/mi3487c.jpg	0
/home/user/train/01/lkddi9c.jpg	2

このテキストファイルを処理する Dataset クラスを作成する。このクラスは torch.utils.data.Dataset を継承して作成する。また、全データ数を返す関数として __len__ 関数を定義し、さらに、番号を受け取りその番号にあたるデータを返す関数として __getitem__ を定義する。

class MyDataset(torch.utils.data.Dataset):

  def __init__(self, label_path, transform=None):
    # テキストファイルを読み込み、このファイルに保存されているすべてのデータに対して、
    # 画像へのパスを変数 x に保存し、教師ラベルを y に保存する。
    x = []
    y = []
    
    # テキストファイルの読み込みと処理
    with open(label_path, 'r') as infh:
      for line in infh:
        d = line.replace('\n', '').split('\t')
        x.append(os.path.join(os.path.dirname(label_path), d[0]))
        y.append(int(d[1]))
    
    # 画像へのパスをプライベート変数に格納
    self.x = x
    
    # 教師ラベルを torch が利用できるデータ型に変換してから格納
    self.y = torch.tensor(y)
     
    # 画像に対する前処理が定義されている場合は、それの定義をプライベート変数に格納
    self.transform = transform
  
  
  def __len__(self):
    # 全データの件数を返す関数
    return len(self.x)
  
  
  def __getitem__(self, i):
    # 全データの中から i 番目の画像と教師ラベルを返す関数
    # 画像については、ファイルパスにある画像を Image.open 関数で読み込んで、行列に変換する。
    # 前処理が定義されているならば前処理を行う
    img = PIL.Image.open(self.x[i]).convert('RGB')
    if self.transform is not None:
      img = self.transform(img)
    
    return img, self.y[i]

RuntimeError: multi-target not supported

自作 Dataset クラスを使用すると、"RuntimeError: multi-target not supported" のようなエラーが起こる場合がある。ほとんどの場合、教師ラベルが one-hot 表現となっていることが原因と考えられる。したがって、次のようなエラーが起きた場合、教師ラベルが one-hot 表現ではなく、ゼロから始まる整数となっていることを再確認するとよい。

"/lib/python3.6/site-packages/torch/nn/functional.py", line 2115, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:18

torch.utils.data.DataLoader

DataLoader 関数は、Dataset クラスで読み込んだデータをミニバッチ学習できるように制御するための関数である。例えば、学習データと検証データの両方を用意する場合は、次のようにする。

train_data_dir = 'drive/My Drive/datasets/mnist/train_labels.tsv'
valid_data_dir = 'drive/My Drive/datasets/mnist/valid_labels.tsv'

trainset = MyDataset(train_data_dir)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=8)

validset = MyDataset(valid_data_dir)
validloader = torch.utils.data.DataLoader(validset, batch_size=64, shuffle=False, num_workers=8)

DataLoader 関数は、batch_sizeshufflenum_workers などの引数を持つ。

  • batch_size はミニバッチのサイズを制御する引数であり、画像の入力サイズや GPU メモリ容量に応じて決める。
  • shuffle はエポック開始時にデータをシャッフルするかどうかを決める引数である。一般的に、学習を行う際に、エポックごとにデータをシャッフルした方が極小解に陥りにくいので、シャッフルが推奨される。
  • num_workers はデータの読み込み時に利用するプロセス数を決める。数を多くすれば、データの読み込みが早くなるが、学習に使用するコンピューターのプロセッサー数を考慮する。

transforms

transforms はデータの前処理を行うための関数である。画像に対する簡単な前処理については torchvision のパッケージで定義されている。画像のサイズを揃えたり、左右反転したり、画像の行列データを pytorch のテンソル型に変換したり、データを正規化したりするといった基本的な機能は、すべて torchvision で定義されている。例えば、入力画像を 224×224 のサイズに揃えて、ランダムに反転してからテンソル型に変換して、正規化を行ったあとに、学習プロセスに渡す場合はの前処理を次のように定義できる。

transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((240, 320)),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

前処理で、もう少し凝った augmentation 処理を行いたい場合は、独自の transforms を定義することもできる。この際、他の transforms と組み合わせて使えるようにするためには、__call__ 関数を定義する必要がある。

左右反転する transforms

一例として、確率 0.5 で左右反転する transforms を定義する場合は、次のように書く。

class RandomFlip:
    def __init__(self):
        pass

    def __call__(self, img):
        if random.random() > 0:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            return img

この transforms を使用する場合は以下のようにする。

transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((240, 320)),
        RandomFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

画像サイズを揃える前処理

torchvision の Resize を利用して画像のサイズを変更するとき、入力画像のサイズを揃える必要がある。例えば、入力画像がすべて 256×256 ピクセルで、これらを 224×224 ピクセルに変更することは可能。しかし、入力が縦長だったり、横長っだったり、サイズが揃っていなかった場合は、次のようなエラーが起こる。

RuntimeError: stack expects each tensor to be equal size,
 but got [3, 224, 224] at entry 0 and [3, 224, 336] at entry 3

入力画像のサイズが揃っていないときにサイズの変更をしたければ、Resize のような前処理を独自に定義すればよい。例えば、下は、入力画像を指定された大きさの正方形に変更する処理を行う前処理の例を示している。


class SquareResize:
    def __init__(self, shape=224, bg_color = [0, 0, 0]):
        self.shape = shape
        self.bg_color = bg_color

    def __call__(self, img):
        w, h = img.size
        img_square = None

        if w == h:
            img_square = img
        elif w > h:
            img_square = PIL.Image.new(img.mode, (w, w), self.bg_color)
            img_square.paste(img, (0, (w - h) // 2))
        else:
            img_square = PIL.Image.new(img.mode, (h, h), self.bg_color)
            img_square.paste(img, ((h - w) // 2, 0))

        img_square = img_square.resize((self.shape, self.shape))
        return img_square


transforms = torchvision.transforms.Compose([
        SquareResize(224, (0, 0, 0)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

imgaug を利用した前処理

画像の augmentation 用のパッケージ imgaug を利用して前処理を定義する場合は、例えば以下のようにすることができる。

class ImgaugTransform:
    def __init__(self):

        self.aug = ia.augmenters.Sequential([
            ia.augmenters.Resize((224, 224)),
            ia.augmenters.Sometimes(0.1, ia.augmenters.Add((-40, 40))),
            ia.augmenters.Sometimes(0.1, ia.augmenters.AdditiveGaussianNoise(scale=(0, 0.2*255))),
            ia.augmenters.Sometimes(0.1, ia.augmenters.GaussianBlur(sigma=(0, 3.0))),
            ia.augmenters.Fliplr(0.5),
            ia.augmenters.Affine(rotate=(-20, 20), mode='symmetric'),
            ia.augmenters.Sometimes(0.1,
                          ia.augmenters.OneOf([ia.augmenters.Dropout(p=(0, 0.1)),
                                               ia.augmenters.CoarseDropout(0.1, size_percent=0.5)])),
            ia.augmenters.Cutout(nb_iterations=(1, 5), size=0.2, squared=False, fill_mode='gaussian', fill_per_channel=True),
            ia.augmenters.AddToHueAndSaturation(value=(-10, 10), per_channel=True)
        ])

    def __call__(self, img):
        img = np.array(img)
        return self.aug.augment_image(img)