## Coding 8: Binary Segmentation

In this exercise, we'll do segmentation (per-pixel classification).

We'll use a dataset of horses with their corresponding masks,  
and the goal is to train a network that outputs if a pixel is a horse or not.

<img src="http://cvlab.postech.ac.kr/research/deconvnet/images/overall.png" width=512px/>

Image: Learning Deconvolution Network for Semantic Segmentation, Noh et al. ICCV 2015


## Data Preparation

Run this cell to download the data.

In [None]:
!rm -f weizmann_horse_db.zip
!rm -rf weizmann_horse_db/
!pip install gdown
!gdown https://drive.google.com/uc?id=1Tj7M8maC3QDkAf_dDtTFFUvtyzcuhlkg
!unzip -qq weizmann_horse_db.zip

## Dataset

The dataset yields pairs of (image, label)

The image is a [3, 128, 128] float tensor of a horse.  
The label is a [128, 128] 0-1 float tensor of the mask of the horse.

In [None]:
import numpy as np
import pathlib
import torch
import torchvision
import matplotlib.pyplot as plt

from PIL import Image


class WeizmannHorseDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transform=None):
        self.images = list()
        self.masks = list()
        if not transform:
            self.transform = torchvision.transforms.Compose([
                torchvision.transforms.Resize(128),
                torchvision.transforms.CenterCrop(128),
            ])
        else:
            self.transform = transform

        images = sorted((image_dir / 'horse').glob('*.png'))
        masks = sorted((image_dir / 'mask').glob('*.png'))

        for i, (image_path, mask_path) in enumerate(zip(images, masks)):
            image = Image.open(image_path).convert('RGB')
            image.load()

            mask = Image.open(mask_path).convert('L')
            mask.load()
            
            self.images.append(image)
            self.masks.append(mask)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, mask = self.transform(self.images[idx]), self.transform(self.masks[idx])
        image = torchvision.transforms.functional.to_tensor(image)
        mask = torch.tensor(np.asarray(mask, np.float32)[None])
        return image, mask


def fetch_dataloader(batch_size, transform=None, split='train'):
    """
    Loads data from disk and returns a data_loader.
    A DataLoader is similar to a list of (image, label) tuples.
    You do not need to fully understand this code to do this assignment, we're happy to explain though.
    """
    data = WeizmannHorseDataset(pathlib.Path('weizmann_horse_db'), transform=transform)

    # Custom train/val split.
    if split == 'train':
        indices = [i for i in range(len(data)) if (i%10 > 1)]
    elif split == 'val':
        indices = [i for i in range(len(data)) if (i%10 == 1)]
    else:
        indices = [i for i in range(len(data)) if (i%10 == 0)]

    data = torch.utils.data.Subset(data, indices)
    loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=split=='train', num_workers=0, drop_last=False)
    return loader


data_train = fetch_dataloader(32, split='train')
data_val = fetch_dataloader(32, split='val')
data_test = fetch_dataloader(128, split='test')

In [None]:
# Setup the evaluation metric
class Metric:
    """
       Compute two metrics of success.
       Average accuracy and intersection over union (also called Jaccard index)
    """
    def __init__(self):
        self.count, self.correct, self.intersection, self.union = 0, 0, 0, 0

    def add(self, pred, label):
        pred = pred > 0
        label = label > 0
        self.intersection += float((pred & label).float().sum())
        self.union += float((pred | label).float().sum())
        self.count += float((pred == pred).float().sum())
        self.correct += float((label == pred).float().sum())
    
    @property
    def iou(self):
        return self.intersection / max(self.union, 1)
    
    @property
    def accuracy(self):
        return self.correct / max(self.count, 1)

In [None]:
def visualize(data):
    images, masks = next(iter(data))
    masks = masks.tile((1,3,1,1))

    vis = torchvision.utils.make_grid([*images[:4], *masks[:4]], nrow=4)

    plt.imshow(vis.permute(1, 2, 0))
    plt.show()


visualize(data_train)
visualize(data_val)

## TensorBoard Setup

Same old embedded tensorboard code.

In [None]:
import torch.utils.tensorboard as tb
import tempfile

log_dir = tempfile.mkdtemp()

%load_ext tensorboard
%tensorboard --logdir {log_dir} --reload_interval 1

## Model and Training


In [None]:
import time

class Network(torch.nn.Module):
    """
    This is the main network for this task.
    
    Note that since we are looking at a segmentation task with only two classes,
    we are using a slightly different output format than we have looked at
    before. We will only output a single value per pixel, and if the value is
    less than zero, we classify that pixel as 'background', if it's greater than
    zero we classify that pixel as 'horse'. This reduces the complexity of
    training slightly since we only need to produce one output per pixel rather
    than two.
    """
    
    def __init__(self, input_channels=3, output_channels=1):
        super().__init__()
        self.net = torch.nn.Conv2d(3, 1, 5, padding='same')

    def forward(self, x):
        """
        Translate the image to a mask of the horse.

        Input: 
            x (float tensor N x 3 x 128 x 128): input image of a horse
        Output:
            y (float tensor N x 1 x 128 x 128): binary mask of the horse
        """
        return self.net(x)


def train(model, lr=1e-4, epochs=50):
    # Setting up the tensorboard logger
    logger = tb.SummaryWriter(log_dir + '/{}'.format(time.strftime('%H-%M-%S')), flush_secs=1)
    global_step = 0

    # Pick a loss and optimizer
    # We haven't seen BCEWithLogits before -- it is similar to CrossEntropyLoss,
    # but it is designed for binary tasks.
    loss_fun = torch.nn.BCEWithLogitsLoss()
    optim = torch.optim.AdamW(model.parameters())

    # Train the model
    for epoch in range(epochs):
        print("Epoch:", epoch)
        # Train for an epoch
        model.train()
        metric = Metric()
        for image, label in data_train:
            # Move image, label to GPU
            image, label = image.to(device), label.to(device)
            
            # Compute network output
            pred = model(image)
            
            # Compute loss
            loss_val = loss_fun(pred, label)

            metric.add(pred, label)

            # Zero gradient
            optim.zero_grad()
            # Backward
            loss_val.backward()
            # Step optim
            optim.step()
            # Logging
            logger.add_scalar('train/loss', float(loss_val), global_step=global_step)
            global_step += 1
        
        logger.add_scalar('train/accuracy', float(metric.accuracy), global_step=global_step)
        logger.add_scalar('train/iou', float(metric.iou), global_step=global_step)

        # Evaluate the model
        model.eval()
        metric = Metric()
        for it, (image, label) in enumerate(data_val):
            # Move image, label to GPU
            image, label = image.to(device), label.to(device)
            
            # Compute network output
            pred = model(image)

            metric.add(pred, label)

            if it == 0:
                logger.add_image('val/image', torchvision.utils.make_grid(image[:8], nrow=4), global_step=global_step)
                logger.add_image('val/pred', torchvision.utils.make_grid(torch.sigmoid(pred[:8]), nrow=4), global_step=global_step)
                logger.add_image('val/output', torchvision.utils.make_grid((pred[:8]>0).float(), nrow=4), global_step=global_step)
        logger.add_scalar('val/accuracy', float(metric.accuracy), global_step=global_step)
        logger.add_scalar('val/iou', float(metric.iou), global_step=global_step)


# Actually train the model here!
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Network(3, 1)
model.to(device)

train(model)