TFT: an Interpretable Transformer

Author:Murphy  |  View: 25005  |  Time: 2025-03-22 23:27:51

Introduction

Every company in the world needs forecasting to plan their operations regardless the sector in which they operate. There are several forecast use cases to solve in companies such as sales for yearly planning, customer service contacts for monthly planning of agents for each language, sku sales to plan production and/or procurement and so on.

Although, there are different use cases, all of them share one need from their stakeholders: Interpretability! If you deployed a forecast model in the past for a stakeholder, you came across to the question: ‘why is the model making such prediction?'

In this article I explore TFT, an interpretable Transformer for time series forecasting. I also provide a step-by-step implementation of TFT to forecast weekly sales in a dataset from Walmart using Darts (a forecasting library for Python). And finally, I show how to interpret the model and its performance for a 16 week horizon forecast in the Walmart dataset.

Figure 1: Interpretable AI (image generated by the author with DALL-E)

As always, the code is available on Github.

TFT: Temporal Fusion Transformers

What is it?

When it comes to time series forecasting, usually they are influenced not only by their historical values but also on other inputs. They might contain a mix of complex inputs like static covariates (i.e. time-invariant features like the brand of a product), dynamic covariates with known future inputs like the product discount and other dynamic covariates with unknown future inputs such as the number of visitors for the next weeks.

Several Deep Learning models have been proposed to tackle the presence of multiple inputs for time series forecasting but they are typically ‘black-box' models which do not allow to understand how each component is impacting the forecast produced.

Temporal Fusion Transformers (TFT) [1] is an attention-based architecture that combines multi-horizon forecasting with interpretable insights. It has recurrent layers to learn temporal relationships at different scales, self-attention layers for interpretability, variable selection networks to perform feature selection, gating layers to suppress unnecessary components and its loss function is quantile loss to produce forecast intervals. In Figure 2 you can see TFT architecture that will be explained in more detail in the next section.

Figure 2: TFT architecture (source)

How does it work?

TFT has 5 major components:

1. Gating Mechanisms powered by Gated Residual Network (GRN) that gives the flexibility to apply non-linear processing only when needed. This is important because it is hard to know in advance the non-linear relationship between dynamic covariates and the target. When we have small and noisy datasets, simpler models can be more beneficial and the non-linear processes can be skipped.

The GRN works the following way:

  • It starts with a Primary Input a and a Context Vector c (which is the result of the static encoder – will be explained later).
  • Those inputs go through a Dense Layer with ELU [2] activation function that acts as identity function when W2,ω a + W3,ω c + b2,ω >> 0 and as a constant when W2,ω a + W3,ω c + b2,ω << 0.
  • Then its output goes through a new Dense Layer which produces the input for the Gated Linear Units (GLUs) [3]. GLU has a sigmoid activation function that is responsible for controlling how much the GRN contributes to the original input a.
  • Due to the sigmoid function GLU can produce an output of 0 which will make the model skip the non-linear contribution and the Layer Normalization will just receive the original a from the skip connection.
Figure 3: The inside of Gated Residual Network and how it works (image made by the author)

2. Variable Selection Networks (VSN) help the model to weight the relevance and contribution of each static and dynamic covariates. Apart from providing which features are most important for the forecasting, it also performs feature selection to remove any unnecessary noisy inputs that can negatively impact performance. Each type of input (static, past and future) has its own VSN represented by different colours in Figure 2.

The input features are transformed before getting into VSN. Categorical inputs are encoded into a d-dimensional embedding vector while numerical features are linearly transformed into a d-dimensional vector.

After that, VSN unfolds into two branches:

1st Branch:

  • Each transformed feature E(j) at time t are concatenated into a vector of all past inputs at time t such as [E(1)t, E(2)t, …, E(j-1)t, E(j)t], where j denotes a specific feature.
  • This vector is concatenated to a Context Vector c that goes through a GRN where a non-linear __ transformation is applied as explained previously.
  • The output of GRN is passed to a Softmax layer that produces a vector with the weights for each feature. The feature selection is performed in this step since Softmax can produce any value from 0 to 1.
Figure 4: Process to get the feature importance of each feature (image made by the author)

2nd Branch:

  • Each transformed feature E at time t feeds an independent GRN producing the output ~E though a non-linear processing.
Figure 5: Transformation of E(j) at time t with GRN (image made by the author)

Combination Step:

  • With the feature importance vector and the non-linear transformations of E, a element wise combination of both vectors is performed to generated a processed feature vector weight by their relevance.
Figure 6: Combination of Feature weight and Transformed features (image made by the author)

3. Static Covariate Encoders encode the static covariates into four different vectors using four different GRNs:

  • Context vector s used for temporal variable selection in VSNs
  • Context vector c and h used for local processing of temporal features in LSTM Encoder-Decoder
  • Context vector e used to enrich temporal features with static information in the Enrichment Layer
Figure 7: Static Covariate Encoder into 4 different context vectors for different uses (image made by the author)

4. Temporal Processing is important in time series because often the surrounding observations are the most useful for future predictions. This local context was already developed for attention-based architectures, however they are only suitable for observed inputs and cannot handle known future inputs at the same time.

