yuuho.wiki

カオスの欠片を集めて知恵の泉を作る

ユーザ用ツール

サイト用ツール


tips:python:pytorch:start

PyTorch

基本

ndarrayからTensor

tensor = torch.from_numpy(arr.transpose(2,0,1)[np.newaxis,:,:,:])

Tensorからndarray

arr = tensor.numpy()

ModuleList U-Net

フィボナッチをワンライナーでやる系のやつを使う

seq Sequential
mods ModuleList [seq,seq,seq,…]
x input tensor
y output tensors
y = [(l.append(f(l[-1])),l[-1])[1] for l in [[x]] for f in mods]

変形

torch.nn.functional.affine_grid

input
theta (N,2,3) 逆変換行列
size size 出力特徴マップのサイズ
output
grid (N,H,W,2_xy) グリッドサンプラー

座標系におけるアフィン変換の逆変換行列を与えると良い.

base_gridが格子点で生成され,この逆変換行列で座標変換されたものが返されるのだろう.

torch.nn.functional.grid_sample

入力画像をx[-1,+1],y[-1,+1]のキャンバスに置いて grid で示された座標の値を得る. この座標系では画像右方向にx軸,下方向にy軸がある.

Thin-Plate Splineを自前で実装する場合

Thin-Plate Splineのnn.Moduleを自前で実装するなら, コンストラクタに変換後コントロールポイント座標と特徴マップのサイズ(H,W)を与え, インスタンス生成時に $$ \left(\begin{array}{c|c} {\bf R_{\rm cpt}} & {\bf H_{\rm cpt}}\\ \hline {\bf H_{\rm cpt}}^\top & {\bf O} \end{array}\right)^{-1} $$ を作成する. 予めベースサンプラーとして $$\left(\begin{array}{c|c} {\bf R_{\rm gpt}} & {\bf H_{\rm gpt}} \end{array}\right)$$を用意しておく.

ネットワーク中で変換前コントロールポイント座標が与えられたら $$\left(\begin{array}{c} {\bf P'_{\rm cps}}\\ \hline {\bf O} \end{array}\right)$$を作り, $$ \left(\begin{array}{c} {\bf W}\\ \hline {\bf A} \end{array}\right) = \left(\begin{array}{c|c} {\bf R_{\rm cpt}} & {\bf H_{\rm cpt}}\\ \hline {\bf H_{\rm cpt}}^\top & {\bf O} \end{array}\right)^{-1} \left(\begin{array}{c} {\bf P'_{\rm cps}}\\ \hline {\bf O} \end{array}\right) $$ で変換後コントロールポイントから変換前コントロールポイントへ変換する変換行列が計算できる. これによって $$ \left(\begin{array}{c|c} {\bf R_{\rm gpt}} & {\bf H_{\rm gpt}} \end{array}\right) \left(\begin{array}{c}{\bf W}\\ \hline {\bf A} \end{array}\right) = {\bf P'_{\rm gps}} $$ でサンプラーが作成出来る.

constructor
size N,H,Wの情報だけあればいい
cpt (N,T,2_xy) 目標コントロールポイント座標
forward input
cps (N,T,2_xy) コントロールポイント元座標
forward output
grid (N,H,W,2_xy) グリッド
instance public forward
class private calc_R_mat

align_cornersについて

TODO

grid について

ダサい?

        y_map = torch.zeros((1,W,1),device=device) \
                    + torch.arange(start=0,end=H,dtype=dtype,device=device).view(H,1,1)
        x_map = torch.zeros((H,1,1),device=device) \
                    + torch.arange(start=0,end=W,dtype=dtype,device=device).view(1,W,1)
        coord_map = torch.cat([x_map,y_map],2)
torch.stack( torch.meshgrid(
    torch.arange(W, dtype=dtype, device=device),
    torch.arange(H, dtype=dtype, device=device),
                             indexing='xy'), -1 )

今度、速度比較したい。

nn.Module

register_bufferregister_parameterはstate_dict()で呼び出せるようにする. register_bufferはmodel.parameters()で呼び出されないのでoptimizerによる更新がない. register_parameterは呼び出されるので更新がある.

print関数の調整

  • precision : 桁数
  • threshold : 省略する個数の閾値
  • edgeitems : 真ん中省略するときの最初と最後の個数
  • linewidth : 一行の最大文字数
  • profile : 'default', 'short', 'full'
  • sci_mode : 'True' or 'False'

torchvision

vgg, inception, resnet などの入力は 3チャンネル、range 0-1

backward を自前で定義

既製のものの動作を確認する

あるモジュールがあり、loss = mod(v1,v2,…) とする。
準備として、全ての始点変数(v1,v2,..のこと)を .requires_grad() しておく。
また、途中計算のすべての変数を .retain_grad() しておく。
こうすると、 逆伝播を loss.backward(create_graph=True) で実行すると、
すべての変数に対して .grad.data で中身が見れるようになる。

conv の逆伝播を自前で作る

# x (N,iC,H,W)
# w (oC,iC,kH,kW)
# b (oC,)
y = F.conv2d(x,w,b)
 
dLdy
dLdb = dLdy.sum(dim=(0,2,3))                            # (N,oC,H,W) -> (oC,)
dLdx = F.conv_transpose2d( dLdy, w, padding=padding)
dLdw = torch.flip( F.conv2d( dLdy.permute(1,0,2,3),
                                x.permute(1,0,2,3), padding=padding ), (2,3))
        # (oC,N,H,W) conv (iC,N,H,W) -> (oC,iC,kH,kW)
        # flip is [:,:,::-1,::-1]

自作関数

torch.autograd.Function を継承したクラスを作り、
中に staticmethod として forward()backward() を 持っておくと動く。

forward(), backward() ともに第一引数は ctx とする。
この ctx は pytorch の特別なオブジェクト。

backward() の戻り値は forward() の引数に対する勾配。
backward() の引数は forward() の戻り値に対する勾配。

このクラスの apply というメンバが、ユーザーが実際に使う関数となる。

カスタム

pytorchコードリーディング

- pytorch/
    - torch/
        - csrc/    : C++とかで書かれたコード?
aten PyTorch用のテンソル計算ライブラリ

ビルド

PyTorch1.5からビルドシステムにちょっと変更が入った? マルチスレッドな Ninja に対応したことで MAX_JOBS=2 とか付けて setup.py するか、 BuildExtension.with_options(use_ninja=False) にするかしないと以前のシステムは動かないようになった。

未整理メモ

nn.Module やそれに所属するメソッドをprintしようとしたときの挙動

クラス内のメソッド (クラスメソッド、インスタンスメソッド両方) を print() すると 以下のように描画される。

<bound method メソッド名 of インスタンスをprint()したもの >

インスタンスのクラスなどに __str____repr__ が設定されていたらそれが表示されてしまう。

nn.Module は __str____repr__ がともに書き換えられているので、 nn.Module またはそれから継承したクラスのメソッドを print() すると

<bound method メソッド名 of クラス名(
  クラスに存在している Module を print() したもの
)>

ちなみにクラスに対してあとからメソッドを追加するときは hoge.fuga = myfunc ではなく hoge.__class__.fuga = myfunc とする。 前者は function となり後者は method となる。

tips/python/pytorch/start.txt · 最終更新: 2022/01/27 03:24 by yuuho