How to pick one val loss value for Earlystopping , when multiple losses are returned

I am training a Variational Autoencoder with custom losses, which has three components

  1. total_loss
  2. reconstruction_loss
  3. kl_loss

The code is too big, I hope the snippet articulates my problem (The running code is here in colab). With these two functions I am able to train my model and at the end of every epoch I am able to get the values for all three losses.

def train_step(self, data):
    if isinstance(data, tuple):
        data = data[0]
    with tf.GradientTape() as tape:
        z_mean, z_log_var, z = self.encoder(data)
        reconstruction = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(tf.square(data - reconstruction),
                                             axis = [1,2,3])
        reconstruction_loss *= self.r_loss_factor
        kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        kl_loss = tf.reduce_sum(kl_loss, axis = 1)
        kl_loss *= -0.5
        total_loss = reconstruction_loss + kl_loss
    grads = tape.gradient(total_loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    return {
        "loss": total_loss,
        "reconstruction_loss": reconstruction_loss,
        "kl_loss": kl_loss,
    }

def test_step(self, input_data):
    validation_data = input_data[0] # <-- Seperate X and y
    z_mean, z_log_var, z = self.encoder(validation_data)
    val_reconstruction = self.decoder(z)
    val_reconstruction_loss = tf.reduce_mean(tf.square(validation_data - val_reconstruction),
                                 axis = [1,2,3])
    val_reconstruction_loss *= self.r_loss_factor

    val_kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
    val_kl_loss = tf.reduce_sum(val_kl_loss, axis = 1)
    val_kl_loss *= -0.5
    val_total_loss = val_reconstruction_loss + val_kl_loss
    return {
        "loss": val_total_loss,
        "reconstruction_loss": val_reconstruction_loss,
        "kl_loss": val_kl_loss,
    }

However, if I insert a earlystopper in the call backs like this:

early_stopper = EarlyStopping(monitor="val_loss", min_delta=0.001, patience=10)

I get an error at epoch end

ValueError: The truth value of an array with more than one element is
ambiguous. Use a.any() or a.all()

My hunch is , the earlystopper is getting 3 values to judge from. How can I select only the total val loss for judging early stopping condition

  • The issue is that your losses are arrays, you need to average over the batch axis to get scalars.

    – 

Problem

The EarlyStopping callback is expecting a single value for monitoring, but in your test_step function, you are returning a dictionary with three values.

Solution

I suggest to call the stopper on the metric you want to monitor. Your expanded code (Notice I assume thsat your data is already batched):

early_stopper = EarlyStopping(monitor="val_total_loss", min_delta=0.001, patience=10)
num_epochs=1

# Your training loop
for epoch in range(num_epochs):
    # Training step
    for batch in train_data:
        train_results = model.train_step(batch)
    
    # Validation step
    val_results = []
    for batch in val_data:
        val_results.append(model.test_step(batch))
    
    # Average validation losses
    avg_val_total_loss = tf.reduce_mean([result["val_total_loss"] for result in val_results])
    avg_val_reconstruction_loss = tf.reduce_mean([result["val_reconstruction_loss"] for result in val_results])
    avg_val_kl_loss = tf.reduce_mean([result["val_kl_loss"] for result in val_results])
    
    # Use the total validation loss for early stopping
    should_stop = early_stopper.on_epoch_end(epoch, avg_val_total_loss)
    
    # Check if early stopping criteria are met
    if should_stop:
        break

Leave a Comment