I am training a Variational Autoencoder with custom losses, which has three components
- total_loss
- reconstruction_loss
- kl_loss
The code is too big, I hope the snippet articulates my problem (The running code is here in colab). With these two functions I am able to train my model and at the end of every epoch I am able to get the values for all three losses.
def train_step(self, data):
if isinstance(data, tuple):
data = data[0]
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(tf.square(data - reconstruction),
axis = [1,2,3])
reconstruction_loss *= self.r_loss_factor
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_sum(kl_loss, axis = 1)
kl_loss *= -0.5
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return {
"loss": total_loss,
"reconstruction_loss": reconstruction_loss,
"kl_loss": kl_loss,
}
def test_step(self, input_data):
validation_data = input_data[0] # <-- Seperate X and y
z_mean, z_log_var, z = self.encoder(validation_data)
val_reconstruction = self.decoder(z)
val_reconstruction_loss = tf.reduce_mean(tf.square(validation_data - val_reconstruction),
axis = [1,2,3])
val_reconstruction_loss *= self.r_loss_factor
val_kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
val_kl_loss = tf.reduce_sum(val_kl_loss, axis = 1)
val_kl_loss *= -0.5
val_total_loss = val_reconstruction_loss + val_kl_loss
return {
"loss": val_total_loss,
"reconstruction_loss": val_reconstruction_loss,
"kl_loss": val_kl_loss,
}
However, if I insert a earlystopper in the call backs like this:
early_stopper = EarlyStopping(monitor="val_loss", min_delta=0.001, patience=10)
I get an error at epoch end
ValueError: The truth value of an array with more than one element is
ambiguous. Use a.any() or a.all()
My hunch is , the earlystopper is getting 3 values to judge from. How can I select only the total val loss for judging early stopping condition
Problem
The EarlyStopping callback is expecting a single value for monitoring, but in your test_step function, you are returning a dictionary with three values.
Solution
I suggest to call the stopper on the metric you want to monitor. Your expanded code (Notice I assume thsat your data is already batched):
early_stopper = EarlyStopping(monitor="val_total_loss", min_delta=0.001, patience=10)
num_epochs=1
# Your training loop
for epoch in range(num_epochs):
# Training step
for batch in train_data:
train_results = model.train_step(batch)
# Validation step
val_results = []
for batch in val_data:
val_results.append(model.test_step(batch))
# Average validation losses
avg_val_total_loss = tf.reduce_mean([result["val_total_loss"] for result in val_results])
avg_val_reconstruction_loss = tf.reduce_mean([result["val_reconstruction_loss"] for result in val_results])
avg_val_kl_loss = tf.reduce_mean([result["val_kl_loss"] for result in val_results])
# Use the total validation loss for early stopping
should_stop = early_stopper.on_epoch_end(epoch, avg_val_total_loss)
# Check if early stopping criteria are met
if should_stop:
break
The issue is that your losses are arrays, you need to average over the batch axis to get scalars.