Skip to content

Integrate with PySpark¶

Comet integrates with Apache PySpark.

PySpark is an open-source unified analytics engine for large-scale data processing. Spark provides an interface for programming clusters with implicit data parallelism and fault tolerance.

When integrated with Spark, Comet tracks machine learning training runs.

End-to-end example¶

from comet_ml import start, login

from pyspark.ml.linalg import Vectors
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

login()

def run_logistic_regression(training_data, test_data):
    experiment = start(project_name="comet-example-pyspark-doc")

    # models
    lr = LogisticRegression(
        maxIter=10,
        regParam=0.3,
        elasticNetParam=0.8)

    model = lr.fit(training_data)
    training_summary = model.summary

    predictions = model.transform(test_data)
    evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")

    metrics = {
        'train_auc_score': training_summary.areaUnderROC,
        'train_accuracy': training_summary.accuracy,
        'test_auc_roc_score': evaluator.evaluate(predictions),
        'test_auc_pr_score': evaluator.evaluate(
            predictions, {evaluator.metricName: "areaUnderPR"})
    }

    experiment.log_parameters(lr._input_kwargs) #logging hyperparams to Comet
    experiment.log_metrics(metrics) #logging metric to Comet

def main():
    df = sqlContext.read.format('com.databricks.spark.csv').options(
        header='true', inferschema='true').load('./data/breast_cancer.csv')

    # Spliting in train and test set. Beware : It sorts the dataset
    (train_df, test_df) = df.randomSplit([0.7, 0.3])
    training_data = train_df.rdd.map(lambda x: (
        Vectors.dense(x[0:-1]), x[-1])).toDF(["features", "label"])
    test_data = test_df.rdd.map(lambda x: (
        Vectors.dense(x[0:-1]), x[-1])).toDF(["features", "label"])

    run_logistic_regression(training_data, test_data)

if __name__ == '__main__':
    main()
Dec. 2, 2024