UNETR: Transformers For 3D Medical Image Segmentation
1 Introduction
different resolutions, hence allowing for recovering spatial information that is lost
during downsampling.
Although such FCNN-based approaches have powerful representation learning
capabilities, their performance in learning long-range dependencies is limited to
their localized receptive fields. As a result, such a deficiency in capturing multi-
scale contextual information leads to sub-optimal segmentation of structures with
variable shapes and scales (e.g. brain lesions with different sizes). Several efforts
have tried to mitigate this issue by employing atrous convolutional layers [4,15,10].
However, due to the locality of CNNs, their receptive fields are still limited to a
small region.
In the NLP domain, transformer-based models [24,6] have achieved state-
of-the-art benchmarks in various tasks. The self-attention mechanism in the
transformers enables them to dynamically highlight the important features of
word sequences and learn its long-range dependencies. This notion has recently
been extended to computer vision by the introduction of Visual Transformer
(ViT) [7]. In ViT, an image is represented as a sequence of patch embeddings
that will be used for direct prediction of class labels for image classification.
In this work, we propose to leverage the power of transformers for volumetric
medical image segmentation and introduce a novel architecture dubbed as UNETR
for this purpose. In particular, we reformulate the task of 3D segmentation as a
1D sequence-to-sequence prediction problem and use a pure transformer as the
encoder to learn contextual information from the embedded input patches. The
extracted representations from transformer encoder is merged with a decoder via
skip connections at multiple resolutions for prediction of segmentation outputs.
We have extensively validated the effectiveness of our UNETR on brain
tumour and spleen segmentation tasks in the MSD dataset [22] and our exper-
iments demonstrate favorable performance in comparison to other models in
our validation set. To the best of our knowledge, we are the first to propose a
completely transformer-based encoder for volumetric medical image segmenta-
tion. Considering the prevalence of volumetric data in medical imaging and their
extensive use in segmentation, we believe our UNETR paves the way for a new
class of transformer-based segmentation models which can be utilized for various
2 Related Work
CNN-based Segmentation Networks : Since introduction of the seminal U-Net [21],
CNN-based networks have achieved state-of-the-art results on various 2D and 3D
various medical image segmentation tasks [8,29,25,9,16,28]. Despite their success,
a limitation of these networks is their poor performance in learning global context
and long-range spatial dependencies, which can severely impact the segmentation
performance for challenging tasks.
𝐻×𝑊×𝐷×4 𝐻 × 𝑊 × 𝐷 × 64
Linear 𝐻 × 𝑊 × 𝐷 × 64
Norm 𝑧3 Output
C 𝐻×𝑊×𝐷×3
× × × 768 × × × 128 𝐻 𝑊 𝐷
16 16 16 2 2 2 × × × 128
Multi-Head 2 2 2
× × × 768 × × × 256 𝐻 𝑊 𝐷
16 16 16 4 4 4 × × × 256
Norm 4 4 4 Deconv 2 × 2 × 2
MLP 𝑧9
Deconv 2 × 2 × 2, Conv 3 × 3 × 3, BN, ReLU
+ × × × 768
16 16 16
× 12 𝐻 𝑊 𝐷
× × × 512 Conv 3 × 3 × 3, BN, ReLU
8 8 8
𝐻 𝑊 𝐷 Conv 1 × 1 × 1
× × × 768
16 16 16
3 Methodology
3.1 Architecture
After the embedding layer, we utilize a stack of transformer blocks [24,7] com-
prising of multiheaded self-attention (MSA) and multilayer perceptron (MLP)
sublayers according to
zi = MLP(Norm(z0 i )) + z0 i , (3)
Where Norm represents layer normalization, MLP comprises of two linear layers
with GELU activation functions and i is the intermediate block identifier ranging
from 1 to T = 12 total blocks in our current setting. A MSA block comprises of
n parallel self-attention (SA) heads. The (SA) block is a parameterized function
that learns the similarity between two elements in the input sequence (z) and
their set of query (q) and key (k) representations. Thus, the output of (SA) is
computed as follows
qk >
SA(z) = Softmax( √ )v, (4)
Where v denotes the values in the input sequence and Ch = C/n is a scaling
factor. Furthermore, the output of MSA is defined as
Where Wmsa represents the learnable weight matrices of different heads (SA).
Inspired by UNet-like architectures, where features from multiple resolutions
of the encoder are merged with the decoder, we extract sequence representation
zi (i ∈ {3, 6, 9, 12}), with size H×W
N3 × C, from the transformer and reshape
them into a N × N × N × C tensor. A representation in our definition is in
the embedding space if it has been reshaped as an output of the transformer
and has a feature size of C (i.e. transformer’s embedding size). Consequently, we
project the reshaped tensor from the embedding space into the input space by
utilizing consecutive 3 × 3 × 3 convolutional layers that are followed by batch
normalization (See Fig. 1 for details).
At the bottleneck of our encoder (i.e. output of transformer’s last layer),
we apply a deconvolutional layer to the transformed feature map to increase
Our loss function is a combination of dice [18] and cross entropy terms that can
be computed in a voxel-wise manner according to
2X i=1 Gi,j Yi,j 1 XX
L=1− − Gi,j log Yi,j . (6)
J j=1 Ii=1 G2i,j + Ii=1 Yi,j
2 I i=1 j=1
where I is the number of voxels; J is the number of classes; Yi,j and Gi,j denote
the probability output and one-hot encoded ground truth for class j at voxel i,
4 Experiments
4.1 Datasets
To cover various objects and image modalities, the datasets of task 1 (brain
tumour MRI segmentation) and task 9 (spleen CT segmentation) from MSD
challenge [22] are adopted for experiments with our own data split of 5-fold
cross validation. For task 1, the entire training set of 484 multi-modal multi-
site MRI data (FLAIR, T1w, T1gd, T2w) with ground truth labels of gliomas
segmentation necrotic/active tumour and oedema is utilized for model training.
The resolution/spacing of task 1 is uniformly 1.0 × 1.0 × 1.0 mm3 . For task 9, 41
CT volumes with spleen body annotation are used. The resolution/spacing of
volumes in task 9 ranges from 0.613 × 0.613 × 1.50 mm3 to 0.977 × 0.977 × 8.0
mm3 . All volumes are re-sampled into the isotropic voxel spacing 1.0 mm during
For task 1 with MRI images, the voxel intensities are pre-processed with
z-score normalization. For task 9 with CT images, the voxel intensities of the
images are normalized to the range [0, 1] according to 5th and 95th percentile
of overall foreground intensities. Furthermore, the problem of task 1 is formu-
lated as a 3-class segmentation task with 4-channel input whereas task 9 is
formulated as a binary segmentation task (foreground and background) with
single-channel input. We randomly sample the input images with volume sizes of
[128, 128, 128] and [96, 96, 96] for tasks 1 and 9 respectively. The random patches
of foreground/background are sampled at ratio 1 : 1.
Fold Split-1 Split-2 Split-3 Split-4 Split-5 DSC1 DSC2 DSC3 Avg.
VNet [18] 64.83 67.28 65.23 65.2 66.34 75.96 54.99 66.38 65.77
AHNet [17] 65.78 69.31 65.16 65.05 67.84 75.8 57.58 66.50 66.63
Att-UNet [20] 66.39 70.18 65.39 66.11 67.29 75.29 57.11 68.81 67.07
UNet [5] 67.20 69.11 66.84 66.95 68.16 75.03 57.87 70.06 67.65
SegResNet [19] 69.62 71.84 67.86 68.52 70.43 76.37 59.56 73.03 69.65
UNETR 70.79 73.70 70.12 72.10 72.38 79.00 60.62 75.82 71.81
Table 1. Cross validation results of brain tumour Segmentation task. For each split, we
provide the average dice score of three classes. DSC1, DSC2 and DSC3 denote average
dice scores for the Whole Tumour (WT), Enhancing Tumour (ET) and Tumour Core
(TC) across all folds respectively.
The UNETR is implemented in Pytorch1 and MONAI2 . The model was trained
on a NVIDIA V100 32GB GPU and an Intel® Core™ i7-7800X CPU @ 3.50GHz
× 12. All models were trained with a batch size of 2 and using the Adam
optimization algorithm with initial learning rate of 0.0001 for 25, 000 iterations.
Using a fixed split for all experiments, we have used five fold cross-validation and
evaluated the performance of our model by using Dice-Sørensen score (DSC). We
have used a dimension of 16×16×16 for generating the input patches, and T = 12
transformer blocks with embedding size of C = 768 as the encoder of UNETR.
We did not use any pretrained transformer model(e.g. ViT on ImageNet) since
pretraining did not show any performance improvement.
Fig. 2. (a) Ground Truth. Outputs of : (b) UNETR. (c) SegResNet. (d) UNet.
the task of spleen segmentation. Similarly, the UNETR outperforms the closest
baselines by least 1.11%. Furthermore, in order to allow for a fair comparison,
we did not compare against external models on MSD test set, since leveraging
ensembles, commonly used for boosting the test time performance and different
training conditions can significantly alter the benchmarks.
5 Ablation
6 Conclusion
tasks. Our proposed UNETR lays the foundation for a new class of transformer-
based models for medical image segmentation. Although the UNETR is primarily
designed for 3D applications, an extension for 2D applications is straightforward
and can be explored in future efforts.
A. Hatamizadeh et al.
UNETR: Transformers for 3D Medical Image Segmentation
