PyTorch/Tensor 型の自動微分 autograd について

Tensor 自動微分

PyTorch で Tensor を作るときに使用する torch.tensor 関数には requires_grad 引数が用意されている。デフォルトでは requires_grad=False となっているが、requires_grad=True を指定して Tensor を作成すると、順伝播の計算を行うと自動的にその微分値も計算される。

微分

いま、3 つの変数 x, w, b を考えて、この 3 つの変数からなる計算グラフ y = wx + b を考える。この計算グラフ中の w, x, b に値を代入して y を計算して、y を各変数で微分した値を求めてみる。ここで、計算しやすいように、仮に、x = 2.0, w = 3.0, b = 5.0 とする。このとき、y を x, w, b で微分した値は次のように計算される。

\[ \frac{dy}{dx} = w = 3.0 \] \[ \frac{dy}{dw} = x = 2.0 \] \[ \frac{dy}{db} = 1.0 \]

以下の項目で、PyTorch の Tensor を使ってこれらの微分を求める方法を説明する。

Tensor の自動微分

まず、y = wx + b となる計算グラフを生成するために、変数 w、x、b を Tensor として生成する。このとき、Tensor を requires_grad=True のオプションで生成する。Tensor を生成した直後では、w、x、b は互いに独立である。これら Tensor が計算グラフとして繋がっていなく、目的となる Tensor が存在しないので、微分できない。そのため、これらの Tensor の勾配属性(.grad)には何も保存されていないような状態である。

import torch

x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)

print(x.grad)
## None
print(w.grad)
## None
print(b.grad)
## None

次に、計算グラフ y = wx + b を構築するとともに変数を代入して実行する(define by run)。一度、計算グラフが構築されると、w、x、b が互いに繋がり、数値演算され、その結果が y に伝播されるようになる。そして、このグラフの構造や w、x、b の関係などがメモリ上に保存される。

y = w * x + b
print(y.grad_fn)
## <AddBackward0 object at 0x7f20a37583c8>

y = w * x + b を実行すると、wxb のデータが代入され、y が計算される。y を計算したとに、この値のもとで、y を w、x、b で微分した値を求めるには、計算グラフの終端 y から先端まで遡って反映していけば良い。情報を計算グラフの方向を逆に反映させるために、この計算グラフ y に対して .backward 関数を実行すればよい。

y.backward()

print(x.grad)
## tensor(3.)
print(w.grad)
## tensor(2.)
print(b.grad)
## tensor(1.)

自動微分オン/オフの切り替え

計算グラフを構築するときに、あとで逆伝播を行うかどうかを指定することができる。モデルの学習中では、順伝播を行った後に、誤差を逆伝播させる必要がある。そのため、モデルの学習時には、逆伝播可能な計算グラフを構築する必要がある。一方で、検証または推論するときには、逆伝播を行わない計算グラフを構築しても良い。PyTroch では、計算グラフを構築するときに、torch.set_grad_enabled 関数で容易に逆伝播可能あるいは不可能に切り替えることができる。

import torch

x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)


# training
with torch.set_grad_enabled(True):
  y = w * x + b
  print(y.requires_grad)
  ## True
  print(y.grad_fn)
  ## <AddBackward0 object at 0x7f8ee10712b0>


# validation
with torch.set_grad_enabled(False):
  y = w * x + b
  print(y.requires_grad)
  ## False
  print(y.grad_fn)
  ## None