-
Notifications
You must be signed in to change notification settings - Fork 5
/
modeling_finetune.py
1035 lines (879 loc) · 39.1 KB
/
modeling_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# --------------------------------------------------------
# Fast-iTPN: Integrally Pre-Trained Transformer Pyramid Network with Token Migration
# Github source: https://github.com/sunsmarterjie/iTPN/tree/main/fast_itpn
# Copyright (c) 2023 University of Chinese Academy of Sciences
# Licensed under The MIT License [see LICENSE for details]
# By Yunjie Tian
# Based on EVA02, timm and deit code bases
# https://github.com/baaivision/EVA/tree/master/EVA-02
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit/
# --------------------------------------------------------'
from functools import partial
import math
import torch
import torch.nn as nn
from timm.models.registry import register_model
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import to_2tuple, drop_path, trunc_normal_
from torch import Tensor, Size
from typing import Union, List
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
_shape_t = Union[int, List[int], Size]
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
norm_layer=nn.LayerNorm, subln=False
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.ffn_ln(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
norm_layer=nn.LayerNorm, subln=False
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else None
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
if self.ffn_ln is not None:
x = x.permute(0, 2, 3, 1)
x = self.ffn_ln(x)
x = x.permute(0, 3, 1, 2)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLU(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
norm_layer=nn.LayerNorm, subln=False
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w1 = nn.Linear(in_features, hidden_features)
self.w2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.w3 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x1 = self.w1(x)
x2 = self.w2(x)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden)
x = self.w3(x)
x = self.drop(x)
return x
class ConvSwiGLU(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
norm_layer=nn.LayerNorm, subln=False
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w1 = nn.Conv2d(in_features, hidden_features, 1)
self.w2 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
self.w3 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
B, C, H, W = x.shape
x1 = self.w1(x).flatten(2).transpose(1, 2)
x2 = self.w2(x).flatten(2).transpose(1, 2)
hidden = self.act(x1) * x2
x = self.ffn_ln(hidden).transpose(1, 2).view(B, C, H, W)
x = self.w3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None,
attn_head_dim=None, use_decoupled_rel_pos_bias=False, deepnorm=False, subln=False
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.deepnorm = deepnorm
self.subln = subln
if self.deepnorm or self.subln:
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
else:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.rel_pos_bias = None
self.qk_float = True
self.window_size = None
self.relative_position_bias_table = None
if window_size:
if use_decoupled_rel_pos_bias:
self.rel_pos_bias = DecoupledRelativePositionBias(window_size=window_size, num_heads=num_heads)
else:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1) + 3 # (2*14-1) * (2*14-1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None, attn_mask=None):
B, N, C = x.shape
if self.deepnorm or self.subln:
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
else:
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
if self.qk_float:
attn = (q.float() @ k.float().transpose(-2, -1))
else:
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
if self.rel_pos_bias is not None:
attn = attn + self.rel_pos_bias().type_as(attn)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias.type_as(attn)
if attn_mask is not None:
attn_mask = attn_mask.bool()
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1).type_as(x)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None,
use_decoupled_rel_pos_bias=False,
depth=None,
postnorm=False,
deepnorm=False,
subln=False,
swiglu=False,
naiveswiglu=False,
):
super().__init__()
with_attn = num_heads > 0
self.norm1 = norm_layer(dim) if with_attn else None
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size,
use_decoupled_rel_pos_bias=use_decoupled_rel_pos_bias, attn_head_dim=attn_head_dim,
deepnorm=deepnorm,
subln=subln
) if with_attn else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if swiglu:
self.mlp = xops.SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim
) # hidden_features: 2/3
elif naiveswiglu:
self.mlp = SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer,
)
else:
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),
requires_grad=True) if self.attn is not None else None
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
self.deepnorm = deepnorm
if self.deepnorm:
self.alpha = math.pow(2.0 * depth, 0.25)
self.postnorm = postnorm
def forward(self, x, rel_pos_bias=None, attn_mask=None):
if self.gamma_2 is None:
if self.postnorm:
if self.attn is not None:
x = x + self.drop_path(
self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.deepnorm:
if self.attn is not None:
residual = x
x = self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
x = self.drop_path(x)
x = residual * self.alpha + x
x = self.norm1(x)
residual = x
x = self.mlp(x)
x = self.drop_path(x)
x = residual * self.alpha + x
x = self.norm2(x)
else:
if self.attn is not None:
x = x + self.drop_path(
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
if self.postnorm:
if self.attn is not None:
x = x + self.drop_path(
self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
else:
if self.attn is not None:
x = x + self.drop_path(
self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class ConvMlpBlock(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop_path=0., init_values=None, norm_layer=nn.LayerNorm,
depth=None,
postnorm=False,
deepnorm=False,
subln=False,
swiglu=False,
naiveswiglu=False,
):
super().__init__()
self.attn = None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if swiglu:
self.mlp = xops.SwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim
) # hidden_features: 2/3
elif naiveswiglu:
self.mlp = ConvSwiGLU(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer,
)
else:
self.mlp = ConvMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
subln=subln,
norm_layer=norm_layer
)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),
requires_grad=True) if self.attn is not None else None
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
self.deepnorm = deepnorm
if self.deepnorm:
self.alpha = math.pow(2.0 * depth, 0.25)
self.postnorm = postnorm
def forward(self, x):
if self.gamma_2 is None:
if self.postnorm:
x = x + self.drop_path(self.norm2(self.mlp(x)))
elif self.deepnorm:
residual = x
x = self.mlp(x)
x = self.drop_path(x)
x = residual * self.alpha + x
x = self.norm2(x)
else:
x = x + self.drop_path(self.mlp(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)))
else:
if self.postnorm:
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
else:
m = self.mlp(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
x = x + self.drop_path(self.gamma_2 * m)
return x
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, inner_patches=4, in_chans=3, embed_dim=128, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.inner_patches = inner_patches
self.patches_resolution = self.patch_shape = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
conv_size = [size // inner_patches for size in patch_size]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=conv_size, stride=conv_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
patches_resolution = (H // self.patch_size[0], W // self.patch_size[1])
num_patches = patches_resolution[0] * patches_resolution[1]
x = self.proj(x).view(
B, -1,
patches_resolution[0], self.inner_patches,
patches_resolution[1], self.inner_patches,
).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches, self.inner_patches, -1)
if self.norm is not None:
x = self.norm(x)
return x
class ConvPatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=16, inner_patches=4, in_chans=3, embed_dim=128, norm_layer=None,
stop_grad_conv1=False):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.stop_grad_conv1 = stop_grad_conv1
self.inner_patches = inner_patches
self.patches_resolution = self.patch_shape = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
conv_size = [size // inner_patches for size in patch_size]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=conv_size, stride=conv_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x, bool_masked_pos=None, mask_token=None):
B, C, H, W = x.shape
x = self.proj(x)
if self.stop_grad_conv1:
x = x.detach() * 0.9 + x * 0.1
if bool_masked_pos is not None:
x = torch.nn.functional.unfold(x, kernel_size=4, stride=4, padding=0).transpose(1, 2)
seq_len = x.shape[1]
mask_token = mask_token.expand(B, seq_len, -1)
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
x = x * (1 - w) + mask_token * w
x = torch.nn.functional.fold(x.transpose(1, 2), output_size=(H // 4, W // 4), kernel_size=4, padding=0,
stride=4)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerge(nn.Module):
def __init__(self, dim, norm_layer):
super().__init__()
self.norm = norm_layer(dim * 4)
self.reduction = nn.Linear(dim * 4, dim * 2, bias=False)
self.mlp = None
def forward(self, x):
x0 = x[..., 0::2, 0::2, :]
x1 = x[..., 1::2, 0::2, :]
x2 = x[..., 0::2, 1::2, :]
x3 = x[..., 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], dim=-1)
x = self.norm(x)
x = self.reduction(x)
return x
class ConvPatchMerge(nn.Module):
def __init__(self, dim, norm_layer):
super().__init__()
self.norm = norm_layer(dim)
self.reduction = nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2, padding=0)
self.mlp = None
def forward(self, x):
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.reduction(x)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
def _mask_1d_rel_pos_index(seq_len):
index = torch.arange(seq_len)
return index.view(1, seq_len) - index.view(seq_len, 1) + seq_len - 1
def _add_cls_to_index_matrix(index, num_tokens, offset):
index = index.contiguous().view(num_tokens, num_tokens)
new_index = torch.zeros(size=(num_tokens + 1, num_tokens + 1), dtype=index.dtype)
new_index[1:, 1:] = index
new_index[0, 0:] = offset
new_index[0:, 0] = offset + 1
new_index[0, 0] = offset + 2
return new_index
class DecoupledRelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] + 2, 2 * window_size[1] + 2)
num_tokens = window_size[0] * window_size[1]
self.relative_position_bias_for_high = nn.Parameter(torch.zeros(self.num_relative_distance[0], num_heads))
self.relative_position_bias_for_width = nn.Parameter(torch.zeros(self.num_relative_distance[1], num_heads))
# cls to token & token 2 cls & cls to cls
h_index = _mask_1d_rel_pos_index(window_size[0]).view(
window_size[0], 1, window_size[0], 1).expand(-1, window_size[1], -1, window_size[1])
h_index = _add_cls_to_index_matrix(h_index, num_tokens, 2 * window_size[0] - 1)
self.register_buffer("relative_position_high_index", h_index)
w_index = _mask_1d_rel_pos_index(window_size[1]).view(
1, window_size[1], 1, window_size[1]).expand(window_size[0], -1, window_size[0], -1)
w_index = _add_cls_to_index_matrix(w_index, num_tokens, 2 * window_size[1] - 1)
self.register_buffer("relative_position_width_index", w_index)
def forward(self):
relative_position_bias = \
F.embedding(input=self.relative_position_high_index, weight=self.relative_position_bias_for_high) + \
F.embedding(input=self.relative_position_width_index, weight=self.relative_position_bias_for_width)
return relative_position_bias.permute(2, 0, 1).contiguous()
class Fast_iTPN(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=512, depth_stage1=3, depth_stage2=3, depth=24,
num_heads=8, bridge_mlp_ratio=3., mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0.0, init_values=0.1, attn_head_dim=None, norm_layer=nn.LayerNorm,
patch_norm=False, num_classes=1000, use_mean_pooling=False,
init_scale=0.01,
cls_token=False,
grad_ckpt=False,
stop_grad_conv1=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_shared_decoupled_rel_pos_bias=False,
convmlp=False,
postnorm=False,
deepnorm=False,
subln=False,
swiglu=False,
naiveswiglu=False,
**kwargs):
super().__init__()
self.img_size = img_size
self.mlp_ratio = mlp_ratio
self.grad_ckpt = grad_ckpt
self.num_main_blocks = depth
self.depth_stage1 = depth_stage1
self.depth_stage2 = depth_stage2
self.depth = depth
self.patch_size = patch_size
self.num_features = self.embed_dim = embed_dim
self.convmlp = convmlp
self.stop_grad_conv1 = stop_grad_conv1
self.use_rel_pos_bias = use_rel_pos_bias
self.use_shared_rel_pos_bias = use_shared_rel_pos_bias
self.use_shared_decoupled_rel_pos_bias = use_shared_decoupled_rel_pos_bias
self.use_decoupled_rel_pos_bias = False
mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim}
# split image into non-overlapping patches
if convmlp:
self.patch_embed = ConvPatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=mlvl_dims['4'],
stop_grad_conv1=stop_grad_conv1,
norm_layer=norm_layer if patch_norm else None)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=mlvl_dims['4'],
norm_layer=norm_layer if patch_norm else None)
num_patches = self.patch_embed.num_patches
if cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
else:
self.cls_token = None
if use_abs_pos_emb:
if cls_token:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
if use_shared_decoupled_rel_pos_bias:
assert self.rel_pos_bias is None
self.rel_pos_bias = DecoupledRelativePositionBias(window_size=self.patch_embed.patch_shape,
num_heads=num_heads)
self.subln = subln
self.swiglu = swiglu
self.naiveswiglu = naiveswiglu
self.build_blocks(
depths=[depth_stage1, depth_stage2, depth],
dims=mlvl_dims,
num_heads=num_heads,
bridge_mlp_ratio=bridge_mlp_ratio,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
attn_head_dim=attn_head_dim,
postnorm=postnorm,
deepnorm=deepnorm,
subln=subln,
swiglu=swiglu,
naiveswiglu=naiveswiglu,
convmlp=convmlp,
)
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
trunc_normal_(self.cls_token, std=.02)
if isinstance(self.head, nn.Linear):
trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
def build_blocks(self,
depths=[3, 3, 24],
dims={'4': 128 // 4, '8': 256, '16': 512},
num_heads=8,
bridge_mlp_ratio=3.,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
window_size=None,
drop=0.,
attn_drop=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
attn_head_dim=None,
postnorm=False,
deepnorm=False,
subln=False,
swiglu=False,
naiveswiglu=False,
convmlp=False,
):
dpr = iter(x.item() for x in torch.linspace(0, drop_path_rate, depths[0] + depths[1] + depths[2]))
self.blocks = nn.ModuleList()
if convmlp:
self.blocks.extend([
ConvMlpBlock(
dim=dims['4'],
mlp_ratio=bridge_mlp_ratio,
drop_path=next(dpr),
norm_layer=norm_layer,
init_values=0.,
depth=depths[-1],
postnorm=postnorm,
deepnorm=deepnorm,
subln=subln,
swiglu=False,
naiveswiglu=False,
) for _ in range(depths[0])
])
self.blocks.append(ConvPatchMerge(dims['4'], norm_layer))
self.blocks.extend([
ConvMlpBlock(
dim=dims['8'],
mlp_ratio=bridge_mlp_ratio,
drop_path=next(dpr),
norm_layer=norm_layer,
init_values=0.,
depth=depths[-1],
postnorm=postnorm,
deepnorm=deepnorm,
subln=subln,
swiglu=False,
naiveswiglu=False,
) for _ in range(depths[1])
])
self.blocks.append(ConvPatchMerge(dims['8'], norm_layer))
else:
self.blocks.extend([
Block(
dim=dims['4'],
num_heads=0,
mlp_ratio=bridge_mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=next(dpr),
norm_layer=norm_layer,
init_values=init_values,
window_size=window_size,
depth=depths[-1],
postnorm=postnorm,
deepnorm=deepnorm,
subln=subln,
swiglu=swiglu,
naiveswiglu=naiveswiglu,
) for _ in range(depths[0])
])
self.blocks.append(PatchMerge(dims['4'], norm_layer))
self.blocks.extend([
Block(
dim=dims['8'],
num_heads=0,
mlp_ratio=bridge_mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=next(dpr),
norm_layer=norm_layer,
init_values=init_values,
window_size=window_size,
depth=depths[-1],
postnorm=postnorm,
deepnorm=deepnorm,
subln=subln,
swiglu=swiglu,
naiveswiglu=naiveswiglu,
) for _ in range(depths[1])
])
self.blocks.append(PatchMerge(dims['8'], norm_layer))
######### stage 3 ########
self.blocks.extend([
Block(
dim=dims['16'],
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=next(dpr),
norm_layer=norm_layer,
init_values=init_values,
window_size=window_size,
attn_head_dim=attn_head_dim,
depth=depths[-1],
postnorm=postnorm,
deepnorm=deepnorm,
subln=subln,
swiglu=swiglu,
naiveswiglu=naiveswiglu,
) for _ in range(depths[2])
])
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
if self.cls_token is not None:
return {'pos_embed', 'cls_token'}
return {'pos_embed'}
def get_classifer(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
B, C, H, W = x.shape
x = self.patch_embed(x)
if not self.convmlp and self.stop_grad_conv1:
x = x.detach() * 0.9 + x * 0.1
for blk in self.blocks[:-self.num_main_blocks]:
x = checkpoint.checkpoint(blk, x) if self.grad_ckpt else blk(x)
x = x.flatten(2).transpose(1, 2)
if self.cls_token is not None:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks[-self.num_main_blocks:]:
x = checkpoint.checkpoint(blk, x, rel_pos_bias) if self.grad_ckpt else blk(x, rel_pos_bias)
x = self.norm(x)
if self.fc_norm is not None:
if self.cls_token is not None:
x = x[:, 1:, :]
return self.fc_norm(x.mean(1))
else:
return x.mean(1)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def fast_itpn_tiny_1112_patch16_224(pretrained=False, **kwargs):
model = Fast_iTPN(
patch_size=16, embed_dim=384, depth_stage1=1, depth_stage2=1, depth=12, num_heads=6, bridge_mlp_ratio=3.,
mlp_ratio=3., qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
convmlp=True,
naiveswiglu=True,
subln=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(
kwargs["init_ckpt"], map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def fast_itpn_small_2220_patch16_224(pretrained=False, **kwargs):
model = Fast_iTPN(
patch_size=16, embed_dim=384, depth_stage1=2, depth_stage2=2, depth=20, num_heads=6, bridge_mlp_ratio=3.,
mlp_ratio=3., qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
convmlp=True,
naiveswiglu=True,
subln=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(
kwargs["init_ckpt"], map_location="cpu"
)
model.load_state_dict(checkpoint["model"])
return model