TensorFlow Decision Forests: A Comprehensive Introduction

Author:Murphy  |  View: 29572  |  Time: 2025-03-23 18:55:31
Photo by Javier Allegue Barros on Unsplash

Introduction

Two years ago, TensorFlow (TF) team has open-sourced a library to train tree-based models called TensorFlow Decision Forests (TFDF). Just last month they've finally announced that the package is production ready, so I've decided that it's time to take a closer look. The aim of this post is to give you a better idea about the package and show you how to (effectively) use it. Below you can see the structure of this post, feel free to skip to any part that interests you the most.

  1. What is TFDF and why use it?
  2. Train Random Forest (RF) and Gradient Boosted Trees (GBT) models using TFDF
  3. Hyper-parameter tuning with TFDF and Optuna
  4. Model inspection
  5. Serving GBT model using TF Serving

Setup

You can find all the code in my repo, so make sure to star it if you haven't already. In this post we'll be training a few models for loan default prediction using the U.S. Small Business Administration dataset (CC BY-SA 4.0 license) dataset. Models will be trained using already pre-processed data but you can find a notebook in the repo that describes the processing and feature engineering steps. Make sure to follow them if you want to directly replicate my code here. Alternatively, use this code as a starting point and adapt it to your dataset (my recommended approach).

Installing TensorFlow Decision Forests is quite straightforward, just run pip install tensorflow_decision_forests and most of the time this should work. There are some issues reported with M1 and M2 Macs but it worked fine for me personally with the latest version of TFDF.

TensorFlow Decision Forest

What is TFDF?

TensorFlow Decision Forest is actually built on top of the C++ library called Yggdrasil Decision Forests which also developed by Google. The original C++ algorithms are designed to build scalable Decision Tree models that can handle large datasets and high-dimensional feature spaces. By integrating this library into the wider TF ecosystem, users are able now to easily build scalable RF and GBT models without having to learn another language.

Why use it?

The main advantage of this library over e.g. XGBoost or LightGBM is its tight integration with the other TF ecosystem components. It might be particularly interesting for teams who already have other TensorFlow models as part of their pipeline or use TFX. TFDF can be quite easily integrated with e.g. NLP models making multi-modal pipelines easier. Also, if you are serving models using TF Serving you might also want to consider this library due to its native support (no need for ONNX or other cross-package serialisation methods). Finally, this library gives you truly a ton parameters that you can adjust to approximate models from XGBoost, LightGBM, and many other Gradient Boosted Machine (GBM) methods. This means that you don't need to switch between different GBM libraries in the training process which can be quite nice from code maintainability perspective.

Model Training

Make sure to pull this notebook and follow along as below you can only see parts of the code.

Data

As said in the setup section, I'm going to be using a pre-processed version of this dataset. To prepare it for TFDF, we first need to read it in as usual with pandas and decide which columns we're going to treat as categorical and which are going to be numerical.

<script src="https://gist.github.com/aruberts/58932428737a52192021bebeeae389af.js"></script>

Feature Usage

To ensure a well-structured project and avoid unexpected behaviour, it is considered a good practice to specify a FeatureUsage for each feature, although it's not mandatory. Fortunately, it's an easy task: you simply need to decide which feature types to assign to each one from the six supported types – BOOLEAN, CATEGORICAL, CATEGORICAL_SET, DISCRETIZED_NUMERICAL, HASH, and NUMERICAL. Some of these types come with additional parameters, so make sure to read more about them here.

While we'll keep things simple in this example and stick to only numerical and categorical data types, don't hesitate to experiment with the other options, especially DISCRETIZED_NUMERICAL, as they can significantly speed up your training process (similar to LightGBM). As you can see below, you need to provide the chosen data type to semantic parameter and for the categorical features we also want to specify the min_vocab_frequency parameter to get rid of rare values.

<script src="https://gist.github.com/aruberts/29598baff60bfaea56108cfbef5c27f0.js"></script>

Reading Data Using TF Dataset

To simplest way to read in the dataset is by using TF Dataset. TFDF has a very nice utility function calledpd_dataframe_to_tf_dataset which makes this step a piece of cake.

<script src="https://gist.github.com/aruberts/cfb95f75f4ebfcadba4aad3e8a61314e.js"></script>

In the code above we pass our DataFrame objects into the function and provide the following parameters:

  • Name of the label column
  • Name of the weight column (None in this case)
  • Batch size (helps to speed up reading of the data)

The resulting datasets are in the correct format of TF Dataset (batched and pre-fetched) and are ready to be used for training/evaluation. You can of course create your own method for reading in the datasets but you must pay a special attention to the outputted format.

TFDF Default Parameters

Training the models is quite straight-forward if you've followed all the previous data preparation instructions.

<script src="https://gist.github.com/aruberts/5499d158e6e3db22aaf593a807ece915.js"></script>

