CRISPR/Cas9 はゲノム編集技術の一つである。例えば、ある遺伝子を破壊したければ、その遺伝子配列の重要な領域に変異を導入し、その遺伝子を働けなくすればよい。CRIPSR/Cas9 は、このような変異導入に使われる。CRIPSR/Cas9 によるゲノム編集は sgRNA と Cas9 の二つの酵素によって行われる。sgRNA の 5' 末端の 21 塩基は、DNA と相補的に結合する作用をもち、それ以外の部分は sgRNA の基本的な機能を担っている。DNA と相補的に結合した sgRNA は Cas9 との相互作用で、DNA を切断する。切断された DNA はその後修復されるが、細胞周期に依存しない修復であるため、修復過程中に欠損や挿入が頻繁に起こる。したがって、DNA さえ切断できれば、切断箇所で高い頻度で変異が入る。
CRISPR/Cas9 を利用したゲノム編集(変異導入)では、sgRNA と DNA が相補的に結合している部分で行われる。すなわち、sgRNA の 5' 末端にある 21 塩基を正確に設計すれば、破壊したい標的遺伝子のある領域をピンポイントで切断できる。しかし、sgRNA がその遺伝子に相補的に結合したとしても、100% 切断行われると限らない。そのため、一つの遺伝子を破壊したい時、なるべくその遺伝子の切断効率の高いところに sgRNA を結合させたい。そのためには、sgRNA が DNA と相補的に結合する 5' 末端にある 21 塩基をしっかり設計する必要がある。
ここでは、sgRNA の 5' 末端の 21 塩基の並びと切断効率を調べた実験データを使って、21 塩基を入力して切断効率を予測するモデルを構築する。塩基一つ一つを英単語とみなせて、21 塩基の並びを一文とみなせば、自然言語処理のタスクとみなせる。そこで、このページでは、LSTM を使って切断効率を予測するモデルを構築する例を示す。
データセット
Wang らの論文には数万件の実験データがあり、公開されている (Wang et al., 2019)。GitHub 上にはデータとともに予測モデルのソースコード(Keras)が公開されている。ここで、少し整形したデータセットを用いてモデルの構築を行う。まず、データをダウンロードして、展開する。このデータを展開すると train.tsv と valid.tsv の 2 つのデータファイルがみられる。それぞれ訓練データと検証データとして使用する。
wget https://aabbdd.jp/notes/data/crisprcas9.tar.gz
tar xzvf crisprcas9.tar.gz
ls crisprcas9
## README.md train.tsv valid.tsv
head crisprcas9/train
## 0.698979591836735 3 0 0 2 2 2 3 0 2 1 0 13 1 1 2 3 2 2 1 3
## 0.801237623762376 3 0 1 0 2 1 1 0 2 1 2 01 0 0 3 1 2 0 3 3
## 0.936394354999006 0 0 3 1 2 0 1 2 0 2 0 31 0 0 1 0 2 2 2 3
train.tsv と valid.tsv はタブ区切りのテキストファイルである。1 列目は、教師ラベルにあたる切断効率である。2 列目以降は、特徴量にあたる sgRNA の塩基配列の並びである。ただし、塩基 A, T, C, G は 0, 1, 2, 3 に変化して表示してある。
モデル構築(Pytorch / LSTM)
Pytorch で LSTM のモデルを構築する。まず、使用する Python のモジュールをすべて読み込む。
import os
import gzip
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
次にバッチ学習を行うための Dataset クラスを定義する。この Dataset クラス(readDataset
)を、train.tsv または valid.tsv を読み込んで、1 列目のデータを教師ラベルとして y
に保存し、2 列目のデータを特徴量として X
に保存するように設計してある。
class readDataset(torch.utils.data.Dataset):
def __init__(self, data_fpath):
self.X, self.y = self.__load_data(data_fpath)
def __load_data(self, data_fpath):
data = np.loadtxt(data_fpath, delimiter = '\t')
y = torch.from_numpy(data[:, 0]).float()
y = y.view(-1, 1)
X = torch.from_numpy(data[:, 1:]).long()
return X, y
def __len__(self):
return self.X.shape[0]
def __getitem__(self, i):
return self.X[i], self.y[i]
次に LSTM のアーキテクチャを次のように設計する。
- 21 塩基(21 単語)の入力を受け取り、64 次元への単語埋め込みを行う。今回使ったデータには 0, 1, 2, 3 (A, T, C, G) の四つの単語だけからなるので、
Embedding
では単語数を4
と指定してある。 - 単語埋め込み後の 64 次元データを LSTM に代入する。LSTM の隠れユニットを 256 個とする。PyTroch の
LSTM
は、各状態の出力と最後の状態(隠れ層の状態とセルの状態)を出力が、このうち、最後の隠れ層の状態hidden_state
を次の層に与える。 - LSTM の出力を全結合層に代入する。全結合層のユニット数を 512 個とする。このような全結合層を 2 層用意する。
- 全結合層の出力を、入力として受け取り、1 つの値を出力する出力層を設ける。
class GenomicLSTM(torch.nn.Module):
def __init__(self, embedding_dim=64, lstm_dim=256, fc_dim=512):
super(GenomicLSTM, self).__init__()
# embedding
self.embedding = torch.nn.Embedding(4, embedding_dim)
# LSTM
self.lstm = torch.nn.LSTM(embedding_dim, lstm_dim, batch_first=True)
# FC
self.fc_dropout = torch.nn.Dropout(0.5)
self.fc1 = torch.nn.Linear(lstm_dim, fc_dim)
self.fc2 = torch.nn.Linear(fc_dim, fc_dim)
# Dense
self.fc3 = torch.nn.Linear(fc_dim, 1)
def forward(self, x):
embeds = self.embedding(x)
# LSTM
lstm_output, (hidden_state, cell_state) = self.lstm(embeds)
y = hidden_state.view(x.size(0), -1)
# FC
y = torch.nn.functional.relu(self.fc1(y))
y = self.fc_dropout(y)
y = torch.nn.functional.relu(self.fc2(y))
y = self.fc_dropout(y)
y = self.fc3(y)
return y
バッチサイズなどを決め、訓練データと検証データを読み込む。
batch_size = 2048
train_dataset = 'crisprcas9/train.tsv'
valid_dataset = 'crisprcas9/valid.tsv'
datasets = {
'train': readDataset(train_dataset),
'valid': readDataset(valid_dataset)
}
dataloader = {
'train': torch.utils.data.DataLoader(datasets['train'],
batch_size=batch_size, shuffle=True),
'valid': torch.utils.data.DataLoader(datasets['valid'],
batch_size=batch_size, shuffle=False)
}
dataset_sizes = {
'train': len(datasets['train']),
'valid': len(datasets['valid']),
}
モデルの設計図からモデルの実体を作成し、訓練と検証を交互に 30 エポックを行う。この予測モデルは切断効率を予測するモデルであるから、損失関数として平均二乗誤差を用いる。
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GenomicLSTM()
model.to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
epoch_loss = {'train': [], 'valid': []}
epochs = 30
for epoch in range(epochs):
print('#epoch {}/{}'.format(epoch + 1, epochs))
for phase in ['train', 'valid']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0
# mini-batch
for inputs, labels in dataloader[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
if phase == 'train':
scheduler.step()
print('loss: {}'.format(running_loss / dataset_sizes[phase]))
epoch_loss[phase].append(running_loss / dataset_sizes[phase])
## #epoch 1/30
## loss: 0.09723963847756385
## loss: 0.046688280802965165
## #epoch 2/30
## loss: 0.04272238164544106
## loss: 0.033998911436398827
## ...
## #epoch 29/30
## loss: 0.016399667099118234
## loss: 0.013579995899399122
## #epoch 30/30
## loss: 0.016287285193800925
## loss: 0.01359133894642194
最後にもう一度検証を行って、そのときの実験値と予測値の散布図を図示する。モデルの汎化性能を正確に評価するのが目的であれば、この部分は(訓練データ・検証データとは異なる)新たに実験を行って集めたテストデータを使うべきである。ここでは、便宜上、検証データを使用している。
test_dataset = 'crisprcas9/valid.tsv'
model.eval()
datasets = readDataset(test_dataset)
dataloader = torch.utils.data.DataLoader(datasets, batch_size=batch_size, shuffle=False)
pred = None
true = None
for inputs, labels in dataloader:
inputs = inputs.to(device)
with torch.set_grad_enabled(False):
outputs = model(inputs).data.cpu().numpy()
if pred is None:
pred = outputs
true = labels
else:
pred = np.concatenate([pred, outputs], 0)
true = np.concatenate([true, labels], 0)
df = pd.DataFrame({'pred': pred.reshape(-1), 'true': true.reshape(-1)})
fig = plt.figure(figsize=(6, 6), dpi=220)
ax = fig.add_subplot()
ax.scatter(df.iloc[:, 0], df.iloc[:, 1])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('square')
fig.show()
この散布図から、モデルの学習が正しく行われていることがわかる。ここからは、単語埋め込みの次元数、LSTM の隠れ層のユニット数や全結合層の数やユニット数を調整して、モデルを改善していく必要がある。
References
- Optimized CRISPR Guide RNA Design for Two High-Fidelity Cas9 Variants by Deep Learning. Nat Commun. 2019, 10(1):4284. DOI: 10.1038/s41467-019-12281-8