Proxy SHAP: Speed Up Explainability with Simpler Models

Author:Murphy  |  View: 27267  |  Time: 2025-03-22 20:22:28

Data scientists love doing experiments, training models, and making their hands dirty with data. At the beginning of a project, enthusiasm is at the top, but when things become complicated or too time-consuming, looking for simpler solutions is a real must.

There may be situations where business stakeholders ask to make changes to the underlying solution logic or to make further adjustments/trials while trying to improve performance and maintain a good explicative level of the predictive algorithms involved. Identifying possible bottlenecks in the code implementation, which may lead to additional complexity and delays in delivering the final product, is crucial.

Imagine being a data scientist and having the task of developing a predictive model. We have all that we need easily at our disposal and after a while, we are ready to present to the business people our fancy predictive solutions built on thousands of features and millions of records that achieve astonishing performances.

The business stakeholders are fascinated by our presentation and understand the technology's potential, but they added a request. They want to know how the model takes its decisions. Nothing easier we may think…

Let's import shap in our code. Add some fancy colorful plots and come back to the business people for the glory.

When we are back to our PC, the first thing we try is to get the feature contributions (SHAP values) for each prediction sample but we notice that the SHAP computation is very slow when thousands of features and millions of records are involved. This may be a serious problem for our predictive solution being adopted in real time to make predictions and provide explanations. We urgently need to find a solution to speed up SHAP values computation avoiding lowering model performances and expensive refactoring of our excellent work.

In this post, we propose a methodology that produces reliable SHAP values according to the knowledge learned by our model and is fast enough not to make the final user wait forever.

SHAP times as a function of model complexity

Everyone can encounter the need to make SHAP fast in their Data Science journey, but what makes SHAP so slow? This is an interesting question that we investigate empirically.

Given three different gradient-boosting predictive algorithms (Xgboost, LightGBM, and CatBoost), we track the times taken to retrieve SHAP varying three parameters: the data size, the three depths, and the number of estimators/iterations involved during the training phase.

Shap time performances varying data sizes and boosting iterations [Image by the author]
Shap time performances varying data sizes and depths [Image by the author]

As we can imagine, the more data we use the more time the process lasts. The computation time is also linearly dependent on the number of estimators/iterations while being exponentially related to the depth size of the decision trees involved in the boosting.

Given the above observations, how can we make the process faster without compromising the SHAP insights obtained by our best model?

Building Proxy SHAP

Our trained model, which we assume to be a CatBoost, is at our disposal and its performances satisfy us. We spent time and energy optimizing it. Repeating the whole training/optimization procedure using fewer iterations and depth dimensions could waste time since this would most likely worsen predictive performances.

import catboost as ctb
from sklearn import model_selection

CV = model_selection.KFold(5, shuffle=False)

model = model_selection.RandomizedSearchCV(
    ctb.CatBoostRegressor(verbose=0, thread_count=-1, random_state=123), 
    {'n_estimators': stats.randint(1, 300), 
     'depth': [4, 6, 8, 10]}, 
    random_state=123, n_iter=20, refit=True,
    cv=CV, scoring='neg_mean_absolute_error'
).fit(X_train, y_train)

What about training a new lighter model optimized to emulate the SHAP contributions of our original model? The idea sounds interesting. Let's give it a try.

The dataset used to obtain the following results is the "California Housing" directly available in scikit-learn (under an open BSD license) and originally available here [Pace, R. Kelley and Ronald Barry, Sparse Spatial Autoregressions, Statistics and Probability Letters, 33 (1997) 291–297].

As a first step, we need to retrieve the SHAP values for some training samples. They will serve as a "ground truth" set for our lighter model.

class ShapCatBoostRegressor(ctb.CatBoostRegressor):
    def predict_shap(self, X):
        return self.get_feature_importance(
            ctb.Pool(X), type='ShapValues'
        )

ref_shap_val = model_selection.cross_val_predict(
    ShapCatBoostRegressor(
        **model.best_params_, 
        verbose=0, thread_count=-1, random_state=123
    ),
    X_train, y_train,
    method='predict_shap',
    cv=CV 
)