As you can see from the code above, it takes just a few lines to build and train GBT and RF models with default paramaters. All you need to specify is the features used, training and validation datasets, and you're good to go. When evaluating both of these models using ROC and PR AUCs we can see that the performance is already quite good.

# GBT with Default Parameters
PR AUC: 0.8367
ROC AUC: 0.9583

# RF with Default Parameters
PR AUC: 0.8102
ROC AUC: 0.9453

Let's see if these results can be further improved using hyper-parameter tuning. For simplicity, I'm going to focus solely on optimising the GBT model but everything can be as easily applied to the RF models as well.

Hyper-parameter Tuning

There are a ton of parameters to tune, very good explanation of every one of the can be found in the official Yggdrasil documentation. TFDF gives you a few in-built options to tune parameters but you can also use more standard libraries like Optuna or Hyperpot. Here's a list of the approaches ordered from the least involved to the most involved approaches.

  1. Hyper-parameter templates
  2. Hyper-parameter search using pre-defined space
  3. Hyper-parameter search using custom space

Hyper-parameter Templates

Very cool feature that TFDF provides is the availability of hyper-parameter templates. These are the parameters that in the paper were shown to perform the best across a wide range of the datasets. Two available templates are – better_default and benchmark_rank1 . If you're short on time or are not familiar with Machine Learning that well, this might be a good option for you. Specifying these parameters is literally just 1 line of code.

<script src="https://gist.github.com/aruberts/683c2a18e32c0204b83f9f061e2d8fea.js"></script>

Looking at the results we can see that with better_default parameters we were able to get a slight uplift in both ROC and PR AUCs. benchmark_rank1 parameter, on the other hand, perform much worse. This is why it's important to properly evaluate the resulting models before deploying them.

GBT with 'Better Default' Parameters
PR AUC: 0.8483
ROC AUC: 0.9593

GBT with 'Benchmark Rank 1' Parameters
PR AUC: 0.7869
ROC AUC: 0.9442

Pre-defined Search Space

TFDF comes with a nice utility called RandomSearch which performs randomised grid search (similar to sklearn) across many of the available parameters. There‘s an option to specify these parameters manually (see example here) but it's also possible to use a pre-defined search space. Again, if you're not that familiar with ML this might be a good option for you because it removes the need to set these parameters manually.

WARNING: this search took me ages, so I had to stop it after 12 iterations. Some parameters that get tested (e.g. oblique splits) take a long time to fit.

<script src="https://gist.github.com/aruberts/14e79c16210186dbe867452d1807c83d.js"></script>

You can access all of the tried combinations using the following command.

tuning_logs = tuned_model.make_inspector().tuning_logs()
Hyper-parameter table. Screenshot by author.

After 12 iterations, the best model performed a bit worse than the baseline, so use this tuning method cautiously. You can try to alter the search space, but at this point you might as well use another library.

PR AUC: 0.8216
ROC AUC: 0.9418

Custom Search Space (with custom loss)

There are a few notable disadvantages to using RandomSearch approach:

  • Only randomised grid search algorithm is available
  • No option to define your own loss to optimise
  • Full parameter grid needs to be provided if you don't use use_predefined_hps flag

Because of these reasons I highly recommend using external optimisation libraries if you have enough knowledge to set a sensible search space yourself. Below you can see how to do the tuning using optuna .

<script src="https://gist.github.com/aruberts/986f6a537faecdc9cdc4888dbe6ed257.js"></script>

Most of these parameters are quite standard for GBTs, but there are a few note-worthy parameters. First, we can change the growing_strategy to BEST_FIRST_GLOBAL (a.k.a leaf-wise growth) which is the strategy used by LightGBM. Second, we can use BINARY_FOCAL_LOSS which is supposed to perform better with the imbalanced datasets (source). Third, there's an option to change the split_axis parameter to use sparse oblique splits which was shown to be quite effective in this paper. Finally, there's also an opportunity to build "honest trees" using the honest parameter.

Below you can see the results achieved with the best parameters. As you can see, tuning with custom search space has yielded the best results so far.

GBT with Custom Tuned Parameters
PR AUC: 0.8666
ROC AUC: 0.9631

Now that we've settled on the hyper-parameters, let's re-train the model and proceed with its inspection.

Model Inspection

TFDF offers a nice utility object to inspect the trained model called Inspector . There are 3 main uses for this object that I'll explore below:

  1. Inspecting the model's attributes like type, number of trees or features used
  2. Obtaining feature importances
  3. Extracting trees structures

Inspect Model Attributes

The inspector class stores various attributes that you might want to explore if, for example, you've loaded somebody else's model or you haven't used it in a while. You can print out the model type (GBT or RF), the number of trees your model has, the training objective, and the features that were used to train the model. Inspecting number of trees is especially useful since if the early-stopping has kicked in, this parameter is going to be smaller than what you've set it up to be.

