RuntimeError: Given groups=1, weight of size [128, 64, 4, 4], expected input[1, 128, 65, 65] to have 64 channels, but got 128 channels instead

When running the below code:

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        super(NLayerDiscriminator, self).__init__()
        self.n_layers = n_layers

        kw = 4
        padw = int(np.ceil((kw-1.0)/2))
        self.conv1=nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
        self.act=nn.LeakyReLU(0.2, True)
        
        nf=min(ndf*2,512)
        self.conv2=nn.Conv2d(ndf, nf, kernel_size=kw, stride=2, padding=padw)
        self.norm1=norm_layer(nf)
        
        ngf=min(nf*2,512)
        self.conv3=nn.Conv2d(nf, ngf, kernel_size=kw, stride=1, padding=padw)
        self.norm2=norm_layer(ngf)
        self.conv4=nn.Conv2d(ngf, 1, kernel_size=kw, stride=1, padding=padw)
        self.sig=nn.Sigmoid()
        

    def forward(self, input):
        x=self.conv1(input)
        x=self.act(x)
        
        for n in range(1, 3):
            x=self.conv2(x)
            x=self.norm1(x)
            x=self.act(x)
        
        x=self.conv3(x)
        x=self.norm(x)
        x=self.act(x)
        x=self.conv4(x)
        if use_sigmoid:
            x=self.sig(x)
        return x

I get the following error when I try to train:

RuntimeError                              Traceback (most recent call last)
Cell In[19], line 71
     61                     torch.save({
     62                         'gen_AB': gen_AB.state_dict(),
     63                         'gen_BA': gen_BA.state_dict(),
   (...)
     68                         'disc_B_opt': disc_B_opt.state_dict()
     69                     }, f"cycleGAN_{cur_step}.pth")
     70             cur_step += 1
---> 71 train()

Cell In[19], line 27, in train(save_model)
     25 with torch.no_grad():
     26     fake_A = gen_BA(real_B)
---> 27 disc_A_loss = get_disc_loss(real_A, fake_A, disc_A, adv_criterion)
     28 disc_A_loss.backward(retain_graph=True) # Update gradients
     29 disc_A_opt.step() # Update optimizer

Cell In[9], line 16, in get_disc_loss(real_X, fake_X, disc_X, adv_criterion)
      4 '''
      5 Return the loss of the discriminator given inputs.
      6 Parameters:
   (...)
     13         loss (which you aim to minimize)
     14 '''
     15 #### START CODE HERE ####
---> 16 disc_fake_X_hat = disc_X(fake_X.detach()) # Detach generator
     17 disc_fake_X_loss = adv_criterion(disc_fake_X_hat, torch.zeros_like(disc_fake_X_hat))
     18 disc_real_X_hat = disc_X(real_X)

File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[5], line 55, in NLayerDiscriminator.forward(self, input)
     52 x=self.act(x)
     54 for n in range(1, 3):
---> 55     x=self.conv2(x)
     56     x=self.norm1(x)
     57     x=self.act(x)

File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~\anaconda3\Lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~\anaconda3\Lib\site-packages\torch\nn\modules\conv.py:460, in Conv2d.forward(self, input)
    459 def forward(self, input: Tensor) -> Tensor:
--> 460     return self._conv_forward(input, self.weight, self.bias)

File ~\anaconda3\Lib\site-packages\torch\nn\modules\conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    452 if self.padding_mode != 'zeros':
    453     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    454                     weight, bias, self.stride,
    455                     _pair(0), self.dilation, self.groups)
--> 456 return F.conv2d(input, weight, bias, self.stride,
    457                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [128, 64, 4, 4], expected input[1, 128, 65, 65] to have 64 channels, but got 128 channels instead

I tried changing the channel number, but it did not work.
Above is the entire traceback.
‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎ ‎

  • Include the entire traceback, please.

    – 

  • Try to print the size of tensor by print(x.size()) so that you can check the number of channels in each line of codes.

    – 

Leave a Comment