To overcome this problem, the authors proposed a sequence-to-sequence model to handle past and future known inputs. They feed past inputs into a LSTM Encoder and known future inputs into a LSTM Decoder. Both Encoder and Decoder also use the context vectors c and h as inputs __ so that static metadata can influence local processing when creating the temporal features.

Figure 8: LSTM Encoder-Decoder Model (image made by the author)

The output of both Encoder and Decoder is combined with the context vector e and sent to an individual GRN with shared weights in the Enrichment Layer that enhances temporal features with static metadata.

Figure 9: Final Temporal Features created in the Enrichment Layer (image made by the author)

Finally, the final temporal features enriched with static metadata are fed to an interpretable multi-head attention layer that learns the relevance of each time step t in respect to the rest of the input sequence that precedes it.

Figure 10: Temporal Interpretability of Temporal Features in relation to t (image made by the author)

5. Quantile Prediction is achieved by the prediction of various percentiles at each time step. The forecast is generated using a linear transformation of the output from the temporal fusion decoder.

Figure 11: Information flow from LSTM Decoder to the Denser layer that performs the quantile regression (image made by the author)

How to use and interpret TFT in practice

This section covers a step by step implementation of TFT using the same dataset from my previous post about TiDE which is a weekly sales dataset from Walmart available on kaggle (License CC0: Public Domain).

I will use the implementation in Darts to train, predict and interpret TFT.

The dataset has 2 years and 8 months of weekly sales and 16 columns:

  • Store – store number and one of the static covariates
  • Dept – department number and another static covariate
  • Type – type of store and another static covariate
  • Size – size of the store and the last static covariate
  • Date – the temporal index of the time series which is weekly and it will be used to extract dynamic covariates like the week number and the month
  • _WeeklySales – the target variable
  • IsHoliday – a dynamic covariate that identifies if there is a holiday in a specific week
  • Temperature – a dynamic covariate with the average temperature in a specific week
  • _FuelPrice – a dynamic covariate with the price of fuel in a specific week
  • MarkDown 1,2,3,4 and 5 -a dynamic covariate with average discounts in a specific week
  • CPI – a dynamic covariate with the consumer price index
  • Unemployment – a dynamic covariate with the unemployment rate

We start by importing the libraries and defining global variables like the date column, target column, static covariates, dynamic covariates to fill with 0, dynamic covariates to fill with linear interpolation, the frequency of our series, the forecast horizon and the scalers to use:

import pandas as pd
import numpy as np
from datetime import timedelta
import matplotlib.pyplot as plt

from darts import TimeSeries
from darts.dataprocessing.pipeline import Pipeline
from darts.models import TFTModel
from darts.dataprocessing.transformers import Scaler
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.utils.likelihood_models import QuantileRegression
from darts.dataprocessing.transformers import StaticCovariatesTransformer, MissingValuesFiller

TIME_COL = "Date"
TARGET = "Weekly_Sales"
STATIC_COV = ["Store", "Dept", "Type", "Size"]
DYNAMIC_COV_FILL_0 = ["IsHoliday", 'MarkDown1', 'MarkDown2', 'MarkDown3', 'MarkDown4', 'MarkDown5']
DYNAMIC_COV_FILL_INTERPOLATE = ['Temperature', 'Fuel_Price', 'CPI', 'Unemployment']
FREQ = "W-FRI"
FORECAST_HORIZON = 16 # weeks
SCALER = Scaler()
TRANSFORMER = StaticCovariatesTransformer()
PIPELINE = Pipeline([SCALER, TRANSFORMER])

The default scaler is MinMax Scaler, but we can use any we want from scikit-learn as long as it has fit(), transform() and inverse_transform() methods. The same happens for the transformer which by default is Label Encoder from scikit-learn.

After that, we load our dataset and we enrich it with those exogenous features as I mentioned in the dataset description:

# load data and exogenous features
df = pd.read_csv('data/train.csv')
store_info = pd.read_csv('data/stores.csv')
exo_feat = pd.read_csv('data/features.csv').drop(columns='IsHoliday')

# join all data frames
df = pd.merge(df, store_info, on=['Store'], how='left')
df = pd.merge(df, exo_feat, on=['Store', TIME_COL], how='left')

Once the dataset is loaded we need to apply some preprocessing to clean up the data:

  • We set the time column as pd.datetime
  • We convert negative values to 0 (those negative values might indicate returns but I did not spend a lot of time looking into it since it is out of the scope for this article)
  • We fill missing values in Markdown columns with 0, since we assume that when the value is missing is due to lack of promotions
  • We convert the boolean column that identifies a holiday in a specific week to a binary column
  • We transform the Size static covariate from continuous to categorical. When the size is lower than the percentile 25 then is ‘small', when is higher than the percentile 75 then is ‘large' and, finally, when it is between the percentile 25 and 75 then is ‘medium'.
  • Finally, we forecast only the 7 stores with the highest volume of sales to reduce run time (I ran it once with all stores and it took me 11 hours to train the TFT model

    Tags: Artificial Intelligence Data Science Deep Dives Machine Learning Transformers

Comment