Could the forward function of torch.nn pass the learnable parameters and update it?

I have write two torch.nn modules, and I want pass some learnable parameters between them. I know this may not be the norm, but could this lead to some training errors or parameter updating error?

An toy example is shown in below:

import torch
import torch.nn as nn

class SubModel(nn.Module):
    def __init__(self):
        super(SubModel, self).__init__()
        self.weight = nn.Parameter(torch.randn(5, 5))
        self.bias = nn.Parameter(torch.zeros(5))

    def forward(self,x,other_params):
        hidden = torch.matmul(self.weight,other_params) + self.bias
        out = torch.matmul(x,hidden)
        return out


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.weight = nn.Parameter(torch.randn(5, 5))
        self.bias = nn.Parameter(torch.zeros(5))
        self.other_params = nn.Parameter(torch.randn(5,5))
        self.sub_model = SubModel()

    def forward(self, x):
        hidden = self.sub_model(x,self.other_params)
        out = torch.matmul(hidden, self.weight.t()) + self.bias
        return out

Furthermore, I would like to confirm if this parameter other_params can be shared among multiple submodels? could this lead to some training errors or parameter updating error? Thank you very much for explaining the method of parameter update.

For example:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.weight = nn.Parameter(torch.randn(5, 5))
        self.bias = nn.Parameter(torch.zeros(5))
        self.other_params = nn.Parameter(torch.randn(5,5))
        # self.sub_model = SubModel()
        self.sub_model_1 = SubModel()
        self.sub_model_2 = SubModel()

    def forward(self, x):
        hidden = self.sub_model_1(x,self.other_params)
        hidden = self.sub_model_2(hidden,self.other_params)
        out = torch.matmul(hidden, self.weight.t()) + self.bias
        return out

I have add breakpoints and observed the parameters could be updating and loss is decreased.

Leave a Comment