Torch 為了提高速度,向量或是矩陣的賦值是指向同一內(nèi)存的,這不同于 Matlab。如果需要保存舊的tensor即需要開(kāi)辟新的存儲(chǔ)地址而不是引用,可以用 clone() 進(jìn)行深拷貝,
首先我們來(lái)打印出來(lái)clone()操作后的數(shù)據(jù)類型定義變化:
(1). 簡(jiǎn)單打印類型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import torch a = torch.tensor( 1.0 , requires_grad = True ) b = a.clone() c = a.detach() a.data * = 3 b + = 1 print (a) # tensor(3., requires_grad=True) print (b) print (c) ''' 輸出結(jié)果: tensor(3., requires_grad=True) tensor(2., grad_fn=<AddBackward0>) tensor(3.) # detach()后的值隨著a的變化出現(xiàn)變化 ''' |
grad_fn=<CloneBackward>,表示clone后的返回值是個(gè)中間變量,因此支持梯度的回溯。clone操作在一定程度上可以視為是一個(gè)identity-mapping函數(shù)。
detach()操作后的tensor與原始tensor共享數(shù)據(jù)內(nèi)存,當(dāng)原始tensor在計(jì)算圖中數(shù)值發(fā)生反向傳播等更新之后,detach()的tensor值也發(fā)生了改變。
注意: 在pytorch中我們不要直接使用id是否相等來(lái)判斷tensor是否共享內(nèi)存,這只是充分條件,因?yàn)橐苍S底層共享數(shù)據(jù)內(nèi)存,但是仍然是新的tensor,比如detach(),如果我們直接打印id會(huì)出現(xiàn)以下情況。
1
2
3
4
5
6
7
8
|
import torch as t a = t.tensor([ 1.0 , 2.0 ], requires_grad = True ) b = a.detach() #c[:] = a.detach() print ( id (a)) print ( id (b)) #140568935450520 140570337203616 |
顯然直接打印出來(lái)的id不等,我們可以通過(guò)簡(jiǎn)單的賦值后觀察數(shù)據(jù)變化進(jìn)行判斷。
(2). clone()的梯度回傳
detach()函數(shù)可以返回一個(gè)完全相同的tensor,與舊的tensor共享內(nèi)存,脫離計(jì)算圖,不會(huì)牽扯梯度計(jì)算。
而clone充當(dāng)中間變量,會(huì)將梯度傳給源張量進(jìn)行疊加,但是本身不保存其grad,即值為None
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
import torch a = torch.tensor( 1.0 , requires_grad = True ) a_ = a.clone() y = a * * 2 z = a * * 2 + a_ * 3 y.backward() print (a.grad) # 2 z.backward() print (a_.grad) # None. 中間variable,無(wú)grad print (a.grad) ''' 輸出: tensor(2.) None tensor(7.) # 2*2+3=7 ''' |
使用torch.clone()獲得的新tensor和原來(lái)的數(shù)據(jù)不再共享內(nèi)存,但仍保留在計(jì)算圖中,clone操作在不共享數(shù)據(jù)內(nèi)存的同時(shí)支持梯度梯度傳遞與疊加,所以常用在神經(jīng)網(wǎng)絡(luò)中某個(gè)單元需要重復(fù)使用的場(chǎng)景下。
通常如果原tensor的requires_grad=True,則:
- clone()操作后的tensor requires_grad=True
- detach()操作后的tensor requires_grad=False。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
import torch torch.manual_seed( 0 ) x = torch.tensor([ 1. , 2. ], requires_grad = True ) clone_x = x.clone() detach_x = x.detach() clone_detach_x = x.clone().detach() f = torch.nn.Linear( 2 , 1 ) y = f(x) y.backward() print (x.grad) print (clone_x.requires_grad) print (clone_x.grad) print (detach_x.requires_grad) print (clone_detach_x.requires_grad) ''' 輸出結(jié)果如下: tensor([-0.0053, 0.3793]) True None False False ''' |
另一個(gè)比較特殊的是當(dāng)源張量的 require_grad=False,clone后的張量 require_grad=True,此時(shí)不存在張量回傳現(xiàn)象,可以得到clone后的張量求導(dǎo)。
如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
import torch a = torch.tensor( 1.0 ) a_ = a.clone() a_.requires_grad_() #require_grad=True y = a_ * * 2 y.backward() print (a.grad) # None print (a_.grad) ''' 輸出: None tensor(2.) ''' |
了解了兩者的區(qū)別后我們常與其他函數(shù)進(jìn)行搭配使用,實(shí)現(xiàn)數(shù)據(jù)拷貝后的其他需要。
比如我們經(jīng)常使用view()函數(shù)對(duì)tensor進(jìn)行reshape操作。返回的新Tensor與源Tensor可能有不同的size,但是是共享data的,即其中的一個(gè)發(fā)生變化,另外一個(gè)也會(huì)跟著改變。
需要注意的是view返回的Tensor與源Tensor是共享data的,但是依然是一個(gè)新的Tensor(因?yàn)門ensor除了包含data外還有一些其他屬性),兩者id(內(nèi)存地址)并不一致。
1
2
3
4
5
|
x = torch.rand( 2 , 2 ) y = x.view( 4 ) x + = 1 print (x) print (y) # 也加了1 |
view() 僅僅是改變了對(duì)這個(gè)張量的觀察角度,內(nèi)部數(shù)據(jù)并未改變。這時(shí)候想返回一個(gè)真正新的副本(即不共享data內(nèi)存)該怎么辦呢?Pytorch還提供了一個(gè)reshape()可以改變形狀,但是此函數(shù)并不能保證返回的是其拷貝,所以不推薦使用。推薦先用clone創(chuàng)造一個(gè)副本然后再使用view。參考此處
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
x = torch.rand( 2 , 2 ) x_cp = x.clone().view( 4 ) x + = 1 print ( id (x)) print ( id (x_cp)) print (x) print (x_cp) ''' 140568935036464 140568935035816 tensor([[0.4963, 0.7682], [0.1320, 0.3074]]) tensor([[1.4963, 1.7682, 1.1320, 1.3074]]) ''' |
另外使用clone()會(huì)被記錄在計(jì)算圖中,即梯度回傳到副本時(shí)也會(huì)傳到源Tensor。在上一篇中有總結(jié)。
總結(jié):
- torch.detach() — 新的tensor會(huì)脫離計(jì)算圖,不會(huì)牽扯梯度計(jì)算
-
torch.clone() — 新的tensor充當(dāng)中間變量,會(huì)保留在計(jì)算圖中,參與梯度計(jì)算(回傳疊加),但是一般不會(huì)保留自身梯度。
原地操作(in-place, such as resize_ / resize_as_ / set_ / transpose_) 在上面兩者中執(zhí)行都會(huì)引發(fā)錯(cuò)誤或者警告。 - 共享數(shù)據(jù)內(nèi)存是底層設(shè)計(jì),并不能簡(jiǎn)單的通過(guò)直接打印tensor的id地址進(jìn)行判斷,需要在進(jìn)行賦值或運(yùn)算操作后打印比較數(shù)據(jù)的變化進(jìn)行判斷。
- 復(fù)制操作可以根據(jù)實(shí)際需要進(jìn)行結(jié)合使用。
引用官方文檔的話:如果你使用了in-place operation而沒(méi)有報(bào)錯(cuò)的話,那么你可以確定你的梯度計(jì)算是正確的。另外盡量避免in-place的使用。
像y = x + y這樣的運(yùn)算會(huì)新開(kāi)內(nèi)存,然后將y指向新內(nèi)存。我們可以使用Python自帶的id函數(shù)進(jìn)行驗(yàn)證:如果兩個(gè)實(shí)例的ID相同,則它們所對(duì)應(yīng)的內(nèi)存地址相同。
到此這篇關(guān)于PyTorch中clone()、detach()及相關(guān)擴(kuò)展詳解的文章就介紹到這了,更多相關(guān)PyTorch中clone()、detach()及相關(guān)擴(kuò)展內(nèi)容請(qǐng)搜索服務(wù)器之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持服務(wù)器之家!
原文鏈接:https://blog.csdn.net/weixin_43199584/article/details/106876679