国产片侵犯亲女视频播放_亚洲精品二区_在线免费国产视频_欧美精品一区二区三区在线_少妇久久久_在线观看av不卡

腳本之家,腳本語言編程技術(shù)及教程分享平臺(tái)!
分類導(dǎo)航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務(wù)器之家 - 腳本之家 - Python - PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解

PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解

2022-02-14 20:46軟耳朵DONG Python

PyTorch是一個(gè)開源的Python機(jī)器學(xué)習(xí)庫,基于Torch,用于自然語言處理等應(yīng)用程序。2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch,這篇文章主要介紹了PyTorch模型的保存和加載流程

一、模型參數(shù)的保存和加載

  •  torch.save(module.state_dict(), path):使用module.state_dict()函數(shù)獲取各層已經(jīng)訓(xùn)練好的參數(shù)和緩沖區(qū),然后將參數(shù)和緩沖區(qū)保存到path所指定的文件存放路徑(常用文件格式為.pt.pth.pkl)。
  • torch.nn.Module.load_state_dict(state_dict):從state_dict中加載參數(shù)和緩沖區(qū)到Module及其子類中 。
  • torch.nn.Module.state_dict()函數(shù)返回python中的一個(gè)OrderedDict類型字典對(duì)象,該對(duì)象將每一層與它的對(duì)應(yīng)參數(shù)和緩沖區(qū)建立映射關(guān)系,字典的鍵值是參數(shù)或緩沖區(qū)的名稱。只有那些參數(shù)可以訓(xùn)練的層才會(huì)被保存到OrderedDict中,例如:卷積層、線性層等。
  • Python中的字典類以“鍵:值”方式存取數(shù)據(jù),OrderedDict是它的一個(gè)子類,實(shí)現(xiàn)了對(duì)字典對(duì)象中元素的排序(OrderedDict根據(jù)放入元素的先后順序進(jìn)行排序)。由于進(jìn)行了排序,所以順序不同的兩個(gè)OrderedDict字典對(duì)象會(huì)被當(dāng)做是兩個(gè)不同的對(duì)象。
  • 示例:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

# 初始化網(wǎng)絡(luò)
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 獲取state_dict
state_dict = net.state_dict()
# 字典的遍歷默認(rèn)是遍歷key,所以param_tensor實(shí)際上是鍵值
for param_tensor in state_dict: 
    print(param_tensor,":
",state_dict[param_tensor])
# 保存模型參數(shù)
torch.save(state_dict,"net_params.pth")
# 通過加載state_dict獲取模型參數(shù)
net.load_state_dict(state_dict)

輸出:

PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解

二、完整模型的保存和加載

  •  torch.save(module, path):將訓(xùn)練完的整個(gè)網(wǎng)絡(luò)模型module保存到path所指定的文件存放路徑(常用文件格式為.pt.pth)。
  • torch.load(path):加載保存到path中的整個(gè)神經(jīng)網(wǎng)絡(luò)模型。
  • 示例:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

# 初始化網(wǎng)絡(luò)
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整個(gè)網(wǎng)絡(luò)
torch.save(net,"net.pth")
# 加載網(wǎng)絡(luò)
net = torch.load("net.pth")

到此這篇關(guān)于PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解的文章就介紹到這了,更多相關(guān)PyTorch 模型的保存 內(nèi)容請(qǐng)搜索服務(wù)器之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持服務(wù)器之家!

原文鏈接:https://blog.csdn.net/m0_52650517/article/details/120836999

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 亚洲国产精品久久久 | 久久手机免费视频 | 国产va在线| 日韩欧美精品在线 | 亚洲www视频| 亚洲精品电影在线观看 | 国产一区二区h | 这里只有精品视频 | 午夜播放器在线观看 | 91精品国产综合久久久久久 | 高清国产午夜精品久久久久久 | 亚洲视频欧美视频 | 中国女人真人一级毛片 | 国产视频1区 | 一区二区国产在线观看 | 日韩精品一区二区在线观看视频 | 激情网页 | 日韩欧美一级电影 | 国产91久久精品一区二区 | 欧州一区二区三区 | 国产一区二区欧美 | 黄频免费在线观看 | 91综合网| 亚洲国产精品美女 | 欧美视频二区 | 久久人人爽人人爽人人片av不 | 国产精品视频一 | 成人午夜视频在线观看 | 九九热欧美| 成人国产精品久久久 | 国产成人精品久久二区二区 | 国产一卡二卡三卡 | 国产精品免费观看 | 九色影院| 欧美精品在线一区 | а√天堂资源中文最新版地址 | 亚洲国产一区在线 | 国产精品久久久久久久久久小说 | 国产黄色在线观看 | 久久久久久久久久久福利观看 | 欧美中文在线 |