Tabnet: Attentive Interpretable Tabular Learning: Sercan O. Arık Tomas Pfister

Download as pdf or txt
Download as pdf or txt
You are on page 1of 12

TabNet: Attentive Interpretable Tabular Learning

Sercan Ö. Arık Tomas Pfister


Google Cloud AI Google Cloud AI
soarik@google.com tpfister@google.com

ABSTRACT Why is deep learning worth exploring for tabular data? One obvi-
We propose a novel high-performance and interpretable canonical ous motivation is that, similarly to other domains, one would expect
deep tabular data learning architecture, TabNet. TabNet uses se- performance improvements from DNN-based architectures particu-
quential attention to choose which features to reason from at each larly for large datasets [22]. In addition, unlike tree learning which
does not use back-propagation into their inputs to guide efficient
arXiv:1908.07442v4 [cs.LG] 14 Feb 2020

decision step, enabling interpretability and more efficient learning


as the learning capacity is used for the most salient features. We learning from the error signal, DNNs enable gradient descent-based
demonstrate that TabNet outperforms other neural network and de- end-to-end learning for tabular data which can have a multitude of
cision tree variants on a wide range of non-performance-saturated benefits just like in the other domains it has already shown success:
tabular datasets and yields interpretable feature attributions plus (i) it could efficiently encode multiple data types like images along
insights into the global model behavior. Finally, for the first time to with tabular data; (ii) it would alleviate or eliminate the need for
our knowledge, we demonstrate self-supervised learning for tabu- feature engineering, which is currently a key aspect in tree-based
lar data, significantly improving performance with unsupervised tabular data learning methods; (iii) it would enable learning from
representation learning when unlabeled data is abundant. streaming data – tree learning needs global statistics to select split
points and straightforward modifications such as [4] typically yield
KEYWORDS lower accuracy compared to learning from entire data – in contrast,
DNNs show great potential for continual learning[44]; and per-
Interpretable deep learning, tabular data, attention, self-supervised.
haps most importantly (iv) end-to-end models allow representation
learning which enables many valuable new application scenarios
1 INTRODUCTION including data-efficient domain adaptation [17], generative mod-
Deep neural networks (DNNs) have shown notable success with eling [46] and semi-supervised learning [11]. Clearly there are
images [21, 50], text [9, 34] and audio [1, 56]. For these data types, significant benefits in both tree-based and DNN-based methods. Is
a major enabler of the progress is the availability of canonical DNN there a way to design a method that has the most beneficial aspects
architectures that efficiently encode the raw data into meaningful of both?
representations, resulting in high performance on new datasets and In this paper, we propose a new canonical DNN architecture for
related tasks with minor effort. For example, in image understand- tabular data, TabNet, that is designed to learn a ‘decision-tree-like’
ing, variants of residual convolutional networks (e.g. ResNet [21]) mapping in order to inherit the valuable benefits of tree-based meth-
provide reasonably-well performance on new image datasets or ods (interpretability and sparse feature selection), while providing
slightly different visual recognition problems (e.g. segmentation). the key benefits of DNN-based methods (representation learning
One data type that has yet to see such success with a canonical and end-to-end training). In particular, TabNet’s design considers
DNN architecture is tabular data. Despite being the most common two key needs: high performance and interpretability. As men-
data type in real-world AI1 [8], deep learning for tabular data re- tioned, high performance alone is often not enough – a DNN needs
mains under-explored, with variants of ensemble decision trees to be interpretable to substitute tree-based methods. Overall, we
still dominating most applications [28]. Why is this? First, be- make the following contributions in the design of our method:
cause tree-based approaches have certain benefits that make them
(1) Unlike tree-based methods, TabNet inputs raw tabular data
popular: (i) they are representionally efficient (and thus often high-
without any feature preprocessing and is trained using gradient
performing) for decision manifolds with approximately hyperplane
descent-based optimization to learn flexible representations and
boundaries which are common in tabular data; and (ii) they are
enable flexible integration into end-to-end learning.
highly interpretable in their basic form (e.g. by tracking decision
(2) TabNet uses sequential attention to choose which features to rea-
nodes) and there are effective post-hoc explainability methods for
son from at each decision step, enabling interpretability and
their ensemble form, e.g. [36] – this is an important concern in
better learning as the learning capacity is used for the most
many real-world applications (e.g. in financial services, where
salient features (see Fig. 1). This feature selection is instance-
trust behind a high-risk action is crucial); (iii) they are fast to train.
wise, e.g. it can be different for each input, and unlike other
Second, because previously-proposed DNN architectures are not
instance-wise feature selection methods like [6] or [61], Tab-
well-suited for tabular data: conventional DNNs based on stacked
Net employs a single deep learning architecture with end-to-end
convolutional layers or multi-layer perceptrons (MLPs) are vastly
learning.
overparametrized – the lack of appropriate inductive bias often
(3) We show that the above design choices lead to two valuable
causes them to fail to find optimal solutions for tabular decision
properties: (1) TabNet outperforms or is on par with other tabular
manifolds [17].
learning models on various datasets for classification and regres-
1 As it corresponds to a combination of any unrelated categorical and numerical feature. sion problems from different domains; and (2) TabNet enables
Input features

Professional occupation related Investment related

Feedback from Feedback to


Feature selection Input processing Feature selection Input processing
previous step next step
… …

Aggregate information

Predicted output (whether the income level >$50k)

Figure 1: TabNet’s sparse feature selection exemplified for Adult Census Income prediction [14]. Sparse feature selection
enables interpretability and better learning as the capacity is used for the most salient features. TabNet employs multiple
decision blocks that focus on processing a subset of input features for reasoning. Feature selection is based on feedback flowing
from the preceding decision step. Two decision blocks shown as examples process features that are related to professional
occupation and investments, respectively, in order to predict the income level.

Unsupervised pre-training Supervised fine-tuning


Age Cap. gain Education Occupation Gender Relationship Age Cap. gain Education Occupation Gender Relationship
53 200000 ? Exec-managerial F Wife 60 200000 Bachelors Exec-managerial M Husband
19 0 ? Farming-fishing M ? 23 0 High-school Farming-fishing M Unmarried
? 5000 Doctorate Prof-specialty M Husband 45 5000 Doctorate Prof-specialty M Husband
25 ? ? Handlers-cleaners F Wife 23 0 High-school Handlers-cleaners F Wife
59 300000 Bachelors ? ? Husband 56 300000 Bachelors Exec-managerial M Husband
33 0 Bachelors ? F ? 38 10000 Bachelors Prof-specialty F Wife
? 0 High-school Armed-Forces ? Husband 23 0 High-school Armed-Forces M Husband

TabNet encoder TabNet encoder

TabNet decoder Decision making

Age Cap. gain Education Occupation Gender Relationship Income > $50k

Masters True

High-school Unmarried False

43 True

0 High-school F False

Exec-managerial M True

Adm-clerical Wife True

39 M False

Figure 2: Self-supervised tabular learning. Real-world tabular datasets have interdependent feature columns, e.g., the educa-
tion level can be guessed from the occupation, or the gender can be guessed from the relationship. Unsupervised representa-
tion learning by masked self-supervised learning results in an improved encoder model for the supervised learning task.
two kinds of interpretability: local interpretability that visualizes 2 RELATED WORK
the importance of input features and how they are combined,
and global interpretability which quantifies the contribution of Feature selection: Feature selection in machine learning broadly
each input feature to the trained model. refers to judiciously picking a subset of features based on their use-
(4) Finally, we show that our canonical DNN design achieves sig- fulness for prediction. Commonly-used techniques such as forward
nificant performance improvements by using unsupervised selection and LASSO regularization [20] attribute feature impor-
pre-training to predict masked features (see Fig. 2). Our work tance based on the entire training data set, and are referred as
is the first demonstration of self-supervised learning for tabular global methods. Instance-wise feature selection refers to picking
data. features individually for each input, studied in [6] by training an
explainer model to maximize the mutual information between the
selected features and the response variable, and in [61] by using an

