PyTorch サンプルコード

PyTorch

PyTorch は深層学習を行うための Python ライブラリーの一つである。PyTorch の他にも tensorflow/Keras や Caffee といったライブラリーも存在するが、とりわけ近年 PyTorch のユーザー数が急速に伸びし、大きなシェアを占めるようになった。機械学習のトップカンファレンスでも PyTorch が tensorflow などに比べて優位である。

PyTorch は、define-by-run 形式で学習・推論を行うアプローチを取っている。define-by-run では、まずモデルのアーキテクチャを構築して、次にデータをそのアーキテクチャに流して計算グラフを構築する。続けて、その計算グラフを利用して誤差逆伝播により学習を進めることができる。define-by-run では、データを流すたびに計算グラフを構築するため、計算グラフの構築コストが高い。しかし、モデルのアーキテクチャを順伝播の形で定義することができるので、コードの可読性が高い。また、アーキテクチャを定義しているコード群の中に、Python のプログラムを埋め込むことができるため、非常にデバッグしやすいといったメリットがある。

PyTorch の使用例は PyTorch tutorials で多く紹介されている。例えば、物体分類、転移学習や物体検出などの例が掲載されている。PyTorch のチュートリアルにある使用例を真似て勉強していけば、PyTorch の基本的な使い方を身に付けることができる。このページでは、PyTorch のチュートリアルにある使用例をもう少し簡単に砕いた使用例など紹介する。

PyTorch 基本

ニューラルネットワーク

畳み込みニューラルネットワーク

再帰型ニューラルネットワーク