Pytorch CustomDataset TypeError: list indices must be integers or slices, not list

I am facing an error using pytorch custom dataset. The problem is really strange to me because it was working, I didn’t change anything on the code. Here is the scenario:

  1. After building my deep learning model, I test it by training it in 10 then 100 epochs. It was working fine, but I saw that the model need to be trained with more epoch to get better result.

  2. So, I changed the number of epochs to 500. My GPU crashed, maybe because I printed the result every 10 epochs, and it runs out of memory (I don’t know exactly what was the real problem)

  3. Now after restarting my GPU, jupiter notebook throw me a server error with status code 500

  4. I search on internet and found a solution for my jupiter notebook by running the following command: pip install --upgrade nbconvert

  5. After that, the code does not work anymore. I tried to debug it and I find something strange:

  • The custom dataset class give me an error below if I put the class in a python file and call it from jupiter notebook (the parameter idx in the function getitem is a list)
  • The custom dataset class work if I put the class directly in the jupiter notebook cell and call it from this jupiter notebook (the parameter idx in the function getitem is an integer)

Thank you in advance for your answer.

Here is the custom dataset class:

put this in src/custom_dataset.py for example

import os
from natsort import natsorted
from PIL import Image
# Let's see if we have an available GPU
from datasets import Dataset

class LoadPairedDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        print(idx)
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image

Here is the code I use to call the custom dataset:

put this in a jupiter notebook cell

# Imports
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from src.custom_dataset import LoadPairedDataset
 
# Define your own class LoadFromFolder
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        print(idx)  
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image
    

base_path = "../lol-custom"

# dataloader = {"train_n": None, "train_p": None}
transform = transforms.Compose([
            transforms.ToTensor()                       
        ])

train_data = CustomImageDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=5,
    sampler=None,
    num_workers=0
)
# The output will be:
# 0 1 2 3 4 from the print(idx) in the __getitem__ function in CustomImageDataset class
# torch.Size([5, 3, 400, 600]) from the below print
print(next(iter(dataloader)).shape) # This will print 0 1 2 3 4 
print("######### The below throw an error ##############")

train_data = LoadPairedDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
    train_data,
    batch_size=5,
    sampler=None,
    num_workers=0
)
# The output will be:
# [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
# THEN ERROR: TypeError: list indices must be integers or slices, not list
print(next(iter(dataloader)).shape) 

I am using torch 2.01 with python 3.9.18

And finally, here is the output and the stacktrace of the error

0
1
2
3
4
torch.Size([5, 3, 400, 600])
######### The below throw an error ##############
[0, 1, 2, 3, 4]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 61
     52 dataloader = torch.utils.data.DataLoader(
     53     train_data,
     54     batch_size=5,
     55     sampler=None,
     56     num_workers=0
     57 )
     58 # This output will be:
     59 # [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
     60 # THEN ERROR: TypeError: list indices must be integers or slices, not list
---> 61 print(next(iter(dataloader)).shape)

File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
    630 if self._sampler_iter is None:
    631     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    632     self._reset()  # type: ignore[call-arg]
--> 633 data = self._next_data()
    634 self._num_yielded += 1
    635 if self._dataset_kind == _DatasetKind.Iterable and \
    636         self._IterableDataset_len_called is not None and \
    637         self._num_yielded > self._IterableDataset_len_called:

File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
    675 def _next_data(self):
    676     index = self._next_index()  # may raise StopIteration
--> 677     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    678     if self._pin_memory:
    679         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     47 if self.auto_collation:
     48     if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
     51         data = [self.dataset[idx] for idx in possibly_batched_index]

File ~\anaconda3\envs\mmie\lib\site-packages\datasets\arrow_dataset.py:2807, in Dataset.__getitems__(self, keys)
   2805 def __getitems__(self, keys: List) -> List:
   2806     """Can be used to get a batch using a list of integers indices."""
-> 2807     batch = self.__getitem__(keys)
   2808     n_examples = len(batch[next(iter(batch))])
   2809     return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

File ~\Projects\mmie\src\custom_dataset.py:21, in LoadPairedDataset.__getitem__(self, idx)
     19 def __getitem__(self, idx):
     20     print(idx)
---> 21     img_name = os.path.join(self.root_dir, self.images[idx])
     22     image = Image.open(img_name)
     24     if self.transform:

TypeError: list indices must be integers or slices, not list


It’s probably because your custom dataset inherits from a Dataset class but the meaning of a Datset changes.

When in a separate file, you define Dataset as from datasets import Dataset but in the jupyter cell, Dataset is from torch.utils.data import Dataset which is cleary different. I suggest you keep the definition that works in the separate file as well

Leave a Comment