2
!#
+ Softmax
!" < % !" > %
!# > & !# > &
ReLU ReLU &
$" !" − $" % −1
−$" !" + $" % −1
−1 $# !# − $# &
%
−1 −$# !# + $# & !"
FC FC
W: [$" , - $" , 0, 0] W: [0, 0, $# , - $# ] !" < %
b: [-a $" , a $" , -1, -1] b: [-1, -1, -d $# , d $# ] !# < & !" > %
!# < &
[!" ] [!# ]
Mask Mask
M: [1, 0] M: [0, 1]
[!" , !# ]

Figure 3: Illustration of decision tree-like classification using conventional DNN blocks (left) and the corresponding decision
manifold (right). Relevant features are selected by using multiplicative sparse masks on inputs. The selected features are
linearly transformed, and after a bias addition (to represent boundaries) ReLU performs region selection by zeroing the regions
that are on the negative side of the boundary. Aggregation of multiple regions is based on addition. As C 1 and C 2 get larger,
the decision boundary gets sharper due to the softmax.
actor-critic framework to mimic a baseline while optimizing the propose a sequential mechanism for field-level attention. Unlike
feature selection. Unlike these, TabNet employs soft feature selection these, we demonstrate the application of sequential attention for
with controllable sparsity in end-to-end learning – a single model supervised or self-supervised learning instead of mapping tabular
jointly performs feature selection and output mapping, resulting in data to a different data type.
superior performance with compact representations. Self-supervised learning: Unsupervised representation learning
Tree-based learning: Tree-based models are the most common is shown to benefit the supervised learning task especially in small
approaches for tabular data learning. The prominent strength of data regime [47]. Recent work for language [13] and image [55]
tree-based models is their efficacy in picking global features with the data has shown significant advances – specifically careful choice of
most statistical information gain [18]. To improve the performance the unsupervised learning objective (masked input prediction) and
of standard tree-based models by reducing the model variance, one attention-based deep learning architecture is important.
common approach is ensembling. Among ensembling methods,
random forests [23] use random subsets of data with randomly 3 TABNET FOR TABULAR LEARNING
selected features to grow many trees. XGBoost [7] and LightGBM Decision trees are successful for learning from real-world tabular
[30] are the two recent ensemble decision tree approaches that datasets. However, even conventional DNN building blocks can be
dominate most of the recent data science competitions. Our exper- used to implement decision tree-like output manifold – see Fig. 3)
imental results for various datasets show that tree-based models for an example. In such a design, individual feature selection is key
can be outperformed when the representation capacity is improved to obtain decision boundaries in hyperplane form. This idea can be
with deep learning while retaining their feature selecting property. generalized for a linear combination of features where constituent
coefficients determine the proportion of each feature. TabNet is
Integration of DNNs into decision trees: Representing deci- based on such a tree-like functionality. We show that it outperforms
sion trees with canonical DNN building blocks as in [26] yields decision trees while reaping many of their benefits by careful de-
redundancy in representation and inefficient learning. Soft (neural) sign which: (i) uses sparse instance-wise feature selection learned
decision trees [33, 58] are proposed with differentiable decision based on the training dataset; (ii) constructs a sequential multi-step
functions, instead of non-differentiable axis-aligned splits. How- architecture, where each decision step can contribute to a portion of
ever, abandoning trees loses their automatic feature selection ability the decision that is based on the selected features; (iii) improves the
which is important for tabular data. In [60], a soft binning function learning capacity by non-linear processing of the selected features;
is proposed to simulate decision trees in DNNs, which needs to and (iv) mimics an ensemble via higher dimensions and more steps.
enumerate all possible decisions and is inefficient. [31] proposes a Fig. 4 shows the TabNet architecture for encoding tabular data.
DNN architecture by explicitly leveraging expressive feature com- Tabular data are comprised of numerical and categorical features.
binations, however, learning is based on transferring knowledge We use the raw numerical features and consider mapping of categor-
from a gradient-boosted decision tree and limited performance ical features with trainable embeddings2 . We do not consider any
improvements are observed. [53] proposes a DNN architecture global normalization features, but merely apply batch normalization
by adaptively growing from primitive blocks while representation (BN). We pass the same D-dimensional features f ∈ <B×D to each
learning into edges, routing functions and leaf nodes of a decision decision step, where B is the batch size. TabNet’s encoding is based
tree. TabNet differs from these methods as it embeds the soft feature on sequential multi-step processing with Nst eps decision steps. The
selection ability with controllable sparsity via sequential attention. i th step inputs the processed information from the (i − 1)th step
Attentive table-to-text models: Table-to-text models extract tex- 2 E.g.,
the three possible categories A, B and C for a particular feature can be learned
tual information from tabular data, for which recent works [3, 35] to be mapped to scalars 0.4, 0.1, and -0.2.
3
transformer transformer

Attentive
Mask
transformer

BN

Features

Step 1 Step 2
+
x Nsteps
… FC Output

ReLU ReLU
+ Softmax
Split Split Split Feature
Split Split
transformer
Feature Feature Feature
transformer transformer
Feature Feature …
transformer …
transformer transformer Encoded representation
Shared across decision steps Decision
Attentive
transformer
Mask
Attentive
transformer
Mask …step dependent
Attentive
Mask Step 1 Step 2
transformer

GLU

GLU

GLU

GLU
BN

BN

BN

BN
FC

FC

FC

FC
+

+
BN Feature 0.5 Feature 0.5 0.5
BN Agg. transformer transformer …
Agg.
Features
FC FC
Features

+
+

+ … Feature
+ … Reconstructed
features
attributes

(a) TabNet encoder architecture (b) TabNet decoder architecture


Feature
transformer
Attentive
Feature transformer
transformer
Shared across decision steps Decision step dependent

Prior scales

+
Shared across decision steps
GLU

GLU

GLU

Decision step dependent

GLU
BN

BN

BN

BN
FC

FC

FC

FC
+

+
0.5 0.5 0.5

Sparsemax
GLU

GLU

GLU

GLU
BN

BN

BN

BN
FC

FC

FC

FC
+

BN
FC

+
0.5 0.5 0.5

Attentive (c) Feature transformer (d) Attentive transformer


transformer
Figure 4: (a) TabNet encoder for classification or regression, composed of a feature transformer, an attentive transformer
and feature masking atPrior scalesdecision step. A split block divides the processed representation into two, to be used by the
each
+

Attentive
attentive transformer of the subsequent step astransformer
well as for constructing the overall output. At each decision step, the feature
Sparsemax

selection mask can provide interpretable information about the model’s functionality, and the masks can be aggregated to
BN
FC

obtain global feature important attribution. (b)Prior scales decoder, composed of a feature transformer block at each step. (c) A
TabNet
+

feature transformer block example – 4-layer network is shown, where 2 of the blocks are shared across all decision steps and 2
Sparsemax

are decision step-dependent. Each layer is composed of a fully-connected (FC) layer, BN and GLU nonlinearity. (d) An attentive
BN
FC

transformer block example – a single layer mapping is modulated with a prior scale information which aggregates how much
each feature has been used before the current decision step. Normalization of the coefficients is done using sparsemax [37]
for sparse selection of the most salient features at each decision step.
to decide which features to use and outputs the processed feature observed to be superior in performance and aligned with the goal
representation to be aggregated into the overall decision. The idea of sparse feature selection for most real-world datasets. Note that
of top-down attention in the sequential form is inspired by its ap- Eq. 1 ensures D j=1 M[i]b,j = 1. hi is a trainable function, shown in
Í
plications in processing visual and language data (e.g. in visual Fig. 4 using a FC layer, followed by BN. P[i] is the prior scale term,
question answering [25]) and reinforcement learning [40] while denoting how much a particular feature has been used previously:
searching for a small subset of relevant information in high dimen- Öi
sional input. Ablation studies in the Appendix focus on the impact P[i] = (γ − M[j]), (2)
j=1
of various design choices which are explained next. Guidelines on
where γ is a relaxation parameter – when γ = 1, a feature is en-
selection of the important hyperparameters are also provided in
forced to be used only at one decision step and as γ increases,
the Appendix.
more flexibility is provided to use a feature at multiple decision
Feature selection: We employ a learnable mask M[i] ∈ <B×D for
steps. P[0] is initialized as all ones, 1B×D , without any prior on the
soft selection of the salient features. Through sparse selection of the
masked features. If some features are unused (as in self-supervised
most salient features, the learning capacity of a decision step is not
learning), corresponding P[0] entries are made 0 to help model’s
wasted on irrelevant features, and thus the model becomes more
learning. To further control the sparsity of the selected features,
parameter efficient. The masking is in multiplicative form, M[i] · f.
we propose sparsity regularization in the form of entropy [19]:
We use an attentive transformer (see Fig. 4) to obtain the masks
using the processed features from the preceding step, a[i − 1]: ÕNs t eps ÕB ÕD −Mb,j [i]
Lspar se = log(Mb,j [i]+ϵ),
M[i] = sparsemax(P[i − 1] · hi (a[i − 1])). (1) i=1 b=1 j=1 Nst eps · B

Sparsemax normalization [37] encourages sparsity by mapping where ϵ is a small number for numerical stability. We add the
the Euclidean projection onto the probabilistic simplex, which is sparsity regularization to the overall loss, with a coefficient λspar se .
4
Sparsity may provide a favorable inductive bias for convergence to all datasets, categorical inputs are mapped to a single-dimensional
higher accuracy for datasets where most features are redundant. trainable scalar with a learnable embedding5 and numerical columns
Feature processing: We process the filtered features using a fea- are input without and preprocessing.6 We use standard classifica-
ture transformer (see Fig. 4) and then split for the decision step out- tion (softmax cross entropy) and regression (mean squared error)
put and information for the subsequent step, [d[i], a[i]] = fi (M[i]·f), loss functions and we train until convergence. Hyperparameters
where d[i] ∈ <B×Nd and a[i] ∈ <B×Na . For parameter-efficient of the TabNet models are optimized on a validation set and listed
and robust learning with high capacity, a feature transformer should in Appendix. TabNet performance is not very sensitive to most
comprise layers that are shared across all decision steps (as the same hyperparameters as shown with ablation studies in Appendix. In all
features are input across different decision steps), as well as decision of the experiments where we cite results from other papers, we use
step-dependent layers. Fig. 4 shows the implementation as con- the same training, validation and testing data split with the origi-
catenation of two shared layers and two decision step-dependent nal work. Adam optimization algorithm [32] and Glorot uniform
layers. Each FC layer is followed by BN and gated linear unit (GLU) initialization are used for training of all models. An open-source im-
nonlinearity [12]3 , eventually connected to a normalized √ residual
plementation can be found on https://github.com/google-research/
connection with normalization. Normalization with 0.5 helps to google-research/tree/master/tabnet.
stabilize learning by ensuring that the variance throughout the
network does not change dramatically [15]. For faster training, we
5.1 Instance-wise feature selection
Selection of the most salient features can be crucial for high perfor-
aim for large batch sizes. To improve performance with large batch
mance, especially for small datasets. We consider the 6 synthetic
sizes, all BN operations, except the one applied to the input features,
tabular datasets from [6] (consisting 10k training samples). The
are implemented in ghost BN [24] form, with a virtual batch size
synthetic datasets are constructed in such a way that only a subset
BV and momentum m B . For the input features, we observe the ben-
of the features determine the output. For Syn1, Syn2 and Syn3
efit of low-variance averaging and hence avoid ghost BN. Finally,
datasets, the ‘salient’ features are the same for all instances, so that
inspired by decision-tree like aggregation as in Fig. 3, we construct
ÍNs t eps an accurate global feature selection mechanism should be optimal.
the overall decision embedding as dout = i=1 ReLU(d[i]). We E.g., the ground truth output of the Syn2 dataset only depends
apply a linear mapping Wfinal dout to get the output mapping. For on features X 3 -X 6 . For Syn4, Syn5 and Syn6 datasets, the salient
discrete outputs, we additionally employ softmax during training features are instance dependent. E.g., for Syn4 dataset, X 11 is the
(and argmax during inference). indicator, and the ground truth output depends on either X 1 -X 2 or
4 TABULAR SELF-SUPERVISED LEARNING X 3 -X 6 depending on the value of X 11 . This instance dependence
makes global feature selection suboptimal, as the globally-salient
Decoding tabular features: We propose a decoder architecture features would be redundant for some instances.
to reconstruct tabular features from the encoded representations, Table 1 shows the performance of TabNet encoder vs. other tech-
obtained from the TabNet encoder. The decoder is composed of niques, including no selection, using only globally-salient features,
feature transformer blocks, followed by FC layers at each decision Tree Ensembles [16], LASSO regularization, L2X [6] and INVASE
step. The outputs are summed to obtain the reconstructed features.4 [61]. We observe that TabNet outperforms all other methods and is
Self-supervised objective: We propose the task of prediction of on par with INVASE. For Syn1, Syn2 and Syn3 datasets, we observe
missing feature columns from the others. Consider a binary mask that the TabNet performance is very close to global feature selec-
S ∈ {0, 1} B×D . The TabNet encoder inputs (1 − S) · f̂ and the TabNet tion. For Syn4, Syn5 and Syn6 datasets, we observe that TabNet
decoder outputs the reconstructed features, S · f̂. We initialize improves global feature selection, which would contain redundant
P[0] = (1 − S) in the encoder so that the model emphasizes merely features. (Feature selection is visualized in Sec. 5.3.) All other meth-
on the known features, and the decoder’s last FC layer is multiplied ods utilize a predictive model with 43k parameters, and the total
with S to merely output the unknown features. We consider the number of trainable parameters is 101k for INVASE due to the two
reconstruction loss to optimize in self-supervised phase: other networks in the actor-critic framework. On the other hand,
,r 2 TabNet is a single DNN architecture, and its model size is 26k for
Syn1-Syn3 datasets and 31k for Syn4-Syn6 datasets. This compact
ÕB ÕD ÕB ÕB
(f̂b,j − fb,j ) · Sb,j (fb,j − 1/B 2
fb,j ) .

b=1 j=1 b=1 b=1 end-to-end representation is one of TabNet’s valuable properties.
Normalization with the population standard deviation of the ground 5.2 Performance on real-world datasets
truth data is important, as the features may have very different
ranges. We sample the entries Sb,j independently from a Bernolli Forest Cover Type [14]: This dataset corresponds to the task of
distribution with parameter ps , at each iteration. classification of forest cover type from cartographic variables. Table
5 EXPERIMENTS 2 shows that TabNet significantly outperforms ensemble tree based
We study TabNet in wide range of problems, that contain regression approaches that are known to achieve solid performance on this
or classification tasks, particularly with published benchmarks. For task [38]. In addition, we consider AutoInt[51] for this task given
3 InGLU, first a linear mapping is applied and the dimensionality is doubled, and then 5 In some cases, higher dimensional embeddings may slightly improve the performance,
second half of the output is used to determine nonlinear processing on the first half. but interpretation of individual dimensions may become challenging.
4 We have also experimented sequential decoding of features with attentive transformer 6 Specially-designed feature engineering, e.g. logarithmic transformation of variables
at each decision step, but did not observe significant benefits for the purpose of self- highly-skewed distributions, may further improve the results but we leave it out of
supervised learning and thus chose this simpler architecture. the scope of this paper.
5
Table 1: Mean and std. of test area under the receiving operating characteristic curve (AUC) on 6 synthetic datasets from [6], for
TabNet vs. other feature selection-based DNN models: No selection: using all features without any feature selection, Global:
using only globally-salient features, Tree: Tree Ensembles [16], LASSO: LASSO-regularized model, L2X [6] and INVASE [61].
Bold numbers are the best method for each dataset.

Test AUC
Model
Syn1 Syn2 Syn3 Syn4 Syn5 Syn6
No selection .578 ± .004 .789 ± .003 .854 ± .004 .558 ± .021 .662 ± .013 .692 ± .015
Tree .574 ± .101 .872 ± .003 .899 ± .001 .684 ± .017 .741 ± .004 .771 ± .031
Lasso .498 ± .006 .555 ± .061 .886 ± .003 .512 ± .031 .691 ± .024 .727 ± .025
L2X .498 ± .005 .823 ± .029 .862 ± .009 .678 ± .024 .709 ± .008 .827 ± .017
INVASE .690 ± .006 .877 ± .003 .902 ± .003 .787 ± .004 .784 ± .005 .877 ± .003
Global .686 ± .005 .873 ± .003 .900 ± .003 .774 ± .006 .784 ± .005 .858 ± .004
TabNet .682 ± .005 .892 ± .004 .897 ± .003 .776 ± .017 .789 ± .009 .878 ± .004
Table 2: Performance for Forest Cover Type dataset.
the cards. The input-output relationship is deterministic and hand-
Model Test accuracy (%) crafted rules implemented with several lines of code can get 100%
XGBoost 89.34 accuracy. Yet, conventional DNNs, decision trees, and even their
LightGBM 89.28 hybrid variant of deep neural decision tree models [60] severely
CatBoost 85.14 suffer from the imbalanced data and cannot learn the required sort-
AutoInt 90.24 ing and ranking operations with the raw input features [60]. Tuned
XGBoost, CatBoost, and LightGBM show very slight improvements
AutoML Tables (2 node hours) 94.56 over them. On the other hand, TabNet significantly outperforms the
AutoML Tables (10 node hours) 96.67 other methods and approaches to deterministic rule accuracy, as it
AutoML Tables (30 node hours) 96.93 can perform highly-nonlinear processing with high depth, without
TabNet 96.99 overfitting thanks to instance-wise feature selection.

its strength for problems with high feature dimensionality. Au- Table 4: Performance for Sarcos Robotics Arm Inverse Dy-
toInt models pairwise feature interactions with an attention-based namics dataset. Three TabNet models of different sizes are
DNN [51] and significantly underperforms TabNet that employs considered (denoted with -S, -M and -L).
instance-wise feature selection, and considers the interaction be-
Model Test MSE Model size
tween different features if the model infers that it is the appropriate
Random forest 2.39 16.7K
processing to apply. Lastly, we consider AutoML Tables [2], an au-
Stochastic decision tree 2.11 28K
tomated search framework based on ensemble of models including
MLP 2.13 0.14M
linear feed-forward DNN, gradient boosted decision tree, AdaNet
Adaptive neural tree 1.23 0.60M
[10] and ensembles [2]. For AutoML Tables, the amount of node
Gradient boosted tree 1.44 0.99M
hours reflects the measure of the count of searched models for the
ensemble and their complexity.7 A single TabNet model without TabNet-S 1.25 6.3K
fine-grained hyperparameter search outperforms the accuracy of TabNet-M 0.28 0.59M
ensemble models with very thorough hyperparameter search. TabNet-L 0.14 1.75M

Table 3: Performance for Poker Hand induction dataset.


Sarcos Robotics Arm Inverse Dynamics [57]: This dataset cor-
Model Test accuracy (%) responds to the task of regression of inverse dynamics of seven
Decision tree 50.0 degrees-of-freedom of an anthropomorphic robot arm. [53] shows
MLP 50.0 that decent performance with a very small model is possible with
Deep neural decision tree 65.1 a random forest, but the best performance is achieved with their
XGBoost 71.1 adaptive neural tree, which slightly outperforms gradient boosted
LightGBM 70.0 tree. In the very small model size regime, TabNet’s performance is
CatBoost 66.6 on par with the best proposed model with 100x more parameters.
TabNet 99.2 TabNet allocates its capacity to salient features, and yields a more
compact model. When the model size is not constrained, TabNet
Rule-based 100.0 achieves almost an order of magnitude lower test MSE.

Higgs Boson [14]: This dataset corresponds to the task of distin-


Poker Hand [14]: This dataset corresponds to the task of classifi-
guishing between a signal process which produces Higgs bosons
cation of the poker hand from the raw suit and rank attributes of
and a background process. Due to its much larger size (10.5M train-
7 10 node hours is well above the suggested exploration time [2] for this dataset. ing examples), DNNs outperform decision tree variants on this task
6
Table 5: Performance on Higgs Boson dataset. Two TabNet
5.3 Interpretability
models are considered (denoted with -S and -M).
The feature selection masks in TabNet can be understand selected
Model Test accuracy (%) Model size features at each step. Such a capability is not available for con-
Sparse evolutionary MLP 78.47 81K ventional DNNs such as MLPs, as each subsequent layer jointly
Gradient boosted tree-S 74.22 0.12M processes all features without a sparsity-controlled selection mech-
Gradient boosted tree-M 75.97 0.69M anism. For feature selection masks, if Mb,j [i] = 0, then j th feature
MLP 78.44 2.04M of the b th sample should have no contribution to the decision. If fi
Gradient boosted tree-L 76.98 6.96M were a linear function, the coefficient Mb,j [i] would correspond to
TabNet-S 78.25 81K the feature importance of fb,j . Although each decision step employs
TabNet-M 78.84 0.66M non-linear processing, their outputs are combined later in a linear
way. Our goal is to quantify an aggregate feature importance in
even with very large ensembles. We show that the TabNet outper- addition to analysis of each step. Combining the masks at different
forms MLPs with more compact representations. We also compare steps requires a coefficient that can weigh the relative importance
ÍNd
to the state-of-the-art evolutionary sparsification algorithm [39] of each step in the decision. We use η b [i] = c=1 ReLU(db,c [i]) to
that applies non-structured sparsity integrated into training, yield- denote the aggregate decision contribution at i th decision step for
ing a low number of parameters. With its compact representation the b th sample. Intuitively, if db,c [i] < 0, then all features at i th
TabNet yields almost similar performance to sparse evolutionary decision step should have 0 contribution to the overall decision.
training for the same number of parameters. This sparsity learned As its value increases, it plays a higher role in the overall linear
by TabNet is structured differently from alternative approaches – combination. Scaling the decision mask at each decision step with
it does not degrade the operational intensity of the model [59] and η b [i], we propose the aggregate feature importance mask,
can efficiently utilize modern multi-core processors.8 ÕNs t eps . ÕD Õ N
s t eps
Magg−b,j = η b [i]Mb,j [i] η b [i]Mb,j [i].
Table 6: Performance for Rossmann Store Sales dataset. i=1 j=1 i=1
ÍD
Normalization is used to ensure j=1 Magg−b,j = 1.
Model Test MSE
XGBoost 490.83 Synthetic datasets: Fig. 5 shows the aggregate feature importance
LightGBM 504.76 masks for the synthetic datasets discussed in Sec. 5.1.9 The ground
CatBoost 489.75 truth output of the Syn2 dataset only depends on features X 3 -X 6 .
TabNet 485.12 We observe that the aggregate masks are almost all zero for irrel-
evant features and they merely focus on relevant ones. For Syn4
dataset, X 11 is the indicator, and the ground truth output depends
Rossmann Store Sales [29]: This dataset corresponds to the task on either X 1 -X 2 or X 3 -X 6 depending on the value of X 11 . TabNet
of forecasting the store sales from static and time-varying features. yields accurate instance-wise feature selection – it allocates a mask
We observe that TabNet outperforms XGBoost, LightGBM and Cat- to focus on X 11 , and assigns almost all-zero weights to irrelevant
Boost that are commonly-used for such problems. The time features features (the ones other than one of the two feature groups).
(e.g. day) obtain high importance, and the benefit of instance-wise
feature selection is particularly observed for cases like holidays Table 8: Importance ranking of features for Adult Census
where the sales dynamics are different. Income. TabNet yields feature importance rankings consis-
tent with the well-known methods.
Table 7: Performance on KDD datasets.
Feature SHAP Skater XGBoost TabNet
Test accuracy (%) Age 1 1 1 1
Model
Appetency Churn Upselling Census Capital gain 3 3 4 6
XGBoost 98.2 92.7 95.1 95.8 Capital loss 9 9 6 4
CatBoost 98.2 92.8 95.1 95.7 Education 5 2 3 2
TabNet 98.2 92.7 95.0 95.5 Gender 8 10 12 8
Hours per week 7 7 2 7
Marital status 2 8 10 9
KDD datasets: Appetency, Churn and Upselling datasets are clas- Native country 11 11 9 12
sification tasks for customer relationship management, and KDD Occupation 6 5 5 3
Census Income [14] dataset is for income prediction from demo- Race 12 12 11 11
graphic and employment related variables. These datasets show Relationship 4 4 8 5
saturated behavior in performance (even simple models yield sim- Work class 10 8 7 10
ilar results). Table 7 shows that TabNet achieves very similar or
slightly worse performance than XGBoost and CatBoost, that are
known to be robust as they contain high amount of ensembles. Real-world datasets: We first consider the simple real-world task
8 Matrix sparsification techniques such as adaptive pruning [41] in TabNet could further 9 For
better illustration here, unlike Sec. 5.1, the models are trained with 10M training
improve the parameter-efficiency. samples rather than 10K as we obtain sharper feature selection masks
7
Syn2 dataset
Magg M[1] M[2] M[3] M[4]
X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11

/
Test samples

Syn4 dataset
Magg M[1] M[2] M[3] M[4] M[5]

Figure 5: Feature importance masks M[i] (that indicate which features are selected at i th step) and the aggregate feature im-
portance mask Magg showing the global instance-wise feature selection for Syn2 and Syn6 datasets from [6]. Brighter colors
show a higher value. E.g. for Syn2 dataset, only four features (X 3 -X 6 ) are used.
.

Integrated Gradients [52] and DeepLift [49] assign importance score


ratios of less than 30% [27].
Next, we consider Adult Census Income, where the task is to
distinguish whether a personfis income is above $50,000. Table 8
shows the importance ranking of features for TabNet vs. other ex-
plainability techniques from [36] [42]. We observe the commonality
of the most important features (“Age”, “Capital gain/loss”, “Edu-
cation number”, “Relationship”) and the least important features
(“Native country”, “Race”, “Gender”, “Work class”). For the same
problem, Fig. 6(c) shows the impact of the most important feature
on the output decision by visualizing the T-SNE of the decision
manifold. A clear separation between age groups is observed, as
suggested by “Age” being the most important feature by TabNet.

Figure 6: T-SNE of the decision manifold for Adult Census 5.4 Self-supervised learning
Income test samples and the impact of the top feature ‘Age’. We study self-supervised learning on Higgs and Forest Cover Type
datasets. For the pre-training task of guessing the missing columns,
we use the masking parameter ps = 0.8 and train for 1M iterations.
of mushroom edibility prediction [14]. TabNet achieves 100% test We use a subset of the labeled dataset for supervised fine-tuning
accuracy on this dataset. It is indeed known [14] that “Odor” is with a validation set to determine the number of iterations for
the most discriminative feature for this task, with “Odor” feature early stopping. A large validation dataset would be unrealistic for
only, model can get > 98.5% test accuracy [14]. Thus, a high feature small training datasets, so in these experiments we assume its size
importance is expected for it. TabNet assigns an importance score is equal to the training dataset. Table 9 shows that unsupervised
ratio of 43% for it, while other notable methods like LIME [48], pre-training significantly improves performance on the supervised
8
REFERENCES
[1] Dario Amodei, Rishita Anubhai, Eric Battenberg, Carl Case, Jared Casper, et al.
2015. Deep Speech 2: End-to-End Speech Recognition in English and Mandarin.
arXiv:1512.02595 (2015).
[2] AutoML. 2019. AutoML Tables – Google Cloud. https://cloud.google.com/
automl-tables/
[3] J. Bao, D. Tang, N. Duan, Z. Yan, M. Zhou, and T. Zhao. 2019. Text Generation
From Tables. IEEE Trans Audio, Speech, and Language Processing 27, 2 (Feb 2019),
311–320.
[4] Yael Ben-Haim and Elad Tom-Tov. 2010. A Streaming Parallel Decision Tree
Algorithm. JMLR 11 (March 2010), 849–872.
[5] Catboost. 2019. Benchmarks. https://github.com/catboost/benchmarks. Accessed:
2019-11-10.
[6] Jianbo Chen, Le Song, Martin J. Wainwright, and Michael I. Jordan. 2018. Learn-
ing to Explain: An Information-Theoretic Perspective on Model Interpretation.
arXiv:1802.07814 (2018).
[7] Tianqi Chen and Carlos Guestrin. 2016. XGBoost: A Scalable Tree Boosting
System. In KDD.
[8] Michael Chui, James Manyika, Mehdi Miremadi, Nicolaus Henke, Rita Chung,
Figure 7: Convergence with unsupervised pre-training is et al. 2018. Notes from the AI Frontier. McKinsey Global Institute (4 2018).
[9] Alexis Conneau, Holger Schwenk, Loı̈c Barrault, and Yann LeCun. 2016. Very
much faster, shown for Higgs dataset with 10k samples. Deep Convolutional Networks for Natural Language Processing. arXiv:1606.01781
(2016).
Table 9: Self-supervised tabular learning results. Mean and [10] Corinna Cortes, Xavi Gonzalvo, Vitaly Kuznetsov, Mehryar Mohri, and Scott
std. of accuracy (over 15 runs) on Higgs with Tabnet-M Yang. 2016. AdaNet: Adaptive Structural Learning of Artificial Neural Networks.
arXiv:1607.01097 (2016).
model, varying the size of the training dataset for super- [11] Zihang Dai, Zhilin Yang, Fan Yang, William W. Cohen, and Ruslan Salakhutdinov.
vised fine-tuning. 2017. Good Semi-supervised Learning that Requires a Bad GAN. arxiv:1705.09783
(2017).
[12] Yann N. Dauphin, Angela Fan, Michael Auli, and David Grangier. 2016. Language
Training Test accuracy (%) Modeling with Gated Convolutional Networks. arXiv:1612.08083 (2016).
dataset size Supervised With pre-training [13] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. 2018. BERT:
1k 57.47 ± 1.78 61.37 ± 0.88 Pre-training of Deep Bidirectional Transformers for Language Understanding.
arXiv:1810.04805 (2018).
10k 66.66 ± 0.88 68.06 ± 0.39 [14] Dheeru Dua and Casey Graff. 2017. UCI Machine Learning Repository. http:
100k 72.92 ± 0.21 73.19 ± 0.15 //archive.ics.uci.edu/ml
[15] Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin.
2017. Convolutional Sequence to Sequence Learning. arXiv:1705.03122 (2017).
[16] Pierre Geurts, Damien Ernst, and Louis Wehenkel. 2006. Extremely randomized
Table 10: Self-supervised tabular learning results. Mean and trees. Machine Learning 63, 1 (01 Apr 2006), 3–42.
[17] Ian Goodfellow, Yoshua Bengio, and Aaron Courville. 2016. Deep Learning. MIT
std. of accuracy (over 15 runs) on Forest Cover Type, varying Press.
the size of the training dataset for supervised fine-tuning. [18] K. Grabczewski and N. Jankowski. 2005. Feature selection with decision tree
criterion. In HIS.
Training Test accuracy (%) [19] Yves Grandvalet and Yoshua Bengio. 2004. Semi-supervised Learning by Entropy
Minimization. In NIPS.
dataset size Supervised With pre-training [20] Isabelle Guyon and André Elisseeff. 2003. An Introduction to Variable and Feature
1k 65.91 ± 1.02 67.86 ± 0.63 Selection. JMLR 3 (March 2003), 1157–1182.
[21] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 2015. Deep Residual
10k 78.85 ± 1.24 79.22 ± 0.78 Learning for Image Recognition. arXiv:1512.03385 (2015).
[22] Joel Hestness, Sharan Narang, Newsha Ardalani, Gregory F. Diamos, Heewoo
Jun, Hassan Kianinejad, Md. Mostofa Ali Patwary, Yang Yang, and Yanqi Zhou.
2017. Deep Learning Scaling is Predictable, Empirically. arXiv:1712.00409 (2017).
[23] Tin Kam Ho. 1998. The random subspace method for constructing decision
classification task, especially in the regime where the unlabeled forests. PAMI 20, 8 (Aug 1998), 832–844.
dataset is much larger than the labeled dataset. As exemplified in [24] Elad Hoffer, Itay Hubara, and Daniel Soudry. 2017. Train longer, generalize
Fig. 7 the model convergence is much faster with unsupervised pre- better: closing the generalization gap in large batch training of neural networks.
arXiv:1705.08741 (2017).
training. Very fast convergence can be highly beneficial particularly [25] Drew A. Hudson and Christopher D. Manning. 2018. Compositional Attention
in scenarios like continual learning and domain adaptation. Networks for Machine Reasoning. arXiv:1803.03067 (2018).
[26] K. D. Humbird, J. L. Peterson, and R. G. McClarren. 2018. Deep Neural Network
Initialization With Decision Trees. IEEE Trans Neural Networks and Learning
6 CONCLUSION Systems (2018).
We have proposed TabNet, a novel deep learning architecture for [27] Mark Ibrahim, Melissa Louie, Ceena Modarres, and John W. Paisley. 2019.
Global Explanations of Neural Networks: Mapping the Landscape of Predic-
tabular learning. TabNet uses a sequential attention mechanism tions. arxiv:1902.02384 (2019).
to choose a subset of semantically meaningful features to process [28] Kaggle. 2019. Historical Data Science Trends on Kaggle. https://www.kaggle.
com/shivamb/data-science-trends-on-kaggle. Accessed: 2019-04-20.
at each decision step. Instance-wise feature selection enables ef- [29] Kaggle. 2019. Rossmann Store Sales. https://www.kaggle.com/c/
ficient learning as the model capacity is fully used for the most rossmann-store-sales. Accessed: 2019-11-10.
salient features, and also yields more interpretable decision making [30] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, et al. 2017. Light-
GBM: A Highly Efficient Gradient Boosting Decision Tree. In NIPS.
via visualization of selection masks. We demonstrate that TabNet [31] Guolin Ke, Jia Zhang, Zhenhui Xu, Jiang Bian, and Tie-Yan Liu. 2019. TabNN: A
outperforms previous work across tabular datasets from different Universal Neural Network Solution for Tabular Data. https://openreview.net/
domains. Lastly, we demonstrate significant benefits of unsuper- forum?id=r1eJssCqY7
[32] Diederik P. Kingma and Jimmy Ba. 2014. Adam: A Method for Stochastic
vised pre-training for fast adaptation and improved performance. Optimization. In ICLR.

9
[33] P. Kontschieder, M. Fiterau, A. Criminisi, and S. R. Bul. 2015. Deep Neural APPENDIX
Decision Forests. In ICCV.
[34] Siwei Lai, Liheng Xu, Kang Liu, and Jun Zhao. 2015. Recurrent Convolutional A EXPERIMENT HYPERPARAMETERS
Neural Networks for Text Classification. In AAAI.
[35] Tianyu Liu, Kexiang Wang, Lei Sha, Baobao Chang, and Zhifang Sui. 2017. For all datasets, we use a pre-defined hyperparameter search space.
Table-to-text Generation by Structure-aware Seq2seq Learning. arXiv:1711.09724 Nd and Na are chosen from {8, 16, 24, 32, 64, 128}, Nst eps is cho-
(2017).
[36] Scott M. Lundberg, Gabriel G. Erion, and Su-In Lee. 2018. Consistent Individual-
sen from {3, 4, 5, 6, 7, 8, 9, 10}, γ is chosen from {1.0, 1.2, 1.5, 2.0},
ized Feature Attribution for Tree Ensembles. arXiv:1802.03888 (2018). λspar se is chosen from {0, 0.000001, 0.0001, 0.001, 0.01, 0.1}, B is
[37] André F. T. Martins and Ramón Fernández Astudillo. 2016. From Softmax chosen from {256, 512, 1024, 2048, 4096, 8192, 16384, 32768}, BV is
to Sparsemax: A Sparse Model of Attention and Multi-Label Classification.
arXiv:1602.02068 (2016).
chosen from {256, 512, 1024, 2048, 4096} and m B is chosen from
[38] Rory Mitchell, Andrey Adinets, Thejaswi Rao, and Eibe Frank. 2018. XGBoost: {0.6, 0.7, 0.8, 0.9, 0.95, 0.98}. If the model size is not under the de-
Scalable GPU Accelerated Learning. arXiv:1806.11248 (2018). sired cutoff, we decrease the value to satisfy the size constraint. For
[39] Decebal Mocanu, Elena Mocanu, Peter Stone, Phuong Nguyen, Madeleine
Gibescu, and Antonio Liotta. 2018. Scalable training of artificial neural net- all the comparison models, we run a hyperparameter tuning with
works with adaptive sparse connectivity inspired by network science. Nature the same number of search steps.
Communications 9 (12 2018).
Synthetic: All TabNet models use Nd =Na =16, B=3000, BV =100,
[40] Alex Mott, Daniel Zoran, Mike Chrzanowski, Daan Wierstra, and Danilo J.
Rezende. 2019. S3TA: A Soft, Spatial, Sequential, Top-Down Attention Model. m B =0.7. For Syn1 we use λspar se =0.02, Nst eps =4 and γ =2.0; for
https://openreview.net/forum?id=B1gJOoRcYQ Syn2 and Syn3 we use λspar se =0.01, Nst eps =4 and γ =2.0; and
[41] Sharan Narang, Gregory F. Diamos, Shubho Sengupta, and Erich Elsen. 2017.
Exploring Sparsity in Recurrent Neural Networks. arXiv:1704.05119 (2017).
for Syn4, Syn5 and Syn6 we use λspar se =0.005, Nst eps =5 and
[42] Nbviewer. 2019. Notebook on Nbviewer. https://nbviewer.jupyter.org/github/ γ =1.5. Feature transformers use two shared and two decision step-
dipanjanS/data science for all/blob/master/tds model interpretation xai/ dependent FC layer, ghost BN and GLU blocks. All models use
Human-interpretableMachineLearning-DS.ipynb#
[43] N. C. Oza. 2005. Online bagging and boosting. In IEEE Trans Conference on Adam with a learning rate of 0.02 (decayed 0.7 every 200 iterations
Systems, Man and Cybernetics. with an exponential decay) for 4k iterations. For visualizations in
[44] German Ignacio Parisi, Ronald Kemker, Jose L. Part, Christopher Kanan, and
Stefan Wermter. 2018. Continual Lifelong Learning with Neural Networks: A
Sec. 5.3, we also train TabNet models with datasets of size 10M
Review. arXiv:1802.07569 (2018). samples. For this case, we choose Nd = Na = 32, λspar se =0.001,
[45] Liudmila Prokhorenkova, Gleb Gusev, Aleksandr Vorobev, Anna Veronika Doro- B=10000, BV =100, m B =0.9. Adam is used with a learning rate of
gush, and Andrey Gulin. 2018. CatBoost: unbiased boosting with categorical
features. In NIPS.
0.02 (decayed 0.9 every 2k iterations with an exponential decay)
[46] Alec Radford, Luke Metz, and Soumith Chintala. 2015. Unsupervised Repre- for 15k iterations. For Syn2 and Syn3, Nst eps =4 and γ =2. For Syn4
sentation Learning with Deep Convolutional Generative Adversarial Networks. and Syn6, Nst eps =5 and γ =1.5.
arXiv:1511.06434 (2015).
[47] Rajat Raina, Alexis Battle, Honglak Lee, Benjamin Packer, and Andrew Y. Ng. Forest Cover Type: The dataset partition details, and the hyperpa-
2007. Self-Taught Learning: Transfer Learning from Unlabeled Data. In ICML. rameters of XGBoost, LigthGBM, and CatBoost are from [38]. We re-
[48] Marco Ribeiro, Sameer Singh, and Carlos Guestrin. 2016. fiWhy Should I Trust
You?fi: Explaining the Predictions of Any Classifier. In KDD.
optimize AutoInt hyperparameters. TabNet model uses Nd =Na =64,
[49] Avanti Shrikumar, Peyton Greenside, and Anshul Kundaje. 2017. Learning Im- λspar se =0.0001, B=16384, BV =512, m B =0.7, Nst eps =5 and γ =1.5.
portant Features Through Propagating Activation Differences. arXiv:1704.02685 Feature transformers use two shared and two decision step-dependent
(2017).
[50] Karen Simonyan and Andrew Zisserman. 2014. Very Deep Convolutional Net- FC layer, ghost BN and GLU blocks. Adam is used with a learning
works for Large-Scale Image Recognition. arXiv:1409.1556 (2014). rate of 0.02 (decayed 0.95 every 0.5k iterations with an exponen-
[51] Weiping Song, Chence Shi, Zhiping Xiao, Zhijian Duan, Yewen Xu, Ming Zhang, tial decay) for 130k iterations. For unsupervised pre-training, the
and Jian Tang. 2018. AutoInt: Automatic Feature Interaction Learning via Self-
Attentive Neural Networks. arxiv:1810.11921 (2018). decoder model uses Nd =Na =64, B=16384, BV =512, m B =0.7, and
[52] Mukund Sundararajan, Ankur Taly, and Qiqi Yan. 2017. Axiomatic Attribution Nst eps =10. For supervised fine-tuning, we use the batch size B=BV
for Deep Networks. arXiv:1703.01365 (2017).
[53] Ryutaro Tanno, Kai Arulkumaran, Daniel C. Alexander, Antonio Criminisi, and
as the training datasets are small.
Aditya V. Nori. 2018. Adaptive Neural Trees. arXiv:1807.06699 (2018). Poker Hands: We split 6k samples for validation from the train-
[54] Tensorflow. 2019. Classifying Higgs boson processes in the HIGGS Data Set. ing dataset, and after optimization of the hyperparameters, we
https://github.com/tensorflow/models/tree/master/official/boosted trees
[55] Trieu H. Trinh, Minh-Thang Luong, and Quoc V. Le. 2019. Selfie: Self-supervised
retrain with the entire training dataset. Decision tree, MLP and
Pretraining for Image Embedding. arXiv:1906.02940 (2019). deep neural decision tree models follow the same hyperparameters
[56] Aäron van den Oord, Sander Dieleman, Heiga Zen, Karen Simonyan, Oriol with [60]. We tune the hyperparameters of XGBoost, LigthGBM,
Vinyals, et al. 2016. WaveNet: A Generative Model for Raw Audio.
arXiv:1609.03499 (2016). and CatBoost. TabNet uses Nd =Na =16, λspar se =0.000001, B=4096,
[57] Sethu Vijayakumar and Stefan Schaal. 2000. Locally Weighted Projection Re- BV =1024, m B = 0.95, Nst eps =4 and γ =1.5. Feature transformers
gression: An O(n) Algorithm for Incremental Real Time Learning in High Di-
use two shared and two decision step-dependent FC layer, ghost
mensional Space. In ICML.
[58] Suhang Wang, Charu Aggarwal, and Huan Liu. 2017. Using a random forest to BN and GLU blocks. Adam is used with a learning rate of 0.01
inspire a neural network and improving on it. In SDM. (decayed 0.95 every 500 iterations with an exponential decay) for
[59] Wei Wen, Chunpeng Wu, Yandan Wang, Yiran Chen, and Hai Li. 2016. Learning
Structured Sparsity in Deep Neural Networks. arXiv:1608.03665 (2016).
50k iterations.
[60] Yongxin Yang, Irene Garcia Morillo, and Timothy M. Hospedales. 2018. Deep Sarcos: We split 4.5k samples for validation from the training
Neural Decision Trees. arXiv:1806.06988 (2018). dataset, and after optimization of the hyperparameters, we retrain
[61] Jinsung Yoon, James Jordon, and Mihaela van der Schaar. 2019. INVASE: Instance-
wise Variable Selection using Neural Networks. In ICLR. with the entire training dataset. All comparison models follow
the hyperparameters from [53]. TabNet-S model uses Nd =Na =8,
λspar se =0.0001, B=4096, BV =256, m B =0.9, Nst eps =3 and γ =1.2.
Each feature transformer block uses one shared and two decision
step-dependent FC layer, ghost BN and GLU blocks. Adam is used
with a learning rate of 0.01 (decayed 0.95 every 8k iterations with
10
Table 11: Ablation studies for the TabNet encoder model for the forest cover type dataset.

Test accuracy Model


Ablation cases
% (difference) size
Base (Nd = Na = 64, γ = 1.5, Nst eps = 5, λspar se = 0.0001, feature transformer block composed of two shared and
96.99 470k
two decision step-dependent layers, B = 16384)
Decreasing capacity via number of units (with Nd = Na = 32) 94.99 (-2.00) 129k
Decreasing capacity via number of decision steps (with Nst eps = 3) 96.22 (-0.77) 328k
Increasing capacity via number of decision steps (with Nst eps = 9) 95.48 (-1.51) 755k
Decreasing capacity via all-shared feature transformer blocks 96.74 (-0.25) 143k
Increasing capacity via decision step-dependent feature transformer blocks 96.76 (-0.23) 703k
Feature transformer block as a single shared layer 95.32 (-1.67) 35k
Feature transformer block as a single shared layer, with ReLU instead of GLU 93.92 (-3.07) 27k
Feature transformer block as two shared layers 96.34 (-0.66) 71k
Feature transformer block as two shared layers and 1 decision step-dependent layer 96.54 (-0.45) 271k
Feature transformer block as a single decision-step dependent layer 94.71 (-0.28) 105k
Feature transformer block as a single decision-step dependent layer, with Nd =Na =128 96.24 (-0.75) 208k
Feature transformer block as a single decision-step dependent layer, with Nd =Na =128 and replacing GLU with
95.67 (-1.32) 139k
ReLU
Feature transformer block as a single decision-step dependent layer, with Nd =Na =256 and replacing GLU with
96.41 (-0.58) 278k
ReLU
Reducing the impact of prior scale (with γ = 3.0) 96.49 (-0.50) 470k
Increasing the impact of prior scale (with γ = 1.0) 96.67 (-0.32) 470k
No sparsity regularization (with λspar se = 0) 96.50 (-0.49) 470k
High sparsity regularization (with λspar se = 0.01) 93.87 (-3.12) 470k
Small batch size (B = 4096) 96.42 (-0.57) 470k

an exponential decay) for 600k iterations. TabNet-M model uses use the batch size B=BV as the training datasets are small.
Nd =Na =64, λspar se =0.0001, B=4096, BV =128, m B =0.8, Nst eps =7 Rossmann: We use the same preprocessing and data split with
and γ =1.5. Feature transformers use two shared and two decision [5] – data from 2014 is used for training and validation, whereas
step-dependent FC layer, ghost BN and GLU blocks. Adam is used 2015 is used for testing. We split 100k samples for validation
with a learning rate of 0.01 (decayed 0.95 every 8k iterations with from the training dataset, and after optimization of the hyperpa-
an exponential decay) for 600k iterations. The TabNet-L model uses rameters, we retrain with the entire training dataset. The perfor-
Nd =Na =128, λspar se =0.0001, B=4096, BV =128, m B =0.8, Nst eps =5 mance of the comparison models are from [5]. TabNet model uses
and γ =1.5. Feature transformers use two shared and two decision Nd =Na =32, λspar se =0.001, B=4096, BV =512, m B =0.8, Nst eps =5
step-dependent FC layer, ghost BN and GLU blocks. Adam is used and γ =1.2. Feature transformers use two shared and two decision
with a learning rate of 0.02 (decayed 0.9 every 8k iterations with step-dependent FC layer, ghost BN and GLU blocks. Adam is used
an exponential decay) for 600k iterations. with a learning rate of 0.002 (decayed 0.95 every 2000 iterations
Higgs: We split 500k samples for validation from the training with an exponential decay) for 15k iterations.
dataset, and after optimization of the hyperparameters, we re- KDD: For Appetency, Churn and Upselling datasets, we apply the
train with the entire training dataset. MLP models are from [39]. similar preprocessing and split as [45]. The performance of the
For gradient boosted trees [54], we tune the learning rate and comparison models are from [45]. TabNet models use Nd =Na =32,
depth – the gradient boosted tree-S, -M, and -L models use 50, λspar se =0.001, B=8192, BV =256, m B =0.9, Nst eps =7 and γ =1.2.
300 and 3000 trees respectively. TabNet-S model uses Nd =24, Each feature transformer block uses two shared and two decision
Na =26, λspar se =0.000001, B=16384, BV =512, m B =0.6, Nst eps =5 step-dependent FC layer, ghost BN and GLU blocks. Adam is used
and γ =1.5. Feature transformers use two shared and two deci- with a learning rate of 0.01 (decayed 0.9 every 1000 iterations with
sion step-dependent FC layer, ghost BN and GLU blocks. Adam is an exponential decay) for 10k iterations. For Census Income, the
used with a learning rate of 0.02 (decayed 0.9 every 20k iterations dataset and comparison model specifications follow [43]. TabNet
with an exponential decay) for 870k iterations. TabNet-M model model uses Nd =Na =48, λspar se =0.001, B=8192, BV =256, m B =0.9,
uses Nd =96, Na =32, λspar se =0.000001, B=8192, BV =256, m B =0.9, Nst eps =5 and γ =1.5. Feature transformers use two shared and
Nst eps =8 and γ =2.0. Feature transformers use two shared and two two decision step-dependent FC layer, ghost BN and GLU blocks.
decision step-dependent FC layer, ghost BN and GLU blocks. Adam Adam is used with a learning rate of 0.02 (decayed 0.7 every 2000
is used with a learning rate of 0.025 (decayed 0.9 every 10k itera- iterations with an exponential decay) for 4k iterations.
tions with an exponential decay) for 370k iterations. For unsuper- Mushroom edibility: TabNet model uses Nd =Na =8, λspar se =0.001,
vised pre-training, the decoder model uses Nd =Na =128, B=8192, B=2048, BV =128, m B =0.9, Nst eps =3 and γ =1.5. Feature transform-
BV =256, m B =0.9, and Nst eps =20. For supervised fine-tuning, we ers use two shared and two decision step-dependent FC layer, ghost

11
BN and GLU blocks. Adam is used with a learning rate of 0.01 • Adjustment of Nd and Na is an efficient way of obtaining a
(decayed 0.8 every 400 iterations with an exponential decay) for trade-off between performance and complexity. Nd = Na is a
10k iterations. reasonable choice for most datasets. A very high value of Nd and
Adult Census Income: TabNet model uses Nd =Na =16, λspar se = Na may suffer from overfitting and yield poor generalization.
0.0001, B=4096, BV =128, m B =0.98, Nst eps =5 and γ =1.5. Feature • An optimal choice of γ can have a major role on the performance.
transformers use two shared and two decision step-dependent layer, Typically a larger Nst eps value favors for a larger γ .
ghost BN and GLU blocks. Adam is used with a learning rate of • A large batch size is beneficial – if the memory constraints permit,
0.02 (decayed 0.4 every 2.5k iterations with an exponential decay) as large as 1-10 % of the total training dataset size can help
for 7.7k iterations. 85.7% test accuracy is achieved. performance. The virtual batch size is typically much smaller.
• Initially large learning rate is important, which should be gradu-
B ABLATION STUDIES ally decayed until convergence.
Table 11 shows the impact of ablation cases. For all cases, the
number of iterations is optimized on the validation set.
Obtaining high performance necessitates appropriately-adjusted
model capacity based on the characteristics of the dataset. Decreas-
ing the number of units Nd , Na or the number of decision steps
Nst eps are efficient ways of gradually decreasing the capacity with-
out significant degradation in performance. On the other hand,
increasing these parameters beyond some value causes optimiza-
tion issues and do not yield performance benefits. Replacing the
feature transformer block with a simpler alternative, such as a sin-
gle shared layer, can still give strong performance while yielding
a very compact model architecture. This shows the importance of
the inductive bias introduced with feature selection and sequential
attention. To push the performance, increasing the depth of the
feature transformer is an effective approach. While increasing the
depth, parameter sharing between feature transformer blocks across
decision steps is an efficient way to decrease model size without
degradation in performance. We indeed observe the benefit of par-
tial parameter sharing, compared to fully decision step-dependent
blocks or fully shared blocks. We also observe the empirical benefit
of GLU, compared to conventional nonlinearities like ReLU.
The strength of sparse feature selection depends on the two
parameters we introduce: γ and λspar se . We show that optimal
choice of these two is important for performance. A γ close to 1,
or a high λspar se may yield too tight constraints on the strength
of sparsity and may hurt performance. On the other hand, there is
still the benefit of a sufficient low γ and sufficiently high λspar se ,
to aid learning of the model via a favorable inductive bias.
Lastly, given the fixed model architecture, we show the benefit of
large-batch training, enabled by ghost BN [24]. The optimal batch
size for TabNet seems considerably higher than the conventional
batch sizes used for other data types, such as images or speech.

C GUIDELINES FOR HYPERPARAMETERS


We consider datasets ranging from ∼10K to ∼10M samples, with
varying degrees of fitting difficulty. TabNet obtains high perfor-
mance on all with a few general principles on hyperparameters:

• For most datasets, Nst eps ∈ [3, 10] is optimal. Typically, when
there are more information-bearing features, the optimal value
of Nst eps tends to be higher. On the other hand, increasing
it beyond some value may adversely affect training dynamics
as some paths in the network becomes deeper and there are
more potentially-problematic ill-conditioned matrices. A very
high value of Nst eps may suffer from overfitting and yield poor
generalization.
12

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy