生成式对抗网络(Generative Adversarial Network, GAN)是一种对抗性的无监督学习方法。
它包含两个神经网络:生成器(Generator)和判别器(Discriminator)。
- 生成器 G 从潜在空间 z 中随机采样,并生成假样本 G(z)。
- 判别器 D 尝试区分真实样本 x 和生成器 G 生成的假样本 G(z)。
- G 和 D 相互对抗并不断提高学习,这个过程可以生成逼真的样本。
其工作原理为:
- 初始化 G 和 D。
- 从真实数据中采样输入 x,从潜在空间 z 中随机采样 z。
- G 生成假样本 G(z),D评估真实样本 x 和假样本 G(z),输出真假概率。
- 更新 D,最大化 log(D(x)) + log(1 – D(G(z)))。
- 固定 D,更新 G,最小化 log(1 – D(G(z)))。
- 重复 step 3-5, until G 生成的样本足够逼真。
- 获得最终的 G,它可以生成逼真的新数据。
GAN 的 Loss 函数为:
minG maxD L(D, G) = Ex~p(x)[logD(x)] + Ez~p(z)[log(1 - D(G(z)))]
代码实现示例:
python
# 生成器
G = ...
# 判别器
D = ...
# 优化器
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(100):
# 训练判别器
for i in range(5):
# 从真实数据和潜在向量中采样
x = ...
z = ...
# 生成假样本
G_z = G(z)
# 真实数据标签为1,假样本标签为0
real_labels = torch.ones((x.size(0), 1))
fake_labels = torch.zeros((G_z.size(0), 1))
# 计算损失并更新
real_loss = F.binary_cross_entropy(D(x), real_labels)
fake_loss = F.binary_cross_entropy(D(G_z), fake_labels)
loss_D = real_loss + fake_loss
opt_D.zero_grad()
loss_D.backward()
opt_D.step()
# 训练生成器
z = ...
fake_labels = torch.ones((z.size(0), 1))
loss_G = F.binary_cross_entropy(D(G(z)), fake_labels)
opt_G.zero_grad()
loss_G.backward()
opt_G.step()
所以,GAN 通过对抗训练生成器和判别器,可以实现无监督的样本生成。理解其工作原理,有助于我们设计更加强大稳定的GAN,并应用于更广泛的任务,如图像生成、风格迁移等领域。