1. 对于GAN摸索了一段时间,有一点心的,就是要注意使用普通的网络作为生成器和判别器(例如:全连接网络)需要注意使用BatchNormalization,进行批量归一化,不然很难出现好的结果。还有生成器的最后一层需要使用tanh()函数,推荐吧,也可以使用sigmoid,二者在这里的区别,可以自己找找。
  2. 这是GAN的pytorch版本的实现。
  1. 导入相关库

    import torch
    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    
    import matplotlib.pylab as plt
    from matplotlib import animation
    from IPython.display import HTML
    
  2. 设置用到的一些常量

    BATCH_SIZE = 100
    IMG_CHANNELS = 1
    NUM_Z = 100
    NUM_GENERATOR_FEATURES = 64
    NUM_DISCRIMINATOR_FEATURES = 64
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
    # INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
    
  3. 加载数据集(MNIST10)数据集

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    # ds = torchvision.datasets.cifar.CIFAR10(root="data", train=True, transform=transform, download=True)
    ds = torchvision.datasets.mnist.MNIST(root="data", train=True, transform=transform, download=True)
    ds_loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    
  4. 查看数据

    img_batch, lab_batch = next(iter(ds_loader))
    img_batch.shape, lab_batch.shape
    
  5. 绘制数据集图像

    plt.figure(figsize=(8, 8), dpi=80)
    plt.imshow(torchvision.utils.make_grid(img_batch, nrow=10, padding=2, pad_value=1, normalize=True).permute(1, 2, 0))
    plt.tight_layout()
    plt.axis("off")
    
  6. 定义生成器和判别器

    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            # (o - 1) * s - 2 * p + w
            self.main = nn.Sequential(
                # 100 x 1 x 1 --> 512 x 4 x 4
                nn.ConvTranspose2d(NUM_Z, NUM_GENERATOR_FEATURES * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 8),
                nn.ReLU(True),
                # 512 x 4 x 4 --> 512 x 8 x 8
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 8, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
                nn.ReLU(True),
                # 512 x 8 x 8 --> 512 x 16 x 16
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
                nn.ReLU(True),
                # 512 x 16 x 16 --> 512 x 14 x 14
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 1, 1, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
                nn.ReLU(True),
                # 512 x 14 x 14 --> 512 x 28 x 28
                nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 1, IMG_CHANNELS, 2, 2, 0, bias=False),
                nn.Sigmoid(),
            )
            
        def forward(self, x):
            return self.main(x)
    
        
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            
            self.main = nn.Sequential(
                # 1 x 28 x 28 --> 256 x 14 x 14 
                nn.Conv2d(IMG_CHANNELS, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # 256 x 14 x 14 --> 128 x 7 x 7
                nn.Conv2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
                nn.LeakyReLU(0.2, inplace=True),
                # 128 x 7 x 7 --> 64 x 3 x 3
                nn.Conv2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 4, 2, 1, bias=False),
                nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
                nn.LeakyReLU(0.2, inplace=True),
                # 64 x 3 x 3 --> 1 x 1 x 1
                nn.Conv2d(NUM_GENERATOR_FEATURES * 1, 1, 3, 1, 0, bias=False),
                nn.Sigmoid()
            )
        
        def forward(self, x):
            return self.main(x).view(-1)
    
  7. 测试定义的模型

    noise = torch.randn(BATCH_SIZE, NUM_Z, 1, 1)
    generator = Generator()
    fake_img = generator(noise)
    discriminator = Discriminator()
    discriminator(fake_img)
    
  8. 网络参数初始化函数

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    
  9. 定义生成器和判别器对象,优化器,损失函数和评估标准

    generator = Generator().to(DEVICE).apply(weights_init)
    discriminator = Discriminator().to(DEVICE).apply(weights_init)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    loss_fn = nn.BCELoss()
    metrics_fn = lambda y_true, y_pred: torch.mean((y_true == torch.where(y_pred >=0.5, torch.tensor(1., device=DEVICE), torch.tensor(0., device=DEVICE))).to(torch.float32))
    
  10. 定义训练步骤(重点)

    def train_step(inputs, labels):
        labels = labels.to(torch.float32)
        inputs_g = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
        # inputs_g = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
        outputs_g = generator(inputs_g)
        
        # fix generator, unfix discriminator
        for parameter in generator.parameters():
            parameter.require_grad = False
        for parameter in discriminator.parameters():
            parameter.require_grad = True
        optimizer_d.zero_grad()
        
        # real image
        labels = torch.ones_like(labels)
        outputs = discriminator(inputs)
        loss_real = loss_fn(outputs, labels)
        metrics_real = metrics_fn(labels, outputs)
        loss_real.backward()
        
        # fake image
        labels = torch.zeros_like(labels)
        outputs = discriminator(outputs_g.detach())  # 这里有一个detach()
        loss_fake = loss_fn(outputs, labels)
        metrics_fake = metrics_fn(labels, outputs)
        loss_fake.backward()
        
        loss_d = (loss_real + loss_fake) / 2
        metrics_d = (metrics_real + metrics_fake) / 2
        # loss_d.backward()
        optimizer_d.step()
        
        
        # unfix generator, fix discriminator
        for parameter in generator.parameters():
            parameter.require_grad = True
        for parameter in discriminator.parameters():
            parameter.require_grad = False
        optimizer_g.zero_grad()
        
        labels = torch.ones_like(labels)
        outputs = discriminator(outputs_g)
        loss_g = loss_fn(outputs, labels)
        metrics_g = metrics_fn(labels, outputs)
        loss_g.backward()
        optimizer_g.step()
        
        return loss_d.item(), metrics_d.item(), loss_g.item(), metrics_g.item()
    
  11. 测试定义的训练步骤

    train_step(img_batch.to(DEVICE), lab_batch.to(DEVICE))
    
  12. 定义训练循环

    epochs = 8
    loss_d_list, metrics_d_list, loss_g_list, metrics_g_list = [], [], [], []
    grid_img_list = []
    
    for epoch in range(epochs):
        
        loss_d_batch = metrics_d_batch = loss_g_batch = metrics_g_batch = .0
        num_batch = 0
        for img_batch, lab_batch in ds_loader:
            img_batch = img_batch.to(DEVICE)
            lab_batch = lab_batch.to(DEVICE)
            loss_d, metrics_d, loss_g, metrics_g = train_step(img_batch, torch.ones_like(lab_batch))
            num_batch += 1
            loss_d_batch, metrics_d_batch = loss_d_batch + loss_d, metrics_d_batch + metrics_d
            loss_g_batch, metrics_g_batch = loss_g_batch + loss_g, metrics_g_batch + metrics_g
            
        loss_d_batch, metrics_d_batch = loss_d_batch / num_batch, metrics_d_batch / num_batch
        loss_g_batch, metrics_g_batch = loss_g_batch / num_batch, metrics_g_batch / num_batch
        
        loss_d_list.append(loss_d_batch)
        metrics_d_list.append(metrics_d_batch)
        loss_g_list.append(loss_g_batch)
        metrics_g_list.append(metrics_g_batch)
        
        print("[%d/%d] loss_discriminator: %.2f, metrics_distriminator: %.2f, loss_generator: %.2f, metrics_generator: %.2f" % (
            epoch, epochs, loss_d_batch, metrics_d_batch, loss_g_batch, metrics_g_batch))
        
        
        
        with torch.no_grad():
            outputs_g = generator(INPUTS_G)
            outputs_d = discriminator(outputs_g)
            
            grid_img_list.append(torchvision.utils.make_grid(outputs_g.cpu(), nrow=10, normalize=True, pad_value=1))
    
            plt.figure(figsize=(20, 2), dpi=80)
            for i, (img, lab) in enumerate(zip(outputs_g[:16], outputs_d[:16])):
                plt.subplot(1, 16, i+1)
                plt.imshow(img.cpu().permute(1, 2, 0), cmap=plt.cm.binary)
                plt.title("%.2f" % lab.cpu().item())
                plt.axis("off")
            plt.tight_layout()
            plt.show()
    
  13. 绘制损失值和评估指标

    plt.figure(figsize=(12, 4), dpi=80)
    
    plt.subplot(1, 2, 1)
    plt.plot(loss_d_list, label="discriminator_loss")
    plt.plot(loss_g_list, label="generator_loss")
    plt.title("Loss of discriminator and generator")
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(metrics_d_list, label="discriminator_metrics")
    plt.plot(metrics_g_list, label="generator_metrics")
    plt.title("Metrics of discriminator and generator")
    plt.xlabel("epochs")
    plt.ylabel("metrics")
    plt.legend()
    
    plt.show()
    
  1. 绘制动态的GAN图像生成过程

    fig = plt.figure(figsize=(10, 10), dpi=80)
    plt.axis("off")
    
    imgs = [[plt.imshow(np.transpose(img, (1, 2, 0)), animated=True)] for img in grid_img_list]
    ani = animation.ArtistAnimation(fig, imgs, interval=1000, repeat_delay=1000, blit=True)
    
    HTML(ani.to_jshtml())
    
  2. 绘制真实图片和GAN生成图片

    plt.figure(figsize=(20, 10), dpi=80)
    
    plt.subplot(1, 2, 1)
    plt.title("real digits image")
    plt.imshow(torchvision.utils.make_grid(img_batch.cpu(), nrow=10, normalize=True, pad_value=1).permute(1, 2, 0))
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.title("fake digits image")
    plt.imshow(np.transpose(grid_img_list[-1], (1, 2, 0)))
    plt.axis("off")
    

本文地址:https://blog.csdn.net/bash_winner/article/details/113998997