I have write two torch.nn modules, and I want pass some learnable parameters between them. I know this may not be the norm, but could this lead to some training errors or parameter updating error?
An toy example is shown in below:
import torch
import torch.nn as nn
class SubModel(nn.Module):
def __init__(self):
super(SubModel, self).__init__()
self.weight = nn.Parameter(torch.randn(5, 5))
self.bias = nn.Parameter(torch.zeros(5))
def forward(self,x,other_params):
hidden = torch.matmul(self.weight,other_params) + self.bias
out = torch.matmul(x,hidden)
return out
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(5, 5))
self.bias = nn.Parameter(torch.zeros(5))
self.other_params = nn.Parameter(torch.randn(5,5))
self.sub_model = SubModel()
def forward(self, x):
hidden = self.sub_model(x,self.other_params)
out = torch.matmul(hidden, self.weight.t()) + self.bias
return out
Furthermore, I would like to confirm if this parameter other_params
can be shared among multiple submodels? could this lead to some training errors or parameter updating error? Thank you very much for explaining the method of parameter update.
For example:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(5, 5))
self.bias = nn.Parameter(torch.zeros(5))
self.other_params = nn.Parameter(torch.randn(5,5))
# self.sub_model = SubModel()
self.sub_model_1 = SubModel()
self.sub_model_2 = SubModel()
def forward(self, x):
hidden = self.sub_model_1(x,self.other_params)
hidden = self.sub_model_2(hidden,self.other_params)
out = torch.matmul(hidden, self.weight.t()) + self.bias
return out
I have add breakpoints and observed the parameters could be updating and loss is decreased.