参考资料:GAN.pdf
# 原理
# Q1. Where will D converge, given fixed G
# KL Divergence V.S. JS Divergence
# Q2. Where will G converge, after optimal D
# GAN 会遇到的问题
训练不稳定,两个分布没有重叠的话,生成器就会长时间得不到更新
# WGAN
通过引入一个惩罚项缓解无法训练的问题
# 代码实战
# GAN
import torch | |
from torch import nn, optim, autograd | |
import numpy as np | |
import visdom | |
from torch.nn import functional as F | |
from matplotlib import pyplot as plt | |
import random | |
h_dim = 400 | |
batchsz = 512 | |
viz = visdom.Visdom() | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.net = nn.Sequential( | |
nn.Linear(2, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, 2), | |
) | |
def forward(self, z): | |
output = self.net(z) | |
return output | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
self.net = nn.Sequential( | |
nn.Linear(2, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
output = self.net(x) | |
return output.view(-1) | |
def data_generator(): | |
scale = 2. | |
centers = [ | |
(1, 0), | |
(-1, 0), | |
(0, 1), | |
(0, -1), | |
(1. / np.sqrt(2), 1. / np.sqrt(2)), | |
(1. / np.sqrt(2), -1. / np.sqrt(2)), | |
(-1. / np.sqrt(2), 1. / np.sqrt(2)), | |
(-1. / np.sqrt(2), -1. / np.sqrt(2)) | |
] | |
centers = [(scale * x, scale * y) for x, y in centers] | |
while True: | |
dataset = [] | |
for i in range(batchsz): | |
point = np.random.randn(2) * .02 | |
center = random.choice(centers) | |
point[0] += center[0] | |
point[1] += center[1] | |
dataset.append(point) | |
dataset = np.array(dataset, dtype='float32') | |
dataset /= 1.414 # stdev | |
yield dataset | |
# for i in range(100000//25): | |
# for x in range(-2, 3): | |
# for y in range(-2, 3): | |
# point = np.random.randn(2).astype(np.float32) * 0.05 | |
# point[0] += 2 * x | |
# point[1] += 2 * y | |
# dataset.append(point) | |
# | |
# dataset = np.array(dataset) | |
# print('dataset:', dataset.shape) | |
# viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True)) | |
# | |
# while True: | |
# np.random.shuffle(dataset) | |
# | |
# for i in range(len(dataset)//batchsz): | |
# yield dataset[i*batchsz : (i+1)*batchsz] | |
def generate_image(D, G, xr, epoch): | |
""" | |
Generates and saves a plot of the true distribution, the generator, and the | |
critic. | |
""" | |
N_POINTS = 128 | |
RANGE = 3 | |
plt.clf() | |
points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32') | |
points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None] | |
points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :] | |
points = points.reshape((-1, 2)) | |
# (16384, 2) | |
# print('p:', points.shape) | |
# draw contour | |
with torch.no_grad(): | |
points = torch.Tensor(points).cuda() # [16384, 2] | |
disc_map = D(points).cpu().numpy() # [16384] | |
x = y = np.linspace(-RANGE, RANGE, N_POINTS) | |
cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose()) | |
plt.clabel(cs, inline=1, fontsize=10) | |
# plt.colorbar() | |
# draw samples | |
with torch.no_grad(): | |
z = torch.randn(batchsz, 2).cuda() # [b, 2] | |
samples = G(z).cpu().numpy() # [b, 2] | |
plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.') | |
plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+') | |
viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch)) | |
def weights_init(m): | |
if isinstance(m, nn.Linear): | |
# m.weight.data.normal_(0.0, 0.02) | |
nn.init.kaiming_normal_(m.weight) | |
m.bias.data.fill_(0) | |
def gradient_penalty(D, xr, xf): | |
""" | |
:param D: | |
:param xr: | |
:param xf: | |
:return: | |
""" | |
LAMBDA = 0.3 | |
# only constrait for Discriminator | |
xf = xf.detach() | |
xr = xr.detach() | |
# [b, 1] => [b, 2] | |
alpha = torch.rand(batchsz, 1).cuda() | |
alpha = alpha.expand_as(xr) | |
interpolates = alpha * xr + ((1 - alpha) * xf) | |
interpolates.requires_grad_() | |
disc_interpolates = D(interpolates) | |
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, | |
grad_outputs=torch.ones_like(disc_interpolates), | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA | |
return gp | |
def main(): | |
torch.manual_seed(23) | |
np.random.seed(23) | |
G = Generator().cuda() | |
D = Discriminator().cuda() | |
G.apply(weights_init) | |
D.apply(weights_init) | |
optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9)) | |
optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9)) | |
data_iter = data_generator() | |
print('batch:', next(data_iter).shape) | |
viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', | |
legend=['D', 'G'])) | |
for epoch in range(50000): | |
# 1. train discriminator for k steps | |
for _ in range(5): | |
x = next(data_iter) | |
xr = torch.from_numpy(x).cuda() | |
# [b] | |
predr = (D(xr)) | |
# max log(lossr) | |
lossr = - (predr.mean()) | |
# [b, 2] | |
z = torch.randn(batchsz, 2).cuda() | |
# stop gradient on G | |
# [b, 2] | |
xf = G(z).detach() | |
# [b] | |
predf = (D(xf)) | |
# min predf | |
lossf = (predf.mean()) | |
# gradient penalty | |
gp = gradient_penalty(D, xr, xf) | |
loss_D = lossr + lossf + gp | |
optim_D.zero_grad() | |
loss_D.backward() | |
# for p in D.parameters(): | |
# print(p.grad.norm()) | |
optim_D.step() | |
# 2. train Generator | |
z = torch.randn(batchsz, 2).cuda() | |
xf = G(z) | |
predf = (D(xf)) | |
# max predf | |
loss_G = - (predf.mean()) | |
optim_G.zero_grad() | |
loss_G.backward() | |
optim_G.step() | |
if epoch % 100 == 0: | |
viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append') | |
generate_image(D, G, xr, epoch) | |
print(loss_D.item(), loss_G.item()) | |
if __name__ == '__main__': | |
main() |
# WGAN
import torch | |
from torch import nn, optim, autograd | |
import numpy as np | |
import visdom | |
from torch.nn import functional as F | |
from matplotlib import pyplot as plt | |
import random | |
h_dim = 400 | |
batchsz = 512 | |
viz = visdom.Visdom() | |
class Generator(nn.Module): | |
def __init__(self): | |
super(Generator, self).__init__() | |
self.net = nn.Sequential( | |
nn.Linear(2, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, 2), | |
) | |
def forward(self, z): | |
output = self.net(z) | |
return output | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
self.net = nn.Sequential( | |
nn.Linear(2, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, h_dim), | |
nn.ReLU(True), | |
nn.Linear(h_dim, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
output = self.net(x) | |
return output.view(-1) | |
def data_generator(): | |
scale = 2. | |
centers = [ | |
(1, 0), | |
(-1, 0), | |
(0, 1), | |
(0, -1), | |
(1. / np.sqrt(2), 1. / np.sqrt(2)), | |
(1. / np.sqrt(2), -1. / np.sqrt(2)), | |
(-1. / np.sqrt(2), 1. / np.sqrt(2)), | |
(-1. / np.sqrt(2), -1. / np.sqrt(2)) | |
] | |
centers = [(scale * x, scale * y) for x, y in centers] | |
while True: | |
dataset = [] | |
for i in range(batchsz): | |
point = np.random.randn(2) * .02 | |
center = random.choice(centers) | |
point[0] += center[0] | |
point[1] += center[1] | |
dataset.append(point) | |
dataset = np.array(dataset, dtype='float32') | |
dataset /= 1.414 # stdev | |
yield dataset | |
# for i in range(100000//25): | |
# for x in range(-2, 3): | |
# for y in range(-2, 3): | |
# point = np.random.randn(2).astype(np.float32) * 0.05 | |
# point[0] += 2 * x | |
# point[1] += 2 * y | |
# dataset.append(point) | |
# | |
# dataset = np.array(dataset) | |
# print('dataset:', dataset.shape) | |
# viz.scatter(dataset, win='dataset', opts=dict(title='dataset', webgl=True)) | |
# | |
# while True: | |
# np.random.shuffle(dataset) | |
# | |
# for i in range(len(dataset)//batchsz): | |
# yield dataset[i*batchsz : (i+1)*batchsz] | |
def generate_image(D, G, xr, epoch): | |
""" | |
Generates and saves a plot of the true distribution, the generator, and the | |
critic. | |
""" | |
N_POINTS = 128 | |
RANGE = 3 | |
plt.clf() | |
points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32') | |
points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None] | |
points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :] | |
points = points.reshape((-1, 2)) | |
# (16384, 2) | |
# print('p:', points.shape) | |
# draw contour | |
with torch.no_grad(): | |
points = torch.Tensor(points).cuda() # [16384, 2] | |
disc_map = D(points).cpu().numpy() # [16384] | |
x = y = np.linspace(-RANGE, RANGE, N_POINTS) | |
cs = plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose()) | |
plt.clabel(cs, inline=1, fontsize=10) | |
# plt.colorbar() | |
# draw samples | |
with torch.no_grad(): | |
z = torch.randn(batchsz, 2).cuda() # [b, 2] | |
samples = G(z).cpu().numpy() # [b, 2] | |
plt.scatter(xr[:, 0], xr[:, 1], c='orange', marker='.') | |
plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='+') | |
viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch)) | |
def weights_init(m): | |
if isinstance(m, nn.Linear): | |
# m.weight.data.normal_(0.0, 0.02) | |
nn.init.kaiming_normal_(m.weight) | |
m.bias.data.fill_(0) | |
def gradient_penalty(D, xr, xf): | |
""" | |
:param D: | |
:param xr: | |
:param xf: | |
:return: | |
""" | |
LAMBDA = 0.3 | |
# only constrait for Discriminator | |
xf = xf.detach() | |
xr = xr.detach() | |
# [b, 1] => [b, 2] | |
alpha = torch.rand(batchsz, 1).cuda() | |
alpha = alpha.expand_as(xr) | |
interpolates = alpha * xr + ((1 - alpha) * xf) | |
interpolates.requires_grad_() | |
disc_interpolates = D(interpolates) | |
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, | |
grad_outputs=torch.ones_like(disc_interpolates), | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA | |
return gp | |
def main(): | |
torch.manual_seed(23) | |
np.random.seed(23) | |
G = Generator().cuda() | |
D = Discriminator().cuda() | |
G.apply(weights_init) | |
D.apply(weights_init) | |
optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9)) | |
optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9)) | |
data_iter = data_generator() | |
print('batch:', next(data_iter).shape) | |
viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', | |
legend=['D', 'G'])) | |
for epoch in range(50000): | |
# 1. train discriminator for k steps | |
for _ in range(5): | |
x = next(data_iter) | |
xr = torch.from_numpy(x).cuda() | |
# [b] | |
predr = (D(xr)) | |
# max log(lossr) | |
lossr = - (predr.mean()) | |
# [b, 2] | |
z = torch.randn(batchsz, 2).cuda() | |
# stop gradient on G | |
# [b, 2] | |
xf = G(z).detach() | |
# [b] | |
predf = (D(xf)) | |
# min predf | |
lossf = (predf.mean()) | |
# gradient penalty | |
gp = gradient_penalty(D, xr, xf) | |
loss_D = lossr + lossf + gp | |
optim_D.zero_grad() | |
loss_D.backward() | |
# for p in D.parameters(): | |
# print(p.grad.norm()) | |
optim_D.step() | |
# 2. train Generator | |
z = torch.randn(batchsz, 2).cuda() | |
xf = G(z) | |
predf = (D(xf)) | |
# max predf | |
loss_G = - (predf.mean()) | |
optim_G.zero_grad() | |
loss_G.backward() | |
optim_G.step() | |
if epoch % 100 == 0: | |
viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append') | |
generate_image(D, G, xr, epoch) | |
print(loss_D.item(), loss_G.item()) | |
if __name__ == '__main__': | |
main() |