【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

共 14778字,需浏览 30分钟

 ·

2021-08-15 11:58

「@Author:Runsen」

GAN 是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。

生成性对抗网络

2014,蒙特利尔大学的Ian Goodfellow和他的朋友发明了生成性对抗网络(GAN)。自它出版以来,有许多它的变体和客观功能来解决它的问题

论文在这里找到.

论文提出了两种模型:生成模型和判别模型。两个模型竞争,以产生真实和假的样本。2016年,Yann LeCun将GANs描述为“过去二十年机器学习中最酷的想法”。

GAN 的大部分研究和应用都集中在计算机视觉领域。

其原因是卷积神经网络 (CNN) 等深度学习模型在过去 5 到 7 年中在计算机视觉领域取得了巨大成功,例如在具有挑战性的任务(如对象检测和人脸识别。

GAN 的典型例子是生成新的逼真的照片,最令人吃惊的是生成照片般逼真的人脸的例子。

在本教程中,我们将实现一个简单的GAN生成假的MNIST样本。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils

import numpy as np
import matplotlib.pyplot as plt
# CPU / GPU Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)  #cuda

使用MNIST数据集,具有最小大小的数据集。

它由60000个训练图像和10000个测试图像组成,每个图像有28*28的大小和一个彩色通道。

# Define a transform 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, ), std = (0.5, ))
])

# batch_size是一个前向和后向传播过程中的图像数。
batch_size = 100

mnist = datasets.MNIST('./data/MNIST'
                       download = True
                       train = True
                       transform = transform)

mnist_loader = DataLoader(dataset = mnist, 
                          batch_size = batch_size, 
                          shuffle = True)
# CPU
def imshow(img, title):
    img = utils.make_grid(img.cpu().detach())
    img = (img+1)/2
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (120)))
    plt.title(title)
    plt.show()
#GPU
def imshow(img, title):
    npimg = img.detach().numpy()
    fig = plt.figure(figsize = (1010))
    plt.imshow(np.transpose(npimg, (120)))
    plt.title(title)
    plt.show()

images, labels = iter(mnist_loader).next()
imshow(images[0:16, :, :], "MNIST Images")

建立一个GANs模型。一个Generator和Discriminator

GANs由完全连接的层组成。它将从100维高斯分布采样的噪声转换为MNIST图像。鉴别器网络也由完全连接的层组成,用于区分输入数据是真是假。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        latent_size = 100
        output = 28*28
        
        self.main = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, output),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.main(x)
        out = out.view(-112828)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        n_features = 28 * 28
        n_out = 1
        
        self.main = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256128),
            nn.ReLU(inplace=True),
            
            nn.Linear(12864),
            nn.ReLU(inplace=True),
            
            nn.Linear(64, n_out),
            nn.Sigmoid()        
        )
        
    def forward(self, x):
        x = x.view(-128*28)
        out = self.main(x)
        return out

G = Generator().to(device)
D = Discriminator().to(device)

生成性对抗网络训练过程的损失函数是二进制交叉熵损失,由torch.nn.BCELoss实现。

这两种模型都使用torch.optim.Adam作为优化工具,学习率设置为0.002。

# Objective Function
criterion = nn.BCELoss()

# Optimizer
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)

# Constants
noise_dim = 100
num_epochs = 50
total_batch = len(mnist_loader)

# Lists
G_losses = []
D_losses = []

# Noise
sample_size = 16
fixed_noise = torch.randn(sample_size, noise_dim).to(device)

# Train
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(mnist_loader):
        
        # Images #
        images = images.reshape(batch_size, -1).float().to(device)
        
        # Labels #
        ones = torch.ones(batch_size, 1).to(device)
        zeros = torch.zeros(batch_size, 1).to(device)
        
        # Noise #
        noise = torch.randn(batch_size, noise_dim).to(device)
        
        # Initialize Optimizers
        D_optimizer.zero_grad()
        G_optimizer.zero_grad()
        
        #######################
        # Train Discriminator #
        #######################
        
        # Forward Images #
        prob_real = D(images)
        D_real_loss = criterion(prob_real, ones)
        
        # Generate Samples #
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # Forward Fake Samples and Calculate Discriminator Loss #
        D_fake_loss = criterion(prob_fake, zeros)
        D_loss = (D_real_loss + D_fake_loss).mean()
        
        # Back Propagation and Update
        D_loss.backward()
        D_optimizer.step()
        
        ###################
        # Train Generator #
        ###################
        
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # According to the section 3 in paper,
        # early in learning, when G is very poor, D can reject samples from G.
        # In this case, log(1-D(G(z))) saturates. 
        # thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))
        G_loss = criterion(prob_fake, ones)
        
        # Back Propagation and Update
        G_loss.backward()
        G_optimizer.step()
        
        # Save Losses for Plotting Later
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())
        
        # Print Statistics #
        if (i + 1) % 100 == 0:
            print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"
                  %(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))
    
    # Generate Samples #
    if epoch % 1 == 0:
        fake_samples = G(fixed_noise)
        imshow(fake_samples, "Generated MNIST Images")
    
# Save Model Weights for Digit Generation
torch.save(G.state_dict(), './data/GAN.pkl')
plt.figure(figsize = (86))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Losses")
plt.legend()
plt.show()
sample_size = 64
noise_dim = 100

noise = torch.randn(sample_size, noise_dim).to(device)

G.load_state_dict(torch.load('GAN.pkl'))
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")

GAN生成性对抗网络的运用

  • 将语义图像翻译成城市景观和建筑物的照片。
  • 将卫星照片翻译成地图。
  • 从白天到晚上的照片翻译。
  • 将黑白照片翻译成彩色。

- 论文在这里找到:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

- 上述代码的论文:https://arxiv.org/abs/1511.06434

- 上述代码:https://github.com/yihui-he/GAN-MNIST

往期精彩回顾




本站qq群851320808,加入微信群请扫码:
浏览 39
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报