国产片侵犯亲女视频播放_亚洲精品二区_在线免费国产视频_欧美精品一区二区三区在线_少妇久久久_在线观看av不卡

腳本之家,腳本語言編程技術及教程分享平臺!
分類導航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服務器之家 - 腳本之家 - Python - pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

2020-04-29 10:14ZJE_ANDY Python

今天小編就為大家分享一篇pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

一,mnist數(shù)據(jù)集

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

形如上圖的數(shù)字手寫體就是mnist數(shù)據(jù)集。

二,GAN原理(生成對抗網(wǎng)絡)

GAN網(wǎng)絡一共由兩部分組成:一個是偽造器(Generator,簡稱G),一個是判別器(Discrimniator,簡稱D)

一開始,G由服從某幾個分布(如高斯分布)的噪音組成,生成的圖片不斷送給D判斷是否正確,直到G生成的圖片連D都判斷以為是真的。D每一輪除了看過G生成的假圖片以外,還要見數(shù)據(jù)集中的真圖片,以前者和后者得到的損失函數(shù)值為依據(jù)更新D網(wǎng)絡中的權值。因此G和D都在不停地更新權值。以下圖為例:

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

在v1時的G只不過是 一堆噪聲,見過數(shù)據(jù)集(real images)的D肯定能判斷出G所生成的是假的。當然G也能知道D判斷它是假的這個結果,因此G就會更新權值,到v2的時候,G就能生成更逼真的圖片來讓D判斷,當然在v2時D也是會先看一次真圖片,再去判斷G所生成的圖片。以此類推,不斷循環(huán)就是GAN的思想。

三,訓練代碼

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size) # 確定圖片輸入的格式為(1,28,28),由于mnist數(shù)據(jù)集是灰度圖所以通道為1
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()
 
  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers
 
  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
   nn.Tanh()
  )
 
 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img
 
 
class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()
 
  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
   nn.Sigmoid(),
  )
 
 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity
 
 
# Loss function
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
 generator.cuda()
 discriminator.cuda()
 adversarial_loss.cuda()
 
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
 datasets.MNIST(
  "../../data/mnist",
  train=True,
  download=True,
  transform=transforms.Compose(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  ),
 ),
 batch_size=opt.batch_size,
 shuffle=True,
)
 
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 全1
   fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 全0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))
 
   # -----------------
   # Train Generator
   # -----------------
 
   optimizer_G.zero_grad() # 清空G網(wǎng)絡 上一個batch的梯度
 
   # Sample noise as generator input
   z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成的噪音,均值為0方差為1維度為(64,100)的噪音
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)
 
   g_loss.backward() # g_loss用于更新G網(wǎng)絡的權值,g_loss于D網(wǎng)絡的判斷結果 有關
   optimizer_G.step()
 
   # ---------------------
   # Train Discriminator
   # ---------------------
 
   optimizer_D.zero_grad() # 清空D網(wǎng)絡 上一個batch的梯度
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2
 
   d_loss.backward() # d_loss用于更新D網(wǎng)絡的權值
   optimizer_D.step()
 
   print(
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   )
 
   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存一個batchsize中的25張
   if (epoch+1) %2 ==0:
    print('save..')
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

運行結果:

一開始時,G生成的全是雜音:

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

然后逐漸呈現(xiàn)數(shù)字的雛形:

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

最后一次生成的結果:

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

四,測試代碼:

導入最后保存生成器的模型:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g = torch.load('g199.pth') #導入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
 
z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #輸入的噪音
gen_imgs =g(z) #生產(chǎn)圖片
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

生成結果:

pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式

以上這篇pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。

原文鏈接:https://blog.csdn.net/u014453898/article/details/95044228

延伸 · 閱讀

精彩推薦
主站蜘蛛池模板: 亚洲成人一区二区三区 | 欧美区日韩区 | 久久99深爱久久99精品 | 人人射av | jizz亚洲女人高潮大叫 | 欧美,日韩,国产精品免费观看 | 日韩看片 | 成年人免费在线观看视频网站 | 操操操影院 | 亚洲视频在线播放 | 国产成人精品一区二区三区视频 | 在线视频亚洲 | 天天射天天干 | 五月婷婷中文 | 中文字幕在线精品 | 国产精品久久久久一区二区三区 | 黄色免费美女网站 | 国产亚洲精品美女久久久久久久久久 | 久久精品久久久久久 | 天堂中文网官网 | 久久人人网 | 青青草一区二区 | 成人国产免费视频 | 成人福利电影 | 第一色网站| 亚洲国产成人在线 | 国产色在线 | 欧美久久久久 | 中文在线观看视频 | 欧美视频一二三区 | 精品一区久久 | 欧美日韩国产影院 | 啵啵羞羞影院 | 国产欧美综合视频 | 中文字幕 日韩有码 | 精品日韩一区二区三区 | 久久久久久国产 | 久久香蕉网 | 国产福利电影 | 国产欧美精品一区二区三区 | 久久日韩 |