GAN model学习
1.GAN的目标
GAN是生成模型的一种
区分生成模型与判别模型:
生成模型通常是无监督学习(事实上也有监督学习的模型),即数据集是没有标签的,模型从数据集中学习,可以生成数据集中没有的数据;
判别模型很多是有监督学习,即输入是带有标签的,模型通过从数据集中学习,可以对新数据进行判别。
假定输入x,输出y 判别模型相当于是在估计条件概率分布,生成模型则是在估计联合概率分布
举个例子,输入手写数字图片集,判别模型学习的目标是对于给出的图片,可以判断是哪个数字;而生成模型的目标是生成新的数字图片
GAN的目标
2.GAN(生成对抗网络)的思想:
a. 由两个模型组成
生成器G(Generator)
鉴别器D(Discriminator)
b.对抗思想
生成器努力生成能够欺骗鉴别器的样本,而鉴别器努力识别生成的样本是真是假(即是来自数据集还是有生成器生成的),
我们希望达到的目标是:鉴别器无法区分生成器生成的样本到底是真是假
c.生成器和鉴别器之间的关系是一种博弈
i.零和博弈
ii.纳什均衡
3.GAN的训练过程:
以pytorch代码为例:(判别器选择了简单的MLP)
首先是一个很简单的判别器,通过简单的MLP输出(0,1)之间的一个概率作为outputs,targets即为真实标签,交叉熵损失函数一般广泛应用于分类问题中
# 判别器
class Discriminator(nn.Module):
def __init__(self):
# 调用父类的构造函数,初始化父类
super().__init__()
# 定义神经网络
self.model = nn.Sequential(
nn.Linear(4, 3),
nn.Sigmoid(),
nn.Linear(3, 1),
nn.Sigmoid()
)
# 创建损失函数
self.loss_function = nn.MSELoss()
# 创建优化器,随机梯度下降
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
def forward(self, inputs):
return self.model(inputs)
def train(self, inputs, targets):
# 计算网络的输出值
outputs = self.forward(inputs)
loss = self.loss_function(outputs, targets)
# 反向传播
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
然后是生成器
GAN model学习
https://fightforql.github.io/2025/03/13/GAN/