shap_feat_importance = np.abs(ref_shap_val[:,:-1]).mean(0)
shap_feat_importance /= shap_feat_importance.sum()

Secondly, we train smaller models varying the depths and the number of iterations. We are searching for a new model that reproduces the SHAP values faster and reliably. For each combination of depth and iteration, the SHAP values are retrieved and compared with the original one produced by our deeper model. The goodness of SHAP values is measured using standard error metrics like R2. In this way, we end up having an error score for each feature, measured as the difference between the original SHAP values and the approximated ones. A final and unique score is obtained as the mean of features' R2 (the same can be done sample-wise). More weight is assigned to error metrics of the features which are more important for our original model.

import datetime
import itertools

param_combi = {'iters': range(5,125,5), 'depth': range(1,8)}
for i,d in itertools.product(*param_combi.values()):

    start_time = datetime.datetime.now()
    shap_val = model_selection.cross_val_predict(
        ShapCatBoostRegressor(
            n_estimators=i, depth=d, 
            verbose=0, thread_count=-1, random_state=123
        ),
        X_train, y_train,
        method='predict_shap',
        cv=CV 
    )
    end_time = datetime.datetime.now()
    delta = (end_time - start_time).total_seconds()

    result.append({
        "time": delta,
        "iters": i,
        "depth": d,
        "r2_feat_shap": np.average(
                metrics.r2_score(
                    ref_shap_val[:,:-1], shap_val[:,:-1], 
                    multioutput='raw_values'
                ).round(3)
            weights=shap_feat_importance
        ),
        "r2_sample_shap": metrics.r2_score(
            ref_shap_val[:,:-1].sum(1), shap_val[:,:-1].sum(1)
        ),
    })

result = pd.DataFrame(result)

Inspecting the results of the search process this is what we can observe:

On the left, R2 Shap feature scores vs boosting iterations. On the right, time for Shap computation vs boosting iterations [Image by the author]
On the left, R2 Shap feature scores vs boosting depths. On the right, time for Shap computation vs boosting depths [Image by the author]

As the number of iterations or depth increases, the time taken to compute SHAP values increases and the accuracy (R2) of SHAP approximation gets higher. A result that is not surprising but confirms our initial hypothesis.

R2 Shap feature scores and times of Shap computation for all parameters combinations [Image by the author]

The model that best approximates the SHAP calculation is the one with the best trade-off between accuracy and time. More weight can be assigned to accuracy or time depending on what we would like to prioritize.

def distance(time, r2, w_time=0.1, w_r2=0.9):
    return ((time - result.time.min()) **2) *w_time + 
            ((result.r2_feat_shap.max() - r2) **2) *w_r2

result['distance'] = result.apply(
    lambda x: distance(x.time, x.r2_feat_shap), axis=1
)
result = result.sort_values('distance')

proxy_model = ctb.CatBoostRegressor(
    n_estimators=result.head(1).iters.squeeze(), 
    depth=result.head(1).depth.squeeze(),
    verbose=0, thread_count=-1, random_state=123,
).fit(X_train, y_train)
Optimal parameter combination [Image by the author]

Passing from around 3 seconds to around 3 microseconds, we register a clear benefit in terms of time taken to compute SHAP values on a test set of 10k samples.

Time taken to compute shap [Image by the author]

A final comparison, between the original model and the proxy model, shows also how well the proxy SHAP values can approximate the real values on unseen data, especially on the most important features.

Original vs Proxy Shap value comparison on unseen data [Image by the author]
Original vs Proxy Shap value comparison on unseen data [Image by the author]

Summary

In this post, we proposed a methodology to compute reliable SHAP values by using simpler and lighter models. The approach consists of training a lighter model to emulate the SHAP contributions of an original and heavy model. By optimizing parameters such as depth and iterations, we achieved a balance between speed and accuracy. This enables significantly reduced computation time while maintaining reliable SHAP value approximations, offering a practical solution for real-time predictive applications.


CHECK MY GITHUB REPO

Keep in touch: Linkedin

Tags: Data Science Explainable Ai Hands On Tutorials Machine Learning Shapley Values

Comment