大量な画像から学習を行う方法

Keras を使って学習を進めるとき、画像の枚数が少なければ、画像をすべて読み込んだ上で、学習が速く進む。しかし、画像が大量にあるとき、メモリ容量の関係で、それらの画像をすべて読み込めない場合がある。このような場合、大量の画像を少量ずつで読み込んで学習する、というサイクルを繰り返して学習を進めることができる。

Keras では、画像を少量ずつ読み込んで学習するために、画像を少量ずつ読み込んで学習データに変換するというジェネレーターを定義する必要がある。このジェネレーターを、下では BatchGenerator の名前で定義している。このジェネレーターでは、概ね次のような処理を行なっている。

  • ジェネレーターが新たに作られると __init__ が呼び出されて、画像ファイルへのパス image_path とラベル label がそれぞれジェネレーターの xy にセットされる。続いて、学習時に入力する画像のサイズ image_shape とバッチサイズ batch_size もジェネレーターにセットされる。
    • image_path は 1 次元リストで、各要素は文字列で、画像ファイルへのパスとなっている。
    • label は、image_path の画像がどのクラスに属しているのかを示している。ただし、文字列としてのラベルではなく、ラベルをすでに One-hot 表現に変換したデータである必要がある。labelimage_path の各要素が対応している。
  • 学習開始後、このジェネレーターの __getitem__ メソッドが繰り返し実行される。具体的に 1 バッチごとに 1 回実行され、すべての画像データ(x)に対して一回り実行が完了したとき、1 エポックとなる。
    • __getitem__ で、x のどこまでがすでに学習済みかをチェックし、まだ学習していないデータから 1 バッチ分(batch_size)読み込んで、前処理を施して、その結果を返している。
    • 画像を読み込んだ直後では、画像は 0 〜 255 までの整数値で表されている。このまま学習に回すと、値が大きく、学習に悪影響を及ぼすので、全体を 255 で割って、値の範囲を 0 〜 1 の小数値に標準化する。
import cv2
import keras

class BatchGenerator(keras.utils.Sequence):

    def __init__(self, image_path, label, image_shape=(299, 299, 3), batch_size=64):

        self.x = image_path
        self.y = label
        self.length = len(image_path)

        self.batch_size = batch_size
        self.image_shape = image_shape
        self.batches_per_epoch = int((self.length - 1) / batch_size) + 1



    def __getitem__(self, idx):

        batch_from = self.batch_size * idx
        batch_to   = batch_from + self.batch_size

        if batch_to > self.length:
            batch_to = self.length

        x_batch = []
        y_batch = []
        for i in range(batch_from, batch_to):
            img = cv2.imread(self.x[i], cv2.IMREAD_COLOR)
            x_batch.append(img)
            y_batch.append(self.y[i])

        x_batch = np.asarray(x_batch)
        x_batch = x_batch.astype('float32') / 255.0
        y_batch = np.asarray(y_batch)

        return x_batch, y_batch



    def __len__(self):
        return self.batches_per_epoch


    def on_epoch_end(self):
        pass

ジェネレーターを定義してから、次に、このジェネレーターを呼び出して、Keras の fit_generator というメソッドで学習を進めることができるようになる。

import glob
import keras


best_model_path = 'best_model.h5'
final_model_path = 'final_model.h5'
image_shape = (299, 299, 3)
batch_size = 64


# prepare train and test data sets
## x_train = ['sample1.jpg', 'sample2.jpg', ..., 'samplex.jpg']
## y_train = [[1, 0],        [0, 1],        ..., [1, 0]]
## x_test = ['test1.jpg', 'test2.jpg', ..., 'testx.jpg']
## y_test = [[1, 0],      [1, 0],     ..., [0, 1]]

train_batch_generator = BatchGenerator(x_train, y_train, image_shape, batch_size)
test_batch_generator = BatchGenerator(x_test, y_test, image_shape, batch_size)


# create a model using Keras
## model = ...


# start training
chk_point = keras.callbacks.ModelCheckpoint(filepath = best_model_path, monitor='val_loss',
                                            verbose=1, save_best_only=True, save_weights_only=False,
                                            mode='min', period=1)

fit_history = model.fit_generator(train_batch_generator, epochs=epochs,
                                  steps_per_epoch=train_batch_generator.batches_per_epoch,
                                  verbose=1,
                                  validation_data=test_batch_generator,
                                  validation_steps=test_batch_generator.batches_per_epoch,
                                  shuffle=True,
                                  callbacks=[chk_point])
model.save(final_model_path)