2025年【笔记】ACGAN:采用辅助分类器使GAN获取图像分类功能

【笔记】ACGAN:采用辅助分类器使GAN获取图像分类功能class Discriminato nn Module def init self num classes initialize param image size tuple 3 h w super init self num classes

大家好,我是讯享网,很高兴认识大家。


讯享网

class Discriminator(nn.Module): def __init__(self, num_classes): """ initialize :param image_size: tuple (3, h, w) """ super().__init__() self.num_classes = num_classes net = [] # 1:预先定义 channels_in = [3+self.num_classes, 64, 128, 256] channels_out = [64, 128, 256, 512] padding = [1, 1, 1, 0] for i in range(len(channels_in)): net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i], kernel_size=4, stride=2, padding=padding[i], bias=False)) if i == 0: net.append(nn.LeakyReLU(0.2)) else: net.append(nn.BatchNorm2d(num_features=channels_out[i])) net.append(nn.LeakyReLU(0.2)) net.append(nn.Dropout(0.5)) self.classify = nn.Linear(in_features=3*3*512, out_features=num_classes) self.softmax = nn.Softmax(dim=1) self.disciminate = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=3, stride=1, padding=0) self.sigmoid = nn.Sigmoid() self.discriminator = nn.Sequential(*net) def forward(self, x, label): label = label.unsqueeze(2).unsqueeze(3) label = label.repeat(1, 1, x.size(2), x.size(3)) data = torch.cat(tensors=(x, label), dim=1) out = self.discriminator(data) out_ = out.view(x.size(0), -1) classsify = self.softmax(self.classify(out_)) real_or_fake = self.sigmoid(self.disciminate(out)) return real_or_fake.view(x.size(0), -1), classsify

讯享网

讯享网d_out_real_dis, d_out_real_cls = discriminator(image, onehot_label) real_loss_dis = bce_loss(d_out_real_dis, real_label) real_loss_cls = nll_loss(d_out_real_cls, label)

小讯
上一篇 2025-02-27 18:53
下一篇 2025-02-25 18:14

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/60838.html