生成式对抗网络的原理是什么?

生成式对抗网络(Generative Adversarial Network, GAN)是一种对抗性的无监督学习方法。

它包含两个神经网络:生成器(Generator)和判别器(Discriminator)。

  • 生成器 G 从潜在空间 z 中随机采样,并生成假样本 G(z)。
  • 判别器 D 尝试区分真实样本 x 和生成器 G 生成的假样本 G(z)。
  • G 和 D 相互对抗并不断提高学习,这个过程可以生成逼真的样本。

其工作原理为:

  1. 初始化 G 和 D。
  2. 从真实数据中采样输入 x,从潜在空间 z 中随机采样 z。
  3. G 生成假样本 G(z),D评估真实样本 x 和假样本 G(z),输出真假概率。
  4. 更新 D,最大化 log(D(x)) + log(1 – D(G(z)))。
  5. 固定 D,更新 G,最小化 log(1 – D(G(z)))。
  6. 重复 step 3-5, until G 生成的样本足够逼真。
  7. 获得最终的 G,它可以生成逼真的新数据。

GAN 的 Loss 函数为:

minG maxD L(D, G) = Ex~p(x)[logD(x)] + Ez~p(z)[log(1 - D(G(z)))]

代码实现示例:

python
# 生成器
G = ...

# 判别器 
D = ...

# 优化器
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

for epoch in range(100):
    # 训练判别器
    for i in range(5):
        # 从真实数据和潜在向量中采样
        x = ...
        z = ...

        # 生成假样本
        G_z = G(z)

        # 真实数据标签为1,假样本标签为0
        real_labels = torch.ones((x.size(0), 1)) 
        fake_labels = torch.zeros((G_z.size(0), 1))  

        # 计算损失并更新
        real_loss = F.binary_cross_entropy(D(x), real_labels) 
        fake_loss = F.binary_cross_entropy(D(G_z), fake_labels)
        loss_D = real_loss + fake_loss
        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

    # 训练生成器
    z = ...
    fake_labels = torch.ones((z.size(0), 1))
    loss_G = F.binary_cross_entropy(D(G(z)), fake_labels) 
    opt_G.zero_grad()
    loss_G.backward()
    opt_G.step()

所以,GAN 通过对抗训练生成器和判别器,可以实现无监督的样本生成。理解其工作原理,有助于我们设计更加强大稳定的GAN,并应用于更广泛的任务,如图像生成、风格迁移等领域。