Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,975 Bytes
648df8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
def build_pos_embeds(
pos_emb: bool, num_input_tokens: int, vision_hidden_size: int
):
# pos emb
if pos_emb:
pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
else:
pos_emb = None
return pos_emb
def build_prenorm(prenorm, encoder_hidden_size):
if prenorm:
prenorm = LayerNorm(encoder_hidden_size)
else:
prenorm = None
return prenorm
def build_mlp(depth, hidden_size, output_hidden_size):
layers = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
layers.append(nn.SiLU())
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
return nn.Sequential(*layers)
class CAbstractor(nn.Module):
"""Base projector class"""
def __init__(
self,
num_input_tokens: int,
encoder_hidden_size: int,
output_hidden_size: int,
hidden_size: int = 1024,
depth: int = 3,
mlp_depth: int = 2,
num_queries: int = 144,
pos_emb: bool = True,
prenorm: bool = False
):
super().__init__()
self.num_input_tokens = num_input_tokens
self.encoder_hidden_size = encoder_hidden_size
self.output_hidden_size = output_hidden_size
self.mlp_depth = mlp_depth
self.depth = depth
self.num_queries = num_queries
self.hidden_size = hidden_size
# pos emb
self.pos_emb = build_pos_embeds(pos_emb, num_input_tokens, encoder_hidden_size)
self.prenorm = build_prenorm(prenorm, encoder_hidden_size)
self.build_net()
def build_net(self):
encoder_hidden_size = self.encoder_hidden_size
hidden_size = self.hidden_size
output_hidden_size = self.output_hidden_size
depth = self.depth
mlp_depth = self.mlp_depth
n_queries = self.num_queries
assert (n_queries ** 0.5).is_integer(), "n_queries must be square number"
hw = int(n_queries ** 0.5)
# RegBlock = ResBlock + SE
RegBlock = partial(
RegStage,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
s1 = RegBlock(
depth,
encoder_hidden_size,
hidden_size,
)
sampler = nn.AdaptiveAvgPool2d((hw, hw))
s2 = RegBlock(
depth,
hidden_size,
hidden_size,
)
self.net = nn.Sequential(s1, sampler, s2)
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
def _forward(self, x):
# x: [B, L, dim]
# x = x[:, 1:] # drop cls token and 2d forward @Kyusong, If we output CLS token from vision tower, u can use this
hw = int(x.size(1) ** 0.5)
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
x = self.net(x)
x = rearrange(x, "b d h w -> b (h w) d")
x = self.readout(x)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token.
"""
if self.prenorm is not None:
x = self.prenorm(x)
if self.pos_emb is not None:
x += self.pos_emb
x = self._forward(x) # (B, L, output_hidden_size)
return x
if __name__ == "__main__":
B = 2 # batch size
L = 576 # number of input token
H = 1024 # hidden size
n_query = 256
output_h = 4096
x = torch.FloatTensor(B, L, H)
m = CAbstractor(L, H, output_h, num_queries=n_query)
y = m(x)
print(y.shape) # B, N_Query, output_H
|