In [None]:
%pylab inline
import torch
import torchvision
from torchvision import transforms
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device = ', device)

In [None]:
# This is the same data loading code from our in-class exercises
def fetch_dataloader(batch_size, transform=None, is_train=True):
    data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

    indices = [i for i in range(len(data)) if (i%10 > 0) == is_train]

    data = torch.utils.data.Subset(data, indices)
    loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=2)
    return loader

In [None]:
# Here we simply load the images without normalization
train_transform = transforms.ToTensor()
data_train = fetch_dataloader(64, train_transform, is_train=True)

# Now compute image statistics. Since this takes a bit of time to run, what
# I would do is just run this once then hardcode the output into the transform
# for future calls.
sums = torch.zeros(3)
sq_sums = torch.zeros(3)
batches = 0
for data, _ in data_train:
    sums += data.mean(dim=(0, 2, 3))
    sq_sums += (data**2).mean(dim=(0, 2, 3))
    batches += 1

mean = sums / batches
# Std[X] = sqrt(E[X^2] - E[X]^2)
std = (sq_sums / batches - mean**2)**0.5

print("Mean:", mean)
print("Std: ", std)

In [None]:
# Now I put the results into a normalize call. Every time I want to run this
# code from now on, I would use this cell instead of the previous one.
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4916, 0.4824, 0.4468), (0.2471, 0.2436, 0.2618))
])
data_train = fetch_dataloader(64, train_transform, is_train=True)
data_valid = fetch_dataloader(64, train_transform, is_train=False)

In [None]:
# Now let's look at Batch Normalization within the network.

class ConvNet(torch.nn.Module):
    class Block(torch.nn.Module):
        def __init__(self, n_input, n_output, stride=1):
            super().__init__()
            self.net = torch.nn.Sequential(
                # In this notebook I've put the normalizations before the convolutional layers. If
                # you want to do them afterward, comment the line below and uncomment the later
                # BatchNorm2d layer. In this case you will also want to set bias=False in the Conv2d
                # constructor.
                torch.nn.BatchNorm2d(n_input, affine=False),
                torch.nn.Conv2d(n_input, n_output, kernel_size=3, padding=1, stride=stride),
                #torch.nn.BatchNorm2d(n_output),
                torch.nn.ReLU(),
                torch.nn.BatchNorm2d(n_output, affine=False),
                torch.nn.Conv2d(n_output, n_output, kernel_size=3, padding=1),
                #torch.nn.BatchNorm2d(n_output),
                torch.nn.ReLU()
            )
        
        def forward(self, x):
            return self.net(x)
        
    def __init__(self, layers=[32,64,128], n_input_channels=3):
        super().__init__()
        L = [torch.nn.Conv2d(n_input_channels, 32, kernel_size=7, padding=3, stride=2),
             # If you want to do BatchNorm after the convolutional layers, then you should add
             # one here too. Remember to set bias=False in the convolution above.
             #torch.nn.BatchNorm2d(32),
             torch.nn.ReLU(),
             torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            ]
        c = 32
        for l in layers:
            L.append(self.Block(c, l, stride=2))
            c = l
        self.network = torch.nn.Sequential(*L)
        self.classifier = torch.nn.Linear(c, 1)
    
    def forward(self, x):
        # Compute the features
        z = self.network(x)
        # Global average pooling
        z = z.mean(dim=[2,3])
        # Classify
        return self.classifier(z)[:,0]
    
net = ConvNet()
# When using BatchNorm, it is important to set the network to training mode and
# evaluation mode. BatchNorm behaves differently depending on whether it is being
# trained or tested.
net.train()
print(net.training)
net.eval()
print(net.training)