文書の過去の版を表示しています。
PyTorch
基本
ndarrayからTensor
tensor = torch.from_numpy(arr.transpose(2,0,1)[np.newaxis,:,:,:])
Tensorからndarray
arr = tensor.numpy()
ModuleList U-Net
フィボナッチをワンライナーでやる系のやつを使う
seq | Sequential |
mods | ModuleList [seq,seq,seq,…] |
x | input tensor |
y | output tensors |
y = [(l.append(f(l[-1])),l[-1])[1] for l in [[x]] for f in mods]
変形
torch.nn.functional.affine_grid
| input | ||
|---|---|---|
| theta | (N,2,3) | 逆変換行列 |
| size | size | 出力特徴マップのサイズ |
| output | ||
| grid | (N,H,W,2_xy) | グリッドサンプラー |
座標系におけるアフィン変換の逆変換行列を与えると良い.
base_gridが格子点で生成され,この逆変換行列で座標変換されたものが返されるのだろう.
torch.nn.functional.grid_sample
入力画像をx[-1,+1],y[-1,+1]のキャンバスに置いて grid で示された座標の値を得る. この座標系では画像右方向にx軸,下方向にy軸がある.
Thin-Plate Splineを自前で実装する場合
Thin-Plate Splineのnn.Moduleを自前で実装するなら, コンストラクタに変換後コントロールポイント座標と特徴マップのサイズ(H,W)を与え, インスタンス生成時に $$ \left(\begin{array}{c|c} {\bf R_{\rm cpt}} & {\bf H_{\rm cpt}}\\ \hline {\bf H_{\rm cpt}}^\top & {\bf O} \end{array}\right)^{-1} $$ を作成する. 予めベースサンプラーとして $$\left(\begin{array}{c|c} {\bf R_{\rm gpt}} & {\bf H_{\rm gpt}} \end{array}\right)$$を用意しておく.
ネットワーク中で変換前コントロールポイント座標が与えられたら $$\left(\begin{array}{c} {\bf P'_{\rm cps}}\\ \hline {\bf O} \end{array}\right)$$を作り, $$ \left(\begin{array}{c} {\bf W}\\ \hline {\bf A} \end{array}\right) = \left(\begin{array}{c|c} {\bf R_{\rm cpt}} & {\bf H_{\rm cpt}}\\ \hline {\bf H_{\rm cpt}}^\top & {\bf O} \end{array}\right)^{-1} \left(\begin{array}{c} {\bf P'_{\rm cps}}\\ \hline {\bf O} \end{array}\right) $$ で変換後コントロールポイントから変換前コントロールポイントへ変換する変換行列が計算できる. これによって $$ \left(\begin{array}{c|c} {\bf R_{\rm gpt}} & {\bf H_{\rm gpt}} \end{array}\right) \left(\begin{array}{c}{\bf W}\\ \hline {\bf A} \end{array}\right) = {\bf P'_{\rm gps}} $$ でサンプラーが作成出来る.
| constructor | ||
|---|---|---|
| size | N,H,Wの情報だけあればいい | |
| cpt | (N,T,2_xy) | 目標コントロールポイント座標 |
| forward input | ||
| cps | (N,T,2_xy) | コントロールポイント元座標 |
| forward output | ||
| grid | (N,H,W,2_xy) | グリッド |
| instance public | forward |
| class private | calc_R_mat |
align_cornersについて
TODO
nn.Module
register_bufferやregister_parameterはstate_dict()で呼び出せるようにする.
register_bufferはmodel.parameters()で呼び出されないのでoptimizerによる更新がない.
register_parameterは呼び出されるので更新がある.
カスタム
pytorchコードリーディング
- pytorch/
- torch/
- csrc/ : C++とかで書かれたコード?
| aten | PyTorch用のテンソル計算ライブラリ |
ビルド
PyTorch1.5からビルドシステムにちょっと変更が入った?
マルチスレッドな Ninja に対応したことで MAX_JOBS=2 とか付けて setup.py するか、
BuildExtension.with_options(use_ninja=False) にするかしないと以前のシステムは動かないようになった。
