VGG16 転移学習

Open in Colab

torchvision パッケージには VGG や ResNet などのような有名なアーキテクチャが実装されている。また、これらのアーキテクチャを ImageNet データセットで学習させた重みも用意されている。学習済みのアーキテクチャを利用して、自分たちのデータセットで再学習させることで、学習時間を大きく短縮させることができる。また、場合によって、精度の向上も期待できる。このように、あるデータセットを使って学習させたモデルに、別のデータセットを与えて再学習させることを転移学習やファインチューニングと読んだりする。このページでは、転移学習・ファインチューニングの進め方を示す。

<p>必要なパッケージ群を読み込む。matplotlib は学習精度をグラフとして可視化する時に利用する。torch および torchvision は物体分類用のモデルを作成する際に利用する。


import matplotlib.pyplot as plt
import torch
import torchvision

学習用画像および検証用画像を、PyTorch の ImageFolder 関数で読み込めるようなフォルダ構造になるように整理する。ここでは PlantVillage 画像データの一部(ダウンロード)を使う。

data_transforms = {
    'train': torchvision.transforms.Compose([
        torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


train_data_dir = 'drive/My Drive/datasets/plantvillage/train'
valid_data_dir = 'drive/My Drive/datasets/plantvillage/valid'


image_datasets = {
  'train': torchvision.datasets.ImageFolder(train_data_dir, transform=data_transforms['train']),
  'valid': torchvision.datasets.ImageFolder(valid_data_dir, transform=data_transforms['valid'])
}

dataloaders = {
  'train': torch.utils.data.DataLoader(image_datasets['train'], batch_size=16, shuffle=True),
  'valid': torch.utils.data.DataLoader(image_datasets['valid'], batch_size=16)
}

dataset_sizes = {
    'train': len(image_datasets['train']),
    'valid': len(image_datasets['valid'])
}

class_names = image_datasets['train'].classes
## c_0 c_1 c_2 c_3 c_4

転移学習(例1)

転移学習を行う例を示す。ここでは、ImageNet で学習させた VGG16 のアーキテクチャと重みを torchvision パッケージから読み込む。既存の VGG16 の出力層は 1000 カテゴリであるので、この出力層を PlantVillage サンプルデータのカテゴリ数(len(class_names) に修正する。その後、損失関数および学習パラメーターを定義し、学習の準備を整える。

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 出力層の変更
net_ft1 = torchvision.models.vgg16(pretrained=True)
num_ftrs = net_ft1.classifier[6].in_features
net_ft1.classifier[6] = torch.nn.Linear(num_ftrs, len(class_names))

net_ft1 = net_ft1.to(device)

# 損失関数および学習パラメーターの定義
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net_ft1.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

転移学習は、初期状態からの学習に比べて、より少ない学習回数で収束することが多い。例えば、初期状態から学習する場合、40 エポックの学習を行っても、学習精度(訓練精度)が 70% ほどである(実施例)。これに比べて、転移学習の場合は、以下のように 20 エポックの学習を行うだけで、学習精度が 99% 前後になる。

num_epochs = 20
acc_history_ft1 = {'train': [], 'valid': []}

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    for phase in ['train', 'valid']:
        if phase == 'train':
            net_ft1.train()
        else:
            net_ft1.eval()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = net_ft1(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]
        acc_history_ft1[phase].append(epoch_acc)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
fig = plt.figure(figsize=(4,4),dpi=200)
ax = fig.add_subplot()
ax.plot(acc_history_ft1['train'], label='train')
ax.plot(acc_history_ft1['valid'], label='valid')
ax.legend()
ax.set_ylim(0, 1)
fig.show()
VGG16を利用した転移学習s。

転移学習(例2)

上と異なる学習例を示す。VGG16 のモデル構造は、特徴抽出を行う畳み込み層(features、avgpool)と分類を行う全結合層(classifier)の 2 つの部分からなる。上に示した転移学習では、畳み込み層と全結合層のパラメータ全体を再学習させていた。ここでは、畳み込み層のパラメータを固定させて学習できないようにし、全結合層のパラメータのみを再学習させてみる。

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

net_ft2 = torchvision.models.vgg16(pretrained=True)
num_ftrs = net_ft2.classifier[6].in_features
net_ft2.classifier[6] = torch.nn.Linear(num_ftrs, len(class_names))

# 畳み込み層(features, avgpool)部分のパラメーターを固定
for param in net_ft2.features.parameters():
  param.requires_grad = False
for param in net_ft2.avgpool.parameters():
  param.requires_grad = False

パラメータの固定を行った後は、上と同様にして、損失関数と学習パラメータを定義し、学習を進める。

net_ft2 = net_ft2.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net_ft2.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
num_epochs = 20
acc_history_ft2 = {'train': [], 'valid': []}

for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    for phase in ['train', 'valid']:
        if phase == 'train':
            net_ft2.train()
        else:
            net_ft2.eval()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = net_ft2(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]
        acc_history_ft2[phase].append(epoch_acc)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
fig = plt.figure(figsize=(4,4),dpi=200)
ax = fig.add_subplot()
ax.plot(acc_history_ft1['train'], label='train')
ax.plot(acc_history_ft1['valid'], label='valid')
ax.legend()
ax.set_ylim(0, 1)
fig.show()
VGG16の畳み込み層のパラメータを固定し、全結合層のパラメータのみを再学習させる。