training with a list of batch sizes (instead of a constant integer value) – Python

Python:

I am trying to load my data (images) sequentially. But each time, the next load batch does not start after the previous one. For instance, after loading images 1 to 9 for a batch size of 9, the next loading does not start from 10. It incorrectly starts to sample from 1 again. For example, for a list of batch sizes like [9, 10, 14], the current code samples from 1 to 9, then from 1 to 10, and then from 1 to 14. This is incorrect.
For a correct sampling, the sampling should start from 1 to 9, then continue to sample from 10 to 20, then sample from 21 to 35.

In the loading function (get_data), the indexes are correct. They start from zero and continue to the end based on the batch sizes.

In the training function (train), I added some codes to save the images to make sure that the model is being trained based on the correct sequence of images. However, it shows the images from 1 to 9, then from 1 to 10, then from 1 to 14.

These are the two functions I use to load and train:

def get_data(args):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(80),
        torchvision.transforms.RandomResizedCrop(args.image_size, scale=(1.0, 1.0)),
        torchvision.transforms.Grayscale(num_output_channels=1),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)

    # If batch_size is a single value, convert it to a list
    batch_sizes = args.batch_size if isinstance(args.batch_size, list) else [args.batch_size]

    dataloaders = []
    start_idx = 0  # Keep track of the starting index for each batch size
    for batch_size in batch_sizes:
        end_idx = start_idx + batch_size
        # Ensure the end index does not exceed the dataset size
        end_idx = min(end_idx, len(dataset))

        # Create a new SequentialSampler for each batch
        sampler = torch.utils.data.sampler.SequentialSampler(range(start_idx, end_idx))
        # dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
        dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
        dataloaders.append((batch_size, dataloader))  # Store both the size and the loader

        # Update the starting index for the next batch size
        start_idx = end_idx

    return dataloaders
def train(args):
    setup_logging(args.run_name)
    device = args.device
    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    logger = SummaryWriter(os.path.join("runs_RCCUDA", args.run_name))

    dataloaders = get_data(args)
    l = len(dataloaders)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        ii = 0
        
        for batch_size, dataloader in dataloaders:
            # print('len(dataloaders) = ', len(dataloaders))
            pbar = tqdm(dataloader)
            
            for i, (images, _) in enumerate(pbar):
                # print('i = ', i)
                images = images.to(device)
                t = diffusion.sample_timesteps(images.shape[0]).to(device)
                x_t, noise = diffusion.noise_images(images, t)
                predicted_noise = model(x_t, t)
                loss = mse(noise, predicted_noise)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.set_postfix(MSE=loss.item())
                logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
                
                
                immg = deepcopy(images)
                # Process and save the images
                immg = (immg.clamp(-1, 1) + 1) / 2
                immg = (immg * 255).type(torch.uint8)
                save_images(immg, os.path.join("results_RCCUDA", args.run_name, f"{ii}.jpg"))
                ii += 1
                # sdvds

        # sampled_images = diffusion.sample(model, n=images.shape[0])
        # save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
        torch.save(model.state_dict(), os.path.join("models_RCCUDA", args.run_name, f"ckpt_RCCUDA{epoch}.pt"))

Leave a Comment