I am facing an error using pytorch custom dataset. The problem is really strange to me because it was working, I didn’t change anything on the code. Here is the scenario:
-
After building my deep learning model, I test it by training it in 10 then 100 epochs. It was working fine, but I saw that the model need to be trained with more epoch to get better result.
-
So, I changed the number of epochs to 500. My GPU crashed, maybe because I printed the result every 10 epochs, and it runs out of memory (I don’t know exactly what was the real problem)
-
Now after restarting my GPU, jupiter notebook throw me a server error with status code 500
-
I search on internet and found a solution for my jupiter notebook by running the following command:
pip install --upgrade nbconvert
-
After that, the code does not work anymore. I tried to debug it and I find something strange:
- The custom dataset class give me an error below if I put the class in a python file and call it from jupiter notebook (the parameter idx in the function getitem is a list)
- The custom dataset class work if I put the class directly in the jupiter notebook cell and call it from this jupiter notebook (the parameter idx in the function getitem is an integer)
Thank you in advance for your answer.
Here is the custom dataset class:
put this in src/custom_dataset.py for example
import os
from natsort import natsorted
from PIL import Image
# Let's see if we have an available GPU
from datasets import Dataset
class LoadPairedDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
print(idx)
img_name = os.path.join(self.root_dir, self.images[idx])
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
return image
Here is the code I use to call the custom dataset:
put this in a jupiter notebook cell
# Imports
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from src.custom_dataset import LoadPairedDataset
# Define your own class LoadFromFolder
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = os.listdir(root_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
print(idx)
img_name = os.path.join(self.root_dir, self.images[idx])
image = Image.open(img_name)
if self.transform:
image = self.transform(image)
return image
base_path = "../lol-custom"
# dataloader = {"train_n": None, "train_p": None}
transform = transforms.Compose([
transforms.ToTensor()
])
train_data = CustomImageDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=5,
sampler=None,
num_workers=0
)
# The output will be:
# 0 1 2 3 4 from the print(idx) in the __getitem__ function in CustomImageDataset class
# torch.Size([5, 3, 400, 600]) from the below print
print(next(iter(dataloader)).shape) # This will print 0 1 2 3 4
print("######### The below throw an error ##############")
train_data = LoadPairedDataset(root_dir=base_path + "/train/low", transform=transform)
dataloader = torch.utils.data.DataLoader(
train_data,
batch_size=5,
sampler=None,
num_workers=0
)
# The output will be:
# [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
# THEN ERROR: TypeError: list indices must be integers or slices, not list
print(next(iter(dataloader)).shape)
I am using torch 2.01 with python 3.9.18
And finally, here is the output and the stacktrace of the error
0
1
2
3
4
torch.Size([5, 3, 400, 600])
######### The below throw an error ##############
[0, 1, 2, 3, 4]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[1], line 61
52 dataloader = torch.utils.data.DataLoader(
53 train_data,
54 batch_size=5,
55 sampler=None,
56 num_workers=0
57 )
58 # This output will be:
59 # [0, 1, 2, 3, 4] from the print(idx) in the __getitem__ function in CustomImageDataset class
60 # THEN ERROR: TypeError: list indices must be integers or slices, not list
---> 61 print(next(iter(dataloader)).shape)
File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:633, in _BaseDataLoaderIter.__next__(self)
630 if self._sampler_iter is None:
631 # TODO(https://github.com/pytorch/pytorch/issues/76750)
632 self._reset() # type: ignore[call-arg]
--> 633 data = self._next_data()
634 self._num_yielded += 1
635 if self._dataset_kind == _DatasetKind.Iterable and \
636 self._IterableDataset_len_called is not None and \
637 self._num_yielded > self._IterableDataset_len_called:
File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\dataloader.py:677, in _SingleProcessDataLoaderIter._next_data(self)
675 def _next_data(self):
676 index = self._next_index() # may raise StopIteration
--> 677 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
678 if self._pin_memory:
679 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~\anaconda3\envs\mmie\lib\site-packages\torch\utils\data\_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
47 if self.auto_collation:
48 if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
---> 49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
51 data = [self.dataset[idx] for idx in possibly_batched_index]
File ~\anaconda3\envs\mmie\lib\site-packages\datasets\arrow_dataset.py:2807, in Dataset.__getitems__(self, keys)
2805 def __getitems__(self, keys: List) -> List:
2806 """Can be used to get a batch using a list of integers indices."""
-> 2807 batch = self.__getitem__(keys)
2808 n_examples = len(batch[next(iter(batch))])
2809 return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
File ~\Projects\mmie\src\custom_dataset.py:21, in LoadPairedDataset.__getitem__(self, idx)
19 def __getitem__(self, idx):
20 print(idx)
---> 21 img_name = os.path.join(self.root_dir, self.images[idx])
22 image = Image.open(img_name)
24 if self.transform:
TypeError: list indices must be integers or slices, not list
It’s probably because your custom dataset inherits from a Dataset
class but the meaning of a Datset
changes.
When in a separate file, you define Dataset
as from datasets import Dataset
but in the jupyter cell, Dataset
is from torch.utils.data import Dataset
which is cleary different. I suggest you keep the definition that works in the separate file as well