Distributed hyperparameter tuning of Scikit-learn models in Spark



Hyperparameter tuning of machine learning models often requires significant computing time. Scikit-learn implements parallel processing to speed things up, but real speed gain can only be achieved by applying distributed computing like using Spark. In this blog post I show how to do hyperparameter tuning in Spark for any machine learning model, independent whether it’s Scikit-learn, Tensorflow/Keras, XGBoost, LightGBM etc.

Create combinations of hyperparameter values

First, we’re going to create hyperparameter combinations that we want to test our model for. Below is a helper functions to create all combinations for a param_grid that contains the arguments and the values to test for.

import numpy as np
from itertools import product

def create_hyperparameter_combinations(param_grid):
    combinations = list(product(*param_grid.values()))
    return [dict(zip(param_grid.keys(), x)) for x in combinations]

So if our desired hyperparameter space is as follows…

param_grid = {'max_features': ['auto', 0.1, 0.3], 'min_samples_leaf': [None, 50]}

…then the combination of these two hyperparameters and their values is obtained by:

[{'max_features': 'auto', 'min_samples_leaf': None},
 {'max_features': 'auto', 'min_samples_leaf': 50},
 {'max_features': 0.1, 'min_samples_leaf': None},
 {'max_features': 0.1, 'min_samples_leaf': 50},
 {'max_features': 0.3, 'min_samples_leaf': None},
 {'max_features': 0.3, 'min_samples_leaf': 50}]

Create evaluation function

Now, we’re going to create a function that takes a single combination of hyperparameter values and returns performance metrics of the model with these hyperparameters for train and test data. This function will be evaluated in distributed fashion on our Spark cluster. If you want to test 500 different hyperparameter combinations, you will see 500 tasks being executed by Spark. In this example, we’re going to optimise hyperparameters for a Scikit-learn model. This requires Scikit-learn to be installed on the worker nodes of your Spark cluster.

The hyperparameter values for a particular combination are provided to the function as a json, so a string type. The advantage of this is that we don’t need to change the function definition if we want to add hyperparameters. Moreover, we are independent of the type of hyperparameter values. Many hyperparameters in Scikit-learn take different types like integers, floats and strings like for example max_features in the RandomForestClassifier. When we provide them as a string containing a json, Spark never complains.

The function definition is as follows, explanation continues below.

import json

from sklearn.model_selection import cross_validate

def evaluate_clf(base_clf, hyperpars_json, X, y, cv=5):
    hyperpars = json.loads(hyperpars_json)
    cv_results = cross_validate(base_clf, X, y, cv=cv, return_train_score=True)
    return (hyperpars_json,

In the first line, we convert the single combination of hyperparameter values to a Python dict by using json.loads. The next line sets these parameters in our Scikit-learn model called base_clf. The Scikit-learn function cross_validate takes the model, our training data features X and target y to produce train and test scores using cross validation. This function returns the results in a dict of which we take the train_score and test_score values that we return along with the hyperpars_json that we entered. Since Spark sometimes has difficulties with the np.float64 type, we convert the scores to the float type.

Distribute the evaluation function in Spark

The next step is to distribute the hyperparameter combinations and use our evaluation function to calculate model performance metrics for these hyperparameter values. This is done in the function below.

from pyspark.sql import types as T

def get_hyperparameter_results(spark_session, base_clf, X, y, 
                               hyperpar_combinations, cv=5):
    hyperpar_combinations_json = [json.dumps(x) for x in hyperpar_combinations]
    hyperpars_rdd = spark_session.sparkContext.parallelize(hyperpar_combinations_json, 

    rdd_map_result = hyperpars_rdd.map(lambda x: evaluate_clf(base_clf, x, X, y, cv))

    result_schema = T.StructType([T.StructField('hyperpars', T.StringType()),
                                  T.StructField('mean_train_score', T.FloatType()),
                                  T.StructField('mean_test_score', T.FloatType()),

    result_sdf = spark_session.createDataFrame(rdd_map_result, schema=result_schema)
    result = (result_sdf.toPandas()
              .sort_values('mean_test_score', ascending=False)

    result['hyperpars'] = result['hyperpars'].apply(json.loads)
    return result

In the first line we convert the dicts to json strings. Then, we parallelize these jsons by creating an RDD. The evaluate_clf can is mapped to this RDD. The schema of the result is defined as a StructType containing StructFields. Then a Spark dataframe is created that hold the results of our hyperparameter tuning. We convert this Spark dataframe to a Pandas dataframe in order to explore the results easily. As a last step we convert the jsons back to dicts.

Full working example

In the example below we do hyperparameter tuning of a DecisionTreeClassifier to predict classes for the famous iris dataset.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

X, y = load_iris(return_X_y=True)

param_grid = {'max_depth':[3,4,5,10], 'min_samples_leaf':[0.1, 5, 10]}

hyperpar_combinations = create_hyperparameter_combinations(param_grid)

results = get_hyperparameter_results(spark, DecisionTreeClassifier(), X, y, 
                                     hyperpar_combinations, cv=5)
hyperpars mean_train_score mean_test_score
{'max_depth': 3, 'min_samples_leaf': 5} 0.968333 0.940000
{'max_depth': 4, 'min_samples_leaf': 5} 0.968333 0.940000
{'max_depth': 5, 'min_samples_leaf': 5} 0.968333 0.940000
{'max_depth': 10, 'min_samples_leaf': 5} 0.968333 0.940000
{'max_depth': 3, 'min_samples_leaf': 0.1} 0.961667 0.933333
{'max_depth': 3, 'min_samples_leaf': 10} 0.961667 0.933333
{'max_depth': 4, 'min_samples_leaf': 0.1} 0.961667 0.933333
{'max_depth': 4, 'min_samples_leaf': 10} 0.961667 0.933333
{'max_depth': 5, 'min_samples_leaf': 0.1} 0.961667 0.933333
{'max_depth': 5, 'min_samples_leaf': 10} 0.961667 0.933333
{'max_depth': 10, 'min_samples_leaf': 0.1} 0.961667 0.933333
{'max_depth': 10, 'min_samples_leaf': 10} 0.961667 0.933333