## Coding 09: Keypose Estimation

In this exercise, we'll do keypose estimation on the [Leeds Sport Pose Dataset (LSP)](https://sam.johnson.io/research/lsp.html).

Keypose estimation is a detection task - and one of the most popular tasks in core computer vision. There's a lot of different ways to tackle this task and we'll explore the single-stage detection methods to do this (YOLO, RetinaNet, CenterNet). We'll use a dataset of sports athletes with their corresponding joint labels,  and the goal is to train a network that outputs the xy position of each joint.

<img src="https://www.ee.cuhk.edu.hk/~xgwang/StructureFeature/QualitiveRes.png" width=720px/>

Image: Structured Feature Learning for Pose Estimation
, Chu et al. CVPR 2016


## Data Preparation

Run this cell to download the data.

In [None]:
!rm -f lsp_dataset_original.zip
!wget http://sam.johnson.io/research/lsp_dataset_original.zip
!unzip -oq lsp_dataset_original.zip

## Dataset

The dataset yields pairs of (image, label): 
 * The image is a [3, 128, 128] float tensor of an athlete doing something cool.  
 * The label is a [14, 2] -1 to 1 float tensor of each joint (normalized coordinates)

#### Labels

There are 14 different joints, `label[0]` is a xy coordinate of joint `0`, maybe the head or right wrist.

### Think about...

How would you implement additional transforms such as  

* RandomCrop
* RandomHorizontalFlip

In [None]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision

from scipy.io import loadmat
from PIL import Image, ImageDraw


EDGES = [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5],  [6, 7], [7, 8],
         [8, 12], [12, 9], [9, 10], [10, 11], [12, 13], [2, 8], [3, 9]]

class MyResize():
    def __init__(self, size):
        self.size = size

    def __call__(self, image, label):
        """
        Labels come in in pixel coordinates - 

        label[:, 0] is the x coordinate, in the range [0, width]
        label[:, 1] is the y coordinate, in the range [0, height]

        we convert these to [-1, 1] for several reasons - can you think of any?
        """
        xs = (label[:, 0] / image.width - 0.5) * 2.0
        ys = (label[:, 1] / image.height - 0.5) * 2.0

        image = torchvision.transforms.functional.resize(image, self.size)
        label = np.stack((xs, ys), 1)

        return image, label


class LSPDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir):
        dataset_dir = pathlib.Path(dataset_dir)

        self.transform = MyResize((128, 128))
        self.to_tensor = torchvision.transforms.ToTensor()

        self.images = list(sorted((dataset_dir / 'images').glob('*.jpg')))
        self.labels = loadmat(str(dataset_dir / 'joints.mat'))['joints'].transpose(2, 1, 0).astype(np.float32)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        image = Image.open(self.images[i])
        label = self.labels[i]

        image, label = self.transform(image, label)
        image = self.to_tensor(image)

        return image, label


def fetch_dataloader(batch_size, transform=None, split='train'):
    """
    Loads data from disk and returns a data_loader.
    A DataLoader is similar to a list of (image, label) tuples.
    """
    data = LSPDataset(dataset_dir='.')

    # Custom train/val split.
    if split == 'train':
        indices = [i for i in range(len(data)) if (i%10 > 1)]
    elif split == 'val':
        indices = [i for i in range(len(data)) if (i%10 == 1)]
    else:
        indices = [i for i in range(len(data)) if (i%10 == 0)]

    data = torch.utils.data.Subset(data, indices)
    loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=split=='train', num_workers=0)
    return loader


@torch.no_grad()
def visualize_sample(image, label):
    image_numpy = (255 * image.cpu().numpy().transpose(1, 2, 0)).astype(np.uint8)

    debug = Image.fromarray(image_numpy)
    draw = ImageDraw.Draw(debug)

    for i, j in EDGES:
        xi = (label[i, 0].item() / 2.0 + 0.5) * debug.width
        yi = (label[i, 1].item() / 2.0 + 0.5) * debug.height

        xj = (label[j, 0].item() / 2.0 + 0.5) * debug.width
        yj = (label[j, 1].item() / 2.0 + 0.5) * debug.height

        draw.line((xi, yi, xj, yj), fill=(255, 0, 0))

    result = np.array(debug)

    return torchvision.transforms.functional.to_tensor(result)


%matplotlib inline

data_train = fetch_dataloader(32, split='train')
data_val = fetch_dataloader(32, split='val')
data_test = fetch_dataloader(32, split='test')

