adversarial.py
4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import utility
from types import SimpleNamespace
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.dis = discriminator.Discriminator(args)
if gan_type == 'WGAN_GP':
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else:
optim_args = args
self.optimizer = utility.make_optimizer(optim_args, self.dis)
def forward(self, fake, real):
# updating discriminator...
self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k):
self.optimizer.zero_grad()
# d: B x 1 tensor
d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type == 'GAN':
loss_d = self.bce(d_real, d_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.dis(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update
self.loss += loss_d.item()
loss_d.backward(retain_graph=retain_graph)
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.dis.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
# updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type == 'GAN':
label_real = torch.ones_like(d_fake_bp)
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
loss_g = self.bce(better_fake, better_real)
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py