一,mnist數(shù)據(jù)集
形如上圖的數(shù)字手寫體就是mnist數(shù)據(jù)集。
二,GAN原理(生成對抗網(wǎng)絡)
GAN網(wǎng)絡一共由兩部分組成:一個是偽造器(Generator,簡稱G),一個是判別器(Discrimniator,簡稱D)
一開始,G由服從某幾個分布(如高斯分布)的噪音組成,生成的圖片不斷送給D判斷是否正確,直到G生成的圖片連D都判斷以為是真的。D每一輪除了看過G生成的假圖片以外,還要見數(shù)據(jù)集中的真圖片,以前者和后者得到的損失函數(shù)值為依據(jù)更新D網(wǎng)絡中的權值。因此G和D都在不停地更新權值。以下圖為例:
在v1時的G只不過是 一堆噪聲,見過數(shù)據(jù)集(real images)的D肯定能判斷出G所生成的是假的。當然G也能知道D判斷它是假的這個結果,因此G就會更新權值,到v2的時候,G就能生成更逼真的圖片來讓D判斷,當然在v2時D也是會先看一次真圖片,再去判斷G所生成的圖片。以此類推,不斷循環(huán)就是GAN的思想。
三,訓練代碼
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
|
import argparse import os import numpy as np import math import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch os.makedirs( "images" , exist_ok = True ) parser = argparse.ArgumentParser() parser.add_argument( "--n_epochs" , type = int , default = 200 , help = "number of epochs of training" ) parser.add_argument( "--batch_size" , type = int , default = 64 , help = "size of the batches" ) parser.add_argument( "--lr" , type = float , default = 0.0002 , help = "adam: learning rate" ) parser.add_argument( "--b1" , type = float , default = 0.5 , help = "adam: decay of first order momentum of gradient" ) parser.add_argument( "--b2" , type = float , default = 0.999 , help = "adam: decay of first order momentum of gradient" ) parser.add_argument( "--n_cpu" , type = int , default = 8 , help = "number of cpu threads to use during batch generation" ) parser.add_argument( "--latent_dim" , type = int , default = 100 , help = "dimensionality of the latent space" ) parser.add_argument( "--img_size" , type = int , default = 28 , help = "size of each image dimension" ) parser.add_argument( "--channels" , type = int , default = 1 , help = "number of image channels" ) parser.add_argument( "--sample_interval" , type = int , default = 400 , help = "interval betwen image samples" ) opt = parser.parse_args() print (opt) img_shape = (opt.channels, opt.img_size, opt.img_size) # 確定圖片輸入的格式為(1,28,28),由于mnist數(shù)據(jù)集是灰度圖所以通道為1 cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__( self ): super (Generator, self ).__init__() def block(in_feat, out_feat, normalize = True ): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8 )) layers.append(nn.LeakyReLU( 0.2 , inplace = True )) return layers self .model = nn.Sequential( * block(opt.latent_dim, 128 , normalize = False ), * block( 128 , 256 ), * block( 256 , 512 ), * block( 512 , 1024 ), nn.Linear( 1024 , int (np.prod(img_shape))), nn.Tanh() ) def forward( self , z): img = self .model(z) img = img.view(img.size( 0 ), * img_shape) return img class Discriminator(nn.Module): def __init__( self ): super (Discriminator, self ).__init__() self .model = nn.Sequential( nn.Linear( int (np.prod(img_shape)), 512 ), nn.LeakyReLU( 0.2 , inplace = True ), nn.Linear( 512 , 256 ), nn.LeakyReLU( 0.2 , inplace = True ), nn.Linear( 256 , 1 ), nn.Sigmoid(), ) def forward( self , img): img_flat = img.view(img.size( 0 ), - 1 ) validity = self .model(img_flat) return validity # Loss function adversarial_loss = torch.nn.BCELoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Configure data loader os.makedirs( "../../data/mnist" , exist_ok = True ) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist" , train = True , download = True , transform = transforms.Compose( [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([ 0.5 ], [ 0.5 ])] ), ), batch_size = opt.batch_size, shuffle = True , ) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- if __name__ = = '__main__' : for epoch in range (opt.n_epochs): for i, (imgs, _) in enumerate (dataloader): # print(imgs.shape) # Adversarial ground truths valid = Variable(Tensor(imgs.size( 0 ), 1 ).fill_( 1.0 ), requires_grad = False ) # 全1 fake = Variable(Tensor(imgs.size( 0 ), 1 ).fill_( 0.0 ), requires_grad = False ) # 全0 # Configure input real_imgs = Variable(imgs. type (Tensor)) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # 清空G網(wǎng)絡 上一個batch的梯度 # Sample noise as generator input z = Variable(Tensor(np.random.normal( 0 , 1 , (imgs.shape[ 0 ], opt.latent_dim)))) # 生成的噪音,均值為0方差為1維度為(64,100)的噪音 # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator g_loss = adversarial_loss(discriminator(gen_imgs), valid) g_loss.backward() # g_loss用于更新G網(wǎng)絡的權值,g_loss于D網(wǎng)絡的判斷結果 有關 optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # 清空D網(wǎng)絡 上一個batch的梯度 # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() # d_loss用于更新D網(wǎng)絡的權值 optimizer_D.step() print ( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len (dataloader), d_loss.item(), g_loss.item()) ) batches_done = epoch * len (dataloader) + i if batches_done % opt.sample_interval = = 0 : save_image(gen_imgs.data[: 25 ], "images/%d.png" % batches_done, nrow = 5 , normalize = True ) # 保存一個batchsize中的25張 if (epoch + 1 ) % 2 = = 0 : print ( 'save..' ) torch.save(generator, 'g%d.pth' % epoch) torch.save(discriminator, 'd%d.pth' % epoch) |
運行結果:
一開始時,G生成的全是雜音:
然后逐漸呈現(xiàn)數(shù)字的雛形:
最后一次生成的結果:
四,測試代碼:
導入最后保存生成器的模型:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
from gan import Generator,Discriminator import torch import matplotlib.pyplot as plt from torch.autograd import Variable import numpy as np from torchvision.utils import save_image device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) Tensor = torch.cuda.FloatTensor g = torch.load( 'g199.pth' ) #導入生成器Generator模型 #d = torch.load('d.pth') g = g.to(device) #d = d.to(device) z = Variable(Tensor(np.random.normal( 0 , 1 , ( 64 , 100 )))) #輸入的噪音 gen_imgs = g(z) #生產(chǎn)圖片 save_image(gen_imgs.data[: 25 ], "images.png" , nrow = 5 , normalize = True ) |
生成結果:
以上這篇pytorch GAN偽造手寫體mnist數(shù)據(jù)集方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。
原文鏈接:https://blog.csdn.net/u014453898/article/details/95044228