In [None]:
import tempfile
import torch.utils.tensorboard as tb

#log_dir = tempfile.mkdtemp()
log_dir = 'lr_log_dir'

%load_ext tensorboard
%tensorboard --logdir {log_dir} --reload_interval 1

In [None]:
import torchvision
import torchvision.transforms as transforms
import torch
import numpy as np


classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


def fetch_dataloader(batch_size, transform=None, is_train=True):
    data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    indices = [i for i in range(len(data)) if (i%10 > 0) == is_train]
    data = torch.utils.data.Subset(data, indices)
    loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=2)
    return loader


train_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

val_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data_train = fetch_dataloader(64, train_transform, is_train=True)
data_val = fetch_dataloader(64, val_transform, is_train=False)

In [None]:
class ResNetBlock(torch.nn.Module):
    def __init__(self, c_in, c_out, stride=1):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(c_in, c_in, 3, padding='same'),
            torch.nn.ReLU(),
            torch.nn.Conv2d(c_in, c_out, 3, stride=stride, padding=1),
            torch.nn.ReLU()
        )
        self.skip = torch.nn.Identity()
        if c_in != c_out or stride != 1:
            self.skip = torch.nn.Conv2d(c_in, c_out, 1, stride=stride)
        
        torch.nn.init.xavier_normal_(self.net[0].weight, gain=np.sqrt(2))
        torch.nn.init.xavier_normal_(self.net[0].weight, gain=np.sqrt(2))

    def forward(self, x):
        return self.net(x) + self.skip(x)

class ResNetClassifier(torch.nn.Module):
    def __init__(self, input_channels, num_classes, channels=[32, 64, 128]):
        super().__init__()
        layers = [
            torch.nn.Conv2d(3, channels[0], 7, stride=1, padding=3),
            torch.nn.ReLU()
        ]
        for i in range(len(channels) - 1):
            layers.append(ResNetBlock(channels[i], channels[i+1], stride=1))
            layers.append(torch.nn.MaxPool2d(2, stride=2))
        self.conv = torch.nn.Sequential(*layers)
        self.cls = torch.nn.Linear(channels[-1], num_classes)
        
        torch.nn.init.xavier_normal_(self.conv[0].weight, gain=np.sqrt(2))
        torch.nn.init.zeros_(self.cls.weight)

    def forward(self, x):
        x = self.conv(x)
        x = x.mean(dim=(2, 3))
        return self.cls(x)
    
    def predict(self, image):
        return self(image).argmax(1)

In [None]:
import time

def train(model, data_train, data_val, device, lr=0.001, epochs=10):
    logger = tb.SummaryWriter(log_dir + '/{}'.format(time.strftime('%H-%M-%S')))
    global_step = 0

    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.1)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, threshold=0.01)

    for epoch in range(epochs):
        print("Epoch:", epoch, 'lr=', optimizer.param_groups[0]['lr'])
        model.train()

        train_accs = []
        for x, y in data_train:
            x = x.to(device)
            y = y.to(device)

            output = model(x)
            loss = loss_function(output, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            logger.add_scalar('loss', loss, global_step=global_step)
            train_accs.append((output.argmax(dim=1) == y).float().mean().item())

            global_step += 1

        model.eval()
        
        train_acc = torch.tensor(train_accs).mean().item()
        logger.add_scalar('train_accuracy', train_acc, global_step=global_step)

        accuracys_val = list()

        for x, y in data_val:
            x = x.to(device)
            y = y.to(device)

            y_pred = model.predict(x)
            accuracy_val = (y_pred == y).float().mean().item()
            accuracys_val.append(accuracy_val)

        accuracy = torch.FloatTensor(accuracys_val).mean().item()

        logger.add_scalar('accuracy', accuracy, global_step=global_step)
        # scheduler.step()
        # For the ReduceLROnPlateau scheduler
        scheduler.step(accuracy)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ResNetClassifier(3, 10, channels=[32, 64, 128])
model.to(device)

train(model, data_train, data_val, device, lr=5e-2, epochs=30)