Pretrained Neural Network almost stops improving after a few epochs of training

I have a dataset of almost 200,000 medical images for multi-class classification. I’m trying to use pretrained neural networks to classify these images (Mobilenet, Resnet50, VGG16, etc.) However, no matter what I do, my model’s training and validation accuracy stop improving after the first 2 or 3 epochs.

Here’s what I’ve tried:

  • Data Augmentation
  • Getting more data
  • Freezing and unfreezing early layers
  • Removing layers.
  • Balancing the training process by using class weights.
  • Changing batch size
  • I’ve also tried multiple different optimizers with momentum, decaying learning rate, etc.

No matter what I do, the model accuracy jumps to ~60% in the first epoch, about 70% the next epoch, and from there, the model accuracy only increases less than 1% every epoch.

To make matters worse, the validation accuracy is almost always highest at the 1st or 2nd epoch and doesn’t get better from there. It also jumps up and down like crazy too, from 70% validation accuracy at one epoch down to 30% the following one. I’m guessing there’s some overfitting going on, but adding Dropout layers and other solutions don’t seem to make a difference. Here’s an example:

IMG_SIZE = (224, 224)
batch_size = 32
num_classes = 4

train_datagen = ImageDataGenerator(
    rotation_range=90,
    horizontal_flip=True,
    vertical_flip=True,
    rescale=1./255,
)

validation_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    "data/data/train/train",
    target_size=IMG_SIZE,
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=True
)

validation_generator = validation_datagen.flow_from_directory(
    "data/data/Validation/Validation",
    target_size=IMG_SIZE,
    batch_size=batch_size,
    class_mode="categorical",
    shuffle=False
)

# ResNet50 Implementation
resnet_base = ResNet50(input_shape=(224, 224, 3), include_top=False, weights="imagenet")
resnet_base.summary()

for layer in resnet_base.layers[:30]:
    layer.trainable = False

for layer in resnet_base.layers[30:]:
    layer.trainable = True

resnet_model = Sequential()
resnet_model.add(resnet_base)
resnet_model.add(GlobalAveragePooling2D())
resnet_model.add(Dense(4, activation="softmax"))

resnet_model.summary()

#... 

def lr_schedule(epoch, lr):
    return lr * 0.9

learningSchedule = keras.callbacks.LearningRateScheduler(lr_schedule)

opt = keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, )
resnet_model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=['accuracy'])

resnet_model.fit(train_generator, epochs=40, validation_data=validation_generator, class_weight=class_weights, callbacks=[checkpoint, earlyStopping, learningSchedule])

Here’s the training process:

Epoch 1/40
2993/2993 [==============================] - ETA: 0s - loss: 1.2157 - accuracy: 0.4815
Epoch 1: val_accuracy improved from -inf to 0.41756, saving model to Models/ResNet50V2\model_weights_epoch_01_val_acc_0.4176.keras
2993/2993 [==============================] - 1486s 495ms/step - loss: 1.2157 - accuracy: 0.4815 - val_loss: 1.7254 - val_accuracy: 0.4176 - lr: 0.0090
Epoch 2/40
2993/2993 [==============================] - ETA: 0s - loss: 1.0026 - accuracy: 0.5837
Epoch 2: val_accuracy improved from 0.41756 to 0.55022, saving model to Models/ResNet50V2\model_weights_epoch_02_val_acc_0.5502.keras
2993/2993 [==============================] - 2481s 829ms/step - loss: 1.0026 - accuracy: 0.5837 - val_loss: 1.0853 - val_accuracy: 0.5502 - lr: 0.0081
Epoch 3/40
2993/2993 [==============================] - ETA: 0s - loss: 0.9049 - accuracy: 0.6266
Epoch 3: val_accuracy improved from 0.55022 to 0.65756, saving model to Models/ResNet50V2\model_weights_epoch_03_val_acc_0.6576.keras
2993/2993 [==============================] - 3607s 1s/step - loss: 0.9049 - accuracy: 0.6266 - val_loss: 0.9128 - val_accuracy: 0.6576 - lr: 0.0073
Epoch 4/40
2993/2993 [==============================] - ETA: 0s - loss: 0.8458 - accuracy: 0.6491
Epoch 4: val_accuracy did not improve from 0.65756
2993/2993 [==============================] - 1033s 345ms/step - loss: 0.8458 - accuracy: 0.6491 - val_loss: 1.1537 - val_accuracy: 0.5336 - lr: 0.0066
Epoch 5/40
2993/2993 [==============================] - ETA: 0s - loss: 0.7964 - accuracy: 0.6679
Epoch 5: val_accuracy did not improve from 0.65756
2993/2993 [==============================] - 544s 182ms/step - loss: 0.7964 - accuracy: 0.6679 - val_loss: 1.3786 - val_accuracy: 0.4927 - lr: 0.0059
...

Epoch 34: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 630s 209ms/step - loss: 0.3616 - accuracy: 0.8150 - val_loss: 0.9580 - val_accuracy: 0.7087 - lr: 2.7813e-04
Epoch 35/40
2993/2993 [==============================] - ETA: 0s - loss: 0.3535 - accuracy: 0.8179
Epoch 35: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 4022s 1s/step - loss: 0.3535 - accuracy: 0.8179 - val_loss: 0.9790 - val_accuracy: 0.7129 - lr: 2.5032e-04
Epoch 36/40
2993/2993 [==============================] - ETA: 0s - loss: 0.3468 - accuracy: 0.8199
Epoch 36: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 1434s 479ms/step - loss: 0.3468 - accuracy: 0.8199 - val_loss: 0.9922 - val_accuracy: 0.7129 - lr: 2.2528e-04
Epoch 37/40
2993/2993 [==============================] - ETA: 0s - loss: 0.3427 - accuracy: 0.8238
Epoch 37: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 535s 179ms/step - loss: 0.3427 - accuracy: 0.8238 - val_loss: 0.9888 - val_accuracy: 0.7160 - lr: 2.0276e-04
Epoch 38/40
2993/2993 [==============================] - ETA: 0s - loss: 0.3349 - accuracy: 0.8239
Epoch 38: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 543s 181ms/step - loss: 0.3349 - accuracy: 0.8239 - val_loss: 1.0493 - val_accuracy: 0.6920 - lr: 1.8248e-04
Epoch 39/40
2993/2993 [==============================] - ETA: 0s - loss: 0.3348 - accuracy: 0.8261
Epoch 39: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 527s 176ms/step - loss: 0.3348 - accuracy: 0.8261 - val_loss: 1.0890 - val_accuracy: 0.6998 - lr: 1.6423e-04
Epoch 40/40
2993/2993 [==============================] - ETA: 0s - loss: 0.3290 - accuracy: 0.8280
Epoch 40: val_accuracy did not improve from 0.74178
2993/2993 [==============================] - 3417s 1s/step - loss: 0.3290 - accuracy: 0.8280 - val_loss: 1.0317 - val_accuracy: 0.7064 - lr: 1.4781e-04

I’m trying to get validation and training accuracy to at least 90% but I don’t know what else to do.

If there are any more questions or more info is needed, please let me know.

Leave a Comment