yuuho.wiki

カオスの欠片を集めて知恵の泉を作る

ユーザ用ツール

サイト用ツール


tips:python:pytorch:start

差分

このページの2つのバージョン間の差分を表示します。

この比較画面へのリンク

両方とも前のリビジョン前のリビジョン
次のリビジョン
前のリビジョン
tips:python:pytorch:start [2021/08/28 14:08] – [カスタム] yuuhotips:python:pytorch:start [2022/01/27 03:24] (現在) – [未整理メモ] yuuho
行 83: 行 83:
  
 TODO TODO
 +
 +==== grid について ====
 +ダサい?
 +<code python>
 +        y_map = torch.zeros((1,W,1),device=device) \
 +                    + torch.arange(start=0,end=H,dtype=dtype,device=device).view(H,1,1)
 +        x_map = torch.zeros((H,1,1),device=device) \
 +                    + torch.arange(start=0,end=W,dtype=dtype,device=device).view(1,W,1)
 +        coord_map = torch.cat([x_map,y_map],2)
 +</code>
 +
 +<code python>
 +torch.stack( torch.meshgrid(
 +    torch.arange(W, dtype=dtype, device=device),
 +    torch.arange(H, dtype=dtype, device=device),
 +                             indexing='xy'), -1 )
 +</code>
 +
 +今度、速度比較したい。
 ==== nn.Module ==== ==== nn.Module ====
  
行 107: 行 126:
 ===== backward を自前で定義 ===== ===== backward を自前で定義 =====
  
-==== 有名どころ ==== +==== 既製のものの動作を確認する ==== 
-=== conv ===+ 
 +あるモジュールがあり、''loss = mod(v1,v2,...)'' とする。\\ 
 +準備として、全ての始点変数(''v1,v2,..''のこと)を ''.requires_grad()'' しておく。\\ 
 +また、途中計算のすべての変数を ''.retain_grad()'' しておく。\\ 
 +こうすると、 
 +逆伝播を ''loss.backward(create_graph=True)'' で実行すると、\\ 
 +すべての変数に対して ''.grad.data'' で中身が見れるようになる。 
 + 
 + 
 +=== conv の逆伝播を自前で作る ===
  
 <code python> <code python>
-y = conv(x,w,b)+# x (N,iC,H,W) 
 +# w (oC,iC,kH,kW) 
 +# b (oC,) 
 +y = F.conv2d(x,w,b)
  
 dLdy dLdy
行 122: 行 153:
 </code> </code>
  
 +
 +==== 自作関数 ====
 +
 +''torch.autograd.Function'' を継承したクラスを作り、\\
 +中に staticmethod として ''forward()'' と ''backward()''
 +持っておくと動く。
 +
 +''forward()'', ''backward()'' ともに第一引数は ''ctx'' とする。\\
 +この ''ctx'' は pytorch の特別なオブジェクト。
 +
 +''backward()'' の戻り値は ''forward()'' の引数に対する勾配。\\
 +''backward()'' の引数は ''forward()'' の戻り値に対する勾配。
 +
 +このクラスの ''apply'' というメンバが、ユーザーが実際に使う関数となる。
 ===== カスタム ===== ===== カスタム =====
  
行 143: 行 188:
  
  
 +
 +===== 未整理メモ =====
 +
 +=== nn.Module やそれに所属するメソッドをprintしようとしたときの挙動 ===
 +
 +クラス内のメソッド (クラスメソッド、インスタンスメソッド両方) を print() すると
 +以下のように描画される。
 +
 +<code><bound method メソッド名 of インスタンスをprint()したもの ></code>
 +
 +インスタンスのクラスなどに ''%%_%%_str_%%_%%'' や ''%%_%%_repr_%%_%%'' が設定されていたらそれが表示されてしまう。
 +
 +
 +nn.Module は ''%%_%%_str_%%_%%'' と ''_%%_%%repr%%_%%_'' がともに書き換えられているので、
 +nn.Module またはそれから継承したクラスのメソッドを print() すると
 +
 +<code>
 +<bound method メソッド名 of クラス名(
 +  クラスに存在している Module を print() したもの
 +)>
 +</code>
 +
 +ちなみにクラスに対してあとからメソッドを追加するときは
 +''hoge.fuga = myfunc'' ではなく ''hoge.%%_%%_class%%_%%_.fuga = myfunc'' とする。
 +前者は function となり後者は method となる。
tips/python/pytorch/start.1630159739.txt.gz · 最終更新: 2021/08/28 14:08 by yuuho