再帰型ニューラルネットワークによる系列データのモデリング

RNN

系列データには連続した値で記録されたデータである。時系列データのほかに、医療機器の診察データや音声データなどがあげられる。このような系列データをコンピュータ上で扱うには、次のようなことを考慮する必要がある。深層学習の分野において、系列データを、再帰型ニューラルネットワーク(Recurrent Neural Network; RNN)と呼ばれる手法で取り扱うことがある。

  • 系列データを任意の長さで処理できること
  • 系列データの順序を維持できること
  • 系列データを長期間の記憶を保持できること
  • 系列データモデルのパラメーターを全期間を通して共有できること

再帰型ニューラルネットワークの構造

RNN のモデル構造は一つのユニットからなる。このユニットは、入力と出力を行うとともに、状態を保持する。このユニットは、入力 xt を受け取ると、現在の状態 st をもとに、st+1 = g(Wxt + Ust) のように次の状態 st+1 を出力する。このとき、入力 xt と状態 st にはそれぞれウェイト W と U がかけられる。また、g は活性化関数で、tanh または ReLU 関数などが使われている。例えば t = 1 のときは次のように図示できる。

RNN モデルの構造図。状態 t=1 のときの入出力。

また、t = 2 のときは次のように図示できる。このとき、W と U は、t =1 のときのものと同じ出る。実際に、時点(状態) t の値に関わらず、W と U は常に同一ものが使われている。

RNN モデルの構造図。状態 t=1 のときの入出力。

このように、RNN では、1 つのユニットからなり、現在の状態と入力した値で、出力を決めている。入力が複数あるとき、例えば (x1, x2, x3) のときは、まず x1 と s1 で s2 を計算し、続いて、x2 と s2 で s3 を計算し、最後に、x3 と s3 で s4 を計算して出力を行う。s2, s3, s4 を計算する際に使われるウェイト W と U はすべて同一のものである。

RNN モデルの構造図。RNN を使用した予測モデルの例。

上図は、各状態ごとにユニットの入出力を並べたものである。上図を以下のように、各状態ごとにユニットを繋げたような形で図示することもできる。このとき、各ユニット(緑色の円)はすべて同じものであることに注意する必要がある(ただし、状態だけが異なる)。

RNN モデルの構造図。RNN を使用した予測モデルの例。

再帰型ニューラルネットワークの学習

RNN が各入力ごとになんらかの値が出力されるものとする。例えば、xt と入力すると、yt^ が出力される。そして、各 xt のときのなんらかの真の観測量である yt も存在すると、yt^ と yt の差を計算できる。このように yt^ と yt の差を損失関数として定義すると、学習を通して、損失関数が最小となるように RNN モデルのウェイトが計算される。

RNN モデルの学習例。

仮に、yt^ に対する誤差を Jt = (yt - yt^)2 と定義すると、系列全体の損失関数 J は次のように、各時点の損失関数 Jt の和として考えることができる。

\[ J(W) = \sum_{t}J_{t}(W) = \sum_{t}\left( y_{t} - \hat{y}_{t} \right)^{2} \]

このように損失関数が定義されれば、学習を通してウェイト W を計算できるようになる。実際の計算は逆誤差伝播法で行う。

再帰型ニューラルネットワークの問題点

逆誤差伝播法を行うために、J(W) に対して W で偏微分すると、次の式が得られる。

\[ \frac{\partial J(W)}{\partial W} = \sum_{t}\frac{\partial J_{t}(W)}{\partial W} \]

ここで、例えば t = 2 のときの偏微分を計算してみる。

RNN モデルの学習例。時点 t の偏微分は、t よりも古いデータに依存する。
\[ \frac{\partial J_{2}(W)}{\partial W} = \frac{\partial J_{2}(W)}{\partial y_{2}} \frac{\partial y_{2}}{\partial s_{2}} \frac{\partial s_{2}}{\partial W} \]

最後の項 \( \frac{\partial s_{2}}{\partial W} \) に着目すると、s2 は s2 = g(Wx1 + Us1) から計算され、また、s1 は s1 = g(Wx0 + Us0) から計算されているので、\( \frac{\partial s_{2}}{\partial W} \) は定数ではないことが明らかである。そのため \( \frac{\partial s_{2}}{\partial W} \) は正確には次のように計算される必要がある。

\[ \frac{\partial s_{2}}{\partial W} + \frac{\partial s_{2}}{\partial s_{1}} \frac{\partial s_{1}}{\partial W} + \frac{\partial s_{2}}{\partial s_{0}} \frac{\partial s_{0}}{\partial W} \]

よって、損失関数 J に対する偏微分は、t = n のときは、次のように一般化することができる。

\[ \frac{\partial J_{n}(W)}{\partial W} = \frac{\partial J_{n}(W)}{\partial y_{n}} \frac{\partial y_{n}}{\partial s_{n}} \left( \sum_{t=0}^{n} \frac{\partial s_{n}}{\partial s_{t}} \frac{\partial s_{t}}{\partial W} \right) \]

ここで

\[ \frac{\partial s_{n}}{\partial s_{t}} = \frac{\partial s_{n}}{\partial s_{n-1}} \frac{\partial s_{n-1}}{\partial s_{n-2}} \cdots \frac{\partial s_{3}}{\partial s_{2}} \frac{\partial s_{2}}{\partial s_{1}} \frac{\partial s_{1}}{\partial s_{0}} \]

であることと、\( \frac{\partial s_{n}}{\partial s_{n-1}} = W^{T} diag [g'(W_{s_{j-1}} + Ux_{j})] \) であり、W の要素がほとんど 1 未満でかつ g' < 1 (g = tanh) であることに着目すると、\( \frac{\partial s_{n}}{\partial s_{t}} \) は非常に小さな値が多くかけ合わさった値となる。そのため、n が大きくなればなるほど、古い系列データ(e.g. t = 0, 1, 2)の影響が小さくなる。学習全体として、最近の系列データにバイアスが大きくかかるような形で行われる。つまり、RNN の場合は、あまり古い系列データを保持できないという欠点がある。

これらの欠点を対処するためには、活性化関数を tanh から ReLU 関数に変更したり、W の初期値を単位行列にしたりすることがあげられる。また、RNN よりも複雑なモデルである LSTM、GRU などのモデルを使う方法もある。

References

  • Suresh H. MIT 6.S191: Sequence Modeling with Neural Networks. 2018. YouTube