Instructions to use BiliSakura/EUPE-ViT-S with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use BiliSakura/EUPE-ViT-S with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-feature-extraction", model="BiliSakura/EUPE-ViT-S")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("BiliSakura/EUPE-ViT-S", dtype="auto") - EUPE
How to use BiliSakura/EUPE-ViT-S with EUPE:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse files- README.md +1 -0
- eupe/__init__.py +1 -0
- eupe/__pycache__/__init__.cpython-312.pyc +0 -0
- eupe/layers/__init__.py +8 -0
- eupe/layers/__pycache__/__init__.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/attention.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/block.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/ffn_layers.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/layer_scale.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/patch_embed.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/rms_norm.cpython-312.pyc +0 -0
- eupe/layers/__pycache__/rope_position_encoding.cpython-312.pyc +0 -0
- eupe/layers/attention.py +153 -0
- eupe/layers/block.py +249 -0
- eupe/layers/ffn_layers.py +73 -0
- eupe/layers/layer_scale.py +25 -0
- eupe/layers/patch_embed.py +73 -0
- eupe/layers/rms_norm.py +20 -0
- eupe/layers/rope_position_encoding.py +108 -0
- eupe/models/__init__.py +2 -0
- eupe/models/__pycache__/__init__.cpython-312.pyc +0 -0
- eupe/models/__pycache__/vision_transformer.cpython-312.pyc +0 -0
- eupe/models/vision_transformer.py +318 -0
- eupe/utils/__init__.py +2 -0
- eupe/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- eupe/utils/__pycache__/utils.cpython-312.pyc +0 -0
- eupe/utils/utils.py +51 -0
- transformers_eupe.py +1 -0
README.md
CHANGED
|
@@ -34,6 +34,7 @@ This repository contains a converted EUPE checkpoint (from the original Facebook
|
|
| 34 |
- `config.json`: architecture/config parameters
|
| 35 |
- `preprocessor_config.json`: image preprocessing setup
|
| 36 |
- `transformers_eupe.py`: local EUPE Transformers registration wrapper
|
|
|
|
| 37 |
|
| 38 |
## Preprocessing
|
| 39 |
|
|
|
|
| 34 |
- `config.json`: architecture/config parameters
|
| 35 |
- `preprocessor_config.json`: image preprocessing setup
|
| 36 |
- `transformers_eupe.py`: local EUPE Transformers registration wrapper
|
| 37 |
+
- `eupe/`: vendored EUPE model implementation used by `transformers_eupe.py`
|
| 38 |
|
| 39 |
## Preprocessing
|
| 40 |
|
eupe/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Local EUPE package vendored for standalone model loading.
|
eupe/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (172 Bytes). View file
|
|
|
eupe/layers/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import CausalSelfAttention, LinearKMaskedBias, SelfAttention
|
| 2 |
+
from .block import CausalSelfAttentionBlock, SelfAttentionBlock
|
| 3 |
+
from .ffn_layers import Mlp, SwiGLUFFN
|
| 4 |
+
from .layer_scale import LayerScale
|
| 5 |
+
from .patch_embed import PatchEmbed
|
| 6 |
+
from .rms_norm import RMSNorm
|
| 7 |
+
from .rope_position_encoding import RopePositionEmbedding
|
| 8 |
+
|
eupe/layers/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (637 Bytes). View file
|
|
|
eupe/layers/__pycache__/attention.cpython-312.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
eupe/layers/__pycache__/block.cpython-312.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
eupe/layers/__pycache__/ffn_layers.cpython-312.pyc
ADDED
|
Binary file (4.36 kB). View file
|
|
|
eupe/layers/__pycache__/layer_scale.cpython-312.pyc
ADDED
|
Binary file (1.73 kB). View file
|
|
|
eupe/layers/__pycache__/patch_embed.cpython-312.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
eupe/layers/__pycache__/rms_norm.cpython-312.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
eupe/layers/__pycache__/rope_position_encoding.cpython-312.pyc
ADDED
|
Binary file (6.14 kB). View file
|
|
|
eupe/layers/attention.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from eupe.utils import cat_keep_shapes, uncat_with_shapes
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def rope_rotate_half(x: Tensor) -> Tensor:
|
| 11 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 12 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
|
| 16 |
+
return (x * cos) + (rope_rotate_half(x) * sin)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LinearKMaskedBias(nn.Linear):
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
super().__init__(*args, **kwargs)
|
| 22 |
+
o = self.out_features
|
| 23 |
+
assert o % 3 == 0
|
| 24 |
+
if self.bias is not None:
|
| 25 |
+
self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan))
|
| 26 |
+
|
| 27 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 28 |
+
masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None
|
| 29 |
+
return F.linear(input, self.weight, masked_bias)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SelfAttention(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
dim: int,
|
| 36 |
+
num_heads: int = 8,
|
| 37 |
+
qkv_bias: bool = False,
|
| 38 |
+
proj_bias: bool = True,
|
| 39 |
+
attn_drop: float = 0.0,
|
| 40 |
+
proj_drop: float = 0.0,
|
| 41 |
+
mask_k_bias: bool = False,
|
| 42 |
+
device=None,
|
| 43 |
+
) -> None:
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.num_heads = num_heads
|
| 46 |
+
head_dim = dim // num_heads
|
| 47 |
+
self.scale = head_dim**-0.5
|
| 48 |
+
|
| 49 |
+
linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear
|
| 50 |
+
self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device)
|
| 51 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 52 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
|
| 53 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 54 |
+
|
| 55 |
+
def apply_rope(self, q: Tensor, k: Tensor, rope: Tensor | Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
| 56 |
+
q_dtype = q.dtype
|
| 57 |
+
k_dtype = k.dtype
|
| 58 |
+
sin, cos = rope
|
| 59 |
+
rope_dtype = sin.dtype
|
| 60 |
+
q = q.to(dtype=rope_dtype)
|
| 61 |
+
k = k.to(dtype=rope_dtype)
|
| 62 |
+
N = q.shape[-2]
|
| 63 |
+
prefix = N - sin.shape[-2]
|
| 64 |
+
assert prefix >= 0
|
| 65 |
+
q_prefix = q[:, :, :prefix, :]
|
| 66 |
+
q = rope_apply(q[:, :, prefix:, :], sin, cos)
|
| 67 |
+
q = torch.cat((q_prefix, q), dim=-2)
|
| 68 |
+
k_prefix = k[:, :, :prefix, :]
|
| 69 |
+
k = rope_apply(k[:, :, prefix:, :], sin, cos)
|
| 70 |
+
k = torch.cat((k_prefix, k), dim=-2)
|
| 71 |
+
q = q.to(dtype=q_dtype)
|
| 72 |
+
k = k.to(dtype=k_dtype)
|
| 73 |
+
return q, k
|
| 74 |
+
|
| 75 |
+
def forward(self, x: Tensor, attn_bias=None, rope: Tensor = None) -> Tensor:
|
| 76 |
+
qkv = self.qkv(x)
|
| 77 |
+
attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope)
|
| 78 |
+
x = self.proj(attn_v)
|
| 79 |
+
x = self.proj_drop(x)
|
| 80 |
+
return x
|
| 81 |
+
|
| 82 |
+
def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]:
|
| 83 |
+
assert len(x_list) == len(rope_list)
|
| 84 |
+
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
|
| 85 |
+
qkv_flat = self.qkv(x_flat)
|
| 86 |
+
qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens)
|
| 87 |
+
att_out = []
|
| 88 |
+
for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)):
|
| 89 |
+
att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope))
|
| 90 |
+
x_flat, shapes, num_tokens = cat_keep_shapes(att_out)
|
| 91 |
+
x_flat = self.proj(x_flat)
|
| 92 |
+
return uncat_with_shapes(x_flat, shapes, num_tokens)
|
| 93 |
+
|
| 94 |
+
def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor:
|
| 95 |
+
assert attn_bias is None
|
| 96 |
+
B, N, _ = qkv.shape
|
| 97 |
+
C = self.qkv.in_features
|
| 98 |
+
|
| 99 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 100 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 101 |
+
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
|
| 102 |
+
if rope is not None:
|
| 103 |
+
q, k = self.apply_rope(q, k, rope)
|
| 104 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 105 |
+
x = x.transpose(1, 2)
|
| 106 |
+
return x.reshape([B, N, C])
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class CausalSelfAttention(nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
dim: int,
|
| 113 |
+
num_heads: int = 8,
|
| 114 |
+
qkv_bias: bool = False,
|
| 115 |
+
proj_bias: bool = True,
|
| 116 |
+
attn_drop: float = 0.0,
|
| 117 |
+
proj_drop: float = 0.0,
|
| 118 |
+
) -> None:
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.dim = dim
|
| 121 |
+
self.num_heads = num_heads
|
| 122 |
+
head_dim = dim // num_heads
|
| 123 |
+
self.scale = head_dim**-0.5
|
| 124 |
+
|
| 125 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 126 |
+
self.attn_drop = attn_drop
|
| 127 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 128 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 129 |
+
|
| 130 |
+
def init_weights(
|
| 131 |
+
self, init_attn_std: float | None = None, init_proj_std: float | None = None, factor: float = 1.0
|
| 132 |
+
) -> None:
|
| 133 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
| 134 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
| 135 |
+
nn.init.normal_(self.qkv.weight, std=init_attn_std)
|
| 136 |
+
nn.init.normal_(self.proj.weight, std=init_proj_std)
|
| 137 |
+
if self.qkv.bias is not None:
|
| 138 |
+
nn.init.zeros_(self.qkv.bias)
|
| 139 |
+
if self.proj.bias is not None:
|
| 140 |
+
nn.init.zeros_(self.proj.bias)
|
| 141 |
+
|
| 142 |
+
def forward(self, x: Tensor, is_causal: bool = True) -> Tensor:
|
| 143 |
+
B, N, C = x.shape
|
| 144 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 145 |
+
q, k, v = torch.unbind(qkv, 2)
|
| 146 |
+
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
|
| 147 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
| 148 |
+
q, k, v, attn_mask=None, dropout_p=self.attn_drop if self.training else 0, is_causal=is_causal
|
| 149 |
+
)
|
| 150 |
+
x = x.transpose(1, 2).contiguous().view(B, N, C)
|
| 151 |
+
x = self.proj_drop(self.proj(x))
|
| 152 |
+
return x
|
| 153 |
+
|
eupe/layers/block.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
from eupe.utils import cat_keep_shapes, uncat_with_shapes
|
| 7 |
+
|
| 8 |
+
from .attention import CausalSelfAttention, SelfAttention
|
| 9 |
+
from .ffn_layers import Mlp
|
| 10 |
+
from .layer_scale import LayerScale
|
| 11 |
+
|
| 12 |
+
torch._dynamo.config.automatic_dynamic_shapes = False
|
| 13 |
+
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SelfAttentionBlock(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
num_heads: int,
|
| 21 |
+
ffn_ratio: float = 4.0,
|
| 22 |
+
qkv_bias: bool = False,
|
| 23 |
+
proj_bias: bool = True,
|
| 24 |
+
ffn_bias: bool = True,
|
| 25 |
+
drop: float = 0.0,
|
| 26 |
+
attn_drop: float = 0.0,
|
| 27 |
+
init_values=None,
|
| 28 |
+
drop_path: float = 0.0,
|
| 29 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 30 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 31 |
+
attn_class: Callable[..., nn.Module] = SelfAttention,
|
| 32 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 33 |
+
mask_k_bias: bool = False,
|
| 34 |
+
device=None,
|
| 35 |
+
) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.norm1 = norm_layer(dim)
|
| 38 |
+
self.attn = attn_class(
|
| 39 |
+
dim,
|
| 40 |
+
num_heads=num_heads,
|
| 41 |
+
qkv_bias=qkv_bias,
|
| 42 |
+
proj_bias=proj_bias,
|
| 43 |
+
attn_drop=attn_drop,
|
| 44 |
+
proj_drop=drop,
|
| 45 |
+
mask_k_bias=mask_k_bias,
|
| 46 |
+
device=device,
|
| 47 |
+
)
|
| 48 |
+
self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
|
| 49 |
+
|
| 50 |
+
self.norm2 = norm_layer(dim)
|
| 51 |
+
mlp_hidden_dim = int(dim * ffn_ratio)
|
| 52 |
+
self.mlp = ffn_layer(
|
| 53 |
+
in_features=dim,
|
| 54 |
+
hidden_features=mlp_hidden_dim,
|
| 55 |
+
act_layer=act_layer,
|
| 56 |
+
drop=drop,
|
| 57 |
+
bias=ffn_bias,
|
| 58 |
+
device=device,
|
| 59 |
+
)
|
| 60 |
+
self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
|
| 61 |
+
|
| 62 |
+
self.sample_drop_ratio = drop_path
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _maybe_index_rope(rope: tuple[Tensor, Tensor] | None, indices: Tensor) -> tuple[Tensor, Tensor] | None:
|
| 66 |
+
if rope is None:
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
sin, cos = rope
|
| 70 |
+
assert sin.ndim == cos.ndim
|
| 71 |
+
if sin.ndim == 4:
|
| 72 |
+
return sin[indices], cos[indices]
|
| 73 |
+
else:
|
| 74 |
+
return sin, cos
|
| 75 |
+
|
| 76 |
+
def _forward(self, x: Tensor, rope=None) -> Tensor:
|
| 77 |
+
b, _, _ = x.shape
|
| 78 |
+
sample_subset_size = max(int(b * (1 - self.sample_drop_ratio)), 1)
|
| 79 |
+
residual_scale_factor = b / sample_subset_size
|
| 80 |
+
|
| 81 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 82 |
+
indices_1 = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 83 |
+
|
| 84 |
+
x_subset_1 = x[indices_1]
|
| 85 |
+
rope_subset = self._maybe_index_rope(rope, indices_1)
|
| 86 |
+
residual_1 = self.attn(self.norm1(x_subset_1), rope=rope_subset)
|
| 87 |
+
|
| 88 |
+
x_attn = torch.index_add(
|
| 89 |
+
x,
|
| 90 |
+
dim=0,
|
| 91 |
+
source=self.ls1(residual_1),
|
| 92 |
+
index=indices_1,
|
| 93 |
+
alpha=residual_scale_factor,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
indices_2 = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 97 |
+
|
| 98 |
+
x_subset_2 = x_attn[indices_2]
|
| 99 |
+
residual_2 = self.mlp(self.norm2(x_subset_2))
|
| 100 |
+
|
| 101 |
+
x_ffn = torch.index_add(
|
| 102 |
+
x_attn,
|
| 103 |
+
dim=0,
|
| 104 |
+
source=self.ls2(residual_2),
|
| 105 |
+
index=indices_2,
|
| 106 |
+
alpha=residual_scale_factor,
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
|
| 110 |
+
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
|
| 111 |
+
|
| 112 |
+
return x_ffn
|
| 113 |
+
|
| 114 |
+
def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
|
| 115 |
+
b_list = [x.shape[0] for x in x_list]
|
| 116 |
+
sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list]
|
| 117 |
+
residual_scale_factors = [b / sample_subset_size for b, sample_subset_size in zip(b_list, sample_subset_sizes)]
|
| 118 |
+
|
| 119 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 120 |
+
indices_1_list = [
|
| 121 |
+
(torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 122 |
+
for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
|
| 123 |
+
]
|
| 124 |
+
x_subset_1_list = [x[indices_1] for x, indices_1 in zip(x_list, indices_1_list)]
|
| 125 |
+
|
| 126 |
+
if rope_list is not None:
|
| 127 |
+
rope_subset_list = [
|
| 128 |
+
self._maybe_index_rope(rope, indices_1) for rope, indices_1 in zip(rope_list, indices_1_list)
|
| 129 |
+
]
|
| 130 |
+
else:
|
| 131 |
+
rope_subset_list = rope_list
|
| 132 |
+
|
| 133 |
+
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
|
| 134 |
+
norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
|
| 135 |
+
residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
|
| 136 |
+
|
| 137 |
+
x_attn_list = [
|
| 138 |
+
torch.index_add(
|
| 139 |
+
x,
|
| 140 |
+
dim=0,
|
| 141 |
+
source=self.ls1(residual_1),
|
| 142 |
+
index=indices_1,
|
| 143 |
+
alpha=residual_scale_factor,
|
| 144 |
+
)
|
| 145 |
+
for x, residual_1, indices_1, residual_scale_factor in zip(
|
| 146 |
+
x_list, residual_1_list, indices_1_list, residual_scale_factors
|
| 147 |
+
)
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
indices_2_list = [
|
| 151 |
+
(torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 152 |
+
for x, b, sample_subset_size in zip(x_list, b_list, sample_subset_sizes)
|
| 153 |
+
]
|
| 154 |
+
x_subset_2_list = [x[indices_2] for x, indices_2 in zip(x_attn_list, indices_2_list)]
|
| 155 |
+
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
|
| 156 |
+
norm2_flat = self.norm2(flattened)
|
| 157 |
+
norm2_list = uncat_with_shapes(norm2_flat, shapes, num_tokens)
|
| 158 |
+
|
| 159 |
+
residual_2_list = self.mlp.forward_list(norm2_list)
|
| 160 |
+
|
| 161 |
+
x_ffn = [
|
| 162 |
+
torch.index_add(
|
| 163 |
+
x_attn,
|
| 164 |
+
dim=0,
|
| 165 |
+
source=self.ls2(residual_2),
|
| 166 |
+
index=indices_2,
|
| 167 |
+
alpha=residual_scale_factor,
|
| 168 |
+
)
|
| 169 |
+
for x_attn, residual_2, indices_2, residual_scale_factor in zip(
|
| 170 |
+
x_attn_list, residual_2_list, indices_2_list, residual_scale_factors
|
| 171 |
+
)
|
| 172 |
+
]
|
| 173 |
+
else:
|
| 174 |
+
x_out = []
|
| 175 |
+
for x, rope in zip(x_list, rope_list):
|
| 176 |
+
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
|
| 177 |
+
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
|
| 178 |
+
x_out.append(x_ffn)
|
| 179 |
+
x_ffn = x_out
|
| 180 |
+
|
| 181 |
+
return x_ffn
|
| 182 |
+
|
| 183 |
+
def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
|
| 184 |
+
if isinstance(x_or_x_list, Tensor):
|
| 185 |
+
return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
|
| 186 |
+
elif isinstance(x_or_x_list, list):
|
| 187 |
+
if rope_or_rope_list is None:
|
| 188 |
+
rope_or_rope_list = [None for x in x_or_x_list]
|
| 189 |
+
return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
|
| 190 |
+
else:
|
| 191 |
+
raise AssertionError
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class CausalSelfAttentionBlock(nn.Module):
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
dim: int,
|
| 198 |
+
num_heads: int,
|
| 199 |
+
ffn_ratio: float = 4.0,
|
| 200 |
+
ls_init_value: Optional[float] = None,
|
| 201 |
+
is_causal: bool = True,
|
| 202 |
+
act_layer: Callable = nn.GELU,
|
| 203 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 204 |
+
dropout_prob: float = 0.0,
|
| 205 |
+
):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
self.dim = dim
|
| 209 |
+
self.is_causal = is_causal
|
| 210 |
+
self.ls1 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
|
| 211 |
+
self.attention_norm = norm_layer(dim)
|
| 212 |
+
self.attention = CausalSelfAttention(dim, num_heads, attn_drop=dropout_prob, proj_drop=dropout_prob)
|
| 213 |
+
|
| 214 |
+
self.ffn_norm = norm_layer(dim)
|
| 215 |
+
ffn_hidden_dim = int(dim * ffn_ratio)
|
| 216 |
+
self.feed_forward = Mlp(
|
| 217 |
+
in_features=dim,
|
| 218 |
+
hidden_features=ffn_hidden_dim,
|
| 219 |
+
drop=dropout_prob,
|
| 220 |
+
act_layer=act_layer,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.ls2 = LayerScale(dim, init_values=ls_init_value) if ls_init_value else nn.Identity()
|
| 224 |
+
|
| 225 |
+
def init_weights(
|
| 226 |
+
self,
|
| 227 |
+
init_attn_std: float | None = None,
|
| 228 |
+
init_proj_std: float | None = None,
|
| 229 |
+
init_fc_std: float | None = None,
|
| 230 |
+
factor: float = 1.0,
|
| 231 |
+
) -> None:
|
| 232 |
+
init_attn_std = init_attn_std or (self.dim**-0.5)
|
| 233 |
+
init_proj_std = init_proj_std or init_attn_std * factor
|
| 234 |
+
init_fc_std = init_fc_std or (2 * self.dim) ** -0.5
|
| 235 |
+
self.attention.init_weights(init_attn_std, init_proj_std)
|
| 236 |
+
self.attention_norm.reset_parameters()
|
| 237 |
+
nn.init.normal_(self.feed_forward.fc1.weight, std=init_fc_std)
|
| 238 |
+
nn.init.normal_(self.feed_forward.fc2.weight, std=init_proj_std)
|
| 239 |
+
self.ffn_norm.reset_parameters()
|
| 240 |
+
|
| 241 |
+
def forward(
|
| 242 |
+
self,
|
| 243 |
+
x: torch.Tensor,
|
| 244 |
+
):
|
| 245 |
+
|
| 246 |
+
x_attn = x + self.ls1(self.attention(self.attention_norm(x), self.is_causal))
|
| 247 |
+
x_ffn = x_attn + self.ls2(self.feed_forward(self.ffn_norm(x_attn)))
|
| 248 |
+
return x_ffn
|
| 249 |
+
|
eupe/layers/ffn_layers.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, List, Optional
|
| 2 |
+
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
from eupe.utils import cat_keep_shapes, uncat_with_shapes
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ListForwardMixin(object):
|
| 10 |
+
def forward(self, x: Tensor):
|
| 11 |
+
raise NotImplementedError
|
| 12 |
+
|
| 13 |
+
def forward_list(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 14 |
+
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
|
| 15 |
+
x_flat = self.forward(x_flat)
|
| 16 |
+
return uncat_with_shapes(x_flat, shapes, num_tokens)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Mlp(nn.Module, ListForwardMixin):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
in_features: int,
|
| 23 |
+
hidden_features: Optional[int] = None,
|
| 24 |
+
out_features: Optional[int] = None,
|
| 25 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 26 |
+
drop: float = 0.0,
|
| 27 |
+
bias: bool = True,
|
| 28 |
+
device=None,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
out_features = out_features or in_features
|
| 32 |
+
hidden_features = hidden_features or in_features
|
| 33 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
|
| 34 |
+
self.act = act_layer()
|
| 35 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
|
| 36 |
+
self.drop = nn.Dropout(drop)
|
| 37 |
+
|
| 38 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 39 |
+
x = self.fc1(x)
|
| 40 |
+
x = self.act(x)
|
| 41 |
+
x = self.drop(x)
|
| 42 |
+
x = self.fc2(x)
|
| 43 |
+
x = self.drop(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SwiGLUFFN(nn.Module, ListForwardMixin):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
in_features: int,
|
| 51 |
+
hidden_features: Optional[int] = None,
|
| 52 |
+
out_features: Optional[int] = None,
|
| 53 |
+
act_layer: Optional[Callable[..., nn.Module]] = None,
|
| 54 |
+
drop: float = 0.0,
|
| 55 |
+
bias: bool = True,
|
| 56 |
+
align_to: int = 8,
|
| 57 |
+
device=None,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
out_features = out_features or in_features
|
| 61 |
+
hidden_features = hidden_features or in_features
|
| 62 |
+
d = int(hidden_features * 2 / 3)
|
| 63 |
+
swiglu_hidden_features = d + (-d % align_to)
|
| 64 |
+
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
|
| 65 |
+
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
|
| 66 |
+
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
x1 = self.w1(x)
|
| 70 |
+
x2 = self.w2(x)
|
| 71 |
+
hidden = F.silu(x1) * x2
|
| 72 |
+
return self.w3(hidden)
|
| 73 |
+
|
eupe/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LayerScale(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
dim: int,
|
| 11 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 12 |
+
inplace: bool = False,
|
| 13 |
+
device=None,
|
| 14 |
+
) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.inplace = inplace
|
| 17 |
+
self.gamma = nn.Parameter(torch.empty(dim, device=device))
|
| 18 |
+
self.init_values = init_values
|
| 19 |
+
|
| 20 |
+
def reset_parameters(self):
|
| 21 |
+
nn.init.constant_(self.gamma, self.init_values)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 24 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 25 |
+
|
eupe/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Callable, Tuple, Union
|
| 3 |
+
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def make_2tuple(x):
|
| 8 |
+
if isinstance(x, tuple):
|
| 9 |
+
assert len(x) == 2
|
| 10 |
+
return x
|
| 11 |
+
|
| 12 |
+
assert isinstance(x, int)
|
| 13 |
+
return (x, x)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PatchEmbed(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 24 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 25 |
+
in_chans: int = 3,
|
| 26 |
+
embed_dim: int = 768,
|
| 27 |
+
norm_layer: Callable | None = None,
|
| 28 |
+
flatten_embedding: bool = True,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
image_HW = make_2tuple(img_size)
|
| 33 |
+
patch_HW = make_2tuple(patch_size)
|
| 34 |
+
patch_grid_size = (
|
| 35 |
+
image_HW[0] // patch_HW[0],
|
| 36 |
+
image_HW[1] // patch_HW[1],
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.img_size = image_HW
|
| 40 |
+
self.patch_size = patch_HW
|
| 41 |
+
self.patches_resolution = patch_grid_size
|
| 42 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 43 |
+
|
| 44 |
+
self.in_chans = in_chans
|
| 45 |
+
self.embed_dim = embed_dim
|
| 46 |
+
|
| 47 |
+
self.flatten_embedding = flatten_embedding
|
| 48 |
+
|
| 49 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 50 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 51 |
+
|
| 52 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 53 |
+
x = self.proj(x) # B C H W
|
| 54 |
+
H, W = x.size(2), x.size(3)
|
| 55 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 56 |
+
x = self.norm(x)
|
| 57 |
+
if not self.flatten_embedding:
|
| 58 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
def flops(self) -> float:
|
| 62 |
+
Ho, Wo = self.patches_resolution
|
| 63 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 64 |
+
if self.norm is not None:
|
| 65 |
+
flops += Ho * Wo * self.embed_dim
|
| 66 |
+
return flops
|
| 67 |
+
|
| 68 |
+
def reset_parameters(self):
|
| 69 |
+
k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
|
| 70 |
+
nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
|
| 71 |
+
if self.proj.bias is not None:
|
| 72 |
+
nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
|
| 73 |
+
|
eupe/layers/rms_norm.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor, nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RMSNorm(nn.Module):
|
| 6 |
+
def __init__(self, dim: int, eps: float = 1e-5):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 9 |
+
self.eps = eps
|
| 10 |
+
|
| 11 |
+
def reset_parameters(self) -> None:
|
| 12 |
+
nn.init.constant_(self.weight, 1)
|
| 13 |
+
|
| 14 |
+
def _norm(self, x: Tensor) -> Tensor:
|
| 15 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 16 |
+
|
| 17 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 18 |
+
output = self._norm(x.float()).type_as(x)
|
| 19 |
+
return output * self.weight
|
| 20 |
+
|
eupe/layers/rope_position_encoding.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RopePositionEmbedding(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
embed_dim: int,
|
| 12 |
+
*,
|
| 13 |
+
num_heads: int,
|
| 14 |
+
base: float | None = 100.0,
|
| 15 |
+
min_period: float | None = None,
|
| 16 |
+
max_period: float | None = None,
|
| 17 |
+
normalize_coords: Literal["min", "max", "separate"] = "separate",
|
| 18 |
+
shift_coords: float | None = None,
|
| 19 |
+
jitter_coords: float | None = None,
|
| 20 |
+
rescale_coords: float | None = None,
|
| 21 |
+
dtype: torch.dtype | None = None,
|
| 22 |
+
device: torch.device | None = None,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
assert embed_dim % (4 * num_heads) == 0
|
| 26 |
+
both_periods = min_period is not None and max_period is not None
|
| 27 |
+
if (base is None and not both_periods) or (base is not None and both_periods):
|
| 28 |
+
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
|
| 29 |
+
|
| 30 |
+
D_head = embed_dim // num_heads
|
| 31 |
+
self.base = base
|
| 32 |
+
self.min_period = min_period
|
| 33 |
+
self.max_period = max_period
|
| 34 |
+
self.D_head = D_head
|
| 35 |
+
self.normalize_coords = normalize_coords
|
| 36 |
+
self.shift_coords = shift_coords
|
| 37 |
+
self.jitter_coords = jitter_coords
|
| 38 |
+
self.rescale_coords = rescale_coords
|
| 39 |
+
|
| 40 |
+
self.dtype = dtype
|
| 41 |
+
self.register_buffer(
|
| 42 |
+
"periods",
|
| 43 |
+
torch.empty(D_head // 4, device=device, dtype=dtype),
|
| 44 |
+
persistent=True,
|
| 45 |
+
)
|
| 46 |
+
self._init_weights()
|
| 47 |
+
|
| 48 |
+
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
|
| 49 |
+
device = self.periods.device
|
| 50 |
+
dtype = self.dtype
|
| 51 |
+
dd = {"device": device, "dtype": dtype}
|
| 52 |
+
|
| 53 |
+
if self.normalize_coords == "max":
|
| 54 |
+
max_HW = max(H, W)
|
| 55 |
+
coords_h = torch.arange(0.5, H, **dd) / max_HW
|
| 56 |
+
coords_w = torch.arange(0.5, W, **dd) / max_HW
|
| 57 |
+
elif self.normalize_coords == "min":
|
| 58 |
+
min_HW = min(H, W)
|
| 59 |
+
coords_h = torch.arange(0.5, H, **dd) / min_HW
|
| 60 |
+
coords_w = torch.arange(0.5, W, **dd) / min_HW
|
| 61 |
+
elif self.normalize_coords == "separate":
|
| 62 |
+
coords_h = torch.arange(0.5, H, **dd) / H
|
| 63 |
+
coords_w = torch.arange(0.5, W, **dd) / W
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
|
| 66 |
+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
| 67 |
+
coords = coords.flatten(0, 1)
|
| 68 |
+
coords = 2.0 * coords - 1.0
|
| 69 |
+
|
| 70 |
+
if self.training and self.shift_coords is not None:
|
| 71 |
+
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
|
| 72 |
+
coords += shift_hw[None, :]
|
| 73 |
+
|
| 74 |
+
if self.training and self.jitter_coords is not None:
|
| 75 |
+
jitter_max = math.log(self.jitter_coords)
|
| 76 |
+
jitter_min = -jitter_max
|
| 77 |
+
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
|
| 78 |
+
coords *= jitter_hw[None, :]
|
| 79 |
+
|
| 80 |
+
if self.training and self.rescale_coords is not None:
|
| 81 |
+
rescale_max = math.log(self.rescale_coords)
|
| 82 |
+
rescale_min = -rescale_max
|
| 83 |
+
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
|
| 84 |
+
coords *= rescale_hw
|
| 85 |
+
|
| 86 |
+
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
|
| 87 |
+
angles = angles.flatten(1, 2)
|
| 88 |
+
angles = angles.tile(2)
|
| 89 |
+
cos = torch.cos(angles)
|
| 90 |
+
sin = torch.sin(angles)
|
| 91 |
+
|
| 92 |
+
return (sin, cos)
|
| 93 |
+
|
| 94 |
+
def _init_weights(self):
|
| 95 |
+
device = self.periods.device
|
| 96 |
+
dtype = self.dtype
|
| 97 |
+
if self.base is not None:
|
| 98 |
+
periods = self.base ** (
|
| 99 |
+
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
base = self.max_period / self.min_period
|
| 103 |
+
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype)
|
| 104 |
+
periods = base**exponents
|
| 105 |
+
periods = periods / base
|
| 106 |
+
periods = periods * self.max_period
|
| 107 |
+
self.periods.data = periods
|
| 108 |
+
|
eupe/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vision_transformer import DinoVisionTransformer
|
| 2 |
+
|
eupe/models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (246 Bytes). View file
|
|
|
eupe/models/__pycache__/vision_transformer.cpython-312.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
eupe/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.init
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
|
| 9 |
+
from eupe.layers import LayerScale, Mlp, PatchEmbed, RMSNorm, RopePositionEmbedding, SelfAttentionBlock, SwiGLUFFN
|
| 10 |
+
from eupe.utils import named_apply
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("eupe")
|
| 13 |
+
|
| 14 |
+
ffn_layer_dict = {
|
| 15 |
+
"mlp": Mlp,
|
| 16 |
+
"swiglu": SwiGLUFFN,
|
| 17 |
+
"swiglu32": partial(SwiGLUFFN, align_to=32),
|
| 18 |
+
"swiglu64": partial(SwiGLUFFN, align_to=64),
|
| 19 |
+
"swiglu128": partial(SwiGLUFFN, align_to=128),
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
norm_layer_dict = {
|
| 23 |
+
"layernorm": partial(nn.LayerNorm, eps=1e-6),
|
| 24 |
+
"layernormbf16": partial(nn.LayerNorm, eps=1e-5),
|
| 25 |
+
"rmsnorm": RMSNorm,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
dtype_dict = {
|
| 29 |
+
"fp32": torch.float32,
|
| 30 |
+
"fp16": torch.float16,
|
| 31 |
+
"bf16": torch.bfloat16,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def init_weights_vit(module: nn.Module, name: str = ""):
|
| 36 |
+
if isinstance(module, nn.Linear):
|
| 37 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
| 38 |
+
if module.bias is not None:
|
| 39 |
+
nn.init.zeros_(module.bias)
|
| 40 |
+
if hasattr(module, "bias_mask") and module.bias_mask is not None:
|
| 41 |
+
o = module.out_features
|
| 42 |
+
module.bias_mask.fill_(1)
|
| 43 |
+
module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
|
| 44 |
+
if isinstance(module, nn.LayerNorm):
|
| 45 |
+
module.reset_parameters()
|
| 46 |
+
if isinstance(module, LayerScale):
|
| 47 |
+
module.reset_parameters()
|
| 48 |
+
if isinstance(module, PatchEmbed):
|
| 49 |
+
module.reset_parameters()
|
| 50 |
+
if isinstance(module, RMSNorm):
|
| 51 |
+
module.reset_parameters()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DinoVisionTransformer(nn.Module):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
*,
|
| 58 |
+
img_size: int = 224,
|
| 59 |
+
patch_size: int = 16,
|
| 60 |
+
in_chans: int = 3,
|
| 61 |
+
pos_embed_rope_base: float = 100.0,
|
| 62 |
+
pos_embed_rope_min_period: float | None = None,
|
| 63 |
+
pos_embed_rope_max_period: float | None = None,
|
| 64 |
+
pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
|
| 65 |
+
pos_embed_rope_shift_coords: float | None = None,
|
| 66 |
+
pos_embed_rope_jitter_coords: float | None = None,
|
| 67 |
+
pos_embed_rope_rescale_coords: float | None = None,
|
| 68 |
+
pos_embed_rope_dtype: str = "bf16",
|
| 69 |
+
embed_dim: int = 768,
|
| 70 |
+
depth: int = 12,
|
| 71 |
+
num_heads: int = 12,
|
| 72 |
+
ffn_ratio: float = 4.0,
|
| 73 |
+
qkv_bias: bool = True,
|
| 74 |
+
drop_path_rate: float = 0.0,
|
| 75 |
+
layerscale_init: float | None = None,
|
| 76 |
+
norm_layer: str = "layernorm",
|
| 77 |
+
ffn_layer: str = "mlp",
|
| 78 |
+
ffn_bias: bool = True,
|
| 79 |
+
proj_bias: bool = True,
|
| 80 |
+
n_storage_tokens: int = 0,
|
| 81 |
+
mask_k_bias: bool = False,
|
| 82 |
+
untie_cls_and_patch_norms: bool = False,
|
| 83 |
+
untie_global_and_local_cls_norm: bool = False,
|
| 84 |
+
device: Any | None = None,
|
| 85 |
+
**ignored_kwargs,
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
if len(ignored_kwargs) > 0:
|
| 89 |
+
logger.warning(f"Ignored kwargs: {ignored_kwargs}")
|
| 90 |
+
del ignored_kwargs
|
| 91 |
+
|
| 92 |
+
norm_layer_cls = norm_layer_dict[norm_layer]
|
| 93 |
+
|
| 94 |
+
self.num_features = self.embed_dim = embed_dim
|
| 95 |
+
self.n_blocks = depth
|
| 96 |
+
self.num_heads = num_heads
|
| 97 |
+
self.patch_size = patch_size
|
| 98 |
+
|
| 99 |
+
self.patch_embed = PatchEmbed(
|
| 100 |
+
img_size=img_size,
|
| 101 |
+
patch_size=patch_size,
|
| 102 |
+
in_chans=in_chans,
|
| 103 |
+
embed_dim=embed_dim,
|
| 104 |
+
flatten_embedding=False,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
|
| 108 |
+
self.n_storage_tokens = n_storage_tokens
|
| 109 |
+
if self.n_storage_tokens > 0:
|
| 110 |
+
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
|
| 111 |
+
logger.info(f"using base={pos_embed_rope_base} for rope new")
|
| 112 |
+
logger.info(f"using min_period={pos_embed_rope_min_period} for rope new")
|
| 113 |
+
logger.info(f"using max_period={pos_embed_rope_max_period} for rope new")
|
| 114 |
+
logger.info(f"using normalize_coords={pos_embed_rope_normalize_coords} for rope new")
|
| 115 |
+
logger.info(f"using shift_coords={pos_embed_rope_shift_coords} for rope new")
|
| 116 |
+
logger.info(f"using rescale_coords={pos_embed_rope_rescale_coords} for rope new")
|
| 117 |
+
logger.info(f"using jitter_coords={pos_embed_rope_jitter_coords} for rope new")
|
| 118 |
+
logger.info(f"using dtype={pos_embed_rope_dtype} for rope new")
|
| 119 |
+
self.rope_embed = RopePositionEmbedding(
|
| 120 |
+
embed_dim=embed_dim,
|
| 121 |
+
num_heads=num_heads,
|
| 122 |
+
base=pos_embed_rope_base,
|
| 123 |
+
min_period=pos_embed_rope_min_period,
|
| 124 |
+
max_period=pos_embed_rope_max_period,
|
| 125 |
+
normalize_coords=pos_embed_rope_normalize_coords,
|
| 126 |
+
shift_coords=pos_embed_rope_shift_coords,
|
| 127 |
+
jitter_coords=pos_embed_rope_jitter_coords,
|
| 128 |
+
rescale_coords=pos_embed_rope_rescale_coords,
|
| 129 |
+
dtype=dtype_dict[pos_embed_rope_dtype],
|
| 130 |
+
device=device,
|
| 131 |
+
)
|
| 132 |
+
logger.info(f"using {ffn_layer} layer as FFN")
|
| 133 |
+
ffn_layer_cls = ffn_layer_dict[ffn_layer]
|
| 134 |
+
ffn_ratio_sequence = [ffn_ratio] * depth
|
| 135 |
+
blocks_list = [
|
| 136 |
+
SelfAttentionBlock(
|
| 137 |
+
dim=embed_dim,
|
| 138 |
+
num_heads=num_heads,
|
| 139 |
+
ffn_ratio=ffn_ratio_sequence[i],
|
| 140 |
+
qkv_bias=qkv_bias,
|
| 141 |
+
proj_bias=proj_bias,
|
| 142 |
+
ffn_bias=ffn_bias,
|
| 143 |
+
drop_path=drop_path_rate,
|
| 144 |
+
norm_layer=norm_layer_cls,
|
| 145 |
+
act_layer=nn.GELU,
|
| 146 |
+
ffn_layer=ffn_layer_cls,
|
| 147 |
+
init_values=layerscale_init,
|
| 148 |
+
mask_k_bias=mask_k_bias,
|
| 149 |
+
device=device,
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
self.chunked_blocks = False
|
| 155 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 156 |
+
|
| 157 |
+
self.norm = norm_layer_cls(embed_dim)
|
| 158 |
+
|
| 159 |
+
self.untie_cls_and_patch_norms = untie_cls_and_patch_norms
|
| 160 |
+
if untie_cls_and_patch_norms:
|
| 161 |
+
self.cls_norm = norm_layer_cls(embed_dim)
|
| 162 |
+
else:
|
| 163 |
+
self.cls_norm = None
|
| 164 |
+
|
| 165 |
+
self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm
|
| 166 |
+
if untie_global_and_local_cls_norm:
|
| 167 |
+
self.local_cls_norm = norm_layer_cls(embed_dim)
|
| 168 |
+
else:
|
| 169 |
+
self.local_cls_norm = None
|
| 170 |
+
self.head = nn.Identity()
|
| 171 |
+
self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device))
|
| 172 |
+
|
| 173 |
+
def init_weights(self):
|
| 174 |
+
self.rope_embed._init_weights()
|
| 175 |
+
nn.init.normal_(self.cls_token, std=0.02)
|
| 176 |
+
if self.n_storage_tokens > 0:
|
| 177 |
+
nn.init.normal_(self.storage_tokens, std=0.02)
|
| 178 |
+
nn.init.zeros_(self.mask_token)
|
| 179 |
+
named_apply(init_weights_vit, self)
|
| 180 |
+
|
| 181 |
+
def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int]]:
|
| 182 |
+
x = self.patch_embed(x)
|
| 183 |
+
B, H, W, _ = x.shape
|
| 184 |
+
x = x.flatten(1, 2)
|
| 185 |
+
|
| 186 |
+
if masks is not None:
|
| 187 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 188 |
+
cls_token = self.cls_token
|
| 189 |
+
else:
|
| 190 |
+
cls_token = self.cls_token + 0 * self.mask_token
|
| 191 |
+
if self.n_storage_tokens > 0:
|
| 192 |
+
storage_tokens = self.storage_tokens
|
| 193 |
+
else:
|
| 194 |
+
storage_tokens = torch.empty(
|
| 195 |
+
1,
|
| 196 |
+
0,
|
| 197 |
+
cls_token.shape[-1],
|
| 198 |
+
dtype=cls_token.dtype,
|
| 199 |
+
device=cls_token.device,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
x = torch.cat(
|
| 203 |
+
[
|
| 204 |
+
cls_token.expand(B, -1, -1),
|
| 205 |
+
storage_tokens.expand(B, -1, -1),
|
| 206 |
+
x,
|
| 207 |
+
],
|
| 208 |
+
dim=1,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return x, (H, W)
|
| 212 |
+
|
| 213 |
+
def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]:
|
| 214 |
+
x = []
|
| 215 |
+
rope = []
|
| 216 |
+
for t_x, t_masks in zip(x_list, masks_list):
|
| 217 |
+
t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks)
|
| 218 |
+
x.append(t2_x)
|
| 219 |
+
rope.append(hw_tuple)
|
| 220 |
+
for _, blk in enumerate(self.blocks):
|
| 221 |
+
if self.rope_embed is not None:
|
| 222 |
+
rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope]
|
| 223 |
+
else:
|
| 224 |
+
rope_sincos = [None for _ in rope]
|
| 225 |
+
x = blk(x, rope_sincos)
|
| 226 |
+
all_x = x
|
| 227 |
+
output = []
|
| 228 |
+
for idx, (x, masks) in enumerate(zip(all_x, masks_list)):
|
| 229 |
+
if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm:
|
| 230 |
+
if self.untie_global_and_local_cls_norm and self.training and idx == 1:
|
| 231 |
+
x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1])
|
| 232 |
+
elif self.untie_cls_and_patch_norms:
|
| 233 |
+
x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1])
|
| 234 |
+
else:
|
| 235 |
+
x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1])
|
| 236 |
+
x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :])
|
| 237 |
+
else:
|
| 238 |
+
x_norm = self.norm(x)
|
| 239 |
+
x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
|
| 240 |
+
x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]
|
| 241 |
+
output.append(
|
| 242 |
+
{
|
| 243 |
+
"x_norm_clstoken": x_norm_cls_reg[:, 0],
|
| 244 |
+
"x_storage_tokens": x_norm_cls_reg[:, 1:],
|
| 245 |
+
"x_norm_patchtokens": x_norm_patch,
|
| 246 |
+
"x_prenorm": x,
|
| 247 |
+
"masks": masks,
|
| 248 |
+
}
|
| 249 |
+
)
|
| 250 |
+
return output
|
| 251 |
+
|
| 252 |
+
def forward_features(self, x: Tensor | List[Tensor], masks: Optional[Tensor] = None) -> List[Dict[str, Tensor]]:
|
| 253 |
+
if isinstance(x, torch.Tensor):
|
| 254 |
+
return self.forward_features_list([x], [masks])[0]
|
| 255 |
+
else:
|
| 256 |
+
return self.forward_features_list(x, masks)
|
| 257 |
+
|
| 258 |
+
def _get_intermediate_layers_not_chunked(self, x: Tensor, n: int = 1) -> List[Tensor]:
|
| 259 |
+
x, (H, W) = self.prepare_tokens_with_masks(x)
|
| 260 |
+
output, total_block_len = [], len(self.blocks)
|
| 261 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 262 |
+
for i, blk in enumerate(self.blocks):
|
| 263 |
+
if self.rope_embed is not None:
|
| 264 |
+
rope_sincos = self.rope_embed(H=H, W=W)
|
| 265 |
+
else:
|
| 266 |
+
rope_sincos = None
|
| 267 |
+
x = blk(x, rope_sincos)
|
| 268 |
+
if i in blocks_to_take:
|
| 269 |
+
output.append(x)
|
| 270 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 271 |
+
return output
|
| 272 |
+
|
| 273 |
+
def get_intermediate_layers(
|
| 274 |
+
self,
|
| 275 |
+
x: torch.Tensor,
|
| 276 |
+
*,
|
| 277 |
+
n: Union[int, Sequence] = 1,
|
| 278 |
+
reshape: bool = False,
|
| 279 |
+
return_class_token: bool = False,
|
| 280 |
+
return_extra_tokens: bool = False,
|
| 281 |
+
norm: bool = True,
|
| 282 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]]]:
|
| 283 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 284 |
+
if norm:
|
| 285 |
+
outputs_normed = []
|
| 286 |
+
for out in outputs:
|
| 287 |
+
if self.untie_cls_and_patch_norms:
|
| 288 |
+
x_norm_cls_reg = self.cls_norm(out[:, : self.n_storage_tokens + 1])
|
| 289 |
+
x_norm_patch = self.norm(out[:, self.n_storage_tokens + 1 :])
|
| 290 |
+
outputs_normed.append(torch.cat((x_norm_cls_reg, x_norm_patch), dim=1))
|
| 291 |
+
else:
|
| 292 |
+
outputs_normed.append(self.norm(out))
|
| 293 |
+
outputs = outputs_normed
|
| 294 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 295 |
+
extra_tokens = [out[:, 1 : self.n_storage_tokens + 1] for out in outputs]
|
| 296 |
+
outputs = [out[:, self.n_storage_tokens + 1 :] for out in outputs]
|
| 297 |
+
if reshape:
|
| 298 |
+
B, _, h, w = x.shape
|
| 299 |
+
outputs = [
|
| 300 |
+
out.reshape(B, h // self.patch_size, w // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 301 |
+
for out in outputs
|
| 302 |
+
]
|
| 303 |
+
if not return_class_token and not return_extra_tokens:
|
| 304 |
+
return tuple(outputs)
|
| 305 |
+
elif return_class_token and not return_extra_tokens:
|
| 306 |
+
return tuple(zip(outputs, class_tokens))
|
| 307 |
+
elif not return_class_token and return_extra_tokens:
|
| 308 |
+
return tuple(zip(outputs, extra_tokens))
|
| 309 |
+
elif return_class_token and return_extra_tokens:
|
| 310 |
+
return tuple(zip(outputs, class_tokens, extra_tokens))
|
| 311 |
+
|
| 312 |
+
def forward(self, *args, is_training: bool = False, **kwargs) -> List[Dict[str, Tensor]] | Tensor:
|
| 313 |
+
ret = self.forward_features(*args, **kwargs)
|
| 314 |
+
if is_training:
|
| 315 |
+
return ret
|
| 316 |
+
else:
|
| 317 |
+
return self.head(ret["x_norm_clstoken"])
|
| 318 |
+
|
eupe/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import cat_keep_shapes, named_apply, uncat_with_shapes
|
| 2 |
+
|
eupe/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (279 Bytes). View file
|
|
|
eupe/utils/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (3.08 kB). View file
|
|
|
eupe/utils/utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Callable, List, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
|
| 10 |
+
shapes = [x.shape for x in x_list]
|
| 11 |
+
num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
|
| 12 |
+
flattened = torch.cat([x.flatten(0, -2) for x in x_list])
|
| 13 |
+
return flattened, shapes, num_tokens
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]:
|
| 17 |
+
outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
|
| 18 |
+
shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes]
|
| 19 |
+
outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)]
|
| 20 |
+
return outputs_reshaped
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def named_apply(
|
| 24 |
+
fn: Callable,
|
| 25 |
+
module: nn.Module,
|
| 26 |
+
name: str = "",
|
| 27 |
+
depth_first: bool = True,
|
| 28 |
+
include_root: bool = False,
|
| 29 |
+
) -> nn.Module:
|
| 30 |
+
if not depth_first and include_root:
|
| 31 |
+
fn(module=module, name=name)
|
| 32 |
+
for child_name, child_module in module.named_children():
|
| 33 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 34 |
+
named_apply(
|
| 35 |
+
fn=fn,
|
| 36 |
+
module=child_module,
|
| 37 |
+
name=child_name,
|
| 38 |
+
depth_first=depth_first,
|
| 39 |
+
include_root=True,
|
| 40 |
+
)
|
| 41 |
+
if depth_first and include_root:
|
| 42 |
+
fn(module=module, name=name)
|
| 43 |
+
return module
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def fix_random_seeds(seed: int = 31):
|
| 47 |
+
torch.manual_seed(seed)
|
| 48 |
+
torch.cuda.manual_seed_all(seed)
|
| 49 |
+
np.random.seed(seed)
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
|
transformers_eupe.py
CHANGED
|
@@ -112,6 +112,7 @@ class EupeViTModel(PreTrainedModel):
|
|
| 112 |
mask_k_bias=config.mask_k_bias,
|
| 113 |
)
|
| 114 |
self.vit.init_weights()
|
|
|
|
| 115 |
|
| 116 |
def _init_weights(self, module: nn.Module) -> None:
|
| 117 |
# Signature required by PreTrainedModel; initialization is delegated to DinoVisionTransformer.
|
|
|
|
| 112 |
mask_k_bias=config.mask_k_bias,
|
| 113 |
)
|
| 114 |
self.vit.init_weights()
|
| 115 |
+
self.post_init()
|
| 116 |
|
| 117 |
def _init_weights(self, module: nn.Module) -> None:
|
| 118 |
# Signature required by PreTrainedModel; initialization is delegated to DinoVisionTransformer.
|