pytorch中如果自己搭建網絡并且加載別人的與訓練模型的話,如果模型和參數不嚴格匹配,就可能會出問題,接下來記錄一下我的解決方法。
兩個有序字典找不同
模型的參數和pth文件的參數都是有序字典(OrderedDict),把字典中的鍵轉為列表就可以在for循環里迭代找不同了。
1
2
3
4
5
6
7
8
9
10
11
|
model = ResNet18( 1 ) model_dict1 = torch.load( 'resnet18.pth' ) model_dict2 = model.state_dict() model_list1 = list (model_dict1.keys()) model_list2 = list (model_dict2.keys()) len1 = len (model_list1) len2 = len (model_list2) minlen = min (len1, len2) for n in range (minlen): if model_dict1[model_list1[n]].shape ! = model_dict2[model_list2[n]].shape: err = 1 |
自己搭建模型的注意事項
搭網絡時要對照pth文件的字典順序搭,字典順序、權重尺寸(shape)和變量命名必須與pth文件完全一致。如果僅僅是變量命名不同,可采用類似的方法對模型的權重重新賦值。
1
2
3
4
5
6
7
8
9
10
11
12
13
|
model = ResNet18( 1 ) model_dict1 = torch.load( 'resnet18.pth' ) model_dict2 = model.state_dict() model_list1 = list (model_dict1.keys()) model_list2 = list (model_dict2.keys()) len1 = len (model_list1) len2 = len (model_list2) minlen = min (len1, len2) for n in range (minlen): if model_dict1[model_list1[n]].shape ! = model_dict2[model_list2[n]].shape: continue model_dict1[model_list1[n]] = model_dict2[model_list2[n]] model.load_state_dict(model_dict2) |
完整的代碼見自己搭建resnet18網絡并加載torchvision自帶權重
新增的改進代碼
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
model_dict1 = torch.load( 'yolov5.pth' ) model_dict2 = model.state_dict() model_list1 = list (model_dict1.keys()) model_list2 = list (model_dict2.keys()) len1 = len (model_list1) len2 = len (model_list2) m, n = 0 , 0 while True : if m > = len1 or n > = len2: break layername1, layername2 = model_list1[m], model_list2[n] w1, w2 = model_dict1[layername1], model_dict2[layername2] if w1.shape ! = w2.shape: continue model_dict2[layername2] = model_dict1[layername1] m + = 1 n + = 1 model.load_state_dict(model_dict2) |
如果因為模型不匹配,運行第14行語句后,可看自己情況手動對m或n加上1。
補充:pytorch的一些坑:用預訓練的vgg模型的部分層的特征報錯,如張量不匹配
看代碼吧~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
#打算取VGG19的第二個全連接層的輸出,那么就需要構建一個類,這個類要包含VGG的全部卷積層, #以及到第二個全連接層的全部網絡還有他們對應的參數 class Classification_att(nn.Module): def __init__( self , rgb_range): super (Classification_att, self ).__init__() self .vgg19 = models.vgg19(pretrained = True ) vgg = models.vgg19(pretrained = True ).features conv_modules = [m for m in vgg] self .vgg_conv = nn.Sequential( * conv_modules[: 37 ]) classfi = models.vgg19(pretrained = True ).classifier classif_modules = [n for n in classfi] self .vgg_class = nn.Sequential( * classif_modules[: 4 ]) vgg_mean = ( 0.485 , 0.456 , 0.406 ) vgg_std = ( 0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self .sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) for p in self .vgg_conv.parameters(): p.requires_grad = False for p in self .vgg_class.parameters(): p.requires_grad = False self .classifi = nn.Sequential( nn.Linear( 4096 , 1024 ), nn.ReLU( True ), nn.Linear( 1024 , 256 ), nn.ReLU( True ), nn.Linear( 256 , 64 ), ) def forward( self , x): x = F.interpolate(x, size = [ 224 , 224 ], scale_factor = None , mode = 'bilinear' , align_corners = False ) x = self .sub_mean(x) x = self .vgg_conv(x) x = self .vgg_class(x) #執行這部報錯,說張量不匹配 |
原因是因為卷積層的輸出不能直接連接全連接層,即使輸出的張量的總的大小是一致的
查看vgg的pytorch源碼發現是
1
2
3
4
5
|
x = self .features(x) x = self .avgpool(x) x = torch.flatten(x, 1 ) x = self .classifier(x) #自己的代碼沒有torch.flatten(x, 1)這步 |
所以自己的少了一步
1
|
x = torch.flatten(x, 1 ) |
補上就好了!
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持服務器之家。
原文鏈接:https://blog.csdn.net/qq_34288751/article/details/114160725