## Coding 4: CIFAR-10 with a Convolutional Neural Networks

In this exercise, we will finish our CIFAR-10 adventures and do multi-class image classification task with a convolutional neural network (CNN).

The data and setup for the model is identical to the previous coding assignment,
but this time, we will explore the effects of several layers newly introduced in the lectures.

<img src="https://miro.medium.com/max/2510/1*vkQ0hXDaQv57sALXAJquxA.jpeg" width=1024px/>

### TensorBoard Setup

We'll use TensorBoard to monitor training across runs. For classification tasks, we will mostly care about loss and accuracy.

Make sure to run this only once - if TensorBoard fails to load, give it some time, and if nothing shows, you'll need to restart your runtime.

In [None]:
import torch.utils.tensorboard as tb

log_dir = 'log'

%load_ext tensorboard
%tensorboard --logdir {log_dir} --reload_interval 1

## Data Preparation

Here is the same code from the previous exercise - This time, play around with the `transform`, and perform data augmentation.

In [None]:
import torchvision
import torchvision.transforms as transforms
import torch


classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


def fetch_dataloader(batch_size, transform=None, is_train=True):
    """
    Loads data from disk and returns a data_loader.
    A DataLoader is similar to a list of (image, label) tuples.
    You do not need to fully understand this code to do this assignment, we're happy to explain though.
    """
    data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

    # Custom train/val split.
    indices = [i for i in range(len(data)) if (i%10 > 0) == is_train]

    data = torch.utils.data.Subset(data, indices)
    loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=2)
    return loader


train_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

val_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data_train = fetch_dataloader(64, train_transform, is_train=True)
data_val = fetch_dataloader(64, val_transform, is_train=False)

## Model Implementation.

Here we will implement a CNN. The following operations will be relevant -

* `torch.nn.Conv2d` 
* `torch.nn.ReLU` 
* `torch.nn.AvgPool2d` 
* `torch.nn.MaxPool2d`

Take a close look at the parameters of `Conv2d`. Play around with the following and see how it affects the output activation map.

* kernel_size
* stride
* padding
* dilation

*Tensorboard*: So far we've only logged scalar feature. Let's spend a couple minutes today implementing logging images to tensorboard.

In [None]:
class CNNClassifier(torch.nn.Module):
    def __init__(self, input_channels, num_classes):
        """
        Define the layer(s) needed for the model.
        Feel free to define additional input arguments.
        """ 
        super().__init__()
        # A common starting point for Convolutional networks is:
        # [Conv Layer, ReLU, (maybe MaxPool)] x N, then global AvePool, then Linear
        # Play around with the number of convolutional layers, the number of channels,
        # the kernel size, etc.
        # TODO: Define the model
    
    def forward(self, x):
        """
        Calculate the classification score (logits).

        Input: 
            x (float tensor N x 3 x 32 x 32): input images
        Output:
            y (float tensor N x 10): classification scores (logits) for each class
        """
        pass
    
    def predict(self, image):
        return self(image).argmax(1)

## Model Training and Validation

In [None]:
import time

def train(model, data_train, data_val, device, lr=0.001, epochs=10):
    """
    Train the model. Feel free to add arguments for additional model tuning.

    Input:
      model (torch.nn.Module): the model to train
      data_train (torch.utils.data.Dataloader): yields batches of data
      data_val (torch.utils.data.Dataloader): use this to validate your model
      device (torch.device): which device to use to perform computation

      (optional) lr: learning rate hyperparameter
      (optional) epochs: number of passes over dataloader
    """
    # Setting up the tensorboard logger
    logger = tb.SummaryWriter(log_dir + '/{}'.format(time.strftime('%H-%M-%S')))
    global_step = 0

    # Setup the loss function to use
    loss_function = torch.nn.CrossEntropyLoss()

    # Setup the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Wrap in a progress bar.
    for epoch in range(epochs):
        # Set the model to training mode.
        model.train()

        for x, y in data_train:
            x = x.to(device)
            y = y.to(device)

            # Forward pass through the network
            output = model(x)

            # Compute loss
            loss = loss_function(output, y)
            
            # update model weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Add loss to TensorBoard.
            logger.add_scalar('loss', loss, global_step=global_step)

            global_step += 1

        # Set the model to eval mode and compute accuracy.
        # No need to change this, but feel free to implement additional logging.
        model.eval()

        accuracys_val = list()

        for x, y in data_val:
            x = x.to(device)
            y = y.to(device)

            y_pred = model.predict(x)
            accuracy_val = (y_pred == y).float().mean().item()
            accuracys_val.append(accuracy_val)

        accuracy = torch.FloatTensor(accuracys_val).mean().item()

        logger.add_scalar('accuracy', accuracy, global_step=global_step)
        grid = torchvision.utils.make_grid(0.5*x+0.5)
        logger.add_image('images', grid, global_step=global_step)


# Actually train the model here!
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = CNNClassifier(3, 10)
model.to(device)

train(model, data_train, data_val, device, lr=1e-3, epochs=10)

### References
[CNN image taken from blog post](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53)

[PyTorch nn documentation](https://pytorch.org/docs/stable/nn.html)