国产片侵犯亲女视频播放_亚洲精品二区_在线免费国产视频_欧美精品一区二区三区在线_少妇久久久_在线观看av不卡

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - pytorch:實現簡單的GAN示例(MNIST數據集)

pytorch:實現簡單的GAN示例(MNIST數據集)

2020-04-29 10:18xckkcxxck Python

今天小編就為大家分享一篇pytorch:實現簡單的GAN示例(MNIST數據集),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

我就廢話不多說了,直接上代碼吧!

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設置畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定義畫圖工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定義一個取樣的函數
  """Samples elements sequentially from some offset.
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果
show_images(imgs)
 
#判別網絡
def discriminator():
  net = nn.Sequential(   
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成網絡
def generator(noise_dim=NOISE_DIM): 
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判別器的 loss 就是將真實數據的得分判斷為 1,假的數據的得分判斷為 0,而生成器的 loss 就是將假的數據判斷為 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵損失函數
 
def discriminator_loss(logits_real, logits_fake): # 判別器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 來進行訓練,學習率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判別網絡
      real_data = Variable(x).view(bs, -1) # 真實數據
      logits_real = D_net(real_data) # 判別網絡得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數據
      logits_fake = D_net(fake_images) # 判別網絡得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 優化判別網絡
      
      # 生成網絡
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數據
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成網絡的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 優化生成網絡
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)     

以上這篇pytorch:實現簡單的GAN示例(MNIST數據集)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/xckkcxxck/article/details/83037025

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 国产婷婷精品av在线 | 国产精品久久久久久久久久东京 | 日韩免费在线 | 午夜精品福利在线观看 | 人人射人人舔 | 久久综合九九 | 日韩免费在线视频 | 伊人福利视频 | 91网页版 | 色噜噜狠狠狠综合曰曰曰88av | 国产精品美女久久久久aⅴ国产馆 | 国产精品日日 | 国产午夜一区二区三区 | 韩国毛片在线观看 | 九色在线 | 在线观看亚洲精品 | 久久亚洲精品中文字幕 | 亚洲男人在线天堂 | 久热国产视频 | 在线日韩欧美 | 久久国产精品久久久久久电车 | 在线免费国产 | 91免费观看视频 | 日韩欧美在线一区 | 精品久久精品久久 | 久久99精品国产麻豆婷婷洗澡 | 日本精品视频一区二区 | 高清一区在线观看 | 免费精品视频 | 另类视频区| 精品久久久久久国产 | 高清一区二区三区 | 日韩中文字幕一区 | av有声小说一区二区三区 | 日韩电影在线 | 久久久久久久国产视频 | 精品一区二区在线观看 | 亚州av| 欧美日本亚洲 | 欧美在线不卡视频 | 四虎影院在线免费播放 |