Python:
I am trying to load my data (images) sequentially. But each time, the next load batch does not start after the previous one. For instance, after loading images 1 to 9 for a batch size of 9, the next loading does not start from 10. It incorrectly starts to sample from 1 again. For example, for a list of batch sizes like [9, 10, 14], the current code samples from 1 to 9, then from 1 to 10, and then from 1 to 14. This is incorrect.
For a correct sampling, the sampling should start from 1 to 9, then continue to sample from 10 to 20, then sample from 21 to 35.
In the loading function (get_data), the indexes are correct. They start from zero and continue to the end based on the batch sizes.
In the training function (train), I added some codes to save the images to make sure that the model is being trained based on the correct sequence of images. However, it shows the images from 1 to 9, then from 1 to 10, then from 1 to 14.
These are the two functions I use to load and train:
def get_data(args):
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(80),
torchvision.transforms.RandomResizedCrop(args.image_size, scale=(1.0, 1.0)),
torchvision.transforms.Grayscale(num_output_channels=1),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,))
])
dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
# If batch_size is a single value, convert it to a list
batch_sizes = args.batch_size if isinstance(args.batch_size, list) else [args.batch_size]
dataloaders = []
start_idx = 0 # Keep track of the starting index for each batch size
for batch_size in batch_sizes:
end_idx = start_idx + batch_size
# Ensure the end index does not exceed the dataset size
end_idx = min(end_idx, len(dataset))
# Create a new SequentialSampler for each batch
sampler = torch.utils.data.sampler.SequentialSampler(range(start_idx, end_idx))
# dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
dataloaders.append((batch_size, dataloader)) # Store both the size and the loader
# Update the starting index for the next batch size
start_idx = end_idx
return dataloaders
def train(args):
setup_logging(args.run_name)
device = args.device
model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=args.image_size, device=device)
logger = SummaryWriter(os.path.join("runs_RCCUDA", args.run_name))
dataloaders = get_data(args)
l = len(dataloaders)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
ii = 0
for batch_size, dataloader in dataloaders:
# print('len(dataloaders) = ', len(dataloaders))
pbar = tqdm(dataloader)
for i, (images, _) in enumerate(pbar):
# print('i = ', i)
images = images.to(device)
t = diffusion.sample_timesteps(images.shape[0]).to(device)
x_t, noise = diffusion.noise_images(images, t)
predicted_noise = model(x_t, t)
loss = mse(noise, predicted_noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(MSE=loss.item())
logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
immg = deepcopy(images)
# Process and save the images
immg = (immg.clamp(-1, 1) + 1) / 2
immg = (immg * 255).type(torch.uint8)
save_images(immg, os.path.join("results_RCCUDA", args.run_name, f"{ii}.jpg"))
ii += 1
# sdvds
# sampled_images = diffusion.sample(model, n=images.shape[0])
# save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
torch.save(model.state_dict(), os.path.join("models_RCCUDA", args.run_name, f"ckpt_RCCUDA{epoch}.pt"))