## Coding 7: UpConvs for Image Generation

In this exercise, we'll get a breath of fresh air from CIFAR-10 and classification and actually do the opposite of multi-class classification. In multi-class classification, we take an image as input, and output a label.

This week, we'll train a network that **takes in a label as input, and outputs an image**.

<img src="https://miro.medium.com/max/700/1*wy3oRM8jh8LKF6Iku2Uh4w.png" width=1024px/>


## TensorBoard Setup

Same old embedded tensorboard code.

In [None]:
import torch.utils.tensorboard as tb
import tempfile

log_dir = tempfile.mkdtemp()

%load_ext tensorboard
%tensorboard --logdir {log_dir} --reload_interval 1

## Data Preparation

Here we'll use a custom dataset of the faces.

**WARNING**: This cell will overwrite any directory called `images/`. If you have an `images/` directory in the folder this file is in, move it before running this cell.

This cell relies on some Unix-like tools, but all it's doing is downloading and unzipping an archive. If you can't run this cell, just download the zip file from the given URL and unzip it manually.

In [None]:
!rm -f images.zip
!rm -rf images/
!wget 'https://docs.google.com/uc?export=download&id=1CWDIj3jdiAWINZXBlczhOMT3wImzlk1v' -O images.zip
!unzip images.zip

By the way, the people in this dataset are:

- [An Wang](https://en.wikipedia.org/wiki/An_Wang)
- [Héctor García-Molina](https://en.wikipedia.org/wiki/H%c3%a9ctor_Garc%c3%ada-Molina)
- [Katherine Johnson](https://en.wikipedia.org/wiki/Katherine_Johnson)
- [Marian Croak](https://en.wikipedia.org/wiki/Marian_Croak)
- [Mark Dean](https://en.wikipedia.org/wiki/Mark_Dean_(computer_scientist))
- [Raj Reddy](https://en.wikipedia.org/wiki/Raj_Reddy)
- [Roy Clay Sr.](https://en.wikipedia.org/wiki/Roy_Clay)
- [Sanghamitra Mohanty](https://en.wikipedia.org/wiki/Sanghamitra_Mohanty)
- [Sophie Wilson](https://en.wikipedia.org/wiki/Sophie_Wilson)
- [Xia Peisu](https://en.wikipedia.org/wiki/Xia_Peisu)

## Dataset

Currently, the dataset yields pairs of (image, label), where label is an int (image id). Since our network will take the label as input, it's common practice to transform the label into a **one-hot vector**.

If we have a total of $n$ labels, the one-hot representation of a label $l$ is a $n$-dimensional vector $x$, where $x_i$ = 1 if $i = l$, and $0$ otherwise.

### One-hot example  
if we have 5 labels, the one hot representation of the label 0 is `[1, 0, 0, 0, 0]`, the one hot representation of the label 1 is `[0, 1, 0, 0, 0]`, ...

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import pathlib
from PIL import Image

import torch
import torchvision

to_tensor = torchvision.transforms.ToTensor()
images = torch.stack([to_tensor(Image.open(p).convert('RGB')) for p in sorted(pathlib.Path('images').glob('*.jpeg'))])
labels = torch.eye(images.shape[0])
# torch.eye creates the identity matrix with the given dimensions. Each row of the
# matrix is a one-hot encoding for one image.

vis = torchvision.utils.make_grid(images, nrow=10)

plt.figure(figsize=(30, 3))
plt.imshow(vis.permute(1, 2, 0))
plt.show()

## Model Implementation and Training

Some questions to wonder about -

* What network should we use?
* what loss should we use?
* Should we constrain the network output to [0, 1] like an image? How?



In [None]:
import time
import numpy as np

class UpNetwork(torch.nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        # TODO: Define the model
    
    def forward(self, x):
        """
        Translate one-hot vector to image.
        Remember that an image ranges from [0 .. 1].

        Input: 
            x (float tensor N x 10): input id encoded as a one-hot vector
        Output:
            y (float tensor N x 3 x 256 x 256): image corresponding to each id
        """
        pass

    def check_num_parameters(self):
        """
        Get the total number of parameters in the network.
        """
        return sum([np.prod(param.shape) for param in self.parameters()])


def model_quality(model):
    """
    Get a measure of the quality of the model.
    
    This measure is -log(x/y+e) where x is the sum of distances between the
    generated images and the input images, y is the sum of the dstaances
    between the input images and the image where every pixel is 0.5, and e is
    some very small value (exp(-10)).
    
    Intuitively, as the generated images get close to the input images, x
    becomes very small, and x/y approaches 0 from above. The denominator y
    acts as a kind of normalizer, showing how a "random" image might score.
    If x is zero (i.e., we reconstruct the images personally) then this
    expression returns -log(e) = 10. As x grows, x/y grows, log(x/y) grows,
    and -log(x/y) shrinks.
    """
    return float(-(torch.sum((model(labels)-images)**2) / torch.sum((0.5-images)**2)+np.exp(-10)).log())

def size_penalty(model):
    """
    Get a penalty based on the size of the model.
    
    We want to learn small models since they are both faster to train and
    faster to evaluation. This function gets a penalty based on the model size
    and the size of the inputs.
    """
    return 10 * model.check_num_parameters() / np.prod(images.shape)


def train(model, device, lr=5e-3, iterations=5000):
    # Setting up the tensorboard logger
    logger = tb.SummaryWriter(log_dir + '/{}'.format(time.strftime('%H-%M-%S')), flush_secs=1)

    # TODO: What loss should we use?
    loss_function = None

    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.)

    model.train()

    print('size_penalty =', size_penalty(model))

    for global_step in range(iterations):
        if global_step % 100 == 0:
            print("Iteration:", global_step)
        images_pred = model(labels)

        loss = loss_function(images_pred, images)
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        # Add loss to TensorBoard.
        if global_step % 10 == 0:
            logger.add_scalar('loss', loss.item(), global_step=global_step)
            logger.add_scalar('quality', model_quality(model), global_step=global_step)

        if global_step % 100 == 0:
            image_grid = (torchvision.utils.make_grid(images_pred, nrow=10) * 255).byte()
            logger.add_image('image_pred', image_grid, global_step=global_step)

    model.eval()
    print('model_quality   = ', model_quality(model))
    print('score           = ', model_quality(model)-size_penalty(model))

# Actually train the model here!
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UpNetwork(10, 3)
model.to(device)
images = images.to(device)
labels = labels.to(device)

train(model, device)

## Model Evaluation and Visualization

Let's see how our image generation model works!

In [None]:
model.eval()
model = model.cpu()
image_pred = model(torch.eye(10))

vis = torchvision.utils.make_grid(image_pred, nrow=10)

plt.figure(figsize=(30, 3))
plt.imshow(vis.permute(1, 2, 0))
plt.show()