In [1]:
import torch
import numpy as np

In [2]:
class ConvNet(torch.nn.Module):
    class Block(torch.nn.Module):
        def __init__(self, in_channels, out_channels, stride):
            super().__init__()
            self.net = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels, in_channels, 3, padding='same'),
                torch.nn.ReLU(),
                torch.nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
                torch.nn.ReLU()
            )
            
            # Use Xavier initialization for the weights
            torch.nn.init.xavier_normal_(self.net[0].weight, gain=np.sqrt(2))
            torch.nn.init.xavier_normal_(self.net[2].weight, gain=np.sqrt(2))
            # Use a constant initializer for the bias term.
            # Typically you would not do this.
            torch.nn.init.constant_(self.net[0].bias, 0.1)
            torch.nn.init.constant_(self.net[2].bias, 0.1)
            # If you want to try Kaiming normalization, you can do it like this:
            # torch.nn.init.kaiming_normal_(self.net[0].weights, mode='fan_in', nonlinearity='relu')
            # This version keeps the activation magnitudes constant.
            # If you want to keep gradient magnitudes constant:
            # torch.nn.init.kaiming_normal_(self.net[0].weights, mode='fan_out', nonlinearity='relu')
            
        def forward(self, x):
            return self.net(x)
        
    def __init__(self, input_channels=3, channels=[16, 32, 64], output_classes=3):
        super().__init__()
        layers = [
            torch.nn.Conv2d(input_channels, channels[0], 7, padding=3, stride=2),
            torch.nn.ReLU()
        ]
        for i in range(len(channels) - 1):
            layers.append(self.Block(channels[i], channels[i+1], stride=2))
        self.net = torch.nn.Sequential(*layers)
        self.classifier = torch.nn.Linear(channels[-1], output_classes)
        
        # Here we use zeros in the last layer
        torch.nn.init.zeros_(self.classifier.weight)
        torch.nn.init.xavier_normal_(self.net[0].weight, np.sqrt(2))
        # Again, you normally would not change the bias initialization
        torch.nn.init.constant_(self.net[0].bias, 0.1)
        
    def forward(self, x):
        x = self.net(x)
        x = x.mean(dim=(2, 3))
        return self.classifier(x)

In [None]:
net = ConvNet()
print(net.net[0].weight)

In [None]:
print(net.net[0].bias)

In [None]:
print(net.classifier.weight)