任務(wù)要求:
自定義一個層主要是定義該層的實現(xiàn)函數(shù),只需要重載Function的forward和backward函數(shù)即可,如下:
1
2
3
|
import torch from torch.autograd import Function from torch.autograd import Variable |
定義二值化函數(shù)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
class BinarizedF(Function): def forward( self , input ): self .save_for_backward( input ) a = torch.ones_like( input ) b = - torch.ones_like( input ) output = torch.where( input > = 0 ,a,b) return output def backward( self , output_grad): input , = self .saved_tensors input_abs = torch. abs ( input ) ones = torch.ones_like( input ) zeros = torch.zeros_like( input ) input_grad = torch.where(input_abs< = 1 ,ones, zeros) return input_grad |
定義一個module
1
2
3
4
5
6
7
8
|
class BinarizedModule(nn.Module): def __init__( self ): super (BinarizedModule, self ).__init__() self .BF = BinarizedF() def forward( self , input ): print ( input .shape) output = self .BF( input ) return output |
進(jìn)行測試
1
2
3
4
5
|
a = Variable(torch.randn( 4 , 480 , 640 ), requires_grad = True ) output = BinarizedModule()(a) output.backward(torch.ones(a.size())) print (a) print (a.grad) |
其中, 二值化函數(shù)部分也可以按照方式寫,但是速度慢了0.05s
1
2
3
4
5
6
7
8
9
10
11
12
|
class BinarizedF(Function): def forward( self , input ): self .save_for_backward( input ) output = torch.ones_like( input ) output[ input < 0 ] = - 1 return output def backward( self , output_grad): input , = self .saved_tensors input_grad = output_grad.clone() input_abs = torch. abs ( input ) input_grad[input_abs> 1 ] = 0 return input_grad |
以上這篇pytorch自定義二值化網(wǎng)絡(luò)層方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持服務(wù)器之家。
原文鏈接:https://blog.csdn.net/weixin_42696356/article/details/100899711