このページでは、sgRNA の 5' 末端の 21 塩基の並びと切断効率を調べた実験データを使って、21 塩基を入力して切断効率を予測するモデルを bidirectional LSTM で構築する。sgRNA と切断効率の紹介は LSTM のページを参照。
データセット
Wang らの論文には数万件の実験データがあり、公開されている (Wang et al., 2019)。GitHub 上にはデータとともに予測モデルのソースコード(Keras)が公開されている。ここで、少し整形したデータセットを用いてモデルの構築を行う。まず、データをダウンロードして、展開する。このデータを展開すると train.tsv と valid.tsv の 2 つのデータファイルがみられる。それぞれ訓練データと検証データとして使用する。
wget https://aabbdd.jp/notes/data/crisprcas9.tar.gz
tar xzvf crisprcas9.tar.gz
ls crisprcas9
## README.md train.tsv valid.tsv
head crisprcas9/train
## 0.698979591836735 3 0 0 2 2 2 3 0 2 1 0 13 1 1 2 3 2 2 1 3
## 0.801237623762376 3 0 1 0 2 1 1 0 2 1 2 01 0 0 3 1 2 0 3 3
## 0.936394354999006 0 0 3 1 2 0 1 2 0 2 0 31 0 0 1 0 2 2 2 3
train.tsv と valid.tsv はタブ区切りのテキストファイルである。1 列目は、教師ラベルにあたる切断効率である。2 列目以降は、特徴量にあたる sgRNA の塩基配列の並びである。ただし、塩基 A, T, C, G は 0, 1, 2, 3 に変化して表示してある。
モデル構築(Pytorch / bidirectional LSTM)
Pytorch で bidirectional LSTM のモデルを構築する。まず、使用する Python のモジュールをすべて読み込む。
import os
import gzip
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
次にバッチ学習を行うための Dataset クラスを定義する。この Dataset クラス(readDataset
)を、train.tsv または valid.tsv を読み込んで、1 列目のデータを教師ラベルとして y
に保存し、2 列目のデータを特徴量として X
に保存するように設計してある。
class readDataset(torch.utils.data.Dataset):
def __init__(self, data_fpath):
self.X, self.y = self.__load_data(data_fpath)
def __load_data(self, data_fpath):
data = np.loadtxt(data_fpath, delimiter = '\t')
y = torch.from_numpy(data[:, 0]).float()
y = y.view(-1, 1)
X = torch.from_numpy(data[:, 1:]).long()
return X, y
def __len__(self):
return self.X.shape[0]
def __getitem__(self, i):
return self.X[i], self.y[i]
次に bidirectional LSTM のアーキテクチャを次のように設計する。基本的に LSTM のアーキテクチャとほぼ同じであるが、LSTM
のオプションでbidirectional=True
を設定する必要がある。
- 21 塩基(21 単語)の入力を受け取り、64 次元への単語埋め込みを行う。今回使ったデータには 0, 1, 2, 3 (A, T, C, G) の四つの単語だけからなるので、
Embedding
では単語数を4
と指定してある。 - 単語埋め込み後の 64 次元データを bidirectional LSTM に代入する。LSTM の隠れユニットを 256 個とする。PyTroch の
LSTM
は、各状態の出力と最後の状態(隠れ層の状態とセルの状態)を出力が、このうち、最後の隠れ層の状態hidden_state
を次の層に与える。また、bidirectional LSTM であるため、前方向と逆方向の出力があるため、LSTM の 2 倍の出力がある。 - LSTM の出力を全結合層に代入する。全結合層のユニット数を 512 個とする。このような全結合層を 2 層用意する。
- 全結合層の出力を、入力として受け取り、1 つの値を出力する出力層を設ける。
class GenomicBiLSTM(torch.nn.Module):
def __init__(self, embedding_dim=128, lstm_dim=256, fc_dim=512):
super(GenomicBiLSTM, self).__init__()
# embedding
self.embedding = torch.nn.Embedding(4, embedding_dim)
# LSTM
self.lstm = torch.nn.LSTM(embedding_dim, lstm_dim, batch_first=True,
bidirectional=True)
for name, param in self.lstm.named_parameters():
if 'bias' in name:
torch.nn.init.constant_(param, 0.0)
elif 'weight_ih' in name:
torch.nn.init.kaiming_normal_(param)
elif 'weight_hh' in name:
torch.nn.init.orthogonal_(param)
# FC
self.fc_dropout = torch.nn.Dropout(0.5)
self.fc1 = torch.nn.Linear(lstm_dim * 2, fc_dim)
self.fc2 = torch.nn.Linear(fc_dim, fc_dim)
# Output
self.fc3 = torch.nn.Linear(fc_dim, 1)
def forward(self, x):
embeds = self.embedding(x)
lstm_output, (hidden_state, cell_state) = self.lstm(embeds)
y = torch.cat([hidden_state[0], hidden_state[1]], dim=1)
y = torch.nn.functional.relu(self.fc1(y))
y = self.fc_dropout(y)
y = torch.nn.functional.relu(self.fc2(y))
y = self.fc_dropout(y)
y = self.fc3(y)
return y
バッチサイズなどを決め、訓練データと検証データを読み込む。
batch_size = 2048
train_dataset = 'crisprcas9/train.tsv'
valid_dataset = 'crisprcas9/valid.tsv'
datasets = {
'train': readDataset(train_dataset),
'valid': readDataset(valid_dataset)
}
dataloader = {
'train': torch.utils.data.DataLoader(datasets['train'],
batch_size=batch_size, shuffle=True),
'valid': torch.utils.data.DataLoader(datasets['valid'],
batch_size=batch_size, shuffle=False)
}
dataset_sizes = {
'train': len(datasets['train']),
'valid': len(datasets['valid']),
}
モデルの設計図からモデルの実体を作成し、訓練と検証を交互に 40 エポックを行う。この予測モデルは切断効率を予測するモデルであるから、損失関数として平均二乗誤差を用いる。
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GenomicBiLSTM()
model.to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
epoch_loss = {'train': [], 'valid': []}
epochs = 40
for epoch in range(epochs):
print('#epoch {}/{}'.format(epoch + 1, epochs))
for phase in ['train', 'valid']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0
# mini-batch
for inputs, labels in dataloader[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
if phase == 'train':
scheduler.step()
print('loss: {}'.format(running_loss / dataset_sizes[phase]))
epoch_loss[phase].append(running_loss / dataset_sizes[phase])
## #epoch 1/40
## loss: 0.10379909208416939
## loss: 0.03999912516077359
## #epoch 2/40
## loss: 0.04331813626885414
## loss: 0.03559660889705022
## ...
## #epoch 39/40
## loss: 0.01410519671589136
## loss: 0.013499775982399782
## #epoch 40/40
## loss: 0.014087982614338398
## loss: 0.01331430471688509
最後にもう一度検証を行って、そのときの実験値と予測値の散布図を図示する。モデルの汎化性能を正確に評価するのが目的であれば、この部分は(訓練データ・検証データとは異なる)新たに実験を行って集めたテストデータを使うべきである。ここでは、便宜上、検証データを使用している。
test_dataset = 'crisprcas9/valid.tsv'
model.eval()
datasets = readDataset(test_dataset)
dataloader = torch.utils.data.DataLoader(datasets, batch_size=batch_size, shuffle=False)
pred = None
true = None
for inputs, labels in dataloader:
inputs = inputs.to(device)
with torch.set_grad_enabled(False):
outputs = model(inputs).data.cpu().numpy()
if pred is None:
pred = outputs
true = labels
else:
pred = np.concatenate([pred, outputs], 0)
true = np.concatenate([true, labels], 0)
df = pd.DataFrame({'pred': pred.reshape(-1), 'true': true.reshape(-1)})
fig = plt.figure(figsize=(6, 6), dpi=220)
ax = fig.add_subplot()
ax.scatter(df.iloc[:, 0], df.iloc[:, 1])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('square')
fig.show()
この散布図から、モデルの学習が正しく行われていることがわかる。ここからは、単語埋め込みの次元数、bidirectional LSTM の隠れ層のユニット数や全結合層の数やユニット数を調整して、モデルを改善していく必要がある。
References
- Optimized CRISPR Guide RNA Design for Two High-Fidelity Cas9 Variants by Deep Learning. Nat Commun. 2019, 10(1):4284. DOI: 10.1038/s41467-019-12281-8