Deep Convolution GAN for generating CIFAR-10 images

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
from statistics import mean

Step 1: Load the Dataset

In [2]:
# normalize the Training set with data augmentation
transform_train = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load the Training dataset
trainset = torchvision.datasets.CIFAR10(root='./data',download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
Files already downloaded and verified

Step 2: Display Training Images

In [3]:
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

displayloader = torch.utils.data.DataLoader(trainset, batch_size=50000, shuffle=True, num_workers=2)

# get some random training images
dataiter = iter(displayloader)
images, labels = dataiter.next()

num_classes = len(classes)
samples_per_class = 7

for y, cls in enumerate(classes):
    idxs = np.flatnonzero(labels == y)
    # select random 7 images from each class
    idxs = np.random.choice(idxs, samples_per_class, replace=False) 
    for i, idx in enumerate(idxs):
        plt_idx = i * num_classes + y + 1
        plt.subplot(samples_per_class, num_classes, plt_idx)
        img = images[idx]/2 + 0.5 # un-normalize the images for display
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.axis('off')
        if i == 0:
            plt.title(cls)
plt.show()
        

Step 3a: Define a Generator Network

In [4]:
# Generator is CNN with 4 convolution layers and 1 fully connected layer.
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # input z = noise, going into the FC network 100-->4096
            nn.Linear(100, 4*4*256),
            nn.LeakyReLU()
        )

        self.cnn = nn.Sequential(
            
            # DConv1 [state size = (4 x 4 x 256)]
            nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 3, stride=2, padding=0, output_padding=0),
            nn.LeakyReLU(),
            
            # DConv2 [state size = (9 x 9 x 128)]
            nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 3, stride=2, padding=1, output_padding=0),
            nn.LeakyReLU(),
            
            # DConv 3 [state size = (17 x 17 x 64)]
            nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride=2, padding=2, output_padding=1),
            nn.LeakyReLU(),
            
            # DConv 4 [state size = (32 x 32 x 64)]
            nn.Conv2d(in_channels = 64, out_channels = 3, kernel_size = 3, stride=1, padding=1),
            nn.Tanh()
            
            # [output size = (32 x 32 x 3)]
        )

    def forward(self, z):
        x = self.model(z)
        x = x.view(-1, 256, 4, 4) # Deflatten
        x = self.cnn(x)
        return x

