国产片侵犯亲女视频播放_亚洲精品二区_在线免费国产视频_欧美精品一区二区三区在线_少妇久久久_在线观看av不卡

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例

pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例

2020-04-29 09:44xckkcxxck Python

今天小編就為大家分享一篇pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數(shù)據(jù)的格式。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定義數(shù)據(jù)
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定義模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結果
     
   def forward(self, x):
     #x的大小為(batch,1,28,28),所以我們需要將其轉(zhuǎn)化為rnn的輸入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28)
     x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28)
     out, _ = self.rnn(x) #使用默認的隱藏狀態(tài),得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分類結果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定義訓練過程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)   

以上這篇pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/xckkcxxck/article/details/82978942

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 精品国产乱码久久久久夜 | 成年人黄色一级片 | 午夜久久久久 | 亚洲高清av | 国产精品中文字幕在线观看 | 中文字幕观看 | 色爱区综合五月激情 | 国产精品久久久久久久久久久久久 | 日韩在线精品强乱中文字幕 | 久久成人精品视频 | 欧美视频区 | 91春色| 欧美色影院 | 国产精品一区二 | 成人国产精品一级毛片视频 | 中文字幕国产视频 | 成人欧美一区二区 | 精品视频网站 | av久久 | 91精品视频在线 | 欧美一区二区在线视频 | 中文字幕综合 | 精品国产一级毛片 | 国产精品美乳一区二区免费 | 国产精品久久久久久久久久 | 欧美日韩国产一区二区三区 | 日韩精品一区在线视频 | 在线一区二区三区四区 | 一区二区三区四区在线视频 | 成人福利视频网 | 综合九九| 中文字幕国产一区二区 | 欧美 亚洲 一区 | 欧美综合激情 | 一区二区三区在线看 | 国产成年人电影在线观看 | 韩日一区二区 | 国产日韩欧美一区二区 | 国产成人av在线播放 | 午夜精品久久久久久久久久久久 | 亚洲精品国产精品国自产在线 |