GAN model学习

1.GAN的目标

GAN是生成模型的一种

区分生成模型与判别模型:

生成模型通常是无监督学习(事实上也有监督学习的模型),即数据集是没有标签的,模型从数据集中学习,可以生成数据集中没有的数据;
判别模型很多是有监督学习,即输入是带有标签的,模型通过从数据集中学习,可以对新数据进行判别。

假定输入x,输出y 判别模型相当于是在估计条件概率分布P(yx)P(y|x),生成模型则是在估计联合概率分布P(x,y)P(x,y)

举个例子,输入手写数字图片集,判别模型学习的目标是对于给出的图片,可以判断是哪个数字;而生成模型的目标是生成新的数字图片

GAN的目标

2.GAN(生成对抗网络)的思想:

a. 由两个模型组成

生成器G(Generator)
鉴别器D(Discriminator)

b.对抗思想

生成器努力生成能够欺骗鉴别器的样本,而鉴别器努力识别生成的样本是真是假(即是来自数据集还是有生成器生成的),
我们希望达到的目标是:鉴别器无法区分生成器生成的样本到底是真是假

c.生成器和鉴别器之间的关系是一种博弈

i.零和博弈

ii.纳什均衡

3.GAN的训练过程:

以pytorch代码为例:(判别器选择了简单的MLP)

首先是一个很简单的判别器,通过简单的MLP输出(0,1)之间的一个概率作为outputs,targets即为真实标签,交叉熵损失函数一般广泛应用于分类问题中


# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        # 调用父类的构造函数,初始化父类
        super().__init__()
        # 定义神经网络
        self.model = nn.Sequential(
        	nn.Linear(4, 3),
        	nn.Sigmoid(),
        	nn.Linear(3, 1), 
        	nn.Sigmoid()
        )
        # 创建损失函数
        self.loss_function = nn.MSELoss()
        # 创建优化器,随机梯度下降
        self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
        
    def forward(self, inputs):
        return self.model(inputs)
        
    def train(self, inputs, targets):
        # 计算网络的输出值
        outputs = self.forward(inputs)
        loss = self.loss_function(outputs, targets)
        # 反向传播
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

然后是生成器


GAN model学习
https://fightforql.github.io/2025/03/13/GAN/
Author
lsy
Posted on
March 13, 2025
Licensed under