Tabnet: Attentive Interpretable Tabular Learning: Sercan O. Arık Tomas Pfister
Tabnet: Attentive Interpretable Tabular Learning: Sercan O. Arık Tomas Pfister
Tabnet: Attentive Interpretable Tabular Learning: Sercan O. Arık Tomas Pfister
ABSTRACT Why is deep learning worth exploring for tabular data? One obvi-
We propose a novel high-performance and interpretable canonical ous motivation is that, similarly to other domains, one would expect
deep tabular data learning architecture, TabNet. TabNet uses se- performance improvements from DNN-based architectures particu-
quential attention to choose which features to reason from at each larly for large datasets [22]. In addition, unlike tree learning which
does not use back-propagation into their inputs to guide efficient
arXiv:1908.07442v4 [cs.LG] 14 Feb 2020
Aggregate information
Figure 1: TabNet’s sparse feature selection exemplified for Adult Census Income prediction [14]. Sparse feature selection
enables interpretability and better learning as the capacity is used for the most salient features. TabNet employs multiple
decision blocks that focus on processing a subset of input features for reasoning. Feature selection is based on feedback flowing
from the preceding decision step. Two decision blocks shown as examples process features that are related to professional
occupation and investments, respectively, in order to predict the income level.
Age Cap. gain Education Occupation Gender Relationship Income > $50k
Masters True
43 True
0 High-school F False
Exec-managerial M True
39 M False
Figure 2: Self-supervised tabular learning. Real-world tabular datasets have interdependent feature columns, e.g., the educa-
tion level can be guessed from the occupation, or the gender can be guessed from the relationship. Unsupervised representa-
tion learning by masked self-supervised learning results in an improved encoder model for the supervised learning task.
two kinds of interpretability: local interpretability that visualizes 2 RELATED WORK
the importance of input features and how they are combined,
and global interpretability which quantifies the contribution of Feature selection: Feature selection in machine learning broadly
each input feature to the trained model. refers to judiciously picking a subset of features based on their use-
(4) Finally, we show that our canonical DNN design achieves sig- fulness for prediction. Commonly-used techniques such as forward
nificant performance improvements by using unsupervised selection and LASSO regularization [20] attribute feature impor-
pre-training to predict masked features (see Fig. 2). Our work tance based on the entire training data set, and are referred as
is the first demonstration of self-supervised learning for tabular global methods. Instance-wise feature selection refers to picking
data. features individually for each input, studied in [6] by training an
explainer model to maximize the mutual information between the
selected features and the response variable, and in [61] by using an
2
!#
+ Softmax
!" < % !" > %
!# > & !# > &
ReLU ReLU &
$" !" − $" % −1
−$" !" + $" % −1
−1 $# !# − $# &
%
−1 −$# !# + $# & !"
FC FC
W: [$" , - $" , 0, 0] W: [0, 0, $# , - $# ] !" < %
b: [-a $" , a $" , -1, -1] b: [-1, -1, -d $# , d $# ] !# < & !" > %
!# < &
[!" ] [!# ]
Mask Mask
M: [1, 0] M: [0, 1]
[!" , !# ]
Figure 3: Illustration of decision tree-like classification using conventional DNN blocks (left) and the corresponding decision
manifold (right). Relevant features are selected by using multiplicative sparse masks on inputs. The selected features are
linearly transformed, and after a bias addition (to represent boundaries) ReLU performs region selection by zeroing the regions
that are on the negative side of the boundary. Aggregation of multiple regions is based on addition. As C 1 and C 2 get larger,
the decision boundary gets sharper due to the softmax.
actor-critic framework to mimic a baseline while optimizing the propose a sequential mechanism for field-level attention. Unlike
feature selection. Unlike these, TabNet employs soft feature selection these, we demonstrate the application of sequential attention for
with controllable sparsity in end-to-end learning – a single model supervised or self-supervised learning instead of mapping tabular
jointly performs feature selection and output mapping, resulting in data to a different data type.
superior performance with compact representations. Self-supervised learning: Unsupervised representation learning
Tree-based learning: Tree-based models are the most common is shown to benefit the supervised learning task especially in small
approaches for tabular data learning. The prominent strength of data regime [47]. Recent work for language [13] and image [55]
tree-based models is their efficacy in picking global features with the data has shown significant advances – specifically careful choice of
most statistical information gain [18]. To improve the performance the unsupervised learning objective (masked input prediction) and
of standard tree-based models by reducing the model variance, one attention-based deep learning architecture is important.
common approach is ensembling. Among ensembling methods,
random forests [23] use random subsets of data with randomly 3 TABNET FOR TABULAR LEARNING
selected features to grow many trees. XGBoost [7] and LightGBM Decision trees are successful for learning from real-world tabular
[30] are the two recent ensemble decision tree approaches that datasets. However, even conventional DNN building blocks can be
dominate most of the recent data science competitions. Our exper- used to implement decision tree-like output manifold – see Fig. 3)
imental results for various datasets show that tree-based models for an example. In such a design, individual feature selection is key
can be outperformed when the representation capacity is improved to obtain decision boundaries in hyperplane form. This idea can be
with deep learning while retaining their feature selecting property. generalized for a linear combination of features where constituent
coefficients determine the proportion of each feature. TabNet is
Integration of DNNs into decision trees: Representing deci- based on such a tree-like functionality. We show that it outperforms
sion trees with canonical DNN building blocks as in [26] yields decision trees while reaping many of their benefits by careful de-
redundancy in representation and inefficient learning. Soft (neural) sign which: (i) uses sparse instance-wise feature selection learned
decision trees [33, 58] are proposed with differentiable decision based on the training dataset; (ii) constructs a sequential multi-step
functions, instead of non-differentiable axis-aligned splits. How- architecture, where each decision step can contribute to a portion of
ever, abandoning trees loses their automatic feature selection ability the decision that is based on the selected features; (iii) improves the
which is important for tabular data. In [60], a soft binning function learning capacity by non-linear processing of the selected features;
is proposed to simulate decision trees in DNNs, which needs to and (iv) mimics an ensemble via higher dimensions and more steps.
enumerate all possible decisions and is inefficient. [31] proposes a Fig. 4 shows the TabNet architecture for encoding tabular data.
DNN architecture by explicitly leveraging expressive feature com- Tabular data are comprised of numerical and categorical features.
binations, however, learning is based on transferring knowledge We use the raw numerical features and consider mapping of categor-
from a gradient-boosted decision tree and limited performance ical features with trainable embeddings2 . We do not consider any
improvements are observed. [53] proposes a DNN architecture global normalization features, but merely apply batch normalization
by adaptively growing from primitive blocks while representation (BN). We pass the same D-dimensional features f ∈ <B×D to each
learning into edges, routing functions and leaf nodes of a decision decision step, where B is the batch size. TabNet’s encoding is based
tree. TabNet differs from these methods as it embeds the soft feature on sequential multi-step processing with Nst eps decision steps. The
selection ability with controllable sparsity via sequential attention. i th step inputs the processed information from the (i − 1)th step
Attentive table-to-text models: Table-to-text models extract tex- 2 E.g.,
the three possible categories A, B and C for a particular feature can be learned
tual information from tabular data, for which recent works [3, 35] to be mapped to scalars 0.4, 0.1, and -0.2.
3
transformer transformer
Attentive
Mask
transformer
BN
Features
Step 1 Step 2
+
x Nsteps
… FC Output
ReLU ReLU
+ Softmax
Split Split Split Feature
Split Split
transformer
Feature Feature Feature
transformer transformer
Feature Feature …
transformer …
transformer transformer Encoded representation
Shared across decision steps Decision
Attentive
transformer
Mask
Attentive
transformer
Mask …step dependent
Attentive
Mask Step 1 Step 2
transformer
GLU
GLU
GLU
GLU
BN
BN
BN
BN
FC
FC
FC
FC
+
+
BN Feature 0.5 Feature 0.5 0.5
BN Agg. transformer transformer …
Agg.
Features
FC FC
Features
+
+
+ … Feature
+ … Reconstructed
features
attributes
Prior scales
+
Shared across decision steps
GLU
GLU
GLU
GLU
BN
BN
BN
BN
FC
FC
FC
FC
+
+
0.5 0.5 0.5
Sparsemax
GLU
GLU
GLU
GLU
BN
BN
BN
BN
FC
FC
FC
FC
+
BN
FC
+
0.5 0.5 0.5
Attentive
attentive transformer of the subsequent step astransformer
well as for constructing the overall output. At each decision step, the feature
Sparsemax
selection mask can provide interpretable information about the model’s functionality, and the masks can be aggregated to
BN
FC
obtain global feature important attribution. (b)Prior scales decoder, composed of a feature transformer block at each step. (c) A
TabNet
+
feature transformer block example – 4-layer network is shown, where 2 of the blocks are shared across all decision steps and 2
Sparsemax
are decision step-dependent. Each layer is composed of a fully-connected (FC) layer, BN and GLU nonlinearity. (d) An attentive
BN
FC
transformer block example – a single layer mapping is modulated with a prior scale information which aggregates how much
each feature has been used before the current decision step. Normalization of the coefficients is done using sparsemax [37]
for sparse selection of the most salient features at each decision step.
to decide which features to use and outputs the processed feature observed to be superior in performance and aligned with the goal
representation to be aggregated into the overall decision. The idea of sparse feature selection for most real-world datasets. Note that
of top-down attention in the sequential form is inspired by its ap- Eq. 1 ensures D j=1 M[i]b,j = 1. hi is a trainable function, shown in
Í
plications in processing visual and language data (e.g. in visual Fig. 4 using a FC layer, followed by BN. P[i] is the prior scale term,
question answering [25]) and reinforcement learning [40] while denoting how much a particular feature has been used previously:
searching for a small subset of relevant information in high dimen- Öi
sional input. Ablation studies in the Appendix focus on the impact P[i] = (γ − M[j]), (2)
j=1
of various design choices which are explained next. Guidelines on
where γ is a relaxation parameter – when γ = 1, a feature is en-
selection of the important hyperparameters are also provided in
forced to be used only at one decision step and as γ increases,
the Appendix.
more flexibility is provided to use a feature at multiple decision
Feature selection: We employ a learnable mask M[i] ∈ <B×D for
steps. P[0] is initialized as all ones, 1B×D , without any prior on the
soft selection of the salient features. Through sparse selection of the
masked features. If some features are unused (as in self-supervised
most salient features, the learning capacity of a decision step is not
learning), corresponding P[0] entries are made 0 to help model’s
wasted on irrelevant features, and thus the model becomes more
learning. To further control the sparsity of the selected features,
parameter efficient. The masking is in multiplicative form, M[i] · f.
we propose sparsity regularization in the form of entropy [19]:
We use an attentive transformer (see Fig. 4) to obtain the masks
using the processed features from the preceding step, a[i − 1]: ÕNs t eps ÕB ÕD −Mb,j [i]
Lspar se = log(Mb,j [i]+ϵ),
M[i] = sparsemax(P[i − 1] · hi (a[i − 1])). (1) i=1 b=1 j=1 Nst eps · B
Sparsemax normalization [37] encourages sparsity by mapping where ϵ is a small number for numerical stability. We add the
the Euclidean projection onto the probabilistic simplex, which is sparsity regularization to the overall loss, with a coefficient λspar se .
4
Sparsity may provide a favorable inductive bias for convergence to all datasets, categorical inputs are mapped to a single-dimensional
higher accuracy for datasets where most features are redundant. trainable scalar with a learnable embedding5 and numerical columns
Feature processing: We process the filtered features using a fea- are input without and preprocessing.6 We use standard classifica-
ture transformer (see Fig. 4) and then split for the decision step out- tion (softmax cross entropy) and regression (mean squared error)
put and information for the subsequent step, [d[i], a[i]] = fi (M[i]·f), loss functions and we train until convergence. Hyperparameters
where d[i] ∈ <B×Nd and a[i] ∈ <B×Na . For parameter-efficient of the TabNet models are optimized on a validation set and listed
and robust learning with high capacity, a feature transformer should in Appendix. TabNet performance is not very sensitive to most
comprise layers that are shared across all decision steps (as the same hyperparameters as shown with ablation studies in Appendix. In all
features are input across different decision steps), as well as decision of the experiments where we cite results from other papers, we use
step-dependent layers. Fig. 4 shows the implementation as con- the same training, validation and testing data split with the origi-
catenation of two shared layers and two decision step-dependent nal work. Adam optimization algorithm [32] and Glorot uniform
layers. Each FC layer is followed by BN and gated linear unit (GLU) initialization are used for training of all models. An open-source im-
nonlinearity [12]3 , eventually connected to a normalized √ residual
plementation can be found on https://github.com/google-research/
connection with normalization. Normalization with 0.5 helps to google-research/tree/master/tabnet.
stabilize learning by ensuring that the variance throughout the
network does not change dramatically [15]. For faster training, we
5.1 Instance-wise feature selection
Selection of the most salient features can be crucial for high perfor-
aim for large batch sizes. To improve performance with large batch
mance, especially for small datasets. We consider the 6 synthetic
sizes, all BN operations, except the one applied to the input features,
tabular datasets from [6] (consisting 10k training samples). The
are implemented in ghost BN [24] form, with a virtual batch size
synthetic datasets are constructed in such a way that only a subset
BV and momentum m B . For the input features, we observe the ben-
of the features determine the output. For Syn1, Syn2 and Syn3
efit of low-variance averaging and hence avoid ghost BN. Finally,
datasets, the ‘salient’ features are the same for all instances, so that
inspired by decision-tree like aggregation as in Fig. 3, we construct
ÍNs t eps an accurate global feature selection mechanism should be optimal.
the overall decision embedding as dout = i=1 ReLU(d[i]). We E.g., the ground truth output of the Syn2 dataset only depends
apply a linear mapping Wfinal dout to get the output mapping. For on features X 3 -X 6 . For Syn4, Syn5 and Syn6 datasets, the salient
discrete outputs, we additionally employ softmax during training features are instance dependent. E.g., for Syn4 dataset, X 11 is the
(and argmax during inference). indicator, and the ground truth output depends on either X 1 -X 2 or
4 TABULAR SELF-SUPERVISED LEARNING X 3 -X 6 depending on the value of X 11 . This instance dependence
makes global feature selection suboptimal, as the globally-salient
Decoding tabular features: We propose a decoder architecture features would be redundant for some instances.
to reconstruct tabular features from the encoded representations, Table 1 shows the performance of TabNet encoder vs. other tech-
obtained from the TabNet encoder. The decoder is composed of niques, including no selection, using only globally-salient features,
feature transformer blocks, followed by FC layers at each decision Tree Ensembles [16], LASSO regularization, L2X [6] and INVASE
step. The outputs are summed to obtain the reconstructed features.4 [61]. We observe that TabNet outperforms all other methods and is
Self-supervised objective: We propose the task of prediction of on par with INVASE. For Syn1, Syn2 and Syn3 datasets, we observe
missing feature columns from the others. Consider a binary mask that the TabNet performance is very close to global feature selec-
S ∈ {0, 1} B×D . The TabNet encoder inputs (1 − S) · f̂ and the TabNet tion. For Syn4, Syn5 and Syn6 datasets, we observe that TabNet
decoder outputs the reconstructed features, S · f̂. We initialize improves global feature selection, which would contain redundant
P[0] = (1 − S) in the encoder so that the model emphasizes merely features. (Feature selection is visualized in Sec. 5.3.) All other meth-
on the known features, and the decoder’s last FC layer is multiplied ods utilize a predictive model with 43k parameters, and the total
with S to merely output the unknown features. We consider the number of trainable parameters is 101k for INVASE due to the two
reconstruction loss to optimize in self-supervised phase: other networks in the actor-critic framework. On the other hand,
,r 2 TabNet is a single DNN architecture, and its model size is 26k for
Syn1-Syn3 datasets and 31k for Syn4-Syn6 datasets. This compact
ÕB ÕD ÕB ÕB
(f̂b,j − fb,j ) · Sb,j (fb,j − 1/B 2
fb,j ) .
b=1 j=1 b=1 b=1 end-to-end representation is one of TabNet’s valuable properties.
Normalization with the population standard deviation of the ground 5.2 Performance on real-world datasets
truth data is important, as the features may have very different
ranges. We sample the entries Sb,j independently from a Bernolli Forest Cover Type [14]: This dataset corresponds to the task of
distribution with parameter ps , at each iteration. classification of forest cover type from cartographic variables. Table
5 EXPERIMENTS 2 shows that TabNet significantly outperforms ensemble tree based
We study TabNet in wide range of problems, that contain regression approaches that are known to achieve solid performance on this
or classification tasks, particularly with published benchmarks. For task [38]. In addition, we consider AutoInt[51] for this task given
3 InGLU, first a linear mapping is applied and the dimensionality is doubled, and then 5 In some cases, higher dimensional embeddings may slightly improve the performance,
second half of the output is used to determine nonlinear processing on the first half. but interpretation of individual dimensions may become challenging.
4 We have also experimented sequential decoding of features with attentive transformer 6 Specially-designed feature engineering, e.g. logarithmic transformation of variables
at each decision step, but did not observe significant benefits for the purpose of self- highly-skewed distributions, may further improve the results but we leave it out of
supervised learning and thus chose this simpler architecture. the scope of this paper.
5
Table 1: Mean and std. of test area under the receiving operating characteristic curve (AUC) on 6 synthetic datasets from [6], for
TabNet vs. other feature selection-based DNN models: No selection: using all features without any feature selection, Global:
using only globally-salient features, Tree: Tree Ensembles [16], LASSO: LASSO-regularized model, L2X [6] and INVASE [61].
Bold numbers are the best method for each dataset.
Test AUC
Model
Syn1 Syn2 Syn3 Syn4 Syn5 Syn6
No selection .578 ± .004 .789 ± .003 .854 ± .004 .558 ± .021 .662 ± .013 .692 ± .015
Tree .574 ± .101 .872 ± .003 .899 ± .001 .684 ± .017 .741 ± .004 .771 ± .031
Lasso .498 ± .006 .555 ± .061 .886 ± .003 .512 ± .031 .691 ± .024 .727 ± .025
L2X .498 ± .005 .823 ± .029 .862 ± .009 .678 ± .024 .709 ± .008 .827 ± .017
INVASE .690 ± .006 .877 ± .003 .902 ± .003 .787 ± .004 .784 ± .005 .877 ± .003
Global .686 ± .005 .873 ± .003 .900 ± .003 .774 ± .006 .784 ± .005 .858 ± .004
TabNet .682 ± .005 .892 ± .004 .897 ± .003 .776 ± .017 .789 ± .009 .878 ± .004
Table 2: Performance for Forest Cover Type dataset.
the cards. The input-output relationship is deterministic and hand-
Model Test accuracy (%) crafted rules implemented with several lines of code can get 100%
XGBoost 89.34 accuracy. Yet, conventional DNNs, decision trees, and even their
LightGBM 89.28 hybrid variant of deep neural decision tree models [60] severely
CatBoost 85.14 suffer from the imbalanced data and cannot learn the required sort-
AutoInt 90.24 ing and ranking operations with the raw input features [60]. Tuned
XGBoost, CatBoost, and LightGBM show very slight improvements
AutoML Tables (2 node hours) 94.56 over them. On the other hand, TabNet significantly outperforms the
AutoML Tables (10 node hours) 96.67 other methods and approaches to deterministic rule accuracy, as it
AutoML Tables (30 node hours) 96.93 can perform highly-nonlinear processing with high depth, without
TabNet 96.99 overfitting thanks to instance-wise feature selection.
its strength for problems with high feature dimensionality. Au- Table 4: Performance for Sarcos Robotics Arm Inverse Dy-
toInt models pairwise feature interactions with an attention-based namics dataset. Three TabNet models of different sizes are
DNN [51] and significantly underperforms TabNet that employs considered (denoted with -S, -M and -L).
instance-wise feature selection, and considers the interaction be-
Model Test MSE Model size
tween different features if the model infers that it is the appropriate
Random forest 2.39 16.7K
processing to apply. Lastly, we consider AutoML Tables [2], an au-
Stochastic decision tree 2.11 28K
tomated search framework based on ensemble of models including
MLP 2.13 0.14M
linear feed-forward DNN, gradient boosted decision tree, AdaNet
Adaptive neural tree 1.23 0.60M
[10] and ensembles [2]. For AutoML Tables, the amount of node
Gradient boosted tree 1.44 0.99M
hours reflects the measure of the count of searched models for the
ensemble and their complexity.7 A single TabNet model without TabNet-S 1.25 6.3K
fine-grained hyperparameter search outperforms the accuracy of TabNet-M 0.28 0.59M
ensemble models with very thorough hyperparameter search. TabNet-L 0.14 1.75M
/
Test samples
Syn4 dataset
Magg M[1] M[2] M[3] M[4] M[5]
Figure 5: Feature importance masks M[i] (that indicate which features are selected at i th step) and the aggregate feature im-
portance mask Magg showing the global instance-wise feature selection for Syn2 and Syn6 datasets from [6]. Brighter colors
show a higher value. E.g. for Syn2 dataset, only four features (X 3 -X 6 ) are used.
.
Figure 6: T-SNE of the decision manifold for Adult Census 5.4 Self-supervised learning
Income test samples and the impact of the top feature ‘Age’. We study self-supervised learning on Higgs and Forest Cover Type
datasets. For the pre-training task of guessing the missing columns,
we use the masking parameter ps = 0.8 and train for 1M iterations.
of mushroom edibility prediction [14]. TabNet achieves 100% test We use a subset of the labeled dataset for supervised fine-tuning
accuracy on this dataset. It is indeed known [14] that “Odor” is with a validation set to determine the number of iterations for
the most discriminative feature for this task, with “Odor” feature early stopping. A large validation dataset would be unrealistic for
only, model can get > 98.5% test accuracy [14]. Thus, a high feature small training datasets, so in these experiments we assume its size
importance is expected for it. TabNet assigns an importance score is equal to the training dataset. Table 9 shows that unsupervised
ratio of 43% for it, while other notable methods like LIME [48], pre-training significantly improves performance on the supervised
8
REFERENCES
[1] Dario Amodei, Rishita Anubhai, Eric Battenberg, Carl Case, Jared Casper, et al.
2015. Deep Speech 2: End-to-End Speech Recognition in English and Mandarin.
arXiv:1512.02595 (2015).
[2] AutoML. 2019. AutoML Tables – Google Cloud. https://cloud.google.com/
automl-tables/
[3] J. Bao, D. Tang, N. Duan, Z. Yan, M. Zhou, and T. Zhao. 2019. Text Generation
From Tables. IEEE Trans Audio, Speech, and Language Processing 27, 2 (Feb 2019),
311–320.
[4] Yael Ben-Haim and Elad Tom-Tov. 2010. A Streaming Parallel Decision Tree
Algorithm. JMLR 11 (March 2010), 849–872.
[5] Catboost. 2019. Benchmarks. https://github.com/catboost/benchmarks. Accessed:
2019-11-10.
[6] Jianbo Chen, Le Song, Martin J. Wainwright, and Michael I. Jordan. 2018. Learn-
ing to Explain: An Information-Theoretic Perspective on Model Interpretation.
arXiv:1802.07814 (2018).
[7] Tianqi Chen and Carlos Guestrin. 2016. XGBoost: A Scalable Tree Boosting
System. In KDD.
[8] Michael Chui, James Manyika, Mehdi Miremadi, Nicolaus Henke, Rita Chung,
Figure 7: Convergence with unsupervised pre-training is et al. 2018. Notes from the AI Frontier. McKinsey Global Institute (4 2018).
[9] Alexis Conneau, Holger Schwenk, Loı̈c Barrault, and Yann LeCun. 2016. Very
much faster, shown for Higgs dataset with 10k samples. Deep Convolutional Networks for Natural Language Processing. arXiv:1606.01781
(2016).
Table 9: Self-supervised tabular learning results. Mean and [10] Corinna Cortes, Xavi Gonzalvo, Vitaly Kuznetsov, Mehryar Mohri, and Scott
std. of accuracy (over 15 runs) on Higgs with Tabnet-M Yang. 2016. AdaNet: Adaptive Structural Learning of Artificial Neural Networks.
arXiv:1607.01097 (2016).
model, varying the size of the training dataset for super- [11] Zihang Dai, Zhilin Yang, Fan Yang, William W. Cohen, and Ruslan Salakhutdinov.
vised fine-tuning. 2017. Good Semi-supervised Learning that Requires a Bad GAN. arxiv:1705.09783
(2017).
[12] Yann N. Dauphin, Angela Fan, Michael Auli, and David Grangier. 2016. Language
Training Test accuracy (%) Modeling with Gated Convolutional Networks. arXiv:1612.08083 (2016).
dataset size Supervised With pre-training [13] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2018. BERT:
1k 57.47 ± 1.78 61.37 ± 0.88 Pre-training of Deep Bidirectional Transformers for Language Understanding.
arXiv:1810.04805 (2018).
10k 66.66 ± 0.88 68.06 ± 0.39 [14] Dheeru Dua and Casey Graff. 2017. UCI Machine Learning Repository. http:
100k 72.92 ± 0.21 73.19 ± 0.15 //archive.ics.uci.edu/ml
[15] Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin.
2017. Convolutional Sequence to Sequence Learning. arXiv:1705.03122 (2017).
[16] Pierre Geurts, Damien Ernst, and Louis Wehenkel. 2006. Extremely randomized
Table 10: Self-supervised tabular learning results. Mean and trees. Machine Learning 63, 1 (01 Apr 2006), 3–42.
[17] Ian Goodfellow, Yoshua Bengio, and Aaron Courville. 2016. Deep Learning. MIT
std. of accuracy (over 15 runs) on Forest Cover Type, varying Press.
the size of the training dataset for supervised fine-tuning. [18] K. Grabczewski and N. Jankowski. 2005. Feature selection with decision tree
criterion. In HIS.
Training Test accuracy (%) [19] Yves Grandvalet and Yoshua Bengio. 2004. Semi-supervised Learning by Entropy
Minimization. In NIPS.
dataset size Supervised With pre-training [20] Isabelle Guyon and André Elisseeff. 2003. An Introduction to Variable and Feature
1k 65.91 ± 1.02 67.86 ± 0.63 Selection. JMLR 3 (March 2003), 1157–1182.
[21] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 2015. Deep Residual
10k 78.85 ± 1.24 79.22 ± 0.78 Learning for Image Recognition. arXiv:1512.03385 (2015).
[22] Joel Hestness, Sharan Narang, Newsha Ardalani, Gregory F. Diamos, Heewoo
Jun, Hassan Kianinejad, Md. Mostofa Ali Patwary, Yang Yang, and Yanqi Zhou.
2017. Deep Learning Scaling is Predictable, Empirically. arXiv:1712.00409 (2017).
[23] Tin Kam Ho. 1998. The random subspace method for constructing decision
classification task, especially in the regime where the unlabeled forests. PAMI 20, 8 (Aug 1998), 832–844.
dataset is much larger than the labeled dataset. As exemplified in [24] Elad Hoffer, Itay Hubara, and Daniel Soudry. 2017. Train longer, generalize
Fig. 7 the model convergence is much faster with unsupervised pre- better: closing the generalization gap in large batch training of neural networks.
arXiv:1705.08741 (2017).
training. Very fast convergence can be highly beneficial particularly [25] Drew A. Hudson and Christopher D. Manning. 2018. Compositional Attention
in scenarios like continual learning and domain adaptation. Networks for Machine Reasoning. arXiv:1803.03067 (2018).
[26] K. D. Humbird, J. L. Peterson, and R. G. McClarren. 2018. Deep Neural Network
Initialization With Decision Trees. IEEE Trans Neural Networks and Learning
6 CONCLUSION Systems (2018).
We have proposed TabNet, a novel deep learning architecture for [27] Mark Ibrahim, Melissa Louie, Ceena Modarres, and John W. Paisley. 2019.
Global Explanations of Neural Networks: Mapping the Landscape of Predic-
tabular learning. TabNet uses a sequential attention mechanism tions. arxiv:1902.02384 (2019).
to choose a subset of semantically meaningful features to process [28] Kaggle. 2019. Historical Data Science Trends on Kaggle. https://www.kaggle.
com/shivamb/data-science-trends-on-kaggle. Accessed: 2019-04-20.
at each decision step. Instance-wise feature selection enables ef- [29] Kaggle. 2019. Rossmann Store Sales. https://www.kaggle.com/c/
ficient learning as the model capacity is fully used for the most rossmann-store-sales. Accessed: 2019-11-10.
salient features, and also yields more interpretable decision making [30] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, et al. 2017. Light-
GBM: A Highly Efficient Gradient Boosting Decision Tree. In NIPS.
via visualization of selection masks. We demonstrate that TabNet [31] Guolin Ke, Jia Zhang, Zhenhui Xu, Jiang Bian, and Tie-Yan Liu. 2019. TabNN: A
outperforms previous work across tabular datasets from different Universal Neural Network Solution for Tabular Data. https://openreview.net/
domains. Lastly, we demonstrate significant benefits of unsuper- forum?id=r1eJssCqY7
[32] Diederik P. Kingma and Jimmy Ba. 2014. Adam: A Method for Stochastic
vised pre-training for fast adaptation and improved performance. Optimization. In ICLR.
9
[33] P. Kontschieder, M. Fiterau, A. Criminisi, and S. R. Bul. 2015. Deep Neural APPENDIX
Decision Forests. In ICCV.
[34] Siwei Lai, Liheng Xu, Kang Liu, and Jun Zhao. 2015. Recurrent Convolutional A EXPERIMENT HYPERPARAMETERS
Neural Networks for Text Classification. In AAAI.
[35] Tianyu Liu, Kexiang Wang, Lei Sha, Baobao Chang, and Zhifang Sui. 2017. For all datasets, we use a pre-defined hyperparameter search space.
Table-to-text Generation by Structure-aware Seq2seq Learning. arXiv:1711.09724 Nd and Na are chosen from {8, 16, 24, 32, 64, 128}, Nst eps is cho-
(2017).
[36] Scott M. Lundberg, Gabriel G. Erion, and Su-In Lee. 2018. Consistent Individual-
sen from {3, 4, 5, 6, 7, 8, 9, 10}, γ is chosen from {1.0, 1.2, 1.5, 2.0},
ized Feature Attribution for Tree Ensembles. arXiv:1802.03888 (2018). λspar se is chosen from {0, 0.000001, 0.0001, 0.001, 0.01, 0.1}, B is
[37] André F. T. Martins and Ramón Fernández Astudillo. 2016. From Softmax chosen from {256, 512, 1024, 2048, 4096, 8192, 16384, 32768}, BV is
to Sparsemax: A Sparse Model of Attention and Multi-Label Classification.
arXiv:1602.02068 (2016).
chosen from {256, 512, 1024, 2048, 4096} and m B is chosen from
[38] Rory Mitchell, Andrey Adinets, Thejaswi Rao, and Eibe Frank. 2018. XGBoost: {0.6, 0.7, 0.8, 0.9, 0.95, 0.98}. If the model size is not under the de-
Scalable GPU Accelerated Learning. arXiv:1806.11248 (2018). sired cutoff, we decrease the value to satisfy the size constraint. For
[39] Decebal Mocanu, Elena Mocanu, Peter Stone, Phuong Nguyen, Madeleine
Gibescu, and Antonio Liotta. 2018. Scalable training of artificial neural net- all the comparison models, we run a hyperparameter tuning with
works with adaptive sparse connectivity inspired by network science. Nature the same number of search steps.
Communications 9 (12 2018).
Synthetic: All TabNet models use Nd =Na =16, B=3000, BV =100,
[40] Alex Mott, Daniel Zoran, Mike Chrzanowski, Daan Wierstra, and Danilo J.
Rezende. 2019. S3TA: A Soft, Spatial, Sequential, Top-Down Attention Model. m B =0.7. For Syn1 we use λspar se =0.02, Nst eps =4 and γ =2.0; for
https://openreview.net/forum?id=B1gJOoRcYQ Syn2 and Syn3 we use λspar se =0.01, Nst eps =4 and γ =2.0; and
[41] Sharan Narang, Gregory F. Diamos, Shubho Sengupta, and Erich Elsen. 2017.
Exploring Sparsity in Recurrent Neural Networks. arXiv:1704.05119 (2017).
for Syn4, Syn5 and Syn6 we use λspar se =0.005, Nst eps =5 and
[42] Nbviewer. 2019. Notebook on Nbviewer. https://nbviewer.jupyter.org/github/ γ =1.5. Feature transformers use two shared and two decision step-
dipanjanS/data science for all/blob/master/tds model interpretation xai/ dependent FC layer, ghost BN and GLU blocks. All models use
Human-interpretableMachineLearning-DS.ipynb#
[43] N. C. Oza. 2005. Online bagging and boosting. In IEEE Trans Conference on Adam with a learning rate of 0.02 (decayed 0.7 every 200 iterations
Systems, Man and Cybernetics. with an exponential decay) for 4k iterations. For visualizations in
[44] German Ignacio Parisi, Ronald Kemker, Jose L. Part, Christopher Kanan, and
Stefan Wermter. 2018. Continual Lifelong Learning with Neural Networks: A
Sec. 5.3, we also train TabNet models with datasets of size 10M
Review. arXiv:1802.07569 (2018). samples. For this case, we choose Nd = Na = 32, λspar se =0.001,
[45] Liudmila Prokhorenkova, Gleb Gusev, Aleksandr Vorobev, Anna Veronika Doro- B=10000, BV =100, m B =0.9. Adam is used with a learning rate of
gush, and Andrey Gulin. 2018. CatBoost: unbiased boosting with categorical
features. In NIPS.
0.02 (decayed 0.9 every 2k iterations with an exponential decay)
[46] Alec Radford, Luke Metz, and Soumith Chintala. 2015. Unsupervised Repre- for 15k iterations. For Syn2 and Syn3, Nst eps =4 and γ =2. For Syn4
sentation Learning with Deep Convolutional Generative Adversarial Networks. and Syn6, Nst eps =5 and γ =1.5.
arXiv:1511.06434 (2015).
[47] Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and Andrew Y. Ng. Forest Cover Type: The dataset partition details, and the hyperpa-
2007. Self-Taught Learning: Transfer Learning from Unlabeled Data. In ICML. rameters of XGBoost, LigthGBM, and CatBoost are from [38]. We re-
[48] Marco Ribeiro, Sameer Singh, and Carlos Guestrin. 2016. fiWhy Should I Trust
You?fi: Explaining the Predictions of Any Classifier. In KDD.
optimize AutoInt hyperparameters. TabNet model uses Nd =Na =64,
[49] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. 2017. Learning Im- λspar se =0.0001, B=16384, BV =512, m B =0.7, Nst eps =5 and γ =1.5.
portant Features Through Propagating Activation Differences. arXiv:1704.02685 Feature transformers use two shared and two decision step-dependent
(2017).
[50] Karen Simonyan and Andrew Zisserman. 2014. Very Deep Convolutional Net- FC layer, ghost BN and GLU blocks. Adam is used with a learning
works for Large-Scale Image Recognition. arXiv:1409.1556 (2014). rate of 0.02 (decayed 0.95 every 0.5k iterations with an exponen-
[51] Weiping Song, Chence Shi, Zhiping Xiao, Zhijian Duan, Yewen Xu, Ming Zhang, tial decay) for 130k iterations. For unsupervised pre-training, the
and Jian Tang. 2018. AutoInt: Automatic Feature Interaction Learning via Self-
Attentive Neural Networks. arxiv:1810.11921 (2018). decoder model uses Nd =Na =64, B=16384, BV =512, m B =0.7, and
[52] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. 2017. Axiomatic Attribution Nst eps =10. For supervised fine-tuning, we use the batch size B=BV
for Deep Networks. arXiv:1703.01365 (2017).
[53] Ryutaro Tanno, Kai Arulkumaran, Daniel C. Alexander, Antonio Criminisi, and
as the training datasets are small.
Aditya V. Nori. 2018. Adaptive Neural Trees. arXiv:1807.06699 (2018). Poker Hands: We split 6k samples for validation from the train-
[54] Tensorflow. 2019. Classifying Higgs boson processes in the HIGGS Data Set. ing dataset, and after optimization of the hyperparameters, we
https://github.com/tensorflow/models/tree/master/official/boosted trees
[55] Trieu H. Trinh, Minh-Thang Luong, and Quoc V. Le. 2019. Selfie: Self-supervised
retrain with the entire training dataset. Decision tree, MLP and
Pretraining for Image Embedding. arXiv:1906.02940 (2019). deep neural decision tree models follow the same hyperparameters
[56] Aäron van den Oord, Sander Dieleman, Heiga Zen, Karen Simonyan, Oriol with [60]. We tune the hyperparameters of XGBoost, LigthGBM,
Vinyals, et al. 2016. WaveNet: A Generative Model for Raw Audio.
arXiv:1609.03499 (2016). and CatBoost. TabNet uses Nd =Na =16, λspar se =0.000001, B=4096,
[57] Sethu Vijayakumar and Stefan Schaal. 2000. Locally Weighted Projection Re- BV =1024, m B = 0.95, Nst eps =4 and γ =1.5. Feature transformers
gression: An O(n) Algorithm for Incremental Real Time Learning in High Di-
use two shared and two decision step-dependent FC layer, ghost
mensional Space. In ICML.
[58] Suhang Wang, Charu Aggarwal, and Huan Liu. 2017. Using a random forest to BN and GLU blocks. Adam is used with a learning rate of 0.01
inspire a neural network and improving on it. In SDM. (decayed 0.95 every 500 iterations with an exponential decay) for
[59] Wei Wen, Chunpeng Wu, Yandan Wang, Yiran Chen, and Hai Li. 2016. Learning
Structured Sparsity in Deep Neural Networks. arXiv:1608.03665 (2016).
50k iterations.
[60] Yongxin Yang, Irene Garcia Morillo, and Timothy M. Hospedales. 2018. Deep Sarcos: We split 4.5k samples for validation from the training
Neural Decision Trees. arXiv:1806.06988 (2018). dataset, and after optimization of the hyperparameters, we retrain
[61] Jinsung Yoon, James Jordon, and Mihaela van der Schaar. 2019. INVASE: Instance-
wise Variable Selection using Neural Networks. In ICLR. with the entire training dataset. All comparison models follow
the hyperparameters from [53]. TabNet-S model uses Nd =Na =8,
λspar se =0.0001, B=4096, BV =256, m B =0.9, Nst eps =3 and γ =1.2.
Each feature transformer block uses one shared and two decision
step-dependent FC layer, ghost BN and GLU blocks. Adam is used
with a learning rate of 0.01 (decayed 0.95 every 8k iterations with
10
Table 11: Ablation studies for the TabNet encoder model for the forest cover type dataset.
an exponential decay) for 600k iterations. TabNet-M model uses use the batch size B=BV as the training datasets are small.
Nd =Na =64, λspar se =0.0001, B=4096, BV =128, m B =0.8, Nst eps =7 Rossmann: We use the same preprocessing and data split with
and γ =1.5. Feature transformers use two shared and two decision [5] – data from 2014 is used for training and validation, whereas
step-dependent FC layer, ghost BN and GLU blocks. Adam is used 2015 is used for testing. We split 100k samples for validation
with a learning rate of 0.01 (decayed 0.95 every 8k iterations with from the training dataset, and after optimization of the hyperpa-
an exponential decay) for 600k iterations. The TabNet-L model uses rameters, we retrain with the entire training dataset. The perfor-
Nd =Na =128, λspar se =0.0001, B=4096, BV =128, m B =0.8, Nst eps =5 mance of the comparison models are from [5]. TabNet model uses
and γ =1.5. Feature transformers use two shared and two decision Nd =Na =32, λspar se =0.001, B=4096, BV =512, m B =0.8, Nst eps =5
step-dependent FC layer, ghost BN and GLU blocks. Adam is used and γ =1.2. Feature transformers use two shared and two decision
with a learning rate of 0.02 (decayed 0.9 every 8k iterations with step-dependent FC layer, ghost BN and GLU blocks. Adam is used
an exponential decay) for 600k iterations. with a learning rate of 0.002 (decayed 0.95 every 2000 iterations
Higgs: We split 500k samples for validation from the training with an exponential decay) for 15k iterations.
dataset, and after optimization of the hyperparameters, we re- KDD: For Appetency, Churn and Upselling datasets, we apply the
train with the entire training dataset. MLP models are from [39]. similar preprocessing and split as [45]. The performance of the
For gradient boosted trees [54], we tune the learning rate and comparison models are from [45]. TabNet models use Nd =Na =32,
depth – the gradient boosted tree-S, -M, and -L models use 50, λspar se =0.001, B=8192, BV =256, m B =0.9, Nst eps =7 and γ =1.2.
300 and 3000 trees respectively. TabNet-S model uses Nd =24, Each feature transformer block uses two shared and two decision
Na =26, λspar se =0.000001, B=16384, BV =512, m B =0.6, Nst eps =5 step-dependent FC layer, ghost BN and GLU blocks. Adam is used
and γ =1.5. Feature transformers use two shared and two deci- with a learning rate of 0.01 (decayed 0.9 every 1000 iterations with
sion step-dependent FC layer, ghost BN and GLU blocks. Adam is an exponential decay) for 10k iterations. For Census Income, the
used with a learning rate of 0.02 (decayed 0.9 every 20k iterations dataset and comparison model specifications follow [43]. TabNet
with an exponential decay) for 870k iterations. TabNet-M model model uses Nd =Na =48, λspar se =0.001, B=8192, BV =256, m B =0.9,
uses Nd =96, Na =32, λspar se =0.000001, B=8192, BV =256, m B =0.9, Nst eps =5 and γ =1.5. Feature transformers use two shared and
Nst eps =8 and γ =2.0. Feature transformers use two shared and two two decision step-dependent FC layer, ghost BN and GLU blocks.
decision step-dependent FC layer, ghost BN and GLU blocks. Adam Adam is used with a learning rate of 0.02 (decayed 0.7 every 2000
is used with a learning rate of 0.025 (decayed 0.9 every 10k itera- iterations with an exponential decay) for 4k iterations.
tions with an exponential decay) for 370k iterations. For unsuper- Mushroom edibility: TabNet model uses Nd =Na =8, λspar se =0.001,
vised pre-training, the decoder model uses Nd =Na =128, B=8192, B=2048, BV =128, m B =0.9, Nst eps =3 and γ =1.5. Feature transform-
BV =256, m B =0.9, and Nst eps =20. For supervised fine-tuning, we ers use two shared and two decision step-dependent FC layer, ghost
11
BN and GLU blocks. Adam is used with a learning rate of 0.01 • Adjustment of Nd and Na is an efficient way of obtaining a
(decayed 0.8 every 400 iterations with an exponential decay) for trade-off between performance and complexity. Nd = Na is a
10k iterations. reasonable choice for most datasets. A very high value of Nd and
Adult Census Income: TabNet model uses Nd =Na =16, λspar se = Na may suffer from overfitting and yield poor generalization.
0.0001, B=4096, BV =128, m B =0.98, Nst eps =5 and γ =1.5. Feature • An optimal choice of γ can have a major role on the performance.
transformers use two shared and two decision step-dependent layer, Typically a larger Nst eps value favors for a larger γ .
ghost BN and GLU blocks. Adam is used with a learning rate of • A large batch size is beneficial – if the memory constraints permit,
0.02 (decayed 0.4 every 2.5k iterations with an exponential decay) as large as 1-10 % of the total training dataset size can help
for 7.7k iterations. 85.7% test accuracy is achieved. performance. The virtual batch size is typically much smaller.
• Initially large learning rate is important, which should be gradu-
B ABLATION STUDIES ally decayed until convergence.
Table 11 shows the impact of ablation cases. For all cases, the
number of iterations is optimized on the validation set.
Obtaining high performance necessitates appropriately-adjusted
model capacity based on the characteristics of the dataset. Decreas-
ing the number of units Nd , Na or the number of decision steps
Nst eps are efficient ways of gradually decreasing the capacity with-
out significant degradation in performance. On the other hand,
increasing these parameters beyond some value causes optimiza-
tion issues and do not yield performance benefits. Replacing the
feature transformer block with a simpler alternative, such as a sin-
gle shared layer, can still give strong performance while yielding
a very compact model architecture. This shows the importance of
the inductive bias introduced with feature selection and sequential
attention. To push the performance, increasing the depth of the
feature transformer is an effective approach. While increasing the
depth, parameter sharing between feature transformer blocks across
decision steps is an efficient way to decrease model size without
degradation in performance. We indeed observe the benefit of par-
tial parameter sharing, compared to fully decision step-dependent
blocks or fully shared blocks. We also observe the empirical benefit
of GLU, compared to conventional nonlinearities like ReLU.
The strength of sparse feature selection depends on the two
parameters we introduce: γ and λspar se . We show that optimal
choice of these two is important for performance. A γ close to 1,
or a high λspar se may yield too tight constraints on the strength
of sparsity and may hurt performance. On the other hand, there is
still the benefit of a sufficient low γ and sufficiently high λspar se ,
to aid learning of the model via a favorable inductive bias.
Lastly, given the fixed model architecture, we show the benefit of
large-batch training, enabled by ghost BN [24]. The optimal batch
size for TabNet seems considerably higher than the conventional
batch sizes used for other data types, such as images or speech.
• For most datasets, Nst eps ∈ [3, 10] is optimal. Typically, when
there are more information-bearing features, the optimal value
of Nst eps tends to be higher. On the other hand, increasing
it beyond some value may adversely affect training dynamics
as some paths in the network becomes deeper and there are
more potentially-problematic ill-conditioned matrices. A very
high value of Nst eps may suffer from overfitting and yield poor
generalization.
12