在遷移學習finetune時我們通常需要凍結前幾層的參數不參與訓練,在Pytorch中的實現如下:
1
2
3
4
5
6
7
8
9
|
class Model(nn.Module): def __init__( self ): super (Transfer_model, self ).__init__() self .linear1 = nn.Linear( 20 , 50 ) self .linear2 = nn.Linear( 50 , 20 ) self .linear3 = nn.Linear( 20 , 2 ) def forward( self , x): pass |
假如我們想要凍結linear1層,需要做如下操作:
1
2
3
4
5
6
|
model = Model() # 這里是一般情況,共享層往往不止一層,所以做一個for循環 for para in model.linear1.parameters(): para.requires_grad = False # 假如真的只有一層也可以這樣操作: # model.linear1.weight.requires_grad = False |
最后我們需要將需要優化的參數傳入優化器,不需要傳入的參數過濾掉,所以要用到filter()函數。
1
|
optimizer = optim.Adam( filter ( lambda p: p.requires_grad, model.parameters()), lr = 0.1 ) |
其它的博客中都沒有講解filter()函數的作用,在這里我簡單講一下有助于更好的理解。
filter(function, iterable)
- function: 判斷函數
- iterable: 可迭代對象
filter() 函數用于過濾序列,過濾掉不符合條件的元素,返回一個迭代器對象,如果要轉換為列表,可以使用 list() 來轉換。
該接收兩個參數,第一個為函數,第二個為序列,序列的每個元素作為參數傳遞給函數進行判,然后返回 True 或 False,最后將返回 True 的元素放到新列表中。
filter()函數將requires_grad = True的參數傳入優化器進行反向傳播,requires_grad = False的則被過濾掉。
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持服務器之家。
原文鏈接:https://blog.csdn.net/qq_40210586/article/details/103878155