pytorch中自定義backward()函數(shù)。在圖像處理過(guò)程中,我們有時(shí)候會(huì)使用自己定義的算法處理圖像,這些算法多是基于numpy或者scipy等包。
那么如何將自定義算法的梯度加入到pytorch的計(jì)算圖中,能使用Loss.backward()操作自動(dòng)求導(dǎo)并優(yōu)化呢。下面的代碼展示了這個(gè)功能`
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
|
import torch import numpy as np from PIL import Image from torch.autograd import gradcheck class Bicubic(torch.autograd.Function): def basis_function( self , x, a = - 1 ): x_abs = np. abs (x) if x_abs < 1 and x_abs > = 0 : y = (a + 2 ) * np.power(x_abs, 3 ) - (a + 3 ) * np.power(x_abs, 2 ) + 1 elif x_abs > 1 and x_abs < 2 : y = a * np.power(x_abs, 3 ) - 5 * a * np.power(x_abs, 2 ) + 8 * a * x_abs - 4 * a else : y = 0 return y def bicubic_interpolate( self ,data_in, scale = 1 / 4 , mode = 'edge' ): # data_in = data_in.detach().numpy() self .grad = np.zeros(data_in.shape,dtype = np.float32) obj_shape = ( int (data_in.shape[ 0 ] * scale), int (data_in.shape[ 1 ] * scale), data_in.shape[ 2 ]) data_tmp = data_in.copy() data_obj = np.zeros(shape = obj_shape, dtype = np.float32) data_in = np.pad(data_in, pad_width = (( 2 , 2 ), ( 2 , 2 ), ( 0 , 0 )), mode = mode) print (data_tmp.shape) for axis0 in range (obj_shape[ 0 ]): f_0 = float (axis0) / scale - np.floor(axis0 / scale) int_0 = int (axis0 / scale) + 2 axis0_weight = np.array( [[ self .basis_function( 1 + f_0), self .basis_function(f_0), self .basis_function( 1 - f_0), self .basis_function( 2 - f_0)]]) for axis1 in range (obj_shape[ 1 ]): f_1 = float (axis1) / scale - np.floor(axis1 / scale) int_1 = int (axis1 / scale) + 2 axis1_weight = np.array( [[ self .basis_function( 1 + f_1), self .basis_function(f_1), self .basis_function( 1 - f_1), self .basis_function( 2 - f_1)]]) nbr_pixel = np.zeros(shape = (obj_shape[ 2 ], 4 , 4 ), dtype = np.float32) grad_point = np.matmul(np.transpose(axis0_weight, ( 1 , 0 )), axis1_weight) for i in range ( 4 ): for j in range ( 4 ): nbr_pixel[:, i, j] = data_in[int_0 + i - 1 , int_1 + j - 1 , :] for ii in range (data_in.shape[ 2 ]): self .grad[int_0 - 2 + i - 1 , int_1 - 2 + j - 1 , ii] = grad_point[i,j] tmp = np.matmul(axis0_weight, nbr_pixel) data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, ( 1 , 0 )))[:, 0 , 0 ] # img = np.transpose(img[0, :, :, :], [1, 2, 0]) return data_obj def forward( self , input ): print ( type ( input )) input_ = input .detach().numpy() output = self .bicubic_interpolate(input_) # return input.new(output) return torch.Tensor(output) def backward( self ,grad_output): print ( self .grad.shape,grad_output.shape) grad_output.detach().numpy() grad_output_tmp = np.zeros( self .grad.shape,dtype = np.float32) for i in range ( self .grad.shape[ 0 ]): for j in range ( self .grad.shape[ 1 ]): grad_output_tmp[i,j,:] = grad_output[ int (i / 4 ), int (j / 4 ),:] grad_input = grad_output_tmp * self .grad print ( type (grad_input)) # return grad_output.new(grad_input) return torch.Tensor(grad_input) def bicubic( input ): return Bicubic()( input ) def main(): hr = Image. open ( './baboon/baboon_hr.png' ).convert( 'L' ) hr = torch.Tensor(np.expand_dims(np.array(hr), axis = 2 )) hr.requires_grad = True lr = bicubic(hr) print (lr.is_leaf) loss = torch.mean(lr) loss.backward() if __name__ = = '__main__' : main() |
要想實(shí)現(xiàn)自動(dòng)求導(dǎo),必須同時(shí)實(shí)現(xiàn)forward(),backward()兩個(gè)函數(shù)。
1、從代碼中可以看出來(lái),forward()函數(shù)是針對(duì)numpy數(shù)據(jù)操作,返回值再重新指定為torch.Tensor類型。因此就有這個(gè)問(wèn)題出現(xiàn)了:forward輸入input被轉(zhuǎn)換為numpy類型,輸出轉(zhuǎn)換為tensor類型,那么輸出output的grad_fn參數(shù)是如何指定的呢。調(diào)試發(fā)現(xiàn),當(dāng)main()中hr的requires_grad被指定為True,即hr被指定為需要求導(dǎo)的葉子節(jié)點(diǎn)。只要Bicubic類繼承自torch.autograd.Function,那么output也就是代碼中的lr的grad_fn就會(huì)被指定為<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic這個(gè)類。
2、backward()為求導(dǎo)的函數(shù),gard_output是鏈?zhǔn)角髮?dǎo)法則的上一級(jí)的梯度,grad_input即為我們想要得到的梯度。只需要在輸入指定grad_output,在調(diào)用loss.backward()過(guò)程中的某一步會(huì)執(zhí)行到Bicubic的backwward()函數(shù)
以上這篇pytorch中的自定義反向傳播,求導(dǎo)實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持服務(wù)器之家。
原文鏈接:https://blog.csdn.net/xuxiaoyuxuxiaoyu/article/details/86737492