轻松学Pytorch – 构建生成对抗网络

共 5229字,需浏览 11分钟

 ·

2022-05-24 10:10

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

又好久没有继续写了,这个是我写的第21篇文章,我还在继续坚持写下去,虽然经常各种拖延症,但是我还记得,一直没有敢忘记!今天给大家分享一下Pytorch生成对抗网络代码实现。

 

01.什么是生成对抗网络


Ian J. Goodfellow在2014年提出生成对抗网络,从此打开了深度学习中另外一个重要分支,让生成对抗网络(GAN)成为与卷积神经网络(CNN)、循环神经网络(RNN/LSTM)可以并驾齐驱的分支领域。今天GAN仍然是计算机视觉领域研究热点之一,每年还有大量相关的论文产生,GAN已经被用在视觉任务的很多方面,主要包括:

  • 图像合成与数据增广

  • 图像翻译与变换

  • 缺陷检测

  • 图像去噪与重建

  • 图像分割

但是GAN最基本的核心思想还是2014年Ian J. Goodfellow在论文中提到的两个基本的模型分别是:生成器与判别器

生成器(G):

根据输入噪声Z生成输出样本G(z)目标:通过生成样本与目标样本分布一致,成功欺骗鉴别器

判别器(D):

根据输入样本数据来分辨真实样本概率从数据中学习样本数据的差异性

从a到d,可以看到输入噪声的生成分布越来越接近真实分布X,最终达到一种平衡状态,这种稳定的平衡状态叫纳什均衡,还有一部电影跟这个有关系叫《美丽心灵》。

 

02.GAN代码实现


下面的代码实现了基于Mnist数据集实现判别器与生成器,最终通过生成器可以自动生成手写数字识别的图像,输入的z=100是随机噪声,输出的是784个数据表示28x28大小的手写数字样本,损失主要来自两个部分,生成器生成损失,判别器分别判别真实与虚构样本概率,基于反向传播训练两个网络,设置epoch=100,得到最终的生成器生成结果如下:


生成器与判别器代码实现如下


判别器与生成器代码:(后面文字忽略)2004论文中提出,其主要思想可以通过下面一张图像解释:

 1transform = tv.transforms.Compose([tv.transforms.ToTensor(),
2                                   tv.transforms.Normalize((0.5,), (0.5,))])
3train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
4test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
5train_dl = DataLoader(train_ts, batch_size=128, shuffle=True, drop_last=False)
6test_dl = DataLoader(test_ts, batch_size=128, shuffle=True, drop_last=False)
7
8
9class Generator(t.nn.Module):
10    def __init__(self, g_input_dim, g_output_dim):
11        super(Generator, self).__init__()
12        self.fc1 = t.nn.Linear(g_input_dim, 256)
13        self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
14        self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
15        self.fc4 = t.nn.Linear(self.fc3.out_features, g_output_dim)
16
17    # forward method
18    def forward(self, x):
19        x = F.leaky_relu(self.fc1(x), 0.2)
20        x = F.leaky_relu(self.fc2(x), 0.2)
21        x = F.leaky_relu(self.fc3(x), 0.2)
22        return t.tanh(self.fc4(x))
23
24
25class Discriminator(t.nn.Module):
26    def __init__(self, d_input_dim):
27        super(Discriminator, self).__init__()
28        self.fc1 = t.nn.Linear(d_input_dim, 1024)
29        self.fc2 = t.nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
30        self.fc3 = t.nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
31        self.fc4 = t.nn.Linear(self.fc3.out_features, 1)
32
33    # forward method
34    def forward(self, x):
35        x = F.leaky_relu(self.fc1(x), 0.2)
36        x = F.dropout(x, 0.3)
37        x = F.leaky_relu(self.fc2(x), 0.2)
38        x = F.dropout(x, 0.3)
39        x = F.leaky_relu(self.fc3(x), 0.2)
40        x = F.dropout(x, 0.3)
41        return t.sigmoid(self.fc4(x))


损失与训练代码如下


分别定义生成网络训练与鉴别网络的训练方法,然后开始训练即可,代码实现如下:

 1# 生成者与判别者
2bs = 128
3z_dim = 100
4mnist_dim = 784
5# loss
6criterion = t.nn.BCELoss()
7
8# optimizer
9device = "cuda"
10gnet = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
11dnet = Discriminator(mnist_dim).to(device)
12lr = 0.0002
13G_optimizer = t.optim.Adam(gnet.parameters(), lr=lr)
14D_optimizer = t.optim.Adam(dnet.parameters(), lr=lr)
15
16
17def D_train(x):
18    # =======================Train the discriminator=======================#
19    dnet.zero_grad()
20
21    # train discriminator on real
22    x_real, y_real = x.view(-1, mnist_dim), t.ones(bs, 1)
23    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
24
25    D_output = dnet(x_real)
26    D_real_loss = criterion(D_output, y_real)
27
28    # train discriminator on facke
29    z = Variable(t.randn(bs, z_dim).to(device))
30    x_fake, y_fake = gnet(z), Variable(t.zeros(bs, 1).to(device))
31
32    D_output = dnet(x_fake)
33    D_fake_loss = criterion(D_output, y_fake)
34
35    # gradient backprop & optimize ONLY D's parameters
36    D_loss = D_real_loss + D_fake_loss
37    D_loss.backward()
38    D_optimizer.step()
39
40    return D_loss.data.item()
41
42
43def G_train(x):
44    # =======================Train the generator=======================#
45    gnet.zero_grad()
46
47    z = Variable(t.randn(bs, z_dim).to(device))
48    y = Variable(t.ones(bs, 1).to(device))
49
50    G_output = gnet(z)
51    D_output = dnet(G_output)
52    G_loss = criterion(D_output, y)
53
54    # gradient backprop & optimize ONLY G's parameters
55    G_loss.backward()
56    G_optimizer.step()
57
58    return G_loss.data.item()
59
60
61n_epoch = 100
62for epoch in range(1, n_epoch+1):
63    D_losses, G_losses = [], []
64    for batch_idx, (x, _) in enumerate(train_dl):
65        bs_, _,_,_ = x.size()
66        bs = bs_
67        D_losses.append(D_train(x))
68        G_losses.append(G_train(x))
69
70    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
71            (epoch), n_epoch, t.mean(t.FloatTensor(D_losses)), t.mean(t.FloatTensor(G_losses))))



下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 34
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报