国产片侵犯亲女视频播放_亚洲精品二区_在线免费国产视频_欧美精品一区二区三区在线_少妇久久久_在线观看av不卡

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - pytorch加載預訓練模型與自己模型不匹配的解決方案

pytorch加載預訓練模型與自己模型不匹配的解決方案

2021-11-01 10:00找不到服務器1703 Python

這篇文章主要介紹了pytorch加載預訓練模型與自己模型不匹配的解決方案,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教

pytorch中如果自己搭建網絡并且加載別人的與訓練模型的話,如果模型和參數不嚴格匹配,就可能會出問題,接下來記錄一下我的解決方法。

兩個有序字典找不同

模型的參數和pth文件的參數都是有序字典(OrderedDict),把字典中的鍵轉為列表就可以在for循環里迭代找不同了。

?
1
2
3
4
5
6
7
8
9
10
11
model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事項

搭網絡時要對照pth文件的字典順序搭,字典順序、權重尺寸(shape)和變量命名必須與pth文件完全一致。如果僅僅是變量命名不同,可采用類似的方法對模型的權重重新賦值。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代碼見自己搭建resnet18網絡并加載torchvision自帶權重

新增的改進代碼

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因為模型不匹配,運行第14行語句后,可看自己情況手動對m或n加上1。

補充:pytorch的一些坑:用預訓練的vgg模型的部分層的特征報錯,如張量不匹配

看代碼吧~

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#打算取VGG19的第二個全連接層的輸出,那么就需要構建一個類,這個類要包含VGG的全部卷積層,
#以及到第二個全連接層的全部網絡還有他們對應的參數
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear',
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x) 
        x = self.vgg_class(x)  #執行這部報錯,說張量不匹配

原因是因為卷積層的輸出不能直接連接全連接層,即使輸出的張量的總的大小是一致的

查看vgg的pytorch源碼發現是

?
1
2
3
4
5
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代碼沒有torch.flatten(x, 1)這步

所以自己的少了一步

?
1
x = torch.flatten(x, 1)

補上就好了!

以上為個人經驗,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/qq_34288751/article/details/114160725

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 色成人亚洲www78ixcom | 高清一区二区三区 | 精品欧美一区二区久久久伦 | 亚洲精品视频网 | 精品一区二区三区在线观看 | 国产九九九| 精品日韩一区二区 | 欧美 日韩 综合 | 日本精品久久 | 久久国产精品久久久久久 | 国产女人爽到高潮免费视频 | 99久久99久久精品 | 国产精品久久精品 | 国产精品爱久久久久久久 | 高清一区二区三区日本久 | 久久亚洲一区二区三区四区 | 国产亚洲精品精品国产亚洲综合 | 亚洲国产二区 | a视频在线观看 | a久久| 日韩在线观看中文字幕 | 欧美成人精品激情在线观看 | 亚洲国产99 | 亚州国产 | 日韩精品在线一区 | 国产精品久久久久久久午夜 | 91av电影在线观看 | 欧美一级在线观看 | 综合精品久久久 | 国产婷婷精品av在线 | 免费色网站 | 在线观看视频一区 | 国内自拍视频在线观看 | 国产探花在线精品一区二区 | 国产精品激情在线观看 | 免费成人一级片 | 久久精品国产99 | 国产精品极品美女在线观看免费 | 99久久精品一区二区成人 | 中文字幕欧美日韩 | 欧美日韩91 |