yuuho.wiki

カオスの欠片を集めて知恵の泉を作る

ユーザ用ツール

サイト用ツール


tips:python:pytorch:start

差分

このページの2つのバージョン間の差分を表示します。

この比較画面へのリンク

両方とも前のリビジョン前のリビジョン
次のリビジョン
前のリビジョン
tips:python:pytorch:start [2021/08/26 06:26] – [print関数の調整] yuuhotips:python:pytorch:start [2022/01/27 03:24] (現在) – [未整理メモ] yuuho
行 83: 行 83:
  
 TODO TODO
 +
 +==== grid について ====
 +ダサい?
 +<code python>
 +        y_map = torch.zeros((1,W,1),device=device) \
 +                    + torch.arange(start=0,end=H,dtype=dtype,device=device).view(H,1,1)
 +        x_map = torch.zeros((H,1,1),device=device) \
 +                    + torch.arange(start=0,end=W,dtype=dtype,device=device).view(1,W,1)
 +        coord_map = torch.cat([x_map,y_map],2)
 +</code>
 +
 +<code python>
 +torch.stack( torch.meshgrid(
 +    torch.arange(W, dtype=dtype, device=device),
 +    torch.arange(H, dtype=dtype, device=device),
 +                             indexing='xy'), -1 )
 +</code>
 +
 +今度、速度比較したい。
 ==== nn.Module ==== ==== nn.Module ====
  
行 96: 行 115:
     * threshold : 省略する個数の閾値     * threshold : 省略する個数の閾値
     * edgeitems : 真ん中省略するときの最初と最後の個数     * edgeitems : 真ん中省略するときの最初と最後の個数
 +    * linewidth : 一行の最大文字数
 +    * profile   : 'default', 'short', 'full'
 +    * sci_mode  : 'True' or 'False'
 +
  
  
行 101: 行 124:
  
 vgg, inception, resnet などの入力は 3チャンネル、range 0-1 vgg, inception, resnet などの入力は 3チャンネル、range 0-1
 +===== backward を自前で定義 =====
 +
 +==== 既製のものの動作を確認する ====
 +
 +あるモジュールがあり、''loss = mod(v1,v2,...)'' とする。\\
 +準備として、全ての始点変数(''v1,v2,..''のこと)を ''.requires_grad()'' しておく。\\
 +また、途中計算のすべての変数を ''.retain_grad()'' しておく。\\
 +こうすると、
 +逆伝播を ''loss.backward(create_graph=True)'' で実行すると、\\
 +すべての変数に対して ''.grad.data'' で中身が見れるようになる。
 +
 +
 +=== conv の逆伝播を自前で作る ===
 +
 +<code python>
 +# x (N,iC,H,W)
 +# w (oC,iC,kH,kW)
 +# b (oC,)
 +y = F.conv2d(x,w,b)
 +
 +dLdy
 +dLdb = dLdy.sum(dim=(0,2,3))                            # (N,oC,H,W) -> (oC,)
 +dLdx = F.conv_transpose2d( dLdy, w, padding=padding)
 +dLdw = torch.flip( F.conv2d( dLdy.permute(1,0,2,3),
 +                                x.permute(1,0,2,3), padding=padding ), (2,3))
 +        # (oC,N,H,W) conv (iC,N,H,W) -> (oC,iC,kH,kW)
 +        # flip is [:,:,::-1,::-1]
 +</code>
 +
 +
 +==== 自作関数 ====
 +
 +''torch.autograd.Function'' を継承したクラスを作り、\\
 +中に staticmethod として ''forward()'' と ''backward()''
 +持っておくと動く。
 +
 +''forward()'', ''backward()'' ともに第一引数は ''ctx'' とする。\\
 +この ''ctx'' は pytorch の特別なオブジェクト。
 +
 +''backward()'' の戻り値は ''forward()'' の引数に対する勾配。\\
 +''backward()'' の引数は ''forward()'' の戻り値に対する勾配。
 +
 +このクラスの ''apply'' というメンバが、ユーザーが実際に使う関数となる。
 ===== カスタム ===== ===== カスタム =====
  
行 122: 行 188:
  
  
 +
 +===== 未整理メモ =====
 +
 +=== nn.Module やそれに所属するメソッドをprintしようとしたときの挙動 ===
 +
 +クラス内のメソッド (クラスメソッド、インスタンスメソッド両方) を print() すると
 +以下のように描画される。
 +
 +<code><bound method メソッド名 of インスタンスをprint()したもの ></code>
 +
 +インスタンスのクラスなどに ''%%_%%_str_%%_%%'' や ''%%_%%_repr_%%_%%'' が設定されていたらそれが表示されてしまう。
 +
 +
 +nn.Module は ''%%_%%_str_%%_%%'' と ''_%%_%%repr%%_%%_'' がともに書き換えられているので、
 +nn.Module またはそれから継承したクラスのメソッドを print() すると
 +
 +<code>
 +<bound method メソッド名 of クラス名(
 +  クラスに存在している Module を print() したもの
 +)>
 +</code>
 +
 +ちなみにクラスに対してあとからメソッドを追加するときは
 +''hoge.fuga = myfunc'' ではなく ''hoge.%%_%%_class%%_%%_.fuga = myfunc'' とする。
 +前者は function となり後者は method となる。
tips/python/pytorch/start.1629959210.txt.gz · 最終更新: 2021/08/26 06:26 by yuuho