How to modify the forward pass of a loaded torch model with forward and backward hooks?

In torch, I have a pretrained architecture for which I load the weights. I know the forward pass is not as straightforward as a simple pass through all the layers but do not have access to the forward pass. Let’s say that this model looks like this.

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.A = nn.Linear(128, 128)
        self.B = nn.Linear(128, 128)
        self.C = nn.Linear(128, 128)
        self.D = nn.Linear(128, 128)
    
    def forward(self, x):
        x = F.relu(self.A(x)) + x
        x = F.relu(self.B(x))
        x = F.relu(self.C(x))
        x = self.D(x)
        return x

I would like to have a network E, which would run in parallel to B: takes as input the output of A, and the output of B and E are added and are the input of C. If I could have access to the forward pass, it would look like this.

    def forward(self, x):
        x = F.relu(self.A(x)) + x

        x1 = F.relu(self.B(x))
        x2 = F.relu(self.E(x))

        x = F.relu(self.C(x1+x2))
        x = self.D(x)
        return x

I would like to do this with hooks (both forward and backward), without modifying the Model class defined, but am not sure how to do it in a way so that the computation graph is similar to the forward pass shown above.

  • What does “I do not have access to the forward pass” mean? Why can’t you simply overwrite the forward pass or use the pretrained model as a sub-module?

    – 

Leave a Comment