XGBOOST: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage

Running the SparkXGBClassifier on Databricks i3.2xlarge (61gb memory, 8 cores)

Currently training 100+ models using a for loop like the following:

import os
os.environ["PYSPARK_PIN_THREAD"] = "true"

spark.conf.set("spark.yarn.executor.memoryOverhead", 600)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

for i in models:
    labeled_train_df = get_train_data(df, i)

    if labeled_train_df:
        tuned_model = spark.table("default.tuned_params").select("model_name", "xgboost_params").collect()
        tuned_model_dict = {i.model_name: i.xgboost_params for i in tuned_model}

        param_dict = tuned_model_dict.get(i)
        if param_dict:
            _n_estimators, _learning_rate, _max_depth, _gamma = param_dict['n_estimators'], param_dict['learning_rate'], param_dict['max_depth'], param_dict['gamma']
        else:
            _n_estimators, _learning_rate, _max_depth, _gamma = 100, 0.30, 6, 0

        xgb = SparkXGBClassifier(
            eval_metric="auc", 
            booster="gbtree",
            features_col="features", 
            label_col="label",
            tree_method='hist', 
            n_estimators = _n_estimators,
            learning_rate = _learning_rate,
            max_depth = _max_depth,
            gamma = _gamma        
        )

        train_data, test_data = labeled_train_df.randomSplit([0.7, .3])

        model = xgb.fit(train_data)
        model.save(f'{s3}/{i}')

        del model, train_data, test_data, labeled_train_df

The problem is after the 2nd iteration, I get the following error below. Is this a OOM issue? Each iteration should only have roughly 3million rows of training data. If it is a OOM issue, why is it running iteration 1 and 2 successfully before erroring out even after manually deleting model variables? How can I fix this?

[0;31mPy4JJavaError[0m: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(96121, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: 'TypeError: 'str' object cannot be interpreted as an integer'. Full traceback below:
Traceback (most recent call last):
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/xgboost/spark/core.py", line 836, in _train_booster
    booster = worker_train(
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/xgboost/core.py", line 620, in inner_f
    return func(**kwargs)
  File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/xgboost/training.py", line 182, in train
    for i in range(start_iteration, num_boost_round):
TypeError: 'str' object cannot be interpreted as an integer

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:687)
    at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:118)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:640)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
    at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.hasNext(SerDeUtil.scala:88)
    at scala.collection.Iterator.foreach(Iterator.scala:943)
    at scala.collection.Iterator.foreach$(Iterator.scala:943)
    at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:82)
    at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:464)
    at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:864)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:566)
    at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2359)
    at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:357)

    at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3377)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3309)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3300)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3300)
    at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:2790)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3583)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3527)
    at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3515)
    at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)
    at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1178)
    at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
    at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
    at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1166)
    at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2739)
    at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1070)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:445)
    at org.apache.spark.rdd.RDD.collect(RDD.scala:1068)
    at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:282)
    at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
    at sun.reflect.GeneratedMethodAccessor663.invoke(Unknown Source)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
    at py4j.Gateway.invoke(Gateway.java:306)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)
    at py4j.ClientServerConnection.run(ClientServerConnection.java:115)
    at java.lang.Thread.run(Thread.java:750)

Leave a Comment