BiliSakura commited on
Commit
a7134ff
·
verified ·
1 Parent(s): ef71abe

Add files using upload-large-folder tool

Browse files
README.md CHANGED
@@ -34,6 +34,7 @@ This repository contains a converted EUPE checkpoint (from the original Facebook
34
  - `config.json`: architecture/config parameters
35
  - `preprocessor_config.json`: image preprocessing setup
36
  - `transformers_eupe.py`: local EUPE Transformers registration wrapper
 
37
 
38
  ## Preprocessing
39
 
 
34
  - `config.json`: architecture/config parameters
35
  - `preprocessor_config.json`: image preprocessing setup
36
  - `transformers_eupe.py`: local EUPE Transformers registration wrapper
37
+ - `eupe/`: vendored EUPE model implementation used by `transformers_eupe.py`
38
 
39
  ## Preprocessing
40
 
eupe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Local EUPE package vendored for standalone model loading.
eupe/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (172 Bytes). View file
 
eupe/layers/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .attention import CausalSelfAttention, LinearKMaskedBias, SelfAttention
2
+ from .block import CausalSelfAttentionBlock, SelfAttentionBlock
3
+ from .ffn_layers import Mlp, SwiGLUFFN
4
+ from .layer_scale import LayerScale
5
+ from .patch_embed import PatchEmbed
6
+ from .rms_norm import RMSNorm
7
+ from .rope_position_encoding import RopePositionEmbedding
8
+
eupe/layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (637 Bytes). View file
 
eupe/layers/__pycache__/attention.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
eupe/layers/__pycache__/block.cpython-312.pyc ADDED
Binary file (12.3 kB). View file
 
eupe/layers/__pycache__/ffn_layers.cpython-312.pyc ADDED
Binary file (4.36 kB). View file
 
eupe/layers/__pycache__/layer_scale.cpython-312.pyc ADDED
Binary file (1.73 kB). View file
 
eupe/layers/__pycache__/patch_embed.cpython-312.pyc ADDED
Binary file (4.22 kB). View file
 
eupe/layers/__pycache__/rms_norm.cpython-312.pyc ADDED
Binary file (1.92 kB). View file
 
eupe/layers/__pycache__/rope_position_encoding.cpython-312.pyc ADDED
Binary file (6.14 kB). View file
 
eupe/layers/attention.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from eupe.utils import cat_keep_shapes, uncat_with_shapes
7
+ from torch import Tensor, nn
8
+
9
+
10
+ def rope_rotate_half(x: Tensor) -> Tensor:
11
+ x1, x2 = x.chunk(2, dim=-1)
12
+ return torch.cat([-x2, x1], dim=-1)
13
+
14
+
15
+ def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
16
+ return (x * cos) + (rope_rotate_half(x) * sin)
17
+
18
+
19
+ class LinearKMaskedBias(nn.Linear):
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+ o = self.out_features
23
+ assert o % 3 == 0
24
+ if self.bias is not None:
25
+ self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan))
26
+
27
+ def forward(self, input: Tensor) -> Tensor:
28
+ masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None
29
+ return F.linear(input, self.weight, masked_bias)
30
+
31
+
32
+ class SelfAttention(nn.Module):
33
+ def __init__(
34
+ self,
35
+ dim: int,
36
+ num_heads: int = 8,
37
+ qkv_bias: bool = False,
38
+ proj_bias: bool = True,
39
+ attn_drop: float = 0.0,
40
+ proj_drop: float = 0.0,
41
+ mask_k_bias: bool = False,
42
+ device=None,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.num_heads = num_heads
46
+ head_dim = dim // num_heads
47
+ self.scale = head_dim**-0.5
48
+
49
+ linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear
50
+ self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device)
51
+ self.attn_drop = nn.Dropout(attn_drop)
52
+ self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+
55
+ def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
56
+ q_dtype = q.dtype
57
+ k_dtype = k.dtype
58
+ sin, cos = rope
59
+ rope_dtype = sin.dtype
60
+ q = q.to(dtype=rope_dtype)
61
+ k = k.to(dtype=rope_dtype)
62
+ N = q.shape[-2]
63
+ prefix = N - sin.shape[-2]
64
+ assert prefix >= 0
65
+ q_prefix = q[:, :, :prefix, :]
66
+ q = rope_apply(q[:, :, prefix:, :], sin, cos)
67
+ q = torch.cat((q_prefix, q), dim=-2)
68
+ k_prefix = k[:, :, :prefix, :]
69
+ k = rope_apply(k[:, :, prefix:, :], sin, cos)
70
+ k = torch.cat((k_prefix, k), dim=-2)
71
+ q = q.to(dtype=q_dtype)
72
+ k = k.to(dtype=k_dtype)
73
+ return q, k
74
+
75
+ def forward(self, x: Tensor, attn_bias=None, rope: Tensor = None) -> Tensor:
76
+ qkv = self.qkv(x)
77
+ attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope)
78
+ x = self.proj(attn_v)
79
+ x = self.proj_drop(x)
80
+ return x
81
+
82
+ def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]:
83
+ assert len(x_list) == len(rope_list)
84
+ x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
85
+ qkv_flat = self.qkv(x_flat)
86
+ qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens)
87
+ att_out = []
88
+ for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)):
89
+ att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope))
90
+ x_flat, shapes, num_tokens = cat_keep_shapes(att_out)
91
+ x_flat = self.proj(x_flat)
92
+ return uncat_with_shapes(x_flat, shapes, num_tokens)
93
+
94
+ def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor:
95
+ assert attn_bias is None
96
+ B, N, _ = qkv.shape
97
+ C = self.qkv.in_features
98
+
99
+ qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
100
+ q, k, v = torch.unbind(qkv, 2)
101
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
102
+ if rope is not None:
103
+ q, k = self.apply_rope(q, k, rope)
104
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
105
+ x = x.transpose(1, 2)
106
+ return x.reshape([B, N, C])
107
+
108
+
109
+ class CausalSelfAttention(nn.Module):
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ num_heads: int = 8,
114
+ qkv_bias: bool = False,
115
+ proj_bias: bool = True,
116
+ attn_drop: float = 0.0,
117
+ proj_drop: float = 0.0,
118
+ ) -> None:
119
+ super().__init__()
120
+ self.dim = dim
121
+ self.num_heads = num_heads
122
+ head_dim = dim // num_heads
123
+ self.scale = head_dim**-0.5
124
+
125
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
126
+ self.attn_drop = attn_drop
127
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
128
+ self.proj_drop = nn.Dropout(proj_drop)
129
+
130
+ def init_weights(
131
+ self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0
132
+ ) -> None:
133
+ init_attn_std = init_attn_std or (self.dim**-0.5)
134
+ init_proj_std = init_proj_std or init_attn_std * factor
135
+ nn.init.normal_(self.qkv.weight, std=init_attn_std)
136
+ nn.init.normal_(self.proj.weight, std=init_proj_std)
137
+ if self.qkv.bias is not None:
138
+ nn.init.zeros_(self.qkv.bias)
139
+ if self.proj.bias is not None:
140
+ nn.init.zeros_(self.proj.bias)
141
+
142
+ def forward(self, x: Tensor, is_causal: bool = True) -> Tensor:
143
+ B, N, C = x.shape
144
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
145
+ q, k, v = torch.unbind(qkv, 2)
146
+ q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
147
+ x = torch.nn.functional.scaled_dot_product_attention(
148
+ q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal
149
+ )
150
+ x = x.transpose(1, 2).contiguous().view(B, N, C)
151
+ x = self.proj_drop(self.proj(x))
152
+ return x
153
+
eupe/layers/block.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from eupe.utils import cat_keep_shapes, uncat_with_shapes
7
+
8
+ from .attention import CausalSelfAttention, SelfAttention
9
+ from .ffn_layers import Mlp
10
+ from .layer_scale import LayerScale
11
+
12
+ torch._dynamo.config.automatic_dynamic_shapes = False
13
+ torch._dynamo.config.accumulated_cache_size_limit = 1024
14
+
15
+
16
+ class SelfAttentionBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ num_heads: int,
21
+ ffn_ratio: float = 4.0,
22
+ qkv_bias: bool = False,
23
+ proj_bias: bool = True,
24
+ ffn_bias: bool = True,
25
+ drop: float = 0.0,
26
+ attn_drop: float = 0.0,
27
+ init_values=None,
28
+ drop_path: float = 0.0,
29
+ act_layer: Callable[..., nn.Module] = nn.GELU,
30
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
31
+ attn_class: Callable[..., nn.Module] = SelfAttention,
32
+ ffn_layer: Callable[..., nn.Module] = Mlp,
33
+ mask_k_bias: bool = False,
34
+ device=None,
35
+ ) -> None:
36
+ super().__init__()
37
+ self.norm1 = norm_layer(dim)
38
+ self.attn = attn_class(
39
+ dim,
40
+ num_heads=num_heads,
41
+ qkv_bias=qkv_bias,
42
+ proj_bias=proj_bias,
43
+ attn_drop=attn_drop,
44
+ proj_drop=drop,
45
+ mask_k_bias=mask_k_bias,
46
+ device=device,
47
+ )
48
+ self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
49
+
50
+ self.norm2 = norm_layer(dim)
51
+ mlp_hidden_dim = int(dim * ffn_ratio)
52
+ self.mlp = ffn_layer(
53
+ in_features=dim,
54
+ hidden_features=mlp_hidden_dim,
55
+ act_layer=act_layer,
56
+ drop=drop,
57
+ bias=ffn_bias,
58
+ device=device,
59
+ )
60
+ self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
61
+
62
+ self.sample_drop_ratio = drop_path
63
+
64
+ @staticmethod
65
+ def _maybe_index_rope(rope: tuple[Tensor, Tensor] | None, indices: Tensor) -> tuple[Tensor, Tensor] | None:
66
+ if rope is None:
67
+ return None
68
+
69
+ sin, cos = rope
70
+ assert sin.ndim == cos.ndim
71
+ if sin.ndim == 4:
72
+ return sin[indices], cos[indices]
73
+ else:
74
+ return sin, cos
75
+
76
+ def _forward(self, x: Tensor, rope=None) -> Tensor:
77
+ b, _, _ = x.shape
78
+ sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1)
79
+ residual_scale_factor = b / sample_subset_size
80
+
81
+ if self.training and self.sample_drop_ratio > 0.0:
82
+ indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size]
83
+
84
+ x_subset_1 = x[indices_1]
85
+ rope_subset = self._maybe_index_rope(rope, indices_1)
86
+ residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset)
87
+
88
+ x_attn = torch.index_add(
89
+ x,
90
+ dim=0,
91
+ source=self.ls1(residual_1),
92
+ index=indices_1,
93
+ alpha=residual_scale_factor,
94
+ )
95
+
96
+ indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size]
97
+
98
+ x_subset_2 = x_attn[indices_2]
99
+ residual_2 = self.mlp(self.norm2(x_subset_2))
100
+
101
+ x_ffn = torch.index_add(
102
+ x_attn,
103
+ dim=0,
104
+ source=self.ls2(residual_2),
105
+ index=indices_2,
106
+ alpha=residual_scale_factor,
107
+ )
108
+ else:
109
+ x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
110
+ x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
111
+
112
+ return x_ffn
113
+
114
+ def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
115
+ b_list = [x.shape[0] for x in x_list]
116
+ sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list]
117
+ residual_scale_factors = [b / sample_subset_size for b, sample_subset_size in zip(b_list, sample_subset_sizes)]
118
+
119
+ if self.training and self.sample_drop_ratio > 0.0:
120
+ indices_1_list = [
121
+ (torch.randperm(b, device=x.device))[:sample_subset_size]
122
+ for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
123
+ ]
124
+ x_subset_1_list = [x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)]
125
+
126
+ if rope_list is not None:
127
+ rope_subset_list = [
128
+ self._maybe_index_rope(rope, indices_1) for rope, indices_1 in zip(rope_list, indices_1_list)
129
+ ]
130
+ else:
131
+ rope_subset_list = rope_list
132
+
133
+ flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
134
+ norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
135
+ residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
136
+
137
+ x_attn_list = [
138
+ torch.index_add(
139
+ x,
140
+ dim=0,
141
+ source=self.ls1(residual_1),
142
+ index=indices_1,
143
+ alpha=residual_scale_factor,
144
+ )
145
+ for x, residual_1, indices_1, residual_scale_factor in zip(
146
+ x_list, residual_1_list, indices_1_list, residual_scale_factors
147
+ )
148
+ ]
149
+
150
+ indices_2_list = [
151
+ (torch.randperm(b, device=x.device))[:sample_subset_size]
152
+ for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
153
+ ]
154
+ x_subset_2_list = [x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)]
155
+ flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
156
+ norm2_flat = self.norm2(flattened)
157
+ norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens)
158
+
159
+ residual_2_list = self.mlp.forward_list(norm2_list)
160
+
161
+ x_ffn = [
162
+ torch.index_add(
163
+ x_attn,
164
+ dim=0,
165
+ source=self.ls2(residual_2),
166
+ index=indices_2,
167
+ alpha=residual_scale_factor,
168
+ )
169
+ for x_attn, residual_2, indices_2, residual_scale_factor in zip(
170
+ x_attn_list, residual_2_list, indices_2_list, residual_scale_factors
171
+ )
172
+ ]
173
+ else:
174
+ x_out = []
175
+ for x, rope in zip(x_list, rope_list):
176
+ x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
177
+ x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
178
+ x_out.append(x_ffn)
179
+ x_ffn = x_out
180
+
181
+ return x_ffn
182
+
183
+ def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
184
+ if isinstance(x_or_x_list, Tensor):
185
+ return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
186
+ elif isinstance(x_or_x_list, list):
187
+ if rope_or_rope_list is None:
188
+ rope_or_rope_list = [None for x in x_or_x_list]
189
+ return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
190
+ else:
191
+ raise AssertionError
192
+
193
+
194
+ class CausalSelfAttentionBlock(nn.Module):
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ num_heads: int,
199
+ ffn_ratio: float = 4.0,
200
+ ls_init_value: Optional[float] = None,
201
+ is_causal: bool = True,
202
+ act_layer: Callable = nn.GELU,
203
+ norm_layer: Callable = nn.LayerNorm,
204
+ dropout_prob: float = 0.0,
205
+ ):
206
+ super().__init__()
207
+
208
+ self.dim = dim
209
+ self.is_causal = is_causal
210
+ self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
211
+ self.attention_norm = norm_layer(dim)
212
+ self.attention = CausalSelfAttention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob)
213
+
214
+ self.ffn_norm = norm_layer(dim)
215
+ ffn_hidden_dim = int(dim * ffn_ratio)
216
+ self.feed_forward = Mlp(
217
+ in_features=dim,
218
+ hidden_features=ffn_hidden_dim,
219
+ drop=dropout_prob,
220
+ act_layer=act_layer,
221
+ )
222
+
223
+ self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
224
+
225
+ def init_weights(
226
+ self,
227
+ init_attn_std: float | None = None,
228
+ init_proj_std: float | None = None,
229
+ init_fc_std: float | None = None,
230
+ factor: float = 1.0,
231
+ ) -> None:
232
+ init_attn_std = init_attn_std or (self.dim**-0.5)
233
+ init_proj_std = init_proj_std or init_attn_std * factor
234
+ init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
235
+ self.attention.init_weights(init_attn_std, init_proj_std)
236
+ self.attention_norm.reset_parameters()
237
+ nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
238
+ nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
239
+ self.ffn_norm.reset_parameters()
240
+
241
+ def forward(
242
+ self,
243
+ x: torch.Tensor,
244
+ ):
245
+
246
+ x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
247
+ x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
248
+ return x_ffn
249
+
eupe/layers/ffn_layers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional
2
+
3
+ import torch.nn.functional as F
4
+ from torch import Tensor, nn
5
+
6
+ from eupe.utils import cat_keep_shapes, uncat_with_shapes
7
+
8
+
9
+ class ListForwardMixin(object):
10
+ def forward(self, x: Tensor):
11
+ raise NotImplementedError
12
+
13
+ def forward_list(self, x_list: List[Tensor]) -> List[Tensor]:
14
+ x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
15
+ x_flat = self.forward(x_flat)
16
+ return uncat_with_shapes(x_flat, shapes, num_tokens)
17
+
18
+
19
+ class Mlp(nn.Module, ListForwardMixin):
20
+ def __init__(
21
+ self,
22
+ in_features: int,
23
+ hidden_features: Optional[int] = None,
24
+ out_features: Optional[int] = None,
25
+ act_layer: Callable[..., nn.Module] = nn.GELU,
26
+ drop: float = 0.0,
27
+ bias: bool = True,
28
+ device=None,
29
+ ) -> None:
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
34
+ self.act = act_layer()
35
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
36
+ self.drop = nn.Dropout(drop)
37
+
38
+ def forward(self, x: Tensor) -> Tensor:
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop(x)
42
+ x = self.fc2(x)
43
+ x = self.drop(x)
44
+ return x
45
+
46
+
47
+ class SwiGLUFFN(nn.Module, ListForwardMixin):
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ hidden_features: Optional[int] = None,
52
+ out_features: Optional[int] = None,
53
+ act_layer: Optional[Callable[..., nn.Module]] = None,
54
+ drop: float = 0.0,
55
+ bias: bool = True,
56
+ align_to: int = 8,
57
+ device=None,
58
+ ) -> None:
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ d = int(hidden_features * 2 / 3)
63
+ swiglu_hidden_features = d + (-d % align_to)
64
+ self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
65
+ self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
66
+ self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ x1 = self.w1(x)
70
+ x2 = self.w2(x)
71
+ hidden = F.silu(x1) * x2
72
+ return self.w3(hidden)
73
+
eupe/layers/layer_scale.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+
7
+ class LayerScale(nn.Module):
8
+ def __init__(
9
+ self,
10
+ dim: int,
11
+ init_values: Union[float, Tensor] = 1e-5,
12
+ inplace: bool = False,
13
+ device=None,
14
+ ) -> None:
15
+ super().__init__()
16
+ self.inplace = inplace
17
+ self.gamma = nn.Parameter(torch.empty(dim, device=device))
18
+ self.init_values = init_values
19
+
20
+ def reset_parameters(self):
21
+ nn.init.constant_(self.gamma, self.init_values)
22
+
23
+ def forward(self, x: Tensor) -> Tensor:
24
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
25
+
eupe/layers/patch_embed.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable, Tuple, Union
3
+
4
+ from torch import Tensor, nn
5
+
6
+
7
+ def make_2tuple(x):
8
+ if isinstance(x, tuple):
9
+ assert len(x) == 2
10
+ return x
11
+
12
+ assert isinstance(x, int)
13
+ return (x, x)
14
+
15
+
16
+ class PatchEmbed(nn.Module):
17
+ """
18
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ img_size: Union[int, Tuple[int, int]] = 224,
24
+ patch_size: Union[int, Tuple[int, int]] = 16,
25
+ in_chans: int = 3,
26
+ embed_dim: int = 768,
27
+ norm_layer: Callable | None = None,
28
+ flatten_embedding: bool = True,
29
+ ) -> None:
30
+ super().__init__()
31
+
32
+ image_HW = make_2tuple(img_size)
33
+ patch_HW = make_2tuple(patch_size)
34
+ patch_grid_size = (
35
+ image_HW[0] // patch_HW[0],
36
+ image_HW[1] // patch_HW[1],
37
+ )
38
+
39
+ self.img_size = image_HW
40
+ self.patch_size = patch_HW
41
+ self.patches_resolution = patch_grid_size
42
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
43
+
44
+ self.in_chans = in_chans
45
+ self.embed_dim = embed_dim
46
+
47
+ self.flatten_embedding = flatten_embedding
48
+
49
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
50
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
51
+
52
+ def forward(self, x: Tensor) -> Tensor:
53
+ x = self.proj(x) # B C H W
54
+ H, W = x.size(2), x.size(3)
55
+ x = x.flatten(2).transpose(1, 2) # B HW C
56
+ x = self.norm(x)
57
+ if not self.flatten_embedding:
58
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
59
+ return x
60
+
61
+ def flops(self) -> float:
62
+ Ho, Wo = self.patches_resolution
63
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
64
+ if self.norm is not None:
65
+ flops += Ho * Wo * self.embed_dim
66
+ return flops
67
+
68
+ def reset_parameters(self):
69
+ k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
70
+ nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
71
+ if self.proj.bias is not None:
72
+ nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
73
+
eupe/layers/rms_norm.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(self, dim: int, eps: float = 1e-5):
7
+ super().__init__()
8
+ self.weight = nn.Parameter(torch.ones(dim))
9
+ self.eps = eps
10
+
11
+ def reset_parameters(self) -> None:
12
+ nn.init.constant_(self.weight, 1)
13
+
14
+ def _norm(self, x: Tensor) -> Tensor:
15
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
16
+
17
+ def forward(self, x: Tensor) -> Tensor:
18
+ output = self._norm(x.float()).type_as(x)
19
+ return output * self.weight
20
+
eupe/layers/rope_position_encoding.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+
7
+
8
+ class RopePositionEmbedding(nn.Module):
9
+ def __init__(
10
+ self,
11
+ embed_dim: int,
12
+ *,
13
+ num_heads: int,
14
+ base: float | None = 100.0,
15
+ min_period: float | None = None,
16
+ max_period: float | None = None,
17
+ normalize_coords: Literal["min", "max", "separate"] = "separate",
18
+ shift_coords: float | None = None,
19
+ jitter_coords: float | None = None,
20
+ rescale_coords: float | None = None,
21
+ dtype: torch.dtype | None = None,
22
+ device: torch.device | None = None,
23
+ ):
24
+ super().__init__()
25
+ assert embed_dim % (4 * num_heads) == 0
26
+ both_periods = min_period is not None and max_period is not None
27
+ if (base is None and not both_periods) or (base is not None and both_periods):
28
+ raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
29
+
30
+ D_head = embed_dim // num_heads
31
+ self.base = base
32
+ self.min_period = min_period
33
+ self.max_period = max_period
34
+ self.D_head = D_head
35
+ self.normalize_coords = normalize_coords
36
+ self.shift_coords = shift_coords
37
+ self.jitter_coords = jitter_coords
38
+ self.rescale_coords = rescale_coords
39
+
40
+ self.dtype = dtype
41
+ self.register_buffer(
42
+ "periods",
43
+ torch.empty(D_head // 4, device=device, dtype=dtype),
44
+ persistent=True,
45
+ )
46
+ self._init_weights()
47
+
48
+ def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
49
+ device = self.periods.device
50
+ dtype = self.dtype
51
+ dd = {"device": device, "dtype": dtype}
52
+
53
+ if self.normalize_coords == "max":
54
+ max_HW = max(H, W)
55
+ coords_h = torch.arange(0.5, H, **dd) / max_HW
56
+ coords_w = torch.arange(0.5, W, **dd) / max_HW
57
+ elif self.normalize_coords == "min":
58
+ min_HW = min(H, W)
59
+ coords_h = torch.arange(0.5, H, **dd) / min_HW
60
+ coords_w = torch.arange(0.5, W, **dd) / min_HW
61
+ elif self.normalize_coords == "separate":
62
+ coords_h = torch.arange(0.5, H, **dd) / H
63
+ coords_w = torch.arange(0.5, W, **dd) / W
64
+ else:
65
+ raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
66
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
67
+ coords = coords.flatten(0, 1)
68
+ coords = 2.0 * coords - 1.0
69
+
70
+ if self.training and self.shift_coords is not None:
71
+ shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
72
+ coords += shift_hw[None, :]
73
+
74
+ if self.training and self.jitter_coords is not None:
75
+ jitter_max = math.log(self.jitter_coords)
76
+ jitter_min = -jitter_max
77
+ jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
78
+ coords *= jitter_hw[None, :]
79
+
80
+ if self.training and self.rescale_coords is not None:
81
+ rescale_max = math.log(self.rescale_coords)
82
+ rescale_min = -rescale_max
83
+ rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
84
+ coords *= rescale_hw
85
+
86
+ angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
87
+ angles = angles.flatten(1, 2)
88
+ angles = angles.tile(2)
89
+ cos = torch.cos(angles)
90
+ sin = torch.sin(angles)
91
+
92
+ return (sin, cos)
93
+
94
+ def _init_weights(self):
95
+ device = self.periods.device
96
+ dtype = self.dtype
97
+ if self.base is not None:
98
+ periods = self.base ** (
99
+ 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
100
+ )
101
+ else:
102
+ base = self.max_period / self.min_period
103
+ exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype)
104
+ periods = base**exponents
105
+ periods = periods / base
106
+ periods = periods * self.max_period
107
+ self.periods.data = periods
108
+
eupe/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .vision_transformer import DinoVisionTransformer
2
+
eupe/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (246 Bytes). View file
 
eupe/models/__pycache__/vision_transformer.cpython-312.pyc ADDED
Binary file (16.7 kB). View file
 
eupe/models/vision_transformer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.init
7
+ from torch import Tensor, nn
8
+
9
+ from eupe.layers import LayerScale, Mlp, PatchEmbed, RMSNorm, RopePositionEmbedding, SelfAttentionBlock, SwiGLUFFN
10
+ from eupe.utils import named_apply
11
+
12
+ logger = logging.getLogger("eupe")
13
+
14
+ ffn_layer_dict = {
15
+ "mlp": Mlp,
16
+ "swiglu": SwiGLUFFN,
17
+ "swiglu32": partial(SwiGLUFFN, align_to=32),
18
+ "swiglu64": partial(SwiGLUFFN, align_to=64),
19
+ "swiglu128": partial(SwiGLUFFN, align_to=128),
20
+ }
21
+
22
+ norm_layer_dict = {
23
+ "layernorm": partial(nn.LayerNorm, eps=1e-6),
24
+ "layernormbf16": partial(nn.LayerNorm, eps=1e-5),
25
+ "rmsnorm": RMSNorm,
26
+ }
27
+
28
+ dtype_dict = {
29
+ "fp32": torch.float32,
30
+ "fp16": torch.float16,
31
+ "bf16": torch.bfloat16,
32
+ }
33
+
34
+
35
+ def init_weights_vit(module: nn.Module, name: str = ""):
36
+ if isinstance(module, nn.Linear):
37
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
38
+ if module.bias is not None:
39
+ nn.init.zeros_(module.bias)
40
+ if hasattr(module, "bias_mask") and module.bias_mask is not None:
41
+ o = module.out_features
42
+ module.bias_mask.fill_(1)
43
+ module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
44
+ if isinstance(module, nn.LayerNorm):
45
+ module.reset_parameters()
46
+ if isinstance(module, LayerScale):
47
+ module.reset_parameters()
48
+ if isinstance(module, PatchEmbed):
49
+ module.reset_parameters()
50
+ if isinstance(module, RMSNorm):
51
+ module.reset_parameters()
52
+
53
+
54
+ class DinoVisionTransformer(nn.Module):
55
+ def __init__(
56
+ self,
57
+ *,
58
+ img_size: int = 224,
59
+ patch_size: int = 16,
60
+ in_chans: int = 3,
61
+ pos_embed_rope_base: float = 100.0,
62
+ pos_embed_rope_min_period: float | None = None,
63
+ pos_embed_rope_max_period: float | None = None,
64
+ pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
65
+ pos_embed_rope_shift_coords: float | None = None,
66
+ pos_embed_rope_jitter_coords: float | None = None,
67
+ pos_embed_rope_rescale_coords: float | None = None,
68
+ pos_embed_rope_dtype: str = "bf16",
69
+ embed_dim: int = 768,
70
+ depth: int = 12,
71
+ num_heads: int = 12,
72
+ ffn_ratio: float = 4.0,
73
+ qkv_bias: bool = True,
74
+ drop_path_rate: float = 0.0,
75
+ layerscale_init: float | None = None,
76
+ norm_layer: str = "layernorm",
77
+ ffn_layer: str = "mlp",
78
+ ffn_bias: bool = True,
79
+ proj_bias: bool = True,
80
+ n_storage_tokens: int = 0,
81
+ mask_k_bias: bool = False,
82
+ untie_cls_and_patch_norms: bool = False,
83
+ untie_global_and_local_cls_norm: bool = False,
84
+ device: Any | None = None,
85
+ **ignored_kwargs,
86
+ ):
87
+ super().__init__()
88
+ if len(ignored_kwargs) > 0:
89
+ logger.warning(f"Ignored kwargs: {ignored_kwargs}")
90
+ del ignored_kwargs
91
+
92
+ norm_layer_cls = norm_layer_dict[norm_layer]
93
+
94
+ self.num_features = self.embed_dim = embed_dim
95
+ self.n_blocks = depth
96
+ self.num_heads = num_heads
97
+ self.patch_size = patch_size
98
+
99
+ self.patch_embed = PatchEmbed(
100
+ img_size=img_size,
101
+ patch_size=patch_size,
102
+ in_chans=in_chans,
103
+ embed_dim=embed_dim,
104
+ flatten_embedding=False,
105
+ )
106
+
107
+ self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
108
+ self.n_storage_tokens = n_storage_tokens
109
+ if self.n_storage_tokens > 0:
110
+ self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
111
+ logger.info(f"using base={pos_embed_rope_base} for rope new")
112
+ logger.info(f"using min_period={pos_embed_rope_min_period} for rope new")
113
+ logger.info(f"using max_period={pos_embed_rope_max_period} for rope new")
114
+ logger.info(f"using normalize_coords={pos_embed_rope_normalize_coords} for rope new")
115
+ logger.info(f"using shift_coords={pos_embed_rope_shift_coords} for rope new")
116
+ logger.info(f"using rescale_coords={pos_embed_rope_rescale_coords} for rope new")
117
+ logger.info(f"using jitter_coords={pos_embed_rope_jitter_coords} for rope new")
118
+ logger.info(f"using dtype={pos_embed_rope_dtype} for rope new")
119
+ self.rope_embed = RopePositionEmbedding(
120
+ embed_dim=embed_dim,
121
+ num_heads=num_heads,
122
+ base=pos_embed_rope_base,
123
+ min_period=pos_embed_rope_min_period,
124
+ max_period=pos_embed_rope_max_period,
125
+ normalize_coords=pos_embed_rope_normalize_coords,
126
+ shift_coords=pos_embed_rope_shift_coords,
127
+ jitter_coords=pos_embed_rope_jitter_coords,
128
+ rescale_coords=pos_embed_rope_rescale_coords,
129
+ dtype=dtype_dict[pos_embed_rope_dtype],
130
+ device=device,
131
+ )
132
+ logger.info(f"using {ffn_layer} layer as FFN")
133
+ ffn_layer_cls = ffn_layer_dict[ffn_layer]
134
+ ffn_ratio_sequence = [ffn_ratio] * depth
135
+ blocks_list = [
136
+ SelfAttentionBlock(
137
+ dim=embed_dim,
138
+ num_heads=num_heads,
139
+ ffn_ratio=ffn_ratio_sequence[i],
140
+ qkv_bias=qkv_bias,
141
+ proj_bias=proj_bias,
142
+ ffn_bias=ffn_bias,
143
+ drop_path=drop_path_rate,
144
+ norm_layer=norm_layer_cls,
145
+ act_layer=nn.GELU,
146
+ ffn_layer=ffn_layer_cls,
147
+ init_values=layerscale_init,
148
+ mask_k_bias=mask_k_bias,
149
+ device=device,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+
154
+ self.chunked_blocks = False
155
+ self.blocks = nn.ModuleList(blocks_list)
156
+
157
+ self.norm = norm_layer_cls(embed_dim)
158
+
159
+ self.untie_cls_and_patch_norms = untie_cls_and_patch_norms
160
+ if untie_cls_and_patch_norms:
161
+ self.cls_norm = norm_layer_cls(embed_dim)
162
+ else:
163
+ self.cls_norm = None
164
+
165
+ self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm
166
+ if untie_global_and_local_cls_norm:
167
+ self.local_cls_norm = norm_layer_cls(embed_dim)
168
+ else:
169
+ self.local_cls_norm = None
170
+ self.head = nn.Identity()
171
+ self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device))
172
+
173
+ def init_weights(self):
174
+ self.rope_embed._init_weights()
175
+ nn.init.normal_(self.cls_token, std=0.02)
176
+ if self.n_storage_tokens > 0:
177
+ nn.init.normal_(self.storage_tokens, std=0.02)
178
+ nn.init.zeros_(self.mask_token)
179
+ named_apply(init_weights_vit, self)
180
+
181
+ def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int]]:
182
+ x = self.patch_embed(x)
183
+ B, H, W, _ = x.shape
184
+ x = x.flatten(1, 2)
185
+
186
+ if masks is not None:
187
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
188
+ cls_token = self.cls_token
189
+ else:
190
+ cls_token = self.cls_token + 0 * self.mask_token
191
+ if self.n_storage_tokens > 0:
192
+ storage_tokens = self.storage_tokens
193
+ else:
194
+ storage_tokens = torch.empty(
195
+ 1,
196
+ 0,
197
+ cls_token.shape[-1],
198
+ dtype=cls_token.dtype,
199
+ device=cls_token.device,
200
+ )
201
+
202
+ x = torch.cat(
203
+ [
204
+ cls_token.expand(B, -1, -1),
205
+ storage_tokens.expand(B, -1, -1),
206
+ x,
207
+ ],
208
+ dim=1,
209
+ )
210
+
211
+ return x, (H, W)
212
+
213
+ def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]:
214
+ x = []
215
+ rope = []
216
+ for t_x, t_masks in zip(x_list, masks_list):
217
+ t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks)
218
+ x.append(t2_x)
219
+ rope.append(hw_tuple)
220
+ for _, blk in enumerate(self.blocks):
221
+ if self.rope_embed is not None:
222
+ rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope]
223
+ else:
224
+ rope_sincos = [None for _ in rope]
225
+ x = blk(x, rope_sincos)
226
+ all_x = x
227
+ output = []
228
+ for idx, (x, masks) in enumerate(zip(all_x, masks_list)):
229
+ if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm:
230
+ if self.untie_global_and_local_cls_norm and self.training and idx == 1:
231
+ x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1])
232
+ elif self.untie_cls_and_patch_norms:
233
+ x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1])
234
+ else:
235
+ x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1])
236
+ x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :])
237
+ else:
238
+ x_norm = self.norm(x)
239
+ x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
240
+ x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]
241
+ output.append(
242
+ {
243
+ "x_norm_clstoken": x_norm_cls_reg[:, 0],
244
+ "x_storage_tokens": x_norm_cls_reg[:, 1:],
245
+ "x_norm_patchtokens": x_norm_patch,
246
+ "x_prenorm": x,
247
+ "masks": masks,
248
+ }
249
+ )
250
+ return output
251
+
252
+ def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]:
253
+ if isinstance(x, torch.Tensor):
254
+ return self.forward_features_list([x], [masks])[0]
255
+ else:
256
+ return self.forward_features_list(x, masks)
257
+
258
+ def _get_intermediate_layers_not_chunked(self, x: Tensor, n: int = 1) -> List[Tensor]:
259
+ x, (H, W) = self.prepare_tokens_with_masks(x)
260
+ output, total_block_len = [], len(self.blocks)
261
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
262
+ for i, blk in enumerate(self.blocks):
263
+ if self.rope_embed is not None:
264
+ rope_sincos = self.rope_embed(H=H, W=W)
265
+ else:
266
+ rope_sincos = None
267
+ x = blk(x, rope_sincos)
268
+ if i in blocks_to_take:
269
+ output.append(x)
270
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
271
+ return output
272
+
273
+ def get_intermediate_layers(
274
+ self,
275
+ x: torch.Tensor,
276
+ *,
277
+ n: Union[int, Sequence] = 1,
278
+ reshape: bool = False,
279
+ return_class_token: bool = False,
280
+ return_extra_tokens: bool = False,
281
+ norm: bool = True,
282
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]]]:
283
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
284
+ if norm:
285
+ outputs_normed = []
286
+ for out in outputs:
287
+ if self.untie_cls_and_patch_norms:
288
+ x_norm_cls_reg = self.cls_norm(out[:, : self.n_storage_tokens + 1])
289
+ x_norm_patch = self.norm(out[:, self.n_storage_tokens + 1 :])
290
+ outputs_normed.append(torch.cat((x_norm_cls_reg, x_norm_patch), dim=1))
291
+ else:
292
+ outputs_normed.append(self.norm(out))
293
+ outputs = outputs_normed
294
+ class_tokens = [out[:, 0] for out in outputs]
295
+ extra_tokens = [out[:, 1 : self.n_storage_tokens + 1] for out in outputs]
296
+ outputs = [out[:, self.n_storage_tokens + 1 :] for out in outputs]
297
+ if reshape:
298
+ B, _, h, w = x.shape
299
+ outputs = [
300
+ out.reshape(B, h // self.patch_size, w // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
301
+ for out in outputs
302
+ ]
303
+ if not return_class_token and not return_extra_tokens:
304
+ return tuple(outputs)
305
+ elif return_class_token and not return_extra_tokens:
306
+ return tuple(zip(outputs, class_tokens))
307
+ elif not return_class_token and return_extra_tokens:
308
+ return tuple(zip(outputs, extra_tokens))
309
+ elif return_class_token and return_extra_tokens:
310
+ return tuple(zip(outputs, class_tokens, extra_tokens))
311
+
312
+ def forward(self, *args, is_training: bool = False, **kwargs) -> List[Dict[str, Tensor]] | Tensor:
313
+ ret = self.forward_features(*args, **kwargs)
314
+ if is_training:
315
+ return ret
316
+ else:
317
+ return self.head(ret["x_norm_clstoken"])
318
+
eupe/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import cat_keep_shapes, named_apply, uncat_with_shapes
2
+
eupe/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (279 Bytes). View file
 
eupe/utils/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.08 kB). View file
 
eupe/utils/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Callable, List, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+
9
+ def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
10
+ shapes = [x.shape for x in x_list]
11
+ num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
12
+ flattened = torch.cat([x.flatten(0, -2) for x in x_list])
13
+ return flattened, shapes, num_tokens
14
+
15
+
16
+ def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]:
17
+ outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
18
+ shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes]
19
+ outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)]
20
+ return outputs_reshaped
21
+
22
+
23
+ def named_apply(
24
+ fn: Callable,
25
+ module: nn.Module,
26
+ name: str = "",
27
+ depth_first: bool = True,
28
+ include_root: bool = False,
29
+ ) -> nn.Module:
30
+ if not depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ for child_name, child_module in module.named_children():
33
+ child_name = ".".join((name, child_name)) if name else child_name
34
+ named_apply(
35
+ fn=fn,
36
+ module=child_module,
37
+ name=child_name,
38
+ depth_first=depth_first,
39
+ include_root=True,
40
+ )
41
+ if depth_first and include_root:
42
+ fn(module=module, name=name)
43
+ return module
44
+
45
+
46
+ def fix_random_seeds(seed: int = 31):
47
+ torch.manual_seed(seed)
48
+ torch.cuda.manual_seed_all(seed)
49
+ np.random.seed(seed)
50
+ random.seed(seed)
51
+
transformers_eupe.py CHANGED
@@ -112,6 +112,7 @@ class EupeViTModel(PreTrainedModel):
112
  mask_k_bias=config.mask_k_bias,
113
  )
114
  self.vit.init_weights()
 
115
 
116
  def _init_weights(self, module: nn.Module) -> None:
117
  # Signature required by PreTrainedModel; initialization is delegated to DinoVisionTransformer.
 
112
  mask_k_bias=config.mask_k_bias,
113
  )
114
  self.vit.init_weights()
115
+ self.post_init()
116
 
117
  def _init_weights(self, module: nn.Module) -> None:
118
  # Signature required by PreTrainedModel; initialization is delegated to DinoVisionTransformer.