Data Loading into GCN

using this ‘collate_fn’

from torch_geometric.data import Data

def custom_collate(batch):
    # batch is a list of PyTorch Geometric Data objects
    # We need to convert it into a batched PyTorch Geometric Data object


    # Extract individual components from the batch
    x_list = [data.x for data in batch]
    edge_index_list = [data.edge_index for data in batch]
    y_list = [data.y for data in batch]

    # Stack the node features, edge indices, and labels
    x = torch.stack(x_list)
    edge_index = torch.stack(edge_index_list)
    y = torch.stack(y_list)

    # Return a batched PyTorch Geometric Data object
    return Data(x=x, edge_index=edge_index, y=y)

used in our ‘DataLoader’

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

can’t access data and gives error like

KeyError                                  Traceback (most recent call last)
<ipython-input-109-6a97bb16ec79> in <cell line: 5>()
      8 
      9     # Iterate over the batches in the data loader
---> 10     for batch in train_loader:
     11         optimizer.zero_grad()  # Zero the gradients
     12         x, edge_index, y = batch.x, batch.edge_index, batch.y  # Extract data from the batch

6 frames
/usr/local/lib/python3.10/dist-packages/torch_geometric/data/storage.py in __getitem__(self, key)
    109 
    110     def __getitem__(self, key: str) -> Any:
--> 111         return self._mapping[key]
    112 
    113     def __setitem__(self, key: str, value: Any):

KeyError: 0

I have also tried to Iterate over the batches in the data loader is changing it into list which gives error like
‘AttributeError: ‘list’ object has no attribute ‘x”

What am I missing to load my data to train my GCN model

Leave a Comment