Model Understanding >
TabNet
TabNet is a deep tabular data learning model that utilizes a series of sparse attention and feature generation.
TabNet Architecture
TabNet: Attentive Interpretable Tabular Learning uses sequence of attentive transformer and feature transformers. The model can be built as encoder-decoder as well as encoder only. For our purposes of coin price prediction, we use the encoder only.
TabNet works as follows:
- TabNet consists of multiple steps.
- Each step consists of an attentive transformer followed by a feature generation.
-
The attentive transformer identifies the most important features. It uses sparse max (instead of softmax) to generate a mask that retrieves the most important features.
-
The feature generation generates new features from the identified features. It has two parts: a shared feature generation and a per-step feature generation. Those are essentially FC-BN-GLU layers with skip connections.
- The new features are used to make predictions while producing the input for the next step as well.
Notably, TabNet widely uses ghost BN (BatchNorm) instead of regular BN. This enables TabNet to use large batch size for faster training while allowing the regularization effect of small batches.
TabNet for Air Passengers
To demonstrate model performance, we show the model's prediction results for the air passengers dataset. The cross validation process identified the best transformation to make the time series stationary and the optimal hyperparameters. The Root Mean Squared Error on the next day's closing price was used to determine the best model.
In the chart, we display the model's predictions for last split of cross validation and test data.
- train: Training data of the last split.
- validation: Validation data of the last split.
-
prediction (train, validation): Prdiction for train and validation data period. For each row (or a sliding window) of data, predictions are made for n days into the future (where n is set to 1, 2, 7). The predictions are then combined into a single series of dots. Since the accuracy of predictions decreases for large n, we see some hiccups in the predictions. The predictions from the tail of the train spills into the validation period as that's future from the 'train' data period viewpoint. These are somewhat peculiar settings, but it works well in testing if the model's predictions are good enough.
- test(input): Test input data.
- test(actual): Test actual data.
-
prediction(test): The model's prediction given the test input. There's only one prediction from the last row (or the last sliding window) of the test input which corresponds to 1, 2, 7 days later after 'test(input)'.