pix2pix损失函数理解(精)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

 下面分为生成器和鉴别器的损失函数分别进行说明:

1.生成器(generator)的损失函数:生成器的损失函数由对抗损失和像素损失构成。

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # 1.对抗损失,G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # 2.像素损失,G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

 (2)判别器的损失函数: pix2pix中判别器的损失与cGAN相同。

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; 后半部分,stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real:前半部分
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

 Pix2pix-两个领域匹配图像的转换 - 简书

版权声明:本文为CSDN博主「马鹏森」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43135178/article/details/120548580

马鹏森

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