-
Notifications
You must be signed in to change notification settings - Fork 145
/
Copy pathvit.py
496 lines (424 loc) · 17.5 KB
/
vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
"""ViT"""
import functools
from typing import Callable, Optional
import numpy as np
import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common.initializer import TruncatedNormal, XavierUniform, initializer
from .helpers import load_pretrained
from .layers.compatibility import Dropout
from .layers.drop_path import DropPath
from .layers.mlp import Mlp
from .layers.patch_dropout import PatchDropout
from .layers.patch_embed import PatchEmbed
from .layers.pos_embed import resample_abs_pos_embed
from .registry import register_model
__all__ = [
"VisionTransformer",
"vit_b_16_224",
"vit_b_16_384",
"vit_l_16_224", # with pretrained weights
"vit_l_16_384",
"vit_b_32_224", # with pretrained weights
"vit_b_32_384",
"vit_l_32_224", # with pretrained weights
]
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 1000,
"input_size": (3, 224, 224),
"first_conv": "patch_embed.proj",
"classifier": "head",
**kwargs,
}
default_cfgs = {
"vit_b_16_224": _cfg(url=""),
"vit_b_16_384": _cfg(
url="", input_size=(3, 384, 384)
),
"vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-d2635f8b.ckpt"),
"vit_l_16_384": _cfg(
url="", input_size=(3, 384, 384)
),
"vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-4a1c9d8e.ckpt"),
"vit_b_32_384": _cfg(
url="", input_size=(3, 384, 384)
),
"vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-8c8ea164.ckpt"),
}
# TODO: Flash Attention
class Attention(nn.Cell):
"""
Attention layer implementation, Rearrange Input -> B x N x hidden size.
Args:
dim (int): The dimension of input features.
num_heads (int): The number of attention heads. Default: 8.
qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True.
qk_norm (bool): Specifies whether to do normalization to q and k.
attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0.
proj_drop (float): The drop rate of output, greater than 0 and less equal than 1. Default: 0.0.
Returns:
Tensor, output tensor.
Examples:
>>> ops = Attention(768, 12)
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Cell = nn.LayerNorm,
):
super(Attention, self).__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = Tensor(self.head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3, has_bias=qkv_bias)
self.q_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity()
self.k_norm = norm_layer((self.head_dim,)) if qk_norm else nn.Identity()
self.attn_drop = Dropout(attn_drop)
self.proj = nn.Dense(dim, dim)
self.proj_drop = Dropout(proj_drop)
self.mul = ops.Mul()
self.reshape = ops.Reshape()
self.transpose = ops.Transpose()
self.unstack = ops.Unstack(axis=0)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
def construct(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = self.reshape(qkv, (b, n, 3, self.num_heads, self.head_dim))
qkv = self.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = self.unstack(qkv)
q, k = self.q_norm(q), self.k_norm(k)
q = self.mul(q, self.scale**0.5)
k = self.mul(k, self.scale**0.5)
attn = self.q_matmul_k(q, k)
attn = ops.softmax(attn.astype(ms.float32), axis=-1).astype(attn.dtype)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = self.transpose(out, (0, 2, 1, 3))
out = self.reshape(out, (b, n, c))
out = self.proj(out)
out = self.proj_drop(out)
return out
class LayerScale(nn.Cell):
"""
Layer scale, help ViT improve the training dynamic, allowing for the training
of deeper high-capacity image transformers that benefit from depth
Args:
dim (int): The output dimension of attnetion layer or mlp layer.
init_values (float): The scale factor. Default: 1e-5.
Returns:
Tensor, output tensor.
Examples:
>>> ops = LayerScale(768, 0.01)
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5
):
super(LayerScale, self).__init__()
self.gamma = Parameter(initializer(init_values, dim))
def construct(self, x):
return self.gamma * x
class Block(nn.Cell):
"""
Transformer block implementation.
Args:
dim (int): The dimension of embedding.
num_heads (int): The number of attention heads.
qkv_bias (bool): Specifies whether the linear layer uses a bias vector. Default: True.
attn_drop (float): The drop rate of attention, greater than 0 and less equal than 1. Default: 0.0.
proj_drop (float): The drop rate of dense layer output, greater than 0 and less equal than 1. Default: 0.0.
mlp_ratio (float): The ratio used to scale the input dimensions to obtain the dimensions of the hidden layer.
drop_path (float): The drop rate for drop path. Default: 0.0.
act_layer (nn.Cell): Activation function which will be stacked on top of the
normalization layer (if not None), otherwise on top of the conv layer. Default: nn.GELU.
norm_layer (nn.Cell): Norm layer that will be stacked on top of the convolution
layer. Default: nn.LayerNorm.
Returns:
Tensor, output tensor.
Examples:
>>> ops = TransformerEncoder(768, 12, 12, 3072)
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Cell = nn.GELU,
norm_layer: nn.Cell = nn.LayerNorm,
mlp_layer: Callable = Mlp,
):
super(Block, self).__init__()
self.norm1 = norm_layer((dim,))
self.attn = Attention(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer((dim,))
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop
)
self.ls2 = LayerScale(dim=dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def construct(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class VisionTransformer(nn.Cell):
'''
ViT encoder, which returns the feature encoded by transformer encoder.
'''
def __init__(
self,
image_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: bool = True,
init_values: Optional[float] = None,
no_embed_class: bool = False,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
act_layer: nn.Cell = nn.GELU,
embed_layer: Callable = PatchEmbed,
norm_layer: nn.Cell = nn.LayerNorm,
mlp_layer: Callable = Mlp,
class_token: bool = True,
block_fn: Callable = Block,
num_classes: int = 1000,
):
super(VisionTransformer, self).__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
self.global_pool = global_pool
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.dynamic_img_size = dynamic_img_size
self.dynamic_img_pad = dynamic_img_pad
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
elif dynamic_img_pad:
embed_args.update(dict(output_fmt='NHWC'))
self.patch_embed = embed_layer(
image_size=image_size,
patch_size=patch_size,
in_chans=in_channels,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
self.cls_token = Parameter(initializer(TruncatedNormal(0.02), (1, 1, embed_dim))) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = Parameter(initializer(TruncatedNormal(0.02), (1, embed_len, embed_dim)))
self.pos_drop = Dropout(pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer((embed_dim,)) if pre_norm else nn.Identity()
dpr = [x.item() for x in np.linspace(0, drop_path_rate, depth)]
self.blocks = nn.CellList([
block_fn(
dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
attn_drop=attn_drop_rate, proj_drop=proj_drop_rate,
mlp_ratio=mlp_ratio, drop_path=dpr[i], init_values=init_values,
act_layer=act_layer, norm_layer=norm_layer, mlp_layer=mlp_layer,
) for i in range(depth)
])
self.norm = norm_layer((embed_dim,)) if not use_fc_norm else nn.Identity()
self.fc_norm = norm_layer((embed_dim,)) if use_fc_norm else nn.Identity()
self.head_drop = Dropout(drop_rate)
self.head = nn.Dense(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init:
self._init_weights()
def get_num_layers(self):
return len(self.blocks)
def _init_weights(self):
w = self.patch_embed.proj.weight
w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:]))
w_value = initializer(XavierUniform(), w_shape_flatted, w.dtype)
w_value.init_data()
w.set_data(w_value.reshape(w.shape))
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
cell.weight.set_data(
initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype)
)
if cell.bias is not None:
cell.bias.set_data(
initializer('zeros', cell.bias.shape, cell.bias.dtype)
)
elif isinstance(cell, nn.LayerNorm):
cell.gamma.set_data(
initializer('ones', cell.gamma.shape, cell.gamma.dtype)
)
cell.beta.set_data(
initializer('zeros', cell.beta.shape, cell.beta.dtype)
)
def _pos_embed(self, x):
if self.dynamic_img_size or self.dynamic_img_pad:
# bhwc format
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = ops.reshape(x, (B, -1, C))
else:
pos_embed = self.pos_embed
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if self.cls_token is not None:
cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1))
cls_tokens = cls_tokens.astype(x.dtype)
x = ops.concat((cls_tokens, x), axis=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
cls_tokens = ops.broadcast_to(self.cls_token, (x.shape[0], -1, -1))
cls_tokens = cls_tokens.astype(x.dtype)
x = ops.concat((cls_tokens, x), axis=1)
x = x + pos_embed
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward_head(self, x):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(axis=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
x = self.head(x)
return x
def construct(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
@register_model
def vit_b_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_16_224"]
model = VisionTransformer(
image_size=224, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def vit_b_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_16_384"]
model = VisionTransformer(
image_size=384, patch_size=16, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def vit_l_16_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_l_16_224"]
model = VisionTransformer(
image_size=224, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def vit_l_16_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_l_16_384"]
model = VisionTransformer(
image_size=384, patch_size=16, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def vit_b_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_32_224"]
model = VisionTransformer(
image_size=224, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def vit_b_32_384(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_b_32_384"]
model = VisionTransformer(
image_size=384, patch_size=32, in_channels=in_channels, embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model
@register_model
def vit_l_32_224(pretrained: bool = False, num_classes: int = 1000, in_channels: int = 3, **kwargs):
default_cfg = default_cfgs["vit_l_32_224"]
model = VisionTransformer(
image_size=224, patch_size=32, in_channels=in_channels, embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
)
if pretrained:
load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels)
return model