PyTorch で Tensor を作るときに使用する torch.tensor
関数には requires_grad
引数が用意されている。requires_grad=True
を指定して Tensor を作成すると、順伝播の計算を行うと自動的にその微分値も計算される。
Tensor の自動微分のオン・オフついては、次のように使い分ける。ネットワークを生成して、学習を通してパラメーターを更新する際に requires_grad=True
として自動微分可の形にする。学習を終えて、検証あるいはテストする際に、パラメーターを凍結させたいときに requires_grad=False
として自動微分不可のに切り替える。
微分
いま、3 つの変数 x, w, b を考えて、この 3 つの変数からなる計算グラフ y = wx + b を考える。この計算グラフ中の w, x, b に値を代入して y を計算して、y を各変数で微分した値を求めてみる。ここで、計算しやすいように、仮に、x = 2.0, w = 3.0, b = 5.0 とする。このとき、y を x, w, b で微分した値は次のように計算される。
\[ \frac{dy}{dx} = w = 3.0 \] \[ \frac{dy}{dw} = x = 2.0 \] \[ \frac{dy}{db} = 1.0 \]以下の項目で、PyTorch の Tensor を使ってこれらの微分を求める方法を説明する。
Tensor の自動微分
まず、y = wx + b となる計算グラフを生成するために、変数 w、x、b を Tensor として生成する。このとき、Tensor を requires_grad=True
のオプションで生成する。Tensor を生成した直後では、w、x、b は互いに独立である。これら Tensor が計算グラフとして繋がっていなく、目的となる Tensor が存在しないので、微分できない。そのため、これらの Tensor の勾配属性(.grad
)には何も保存されていないような状態である。
import torch
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
print(x.grad)
## None
print(w.grad)
## None
print(b.grad)
## None
次に、計算グラフ y = wx + b を構築するとともに変数を代入して実行する(define by run)。一度、計算グラフが構築されると、w、x、b が互いに繋がり、数値演算され、その結果が y に伝播されるようになる。そして、このグラフの構造や w、x、b の関係などがメモリ上に保存される。
y = w * x + b
print(y.grad_fn)
## <AddBackward0 object at 0x7f20a37583c8>
y = w * x + b
を実行すると、w
、x
、b
のデータが代入され、y
が計算される。y
を計算したとに、この値のもとで、y を w、x、b で微分した値を求めるには、計算グラフの終端 y から先端まで遡って反映していけば良い。情報を計算グラフの方向を逆に反映させるために、この計算グラフ y
に対して .backward
関数を実行すればよい。
y.backward()
print(x.grad)
## tensor(3.)
print(w.grad)
## tensor(2.)
print(b.grad)
## tensor(1.)
自動微分オン/オフの切り替え
計算グラフを構築するときに、あとで逆伝播を行うかどうかを指定することができる。モデルの学習中では、順伝播を行った後に、誤差を逆伝播させる必要がある。そのため、モデルの学習時には、逆伝播可能な計算グラフを構築する必要がある。一方で、検証または推論するときには、逆伝播を行わない計算グラフを構築しても良い。PyTroch では、計算グラフを構築するときに、torch.set_grad_enabled
関数で容易に逆伝播可能あるいは不可能に切り替えることができる。
import torch
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
# training
with torch.set_grad_enabled(True):
y = w * x + b
print(y.requires_grad)
## True
print(y.grad_fn)
## <AddBackward0 object at 0x7f8ee10712b0>
# validation
with torch.set_grad_enabled(False):
y = w * x + b
print(y.requires_grad)
## False
print(y.grad_fn)
## None