num_to_show = 4
vis = []
vis_data = next(iter(data_train))
for i in range(num_to_show):
    image, label = vis_data[0][i], vis_data[1][i]
    vis.append(visualize_sample(image, label))

vis = torchvision.utils.make_grid(vis, nrow=num_to_show)
plt.figure(figsize=(num_to_show * 5, 5))
plt.imshow(vis.permute(1, 2, 0))
plt.show()

## TensorBoard Setup


In [None]:
import torch.utils.tensorboard as tb
import tempfile

log_dir = tempfile.mkdtemp()

%load_ext tensorboard
%tensorboard --logdir {log_dir} --reload_interval 1

## Model and Training

Some questions to wonder about -

* how should the network be structured?
* what loss should we use?

In [None]:
import time

class SoftArgmax(torch.nn.Module):
    def __init__(self, input_channels=3, num_joints=14):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(input_channels, num_joints, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        hmap = torch.nn.functional.softmax(x.view(*x.shape[:2], -1), dim=-1).view(*x.shape)
        xs = torch.linspace(-1, 1, hmap.shape[-1], device=x.device)
        ys = torch.linspace(-1, 1, hmap.shape[-2], device=x.device)
        pred_x = (hmap * xs[None, None, None]).sum(dim=(2,3))
        pred_y = (hmap * ys[None, None, :, None]).sum(dim=(2,3))
        
        return torch.stack((pred_x, pred_y), dim=-1)

class Network(torch.nn.Module):
    def __init__(self, input_channels=3, num_joints=14):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(input_channels, 16, 5, padding=2, stride=2)
        self.soft_argmax = SoftArgmax(16, num_joints)
    
    def forward(self, x):
        """
        Predict joint positions from the image.

        Input: 
            x (float tensor N x 3 x 128 x 128): input image of a person
        Output:
            y (float tensor N x 14 x 2): xy coordinates of joint positions
        """
        x = self.conv1(self.input_norm(x))
        return self.soft_argmax(torch.relu(x))


def visualize(logger, image, pred, label, global_step):
    image_preds = list()
    image_labels = list()

    for _image, _pred, _label in zip(image, pred, label):
        image_pred = visualize_sample(_image, _pred)
        image_preds.append(image_pred)

        image_label = visualize_sample(_image, _label)
        image_labels.append(image_label)

    logger.add_image('pred', torchvision.utils.make_grid(image_preds), global_step=global_step)
    logger.add_image('label', torchvision.utils.make_grid(image_labels), global_step=global_step)


def calculate_pdj(pred, label):
    widths = torch.max(label[:, :, 0], 1)[0] - torch.min(label[:, :, 0], 1)[0]
    heights = torch.max(label[:, :, 1], 1)[0] - torch.min(label[:, :, 1], 1)[0]
    diagonals = torch.sqrt(widths ** 2 + heights ** 2)[:, None]
    d = torch.sqrt((pred[:, :, 0] - label[:, :, 0]) ** 2 + (pred[:, :, 1] - label[:, :, 1]) ** 2)
    return (d < 0.05 * diagonals).float().mean()


def train(model, device, lr=1e-3, epochs=10):
    # Setting up the tensorboard logger
    logger = tb.SummaryWriter(log_dir + '/{}'.format(time.strftime('%H-%M-%S')), flush_secs=1)
    global_step = 0

    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    # TODO: Choose a loss function
    loss_fun = None

    for epoch in range(epochs):
        print("Epoch:", epoch)
        model.train()
    
        for image, label in data_train:
            image, label = image.to(device), label.to(device)

            pred = model(image)
            loss_val = loss_fun(pred, label)
            optim.zero_grad()
            loss_val.backward()
            optim.step()

            with torch.no_grad():
                pdj = calculate_pdj(pred, label)
            logger.add_scalar('train/pdj', float(pdj), global_step=global_step)
            logger.add_scalar('train/loss', float(loss_val), global_step=global_step)
            global_step += 1

        # Compute the average pdj on the validation set
        model.eval()
        model.to(device)
        val_pdj = list()
        for image, label in data_test:
            with torch.no_grad():
                image = image.to(device)
                label = label.to(device)
                pred = model(image)
                pdj = calculate_pdj(pred, label)
                val_pdj.append(pdj.item())
        logger.add_scalar('val/pdj', torch.FloatTensor(val_pdj).mean().item(), global_step=global_step)

        visualize(logger, image, pred, label, global_step)


# Actually train the model here!
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Network(3, 14)
model.to(device)

train(model, device)