File size: 3,975 Bytes
648df8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from functools import partial
import torch
import torch.nn as nn
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage


def build_pos_embeds(
    pos_emb: bool, num_input_tokens: int, vision_hidden_size: int
):
    # pos emb
    if pos_emb:
        pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
        nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
    else:
        pos_emb = None

    return pos_emb

def build_prenorm(prenorm, encoder_hidden_size):
    if prenorm:
        prenorm = LayerNorm(encoder_hidden_size)
    else:
        prenorm = None
    return prenorm


def build_mlp(depth, hidden_size, output_hidden_size):
    layers = [nn.Linear(hidden_size, output_hidden_size)]
    for _ in range(1, depth):
        layers.append(nn.SiLU())
        layers.append(nn.Linear(output_hidden_size, output_hidden_size))
    return nn.Sequential(*layers)


class CAbstractor(nn.Module):
    """Base projector class"""

    def __init__(
        self,
        num_input_tokens: int,
        encoder_hidden_size: int,
        output_hidden_size: int,
        hidden_size: int = 1024,
        depth: int = 3,
        mlp_depth: int = 2,
        num_queries: int = 144,
        pos_emb: bool = True,
        prenorm: bool = False
    ):
        super().__init__()
        self.num_input_tokens = num_input_tokens
        self.encoder_hidden_size = encoder_hidden_size
        self.output_hidden_size = output_hidden_size
        self.mlp_depth = mlp_depth
        self.depth = depth
        self.num_queries = num_queries
        self.hidden_size = hidden_size

        # pos emb
        self.pos_emb = build_pos_embeds(pos_emb, num_input_tokens, encoder_hidden_size)

        self.prenorm = build_prenorm(prenorm, encoder_hidden_size)

        self.build_net()

    def build_net(self):
        encoder_hidden_size = self.encoder_hidden_size
        hidden_size = self.hidden_size
        output_hidden_size = self.output_hidden_size
        depth = self.depth
        mlp_depth = self.mlp_depth
        n_queries = self.num_queries

        assert (n_queries ** 0.5).is_integer(), "n_queries must be square number"
        hw = int(n_queries ** 0.5)

        # RegBlock = ResBlock + SE
        RegBlock = partial(
            RegStage,
            stride=1,
            dilation=1,
            act_layer=nn.SiLU,
            norm_layer=LayerNorm2d,
        )

        s1 = RegBlock(
            depth,
            encoder_hidden_size,
            hidden_size,
        )
        sampler = nn.AdaptiveAvgPool2d((hw, hw))
        s2 = RegBlock(
            depth,
            hidden_size,
            hidden_size,
        )

        self.net = nn.Sequential(s1, sampler, s2)
        self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)

    def _forward(self, x):
        # x: [B, L, dim]
        # x = x[:, 1:]  # drop cls token and 2d forward  @Kyusong, If we output CLS token from vision tower, u can use this
        hw = int(x.size(1) ** 0.5)
        x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
        x = self.net(x)
        x = rearrange(x, "b d h w -> b (h w) d")
        x = self.readout(x)

        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token.
        """
        if self.prenorm is not None:
            x = self.prenorm(x)

        if self.pos_emb is not None:
            x += self.pos_emb

        x = self._forward(x)  # (B, L, output_hidden_size)

        return x



if __name__ == "__main__":
    B = 2 # batch size
    L = 576 # number of input token
    H = 1024 # hidden size

    n_query = 256
    output_h = 4096

    x = torch.FloatTensor(B, L, H)
    m = CAbstractor(L, H, output_h, num_queries=n_query)
    y = m(x)
    print(y.shape) # B, N_Query, output_H