using this ‘collate_fn’
from torch_geometric.data import Data
def custom_collate(batch):
# batch is a list of PyTorch Geometric Data objects
# We need to convert it into a batched PyTorch Geometric Data object
# Extract individual components from the batch
x_list = [data.x for data in batch]
edge_index_list = [data.edge_index for data in batch]
y_list = [data.y for data in batch]
# Stack the node features, edge indices, and labels
x = torch.stack(x_list)
edge_index = torch.stack(edge_index_list)
y = torch.stack(y_list)
# Return a batched PyTorch Geometric Data object
return Data(x=x, edge_index=edge_index, y=y)
used in our ‘DataLoader’
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
can’t access data and gives error like
KeyError Traceback (most recent call last)
<ipython-input-109-6a97bb16ec79> in <cell line: 5>()
8
9 # Iterate over the batches in the data loader
---> 10 for batch in train_loader:
11 optimizer.zero_grad() # Zero the gradients
12 x, edge_index, y = batch.x, batch.edge_index, batch.y # Extract data from the batch
6 frames
/usr/local/lib/python3.10/dist-packages/torch_geometric/data/storage.py in __getitem__(self, key)
109
110 def __getitem__(self, key: str) -> Any:
--> 111 return self._mapping[key]
112
113 def __setitem__(self, key: str, value: Any):
KeyError: 0
I have also tried to Iterate over the batches in the data loader is changing it into list which gives error like
‘AttributeError: ‘list’ object has no attribute ‘x”
What am I missing to load my data to train my GCN model