目次
PyTorch
基本
ndarrayからTensor
tensor = torch.from_numpy(arr.transpose(2,0,1)[np.newaxis,:,:,:])
Tensorからndarray
arr = tensor.numpy()
ModuleList U-Net
フィボナッチをワンライナーでやる系のやつを使う
seq | Sequential |
mods | ModuleList [seq,seq,seq,…] |
x | input tensor |
y | output tensors |
y = [(l.append(f(l[-1])),l[-1])[1] for l in [[x]] for f in mods]
変形
torch.nn.functional.affine_grid
| input | ||
|---|---|---|
| theta | (N,2,3) | 逆変換行列 |
| size | size | 出力特徴マップのサイズ |
| output | ||
| grid | (N,H,W,2_xy) | グリッドサンプラー |
座標系におけるアフィン変換の逆変換行列を与えると良い.
base_gridが格子点で生成され,この逆変換行列で座標変換されたものが返されるのだろう.
torch.nn.functional.grid_sample
入力画像をx[-1,+1],y[-1,+1]のキャンバスに置いて grid で示された座標の値を得る. この座標系では画像右方向にx軸,下方向にy軸がある.
Thin-Plate Splineを自前で実装する場合
Thin-Plate Splineのnn.Moduleを自前で実装するなら, コンストラクタに変換後コントロールポイント座標と特徴マップのサイズ(H,W)を与え, インスタンス生成時に $$ \left(\begin{array}{c|c} {\bf R_{\rm cpt}} & {\bf H_{\rm cpt}}\\ \hline {\bf H_{\rm cpt}}^\top & {\bf O} \end{array}\right)^{-1} $$ を作成する. 予めベースサンプラーとして $$\left(\begin{array}{c|c} {\bf R_{\rm gpt}} & {\bf H_{\rm gpt}} \end{array}\right)$$を用意しておく.
ネットワーク中で変換前コントロールポイント座標が与えられたら $$\left(\begin{array}{c} {\bf P'_{\rm cps}}\\ \hline {\bf O} \end{array}\right)$$を作り, $$ \left(\begin{array}{c} {\bf W}\\ \hline {\bf A} \end{array}\right) = \left(\begin{array}{c|c} {\bf R_{\rm cpt}} & {\bf H_{\rm cpt}}\\ \hline {\bf H_{\rm cpt}}^\top & {\bf O} \end{array}\right)^{-1} \left(\begin{array}{c} {\bf P'_{\rm cps}}\\ \hline {\bf O} \end{array}\right) $$ で変換後コントロールポイントから変換前コントロールポイントへ変換する変換行列が計算できる. これによって $$ \left(\begin{array}{c|c} {\bf R_{\rm gpt}} & {\bf H_{\rm gpt}} \end{array}\right) \left(\begin{array}{c}{\bf W}\\ \hline {\bf A} \end{array}\right) = {\bf P'_{\rm gps}} $$ でサンプラーが作成出来る.
| constructor | ||
|---|---|---|
| size | N,H,Wの情報だけあればいい | |
| cpt | (N,T,2_xy) | 目標コントロールポイント座標 |
| forward input | ||
| cps | (N,T,2_xy) | コントロールポイント元座標 |
| forward output | ||
| grid | (N,H,W,2_xy) | グリッド |
| instance public | forward |
| class private | calc_R_mat |
align_cornersについて
TODO
grid について
ダサい?
y_map = torch.zeros((1,W,1),device=device) \ + torch.arange(start=0,end=H,dtype=dtype,device=device).view(H,1,1) x_map = torch.zeros((H,1,1),device=device) \ + torch.arange(start=0,end=W,dtype=dtype,device=device).view(1,W,1) coord_map = torch.cat([x_map,y_map],2)
torch.stack( torch.meshgrid( torch.arange(W, dtype=dtype, device=device), torch.arange(H, dtype=dtype, device=device), indexing='xy'), -1 )
今度、速度比較したい。
nn.Module
register_bufferやregister_parameterはstate_dict()で呼び出せるようにする.
register_bufferはmodel.parameters()で呼び出されないのでoptimizerによる更新がない.
register_parameterは呼び出されるので更新がある.
print関数の調整
- precision : 桁数
- threshold : 省略する個数の閾値
- edgeitems : 真ん中省略するときの最初と最後の個数
- linewidth : 一行の最大文字数
- profile : 'default', 'short', 'full'
- sci_mode : 'True' or 'False'
torchvision
vgg, inception, resnet などの入力は 3チャンネル、range 0-1
backward を自前で定義
既製のものの動作を確認する
あるモジュールがあり、loss = mod(v1,v2,…) とする。
準備として、全ての始点変数(v1,v2,..のこと)を .requires_grad() しておく。
また、途中計算のすべての変数を .retain_grad() しておく。
こうすると、
逆伝播を loss.backward(create_graph=True) で実行すると、
すべての変数に対して .grad.data で中身が見れるようになる。
conv の逆伝播を自前で作る
# x (N,iC,H,W) # w (oC,iC,kH,kW) # b (oC,) y = F.conv2d(x,w,b) dLdy dLdb = dLdy.sum(dim=(0,2,3)) # (N,oC,H,W) -> (oC,) dLdx = F.conv_transpose2d( dLdy, w, padding=padding) dLdw = torch.flip( F.conv2d( dLdy.permute(1,0,2,3), x.permute(1,0,2,3), padding=padding ), (2,3)) # (oC,N,H,W) conv (iC,N,H,W) -> (oC,iC,kH,kW) # flip is [:,:,::-1,::-1]
自作関数
torch.autograd.Function を継承したクラスを作り、
中に staticmethod として forward() と backward() を
持っておくと動く。
forward(), backward() ともに第一引数は ctx とする。
この ctx は pytorch の特別なオブジェクト。
backward() の戻り値は forward() の引数に対する勾配。
backward() の引数は forward() の戻り値に対する勾配。
このクラスの apply というメンバが、ユーザーが実際に使う関数となる。
カスタム
pytorchコードリーディング
- pytorch/
- torch/
- csrc/ : C++とかで書かれたコード?
| aten | PyTorch用のテンソル計算ライブラリ |
ビルド
PyTorch1.5からビルドシステムにちょっと変更が入った?
マルチスレッドな Ninja に対応したことで MAX_JOBS=2 とか付けて setup.py するか、
BuildExtension.with_options(use_ninja=False) にするかしないと以前のシステムは動かないようになった。
未整理メモ
nn.Module やそれに所属するメソッドをprintしようとしたときの挙動
クラス内のメソッド (クラスメソッド、インスタンスメソッド両方) を print() すると 以下のように描画される。
<bound method メソッド名 of インスタンスをprint()したもの >
インスタンスのクラスなどに __str__ や __repr__ が設定されていたらそれが表示されてしまう。
nn.Module は __str__ と __repr__ がともに書き換えられているので、
nn.Module またはそれから継承したクラスのメソッドを print() すると
<bound method メソッド名 of クラス名( クラスに存在している Module を print() したもの )>
ちなみにクラスに対してあとからメソッドを追加するときは
hoge.fuga = myfunc ではなく hoge.__class__.fuga = myfunc とする。
前者は function となり後者は method となる。