<script src="https://gist.github.com/aruberts/b45a30605c45d0e957702c3687800d45.js"></script>

Another option is to simply run manual_tuned.summary() to examine the model in more detail.

Feature Importances

Just like all the other libraries, TFDF comes with in-built feature importance scores. For GBTs, you get access to NUM_NODES , SUM_SCORE , INV_MEAN_MIN_DEPTH , NUM_AS_ROOT methods of explaining. Note that you can also set compute_permutation_variable_importance parameter to True during training which will add a few additional methods. The downside of this is that the model training will take significantly longer, so use it with care (perhaps on a sample of data).

<script src="https://gist.github.com/aruberts/8dc837809cae00ee223b0d2b4f3c4e23.js"></script>
Imprtances bar plot. Screenshot by author.

For the model that I've built, Term variable has consistently come up as the most important feature with categorical variables like Bank, State, and Bank State following it. I'd say that one of the largest disadvantages of TFDF library is the inability to use SHAP with it. Hopefully, the support will come in the future versions.

Inspect Individual Trees

Sometimes we want to take a look at the individual trees for the sake of explainability or model validation. TFDF gives an easy access to all the trees constructed during training in the inspector object. For now, let's inspect the first tree of our GBT model since usually it's the most informative one.

<script src="https://gist.github.com/aruberts/a429255ceb7ebd76900172501136c909.js"></script>
Tree structure. Screenshot by author.

As you can see, when we're dealing with large trees it may not be that convenient to inspect them using the print out statement. That's why TFDF also has a tree plotting utility – tfdf.model_plotter.plot_model .

<script src="https://gist.github.com/aruberts/5238565bc2e5541ccdc024bfba3e7c52.js"></script>
First GBT tree (depth=4). Screenshot by author.

Also note that for Random Forest models, you can use dtreeviz package which gives you more visually appealing results (here's how to use it). GBT models with this package are not yet supported.

TF Serving

So far, we've trained, tuned and evaluated the model. What else is there? Serving the model of course! Lucky for us, TFDF is natively supported by TF Serving (since the latest version), so this part is also quite easy. If you already have an up-to-date TF Serving instance, all you need to do is to point to your saved model in the model_base_path parameter. You can save the TFDF model with the save method. Notice that you should save it to a folder 1 since it's the first version of your model.

manual_tuned.save("../models/loan_default_model/1/")

For those who haven't used TF Serving model you can find a good tutorial by the TF team here and I've also written this Colab notebook just in case you're working on M1 or M2 Macs (there's currently no support for TF Serving).

Essentially, all you need to do is to install the TF Serving locally and launch it with the right parameters. Once you have the binary downloaded, here's the command to launch the server:

./tensorflow_model_server 
    --rest_api_port=8501 
    --model_name=loan_default_model 
    --model_base_path=/path/models/loan_default_model/1

Please note that model_base_path should be absolute path and not relative. After the TF Serving server starts, you can start sending requests to it. There are two expected formats – instances and inputs . Below you can see an example of the later format but you can see examples of both in this tutorial.

# Input data formatted correctly
data = {
    "Bank": ["Other"],
    "BankState": ["TN"],
    "City": ["Other"],
    "CreateJob": [12.0],
    "FranchiseCode": ["0"],
    "GrAppv": [14900000.0],
    "NoEmp": [28.0],
    "RetainedJob": [16.0],
    "RevLineCr": ["N"],
    "SBA_Appv": [14900000.0],
    "State": ["TN"],
    "Term": [240.0],
    "UrbanRural": ["0"],
    "is_new": [0.0],
    "latitude": [35.3468],
    "longitude": [-86.22],
    "naics_first_two": ["44"],
    "same_state": [1.0],
    "ApprovalFY": [1]
}
payload = {"inputs": data}

# Send the request
url = 'http://localhost:8501/v1/models/default_model:predict'
response = requests.post(url, json=payload)

# Print out the response
print(json.loads(response.text)['outputs'])
# Expected output: [[0.0138759678]]

If you've managed to get the response (it can be different from mine) – congratulations! You've made it to the last step of this post. Now, let's look back at what you've accomplished if you have followed all these chapter.

Conclusion

To summarise, TFDF is a powerful and scalable library for training tree based models in TensorFlow. TFDF models are well integrated with the rest of the TensorFlow ecosystem, so if you're using TFX, have other TF models in production, or are using TF Serving, you'll find this library quite useful.

If you've followed through the notebooks, you should now know how to train, tune, inspect, and server the TFDF models. As you saw, TFDF models are highly customisable, so if you're in need of a high-performance library for tree-based models, give it a shot and let me know how it goes!

Not a Medium Member yet?

Join Medium with my referral link – Antons Tocilins-Ruberts

Tags: Decision Tree Hands On Tutorials Machine Learning TensorFlow Xgboost

Comment