Step 3b: Define a Discriminator Network

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.cnn = nn.Sequential(
            
            # Conv 1  [input size = (32 x 32 x 3)]
            nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride=1, padding=1),
            nn.LeakyReLU(),
            
            # MAX-POOL [state size = (32 x 32 x 64)]
            nn.MaxPool2d(kernel_size = 2,  stride = 2),
            nn.Dropout(0.4),
            
            # Conv 2 [state size = (16 x 16 x 64)]
            nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride=1, padding=1),
            nn.LeakyReLU(),
            
            # MAX-POOL [state size = (16 x 16 x 128)]
            nn.MaxPool2d(kernel_size = 2,  stride = 2),
            nn.Dropout(0.4),
            
            # Conv 3 [state size = (8 x 8 x 128)]
            nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride=1, padding=1),
            nn.LeakyReLU(),
            
            # MAX-POOL [state size = (4 x 4 x 256)]
            nn.MaxPool2d(kernel_size = 2,  stride = 2),
            nn.Dropout(0.4),
    
        )
        self.fc = nn.Sequential(
            
            # FC network 4096-->128
            nn.Linear(4*4*256, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            
            # FC network 128-->1
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(-1, 4*4*256) # Flatten
        x = self.fc(x)
        return x
     

Step 4: Transfer the neural network onto the GPU

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

netD = Discriminator().to(device)
netG = Generator().to(device)

# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
cuda:0

Step 5: Define a Loss Function and Optimizers

In [7]:
criterion = nn.BCELoss() # Binary cross-entropy loss

# Define Optimizers
lr = 0.0002
optimizerD = optim.Adam(netD.parameters(), lr=lr,betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr,betas=(0.5, 0.999))

Step 6: Training the DCGAN

In [8]:
num_epochs = 120
D_loss = []
G_loss = []
img_list = []

# Use a Fixed noise to follow how the generator output evolves
# Draw 64 samples from the distribution
fixed_noise = Variable(torch.randn(64,100,device = device))

for epoch in range(num_epochs):  # loop over the dataset multiple times
    
    D_loss_it = []
    G_loss_it = []
    
    for i, data in enumerate(trainloader, 0):
        
        #####################################
            # (1) TRAIN DISCRIMINATOR
        #####################################
        
        # zero the parameter gradients
        netD.zero_grad()
        
        ##### (a) Train with real images #####
        # get the inputs; data is a list of [inputs, labels]
        real_images = data[0].to(device)
        real_images = Variable(real_images)
        # Batch size in each iteration
        batch_size = real_images.size(0) 
        real_labels = Variable(torch.ones(batch_size,1)).to(device) # labels for real images All ones
        # forward + loss 
        outputs = netD(real_images)
        real_loss = criterion(outputs, real_labels)        
        
        ##### (b) Train with fake images #####
        noise = Variable(torch.randn(batch_size,100,device = device))
        fake_images = netG(noise)
        fake_labels = Variable(torch.zeros(batch_size,1,device=device)) # labels for fake images All zeroes
        # forward + loss
        outputs = netD(fake_images.detach())
        fake_loss = criterion(outputs, fake_labels)
        
        # Total loss + backward + optimize
        lossD = real_loss + fake_loss 
        lossD.backward()
        optimizerD.step()
        
        #####################################
            # (2) TRAIN GENERATOR
        #####################################
        
        # zero the parameter gradients
        netG.zero_grad()
        
        # pick m noise samples from the distribution
        noise = Variable(torch.randn(batch_size,100,device = device))
        fake_images = netG(noise)
        outputs = netD(fake_images)
        
        # loss + backward + optimize
        lossG = criterion(outputs,real_labels)
        lossG.backward()
        optimizerG.step()
        
        # record loss    
        D_loss_it.append(lossD.item())
        G_loss_it.append(lossG.item())  
    
    
    # report status
    print('Iteration: %2d | Discriminator loss: %.3f | Generator loss: %.3f '
              %(epoch + 1, mean(D_loss_it), mean(G_loss_it)))
    
    # average loss over all batches in one epoch
    D_loss.append(mean(D_loss_it)) 
    G_loss.append(mean(G_loss_it))
    
    # save images to follow generator output over fixed noise
    if epoch%5 ==4:
        fake = netG(fixed_noise).detach()
        vutils.save_image(fake,'V1/my_samples_epoch_%03d.png' % (epoch),normalize=True)
        img_list.append(vutils.make_grid(fake.cpu(),padding=2,normalize=True))

print('Finished Training')
Iteration:  1 | Discriminator loss: 0.786 | Generator loss: 2.668 
Iteration:  2 | Discriminator loss: 0.653 | Generator loss: 2.598 
Iteration:  3 | Discriminator loss: 0.713 | Generator loss: 2.212 
Iteration:  4 | Discriminator loss: 0.584 | Generator loss: 2.375 
Iteration:  5 | Discriminator loss: 0.708 | Generator loss: 2.065 
Iteration:  6 | Discriminator loss: 0.765 | Generator loss: 1.992 
Iteration:  7 | Discriminator loss: 0.759 | Generator loss: 1.978 
Iteration:  8 | Discriminator loss: 0.730 | Generator loss: 2.047 
Iteration:  9 | Discriminator loss: 0.795 | Generator loss: 1.862 
Iteration: 10 | Discriminator loss: 0.856 | Generator loss: 1.726 
Iteration: 11 | Discriminator loss: 0.867 | Generator loss: 1.655 
Iteration: 12 | Discriminator loss: 0.896 | Generator loss: 1.604 
Iteration: 13 | Discriminator loss: 0.877 | Generator loss: 1.623 
Iteration: 14 | Discriminator loss: 0.866 | Generator loss: 1.636 
Iteration: 15 | Discriminator loss: 0.846 | Generator loss: 1.714 
Iteration: 16 | Discriminator loss: 0.845 | Generator loss: 1.697 
Iteration: 17 | Discriminator loss: 0.854 | Generator loss: 1.676 
Iteration: 18 | Discriminator loss: 0.850 | Generator loss: 1.705 
Iteration: 19 | Discriminator loss: 0.858 | Generator loss: 1.650 
Iteration: 20 | Discriminator loss: 0.873 | Generator loss: 1.686 
Iteration: 21 | Discriminator loss: 0.844 | Generator loss: 1.710 
Iteration: 22 | Discriminator loss: 0.851 | Generator loss: 1.680 
Iteration: 23 | Discriminator loss: 0.853 | Generator loss: 1.681 
Iteration: 24 | Discriminator loss: 0.849 | Generator loss: 1.680 
Iteration: 25 | Discriminator loss: 0.861 | Generator loss: 1.639 
Iteration: 26 | Discriminator loss: 0.836 | Generator loss: 1.691 
Iteration: 27 | Discriminator loss: 0.826 | Generator loss: 1.712 
Iteration: 28 | Discriminator loss: 0.808 | Generator loss: 1.772 
Iteration: 29 | Discriminator loss: 0.801 | Generator loss: 1.791 
Iteration: 30 | Discriminator loss: 0.772 | Generator loss: 1.847 
Iteration: 31 | Discriminator loss: 0.788 | Generator loss: 1.809 
Iteration: 32 | Discriminator loss: 0.771 | Generator loss: 1.838 
Iteration: 33 | Discriminator loss: 0.793 | Generator loss: 1.806 
Iteration: 34 | Discriminator loss: 0.770 | Generator loss: 1.850 
Iteration: 35 | Discriminator loss: 0.788 | Generator loss: 1.835 
Iteration: 36 | Discriminator loss: 0.771 | Generator loss: 1.875 
Iteration: 37 | Discriminator loss: 0.762 | Generator loss: 1.874 
Iteration: 38 | Discriminator loss: 0.778 | Generator loss: 1.825 
Iteration: 39 | Discriminator loss: 0.777 | Generator loss: 1.848 
Iteration: 40 | Discriminator loss: 0.754 | Generator loss: 1.870 
Iteration: 41 | Discriminator loss: 0.747 | Generator loss: 1.921 
Iteration: 42 | Discriminator loss: 0.756 | Generator loss: 1.893 
Iteration: 43 | Discriminator loss: 0.770 | Generator loss: 1.850 
Iteration: 44 | Discriminator loss: 0.763 | Generator loss: 1.852 
Iteration: 45 | Discriminator loss: 0.745 | Generator loss: 1.897 
Iteration: 46 | Discriminator loss: 0.759 | Generator loss: 1.881 
Iteration: 47 | Discriminator loss: 0.744 | Generator loss: 1.922 
Iteration: 48 | Discriminator loss: 0.738 | Generator loss: 1.932 
Iteration: 49 | Discriminator loss: 0.744 | Generator loss: 1.913 
Iteration: 50 | Discriminator loss: 0.739 | Generator loss: 1.932 
Iteration: 51 | Discriminator loss: 0.766 | Generator loss: 1.873 
Iteration: 52 | Discriminator loss: 0.759 | Generator loss: 1.912 
Iteration: 53 | Discriminator loss: 0.702 | Generator loss: 2.020 
Iteration: 54 | Discriminator loss: 0.712 | Generator loss: 1.985 
Iteration: 55 | Discriminator loss: 0.725 | Generator loss: 1.943 
Iteration: 56 | Discriminator loss: 0.710 | Generator loss: 2.028 
Iteration: 57 | Discriminator loss: 0.714 | Generator loss: 1.995 
Iteration: 58 | Discriminator loss: 0.722 | Generator loss: 2.025 
Iteration: 59 | Discriminator loss: 0.700 | Generator loss: 2.040 
Iteration: 60 | Discriminator loss: 0.703 | Generator loss: 2.055 
Iteration: 61 | Discriminator loss: 0.718 | Generator loss: 2.019 
Iteration: 62 | Discriminator loss: 0.719 | Generator loss: 2.042 
Iteration: 63 | Discriminator loss: 0.694 | Generator loss: 2.102 
Iteration: 64 | Discriminator loss: 0.699 | Generator loss: 2.034 
Iteration: 65 | Discriminator loss: 0.714 | Generator loss: 1.987 
Iteration: 66 | Discriminator loss: 0.719 | Generator loss: 2.008 
Iteration: 67 | Discriminator loss: 0.719 | Generator loss: 2.007 
Iteration: 68 | Discriminator loss: 0.717 | Generator loss: 2.028 
Iteration: 69 | Discriminator loss: 0.701 | Generator loss: 2.080 
Iteration: 70 | Discriminator loss: 0.698 | Generator loss: 2.057 
Iteration: 71 | Discriminator loss: 0.701 | Generator loss: 2.056 
Iteration: 72 | Discriminator loss: 0.701 | Generator loss: 2.078 
Iteration: 73 | Discriminator loss: 0.702 | Generator loss: 2.074 
Iteration: 74 | Discriminator loss: 0.702 | Generator loss: 2.100 
Iteration: 75 | Discriminator loss: 0.705 | Generator loss: 2.125 
Iteration: 76 | Discriminator loss: 0.692 | Generator loss: 2.103 
Iteration: 77 | Discriminator loss: 0.685 | Generator loss: 2.130 
Iteration: 78 | Discriminator loss: 0.680 | Generator loss: 2.170 
Iteration: 79 | Discriminator loss: 0.685 | Generator loss: 2.133 
Iteration: 80 | Discriminator loss: 0.683 | Generator loss: 2.130 
Iteration: 81 | Discriminator loss: 0.672 | Generator loss: 2.162 
Iteration: 82 | Discriminator loss: 0.675 | Generator loss: 2.151 
Iteration: 83 | Discriminator loss: 0.664 | Generator loss: 2.224 
Iteration: 84 | Discriminator loss: 0.684 | Generator loss: 2.152 
Iteration: 85 | Discriminator loss: 0.697 | Generator loss: 2.090 
Iteration: 86 | Discriminator loss: 0.673 | Generator loss: 2.198 
Iteration: 87 | Discriminator loss: 0.677 | Generator loss: 2.186 
Iteration: 88 | Discriminator loss: 0.662 | Generator loss: 2.211 
Iteration: 89 | Discriminator loss: 0.684 | Generator loss: 2.121 
Iteration: 90 | Discriminator loss: 0.665 | Generator loss: 2.186 
Iteration: 91 | Discriminator loss: 0.670 | Generator loss: 2.186 
Iteration: 92 | Discriminator loss: 0.660 | Generator loss: 2.215 
Iteration: 93 | Discriminator loss: 0.654 | Generator loss: 2.207 
Iteration: 94 | Discriminator loss: 0.646 | Generator loss: 2.280 
Iteration: 95 | Discriminator loss: 0.657 | Generator loss: 2.226 
Iteration: 96 | Discriminator loss: 0.656 | Generator loss: 2.212 
Iteration: 97 | Discriminator loss: 0.667 | Generator loss: 2.199 
Iteration: 98 | Discriminator loss: 0.661 | Generator loss: 2.219 
Iteration: 99 | Discriminator loss: 0.652 | Generator loss: 2.236 
Iteration: 100 | Discriminator loss: 0.645 | Generator loss: 2.252 
Iteration: 101 | Discriminator loss: 0.657 | Generator loss: 2.215 
Iteration: 102 | Discriminator loss: 0.664 | Generator loss: 2.222 
Iteration: 103 | Discriminator loss: 0.662 | Generator loss: 2.200 
Iteration: 104 | Discriminator loss: 0.659 | Generator loss: 2.241 
Iteration: 105 | Discriminator loss: 0.665 | Generator loss: 2.223 
Iteration: 106 | Discriminator loss: 0.660 | Generator loss: 2.211 
Iteration: 107 | Discriminator loss: 0.668 | Generator loss: 2.205 
Iteration: 108 | Discriminator loss: 0.665 | Generator loss: 2.184 
Iteration: 109 | Discriminator loss: 0.661 | Generator loss: 2.247 
Iteration: 110 | Discriminator loss: 0.666 | Generator loss: 2.198 
Iteration: 111 | Discriminator loss: 0.682 | Generator loss: 2.145 
Iteration: 112 | Discriminator loss: 0.671 | Generator loss: 2.183 
Iteration: 113 | Discriminator loss: 0.676 | Generator loss: 2.181 
Iteration: 114 | Discriminator loss: 0.669 | Generator loss: 2.193 
Iteration: 115 | Discriminator loss: 0.664 | Generator loss: 2.204 
Iteration: 116 | Discriminator loss: 0.682 | Generator loss: 2.168 
Iteration: 117 | Discriminator loss: 0.679 | Generator loss: 2.200 
Iteration: 118 | Discriminator loss: 0.685 | Generator loss: 2.189 
Iteration: 119 | Discriminator loss: 0.679 | Generator loss: 2.188 
Iteration: 120 | Discriminator loss: 0.678 | Generator loss: 2.186 
Finished Training

Step 7: Visualization of G's progression

In [9]:
import matplotlib.animation as animation
from IPython.display import HTML
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())
Out[9]:

Step 8: Plotting Generator and Discriminator Loss Vs Epoch

In [12]:
plt.plot(G_loss)
plt.plot(D_loss)
plt.ylabel('cost')
plt.xlabel('epochs')
plt.title('Loss vs Epoch (CIFAR-10)')
plt.grid()
plt.legend(['Generator Loss','Discriminator Loss'], loc='upper right')
plt.show()

Step 9: Real Images Vs. Fake Images

In [11]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(trainloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

Thank You!

In [ ]: