I am using this resnet script with slightly modified training data. I would like to remove nn.AdaptiveAvgPool2d(1,1)
and replace it with nn.AvgPool2d(kernel, stride)
. However, the kernel size and stride must be equal to z.size(dim=2)
(tensor width) and z.size(dim=2)
(tensor height) from the forward function.
My question is what is the best way to achieve this? I would still like to define nn.AdaptiveAvgPool2d(1,1)
in the __init__
function like the other torch classes. I realize this may be more of a Python syntax question so please bare with me.
import torch
from torch import nn
import torch.nn.functional as F
class block(nn.Module):
def __init__(self, filters, subsample=False):
super().__init__()
"""
Parameters:
- filters: int
the number of filters for all layers in this block
- subsample: boolean
whether to subsample the input feature maps with stride 2
and doubling in number of filters
Attributes:
- shortcuts: boolean
When false the residual shortcut is removed
resulting in a 'plain' convolutional block.
"""
# Determine subsampling
s = 0.5 if subsample else 1.0
# Setup layers
self.conv1 = nn.Conv2d(int(filters*s), filters, kernel_size=3,
stride=int(1/s), padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(filters, track_running_stats=True)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(filters, track_running_stats=True)
self.relu2 = nn.ReLU()
# Shortcut downsampling
self.downsample = nn.AvgPool2d(kernel_size=1, stride=2)
# Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def shortcut(self, z, x):
"""
Parameters:
- x: tensor
the input to the block
- z: tensor
activations of block prior to final non-linearity
"""
if x.shape != z.shape:
d = self.downsample(x)
p = torch.mul(d, 0)
return z + torch.cat((d, p), dim=1)
else:
return z + x
def forward(self, x, shortcuts=False):
z = self.conv1(x)
z = self.bn1(z)
z = self.relu1(z)
z = self.conv2(z)
z = self.bn2(z)
# Shortcut connection
if shortcuts:
z = self.shortcut(z, x)
z = self.relu2(z)
return z
class ResNet(nn.Module):
def __init__(self, n, shortcuts=True):
super().__init__()
self.shortcuts = shortcuts
# Input
self.convIn = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bnIn = nn.BatchNorm2d(16, track_running_stats=True)
self.relu = nn.ReLU()
# Stack1
self.stack1 = nn.ModuleList([block(16, subsample=False) for _ in range(n)])
# Stack2
self.stack2a = block(32, subsample=True)
self.stack2b = nn.ModuleList([block(32, subsample=False) for _ in range(n-1)])
# Stack3
self.stack3a = block(64, subsample=True)
self.stack3b = nn.ModuleList([block(64, subsample=False) for _ in range(n-1)])
# Output
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) ##REPLACE THIS
## REPLACE WITH nn.AvgPool2d((self.pool_width, self.pool_height), stride=(self.pool_width,self.pool_height))
self.fcOut = nn.Linear(64, 10, bias=True)
self.softmax = nn.LogSoftmax(dim=-1)
# Initialize weights in fully connected layer
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight)
m.bias.data.zero_()
def forward(self, x):
z = self.convIn(x)
z = self.bnIn(z)
z = self.relu(z)
for l in self.stack1: z = l(z, shortcuts=self.shortcuts)
z = self.stack2a(z, shortcuts=self.shortcuts)
for l in self.stack2b:
z = l(z, shortcuts=self.shortcuts)
z = self.stack3a(z, shortcuts=self.shortcuts)
for l in self.stack3b:
z = l(z, shortcuts=self.shortcuts)
## self.pool_width = z.size(dim=2) GET WIDTH VALUE HERE
## self.pool_height = z.size(dim=3) GET HEIGHT VALUE HERE
z = self.avgpool(z)
z = z.view(z.size(0), -1)
z = self.fcOut(z)
return self.softmax(z)
I’ve tried initializing self.pool_width=None
and self.pool_height=None
in the __init__
function so they can be updated in the forward function but that gives an error in AvgPool2d. I’ve also tried moving AvgPool2d inside the forward function but I don’t think that’s best practice given that nn.AvgPool2d
is a class.