Running the SparkXGBClassifier
on Databricks i3.2xlarge
(61gb memory, 8 cores)
Currently training 100+ models using a for loop like the following:
import os
os.environ["PYSPARK_PIN_THREAD"] = "true"
spark.conf.set("spark.yarn.executor.memoryOverhead", 600)
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
for i in models:
labeled_train_df = get_train_data(df, i)
if labeled_train_df:
tuned_model = spark.table("default.tuned_params").select("model_name", "xgboost_params").collect()
tuned_model_dict = {i.model_name: i.xgboost_params for i in tuned_model}
param_dict = tuned_model_dict.get(i)
if param_dict:
_n_estimators, _learning_rate, _max_depth, _gamma = param_dict['n_estimators'], param_dict['learning_rate'], param_dict['max_depth'], param_dict['gamma']
else:
_n_estimators, _learning_rate, _max_depth, _gamma = 100, 0.30, 6, 0
xgb = SparkXGBClassifier(
eval_metric="auc",
booster="gbtree",
features_col="features",
label_col="label",
tree_method='hist',
n_estimators = _n_estimators,
learning_rate = _learning_rate,
max_depth = _max_depth,
gamma = _gamma
)
train_data, test_data = labeled_train_df.randomSplit([0.7, .3])
model = xgb.fit(train_data)
model.save(f'{s3}/{i}')
del model, train_data, test_data, labeled_train_df
The problem is after the 2nd iteration, I get the following error below. Is this a OOM issue? Each iteration should only have roughly 3million rows of training data. If it is a OOM issue, why is it running iteration 1 and 2 successfully before erroring out even after manually deleting model variables? How can I fix this?
[0;31mPy4JJavaError[0m: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(96121, 0) finished unsuccessfully.
org.apache.spark.api.python.PythonException: 'TypeError: 'str' object cannot be interpreted as an integer'. Full traceback below:
Traceback (most recent call last):
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/xgboost/spark/core.py", line 836, in _train_booster
booster = worker_train(
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/xgboost/core.py", line 620, in inner_f
return func(**kwargs)
File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/xgboost/training.py", line 182, in train
for i in range(start_iteration, num_boost_round):
TypeError: 'str' object cannot be interpreted as an integer
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:687)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:118)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:640)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.hasNext(SerDeUtil.scala:88)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at org.apache.spark.api.python.SerDeUtil$AutoBatchedPickler.foreach(SerDeUtil.scala:82)
at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:464)
at org.apache.spark.api.python.PythonRunner$$anon$2.writeIteratorToStream(PythonRunner.scala:864)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:566)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2359)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:357)
at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3377)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3309)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3300)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3300)
at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:2790)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3583)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3527)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3515)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)
at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1178)
at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1166)
at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2739)
at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1070)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:445)
at org.apache.spark.rdd.RDD.collect(RDD.scala:1068)
at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:282)
at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
at sun.reflect.GeneratedMethodAccessor663.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
at py4j.Gateway.invoke(Gateway.java:306)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)
at py4j.ClientServerConnection.run(ClientServerConnection.java:115)
at java.lang.Thread.run(Thread.java:750)