PengLiu commited on
Commit
648df8c
·
1 Parent(s): c77ed95
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +1 -3
  3. demo/gradio_demo.py +374 -0
  4. demo/gradio_demo_with_sam3.py +323 -0
  5. demo/sam3_examples/init.py +0 -0
  6. detect_tools/sam3/.gitignore +153 -0
  7. detect_tools/sam3/CODE_OF_CONDUCT.md +80 -0
  8. detect_tools/sam3/CONTRIBUTING.md +30 -0
  9. detect_tools/sam3/LICENSE +61 -0
  10. detect_tools/sam3/MANIFEST.in +6 -0
  11. detect_tools/sam3/README.md +387 -0
  12. detect_tools/sam3/README_TRAIN.md +190 -0
  13. detect_tools/sam3/assets/init.py +0 -0
  14. detect_tools/sam3/pyproject.toml +131 -0
  15. detect_tools/sam3/sam3/__init__.py +7 -0
  16. detect_tools/sam3/sam3/logger.py +54 -0
  17. detect_tools/sam3/sam3/model/__init__.py +1 -0
  18. detect_tools/sam3/sam3/model/act_ckpt_utils.py +114 -0
  19. detect_tools/sam3/sam3/model/box_ops.py +217 -0
  20. detect_tools/sam3/sam3/model/data_misc.py +209 -0
  21. detect_tools/sam3/sam3/model/decoder.py +956 -0
  22. detect_tools/sam3/sam3/model/edt.py +173 -0
  23. detect_tools/sam3/sam3/model/encoder.py +594 -0
  24. detect_tools/sam3/sam3/model/geometry_encoders.py +850 -0
  25. detect_tools/sam3/sam3/model/io_utils.py +709 -0
  26. detect_tools/sam3/sam3/model/maskformer_segmentation.py +323 -0
  27. detect_tools/sam3/sam3/model/memory.py +201 -0
  28. detect_tools/sam3/sam3/model/model_misc.py +428 -0
  29. detect_tools/sam3/sam3/model/necks.py +125 -0
  30. detect_tools/sam3/sam3/model/position_encoding.py +124 -0
  31. detect_tools/sam3/sam3/model/sam1_task_predictor.py +458 -0
  32. detect_tools/sam3/sam3/model/sam3_image.py +883 -0
  33. detect_tools/sam3/sam3/model/sam3_image_processor.py +222 -0
  34. detect_tools/sam3/sam3/model/sam3_tracker_base.py +1188 -0
  35. detect_tools/sam3/sam3/model/sam3_tracker_utils.py +427 -0
  36. detect_tools/sam3/sam3/model/sam3_tracking_predictor.py +1370 -0
  37. detect_tools/sam3/sam3/model/sam3_video_base.py +1767 -0
  38. detect_tools/sam3/sam3/model/sam3_video_inference.py +1709 -0
  39. detect_tools/sam3/sam3/model/sam3_video_predictor.py +521 -0
  40. detect_tools/sam3/sam3/model/text_encoder_ve.py +328 -0
  41. detect_tools/sam3/sam3/model/tokenizer_ve.py +253 -0
  42. detect_tools/sam3/sam3/model/utils/__init__.py +5 -0
  43. detect_tools/sam3/sam3/model/utils/misc.py +77 -0
  44. detect_tools/sam3/sam3/model/utils/sam1_utils.py +119 -0
  45. detect_tools/sam3/sam3/model/utils/sam2_utils.py +233 -0
  46. detect_tools/sam3/sam3/model/vitdet.py +879 -0
  47. detect_tools/sam3/sam3/model/vl_combiner.py +176 -0
  48. detect_tools/sam3/sam3/model_builder.py +793 -0
  49. detect_tools/sam3/sam3/perflib/__init__.py +8 -0
  50. detect_tools/sam3/sam3/perflib/associate_det_trk.py +137 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.so filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,10 +5,8 @@ colorFrom: pink
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Complex text label dection using SAM3 with VLM-FO1
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 5.49.1
8
+ app_file: demo/gradio_demo_with_sam3.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: Complex text label dection using SAM3 with VLM-FO1
12
  ---
 
 
demo/gradio_demo.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import re
4
+ import numpy as np
5
+ from skimage.measure import label, regionprops
6
+ from skimage.morphology import binary_dilation, disk
7
+ from detect_tools.upn import UPNWrapper
8
+ from vlm_fo1.model.builder import load_pretrained_model
9
+ from vlm_fo1.mm_utils import (
10
+ prepare_inputs,
11
+ extract_predictions_to_indexes,
12
+ )
13
+ from vlm_fo1.task_templates import *
14
+ import torch
15
+ import os
16
+ from copy import deepcopy
17
+
18
+
19
+ TASK_TYPES = {
20
+ "OD/REC": OD_template,
21
+ "ODCounting": OD_Counting_template,
22
+ "Region_OCR": "Please provide the ocr results of these regions in the image.",
23
+ "Brief_Region_Caption": "Provide a brief description for these regions in the image.",
24
+ "Detailed_Region_Caption": "Provide a detailed description for these regions in the image.",
25
+ "Viusal_Region_Reasoning": Viusal_Region_Reasoning_template,
26
+ "OD_All": OD_All_template,
27
+ "Grounding": Grounding_template,
28
+ }
29
+
30
+ EXAMPLES = [
31
+ ["demo_image.jpg", TASK_TYPES["OD/REC"].format("orange, apple"), "OD/REC"],
32
+ ["demo_image_01.jpg", TASK_TYPES["ODCounting"].format("airplane with only one propeller"), "ODCounting"],
33
+ ["demo_image_02.jpg", TASK_TYPES["OD/REC"].format("the ball closest to the bear"), "OD/REC"],
34
+ ["demo_image_03.jpg", TASK_TYPES["OD_All"].format(""), "OD_All"],
35
+ ["demo_image_03.jpg", TASK_TYPES["Viusal_Region_Reasoning"].format("What's the brand of this computer?"), "Viusal_Region_Reasoning"],
36
+ ]
37
+
38
+
39
+ def get_valid_examples():
40
+ valid_examples = []
41
+ demo_dir = os.path.dirname(os.path.abspath(__file__))
42
+ for example in EXAMPLES:
43
+ img_path = example[0]
44
+ full_path = os.path.join(demo_dir, img_path)
45
+ if os.path.exists(full_path):
46
+ valid_examples.append([
47
+ full_path,
48
+ example[1],
49
+ example[2]
50
+ ])
51
+ elif os.path.exists(img_path):
52
+ valid_examples.append([
53
+ img_path,
54
+ example[1],
55
+ example[2]
56
+ ])
57
+ return valid_examples
58
+
59
+
60
+ def detect_model(image, threshold=0.3):
61
+ proposals = upn_model.inference(image)
62
+ filtered_proposals = upn_model.filter(proposals, min_score=threshold)
63
+ return filtered_proposals['original_xyxy_boxes'][0][:100]
64
+
65
+
66
+ def multimodal_model(image, bboxes, text):
67
+ if '<image>' in text:
68
+ print(text)
69
+ parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()]
70
+ print(parts)
71
+ content = []
72
+ for part in parts:
73
+ if part == '<image>':
74
+ content.append({"type": "image_url", "image_url": {"url": image}})
75
+ else:
76
+ content.append({"type": "text", "text": part})
77
+ else:
78
+ content = [{
79
+ "type": "image_url",
80
+ "image_url": {
81
+ "url": image
82
+ }
83
+ }, {
84
+ "type": "text",
85
+ "text": text
86
+ }]
87
+
88
+ messages = [
89
+ {
90
+ "role": "user",
91
+ "content": content,
92
+ "bbox_list": bboxes
93
+ }
94
+ ]
95
+ generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages,
96
+ max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False)
97
+ with torch.inference_mode():
98
+ output_ids = model.generate(**generation_kwargs)
99
+ outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
100
+ print("========output========\n", outputs)
101
+
102
+ if '<ground>' in outputs:
103
+ prediction_dict = extract_predictions_to_indexes(outputs)
104
+ else:
105
+ match_pattern = r"<region(\d+)>"
106
+ matches = re.findall(match_pattern, outputs)
107
+ prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
108
+
109
+ ans_bbox_json = []
110
+ ans_bbox_list = []
111
+ for k, v in prediction_dict.items():
112
+ for box_index in v:
113
+ box_index = int(box_index)
114
+ if box_index < len(bboxes):
115
+ current_bbox = bboxes[box_index]
116
+ ans_bbox_json.append({
117
+ "region_index": f"<region{box_index}>",
118
+ "xmin": current_bbox[0],
119
+ "ymin": current_bbox[1],
120
+ "xmax": current_bbox[2],
121
+ "ymax": current_bbox[3],
122
+ "label": k
123
+ })
124
+ ans_bbox_list.append(current_bbox)
125
+
126
+ return outputs, ans_bbox_json, ans_bbox_list
127
+
128
+
129
+ def draw_bboxes(image, bboxes, labels=None):
130
+ image = image.copy()
131
+ draw = ImageDraw.Draw(image)
132
+
133
+ for bbox in bboxes:
134
+ draw.rectangle(bbox, outline="red", width=3)
135
+ return image
136
+
137
+
138
+ def extract_bbox_and_original_image(edited_image):
139
+ """Extract original image and bounding boxes from ImageEditor output"""
140
+ if edited_image is None:
141
+ return None, []
142
+
143
+ if isinstance(edited_image, dict):
144
+ original_image = edited_image.get("background")
145
+ bbox_list = []
146
+
147
+ if original_image is None:
148
+ return None, []
149
+
150
+ if edited_image.get("layers") is None or len(edited_image.get("layers", [])) == 0:
151
+ return original_image, []
152
+
153
+ try:
154
+ drawing_layer = edited_image["layers"][0]
155
+ alpha_channel = drawing_layer.getchannel('A')
156
+ alpha_np = np.array(alpha_channel)
157
+
158
+ binary_mask = alpha_np > 0
159
+
160
+ structuring_element = disk(5)
161
+ dilated_mask = binary_dilation(binary_mask, structuring_element)
162
+
163
+ labeled_image = label(dilated_mask)
164
+ regions = regionprops(labeled_image)
165
+
166
+ for prop in regions:
167
+ y_min, x_min, y_max, x_max = prop.bbox
168
+ bbox_list.append((x_min, y_min, x_max, y_max))
169
+ except Exception as e:
170
+ print(f"Error extracting bboxes from layers: {e}")
171
+ return original_image, []
172
+
173
+ return original_image, bbox_list
174
+ elif isinstance(edited_image, Image.Image):
175
+ return edited_image, []
176
+ else:
177
+ print(f"Unknown input type: {type(edited_image)}")
178
+ return None, []
179
+
180
+
181
+ def process(image, example_image, prompt, threshold):
182
+ image, bbox_list = extract_bbox_and_original_image(image)
183
+
184
+ if example_image is not None:
185
+ image = example_image
186
+
187
+ if image is None:
188
+ error_msg = "Error: Please upload an image or select a valid example."
189
+ print(f"Error: image is None, original input type: {type(image)}")
190
+ return None, None, error_msg, []
191
+
192
+ try:
193
+ image = image.convert('RGB')
194
+ except Exception as e:
195
+ error_msg = f"Error: Cannot process image - {str(e)}"
196
+ return None, None, error_msg, []
197
+
198
+ if len(bbox_list) == 0:
199
+ bboxes = detect_model(image, threshold)
200
+ else:
201
+ bboxes = bbox_list
202
+ for idx in range(len(bboxes)):
203
+ prompt += f'<region{idx}>'
204
+
205
+ ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt)
206
+
207
+ image_with_detection = draw_bboxes(image, bboxes)
208
+
209
+ annotated_bboxes = []
210
+ if len(ans_bbox_json) > 0:
211
+ for item in ans_bbox_json:
212
+ annotated_bboxes.append(
213
+ ((int(item['xmin']), int(item['ymin']), int(item['xmax']), int(item['ymax'])), item['label'])
214
+ )
215
+ annotated_image = (image, annotated_bboxes)
216
+
217
+ return annotated_image, image_with_detection, ans, ans_bbox_json
218
+
219
+
220
+ def update_btn(is_processing):
221
+ if is_processing:
222
+ return gr.update(value="Processing...", interactive=False)
223
+ else:
224
+ return gr.update(value="Submit", interactive=True)
225
+
226
+
227
+ def launch_demo():
228
+ with gr.Blocks() as demo:
229
+ gr.Markdown("# 🚀 VLM-FO1 Demo")
230
+ gr.Markdown("""
231
+ ### 📋 Instructions
232
+
233
+ **Step 1: Prepare Your Image**
234
+ - Upload an image using the image editor below
235
+ - *Optional:* Draw circular regions with the red brush to specify areas of interest
236
+ - *Alternative:* If not drawing regions, the detection model will automatically identify regions
237
+
238
+ **Step 2: Configure Your Task**
239
+ - Select a task template from the dropdown menu
240
+ - Replace `[WRITE YOUR INPUT HERE]` with your target objects or query
241
+ - *Example:* For detecting "person" and "dog", replace with: `person, dog`
242
+ - *Or:* Write your own custom prompt
243
+
244
+ **Step 3: Fine-tune Detection** *(Optional)*
245
+ - Adjust the detection threshold slider to control sensitivity
246
+
247
+ **Step 4: Generate Results**
248
+ - Click the **Submit** button to process your request
249
+ - View the detection results and model outputs below
250
+
251
+ 🔗 [GitHub Repository](https://github.com/om-ai-lab/VLM-FO1)
252
+ """)
253
+
254
+ with gr.Row():
255
+ with gr.Column():
256
+ img_input_draw = gr.ImageEditor(
257
+ label="Image Input",
258
+ image_mode="RGBA",
259
+ type="pil",
260
+ sources=['upload'],
261
+ brush=gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2),
262
+ interactive=True
263
+ )
264
+
265
+ gr.Markdown("### Prompt & Parameters")
266
+
267
+ def set_prompt_from_template(selected_task):
268
+ return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]"))
269
+
270
+ def load_example(prompt_input, task_type_input, hidden_image_box):
271
+ cached_image = deepcopy(hidden_image_box)
272
+ w, h = cached_image.size
273
+
274
+ transparent_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0))
275
+
276
+ new_editor_value = {
277
+ "background": cached_image,
278
+ "layers": [transparent_layer],
279
+ "composite": None
280
+ }
281
+
282
+ return new_editor_value, prompt_input, task_type_input
283
+
284
+ def reset_hidden_image_box():
285
+ return gr.update(value=None)
286
+
287
+ task_type_input = gr.Dropdown(
288
+ choices=list(TASK_TYPES.keys()),
289
+ value="OD/REC",
290
+ label="Prompt Templates",
291
+ info="Select the prompt template for the task, or write your own prompt."
292
+ )
293
+
294
+ prompt_input = gr.Textbox(
295
+ label="Task Prompt",
296
+ value=TASK_TYPES["OD/REC"].format("[WRITE YOUR INPUT HERE]"),
297
+ lines=2,
298
+ )
299
+
300
+ task_type_input.select(
301
+ set_prompt_from_template,
302
+ inputs=task_type_input,
303
+ outputs=prompt_input
304
+ )
305
+
306
+ hidden_image_box = gr.Image(label="Image", type="pil", image_mode="RGBA", visible=False)
307
+
308
+ threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold")
309
+ submit_btn = gr.Button("Submit", variant="primary")
310
+
311
+ valid_examples = get_valid_examples()
312
+ if len(valid_examples) > 0:
313
+ gr.Markdown("### Examples")
314
+ gr.Markdown("Click on the examples below to quickly load images and corresponding prompts:")
315
+
316
+ examples_data = [[example[0], example[1], example[2]] for index, example in enumerate(valid_examples)]
317
+
318
+ examples = gr.Examples(
319
+ examples=examples_data,
320
+ inputs=[hidden_image_box, prompt_input, task_type_input],
321
+ label="Click to load example",
322
+ examples_per_page=5
323
+ )
324
+
325
+ examples.load_input_event.then(
326
+ fn=load_example,
327
+ inputs=[prompt_input, task_type_input, hidden_image_box],
328
+ outputs=[img_input_draw, prompt_input, task_type_input]
329
+ )
330
+
331
+ img_input_draw.upload(
332
+ fn=reset_hidden_image_box,
333
+ outputs=[hidden_image_box]
334
+ )
335
+
336
+ with gr.Column():
337
+ with gr.Accordion("Detection Result", open=True):
338
+ image_with_detection = gr.Image(label="Detection Result", height=200)
339
+
340
+ image_output = gr.AnnotatedImage(label="VLM-FO1 Result", height=400)
341
+
342
+ result_output = gr.Textbox(label="VLM-FO1 Output", lines=5)
343
+ ans_bbox_json = gr.JSON(label="Extracted Detection Output")
344
+
345
+ submit_btn.click(
346
+ update_btn,
347
+ inputs=[gr.State(True)],
348
+ outputs=[submit_btn],
349
+ queue=False
350
+ ).then(
351
+ process,
352
+ inputs=[img_input_draw, hidden_image_box, prompt_input, threshold_input],
353
+ outputs=[image_output, image_with_detection, result_output, ans_bbox_json],
354
+ queue=True
355
+ ).then(
356
+ update_btn,
357
+ inputs=[gr.State(False)],
358
+ outputs=[submit_btn],
359
+ queue=False
360
+ )
361
+
362
+ return demo
363
+
364
+ if __name__ == "__main__":
365
+ model_path = './resources/VLM-FO1_Qwen2.5-VL-3B-v01'
366
+ upn_ckpt_path = "./resources/upn_large.pth"
367
+ tokenizer, model, image_processors = load_pretrained_model(
368
+ model_path=model_path,
369
+ device="cuda:0",
370
+ )
371
+ upn_model = UPNWrapper(upn_ckpt_path)
372
+
373
+ demo = launch_demo()
374
+ demo.launch(server_name="0.0.0.0", share=False, server_port=8000, debug=False)
demo/gradio_demo_with_sam3.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ import re
5
+ import numpy as np
6
+ from skimage.measure import label, regionprops
7
+ from skimage.morphology import binary_dilation, disk
8
+ from sam3.model_builder import build_sam3_image_model
9
+ from sam3.model.sam3_image_processor import Sam3Processor
10
+ from sam3.visualization_utils import plot_bbox, plot_mask, COLORS
11
+ import matplotlib.pyplot as plt
12
+
13
+ from vlm_fo1.model.builder import load_pretrained_model
14
+ from vlm_fo1.mm_utils import (
15
+ prepare_inputs,
16
+ extract_predictions_to_indexes,
17
+ )
18
+ from vlm_fo1.task_templates import *
19
+ import torch
20
+ import os
21
+ from copy import deepcopy
22
+
23
+
24
+ EXAMPLES = [
25
+ ["demo/sam3_examples/00000-72.jpg","airplane with letter AE on its body"],
26
+ ["demo/sam3_examples/00000-32.jpg","the lying cat which is not black"],
27
+ ["demo/sam3_examples/00000-22.jpg","person wearing a black top"],
28
+ ["demo/sam3_examples/000000378453.jpg", "zebra inside the mud puddle"],
29
+ ["demo/sam3_examples/00000-242.jpg", "person who is holding a book"],
30
+ ]
31
+
32
+
33
+ def get_valid_examples():
34
+ valid_examples = []
35
+ demo_dir = os.path.dirname(os.path.abspath(__file__))
36
+ for example in EXAMPLES:
37
+ img_path = example[0]
38
+ full_path = os.path.join(demo_dir, img_path)
39
+ if os.path.exists(full_path):
40
+ valid_examples.append([
41
+ full_path,
42
+ example[1],
43
+ example[2]
44
+ ])
45
+ elif os.path.exists(img_path):
46
+ valid_examples.append([
47
+ img_path,
48
+ example[1],
49
+ example[2]
50
+ ])
51
+ return valid_examples
52
+
53
+
54
+ def detect_model(image, text, threshold=0.3):
55
+ inference_state = sam3_processor.set_image(image)
56
+ output = sam3_processor.set_text_prompt(state=inference_state, prompt=text)
57
+ boxes, scores, masks = output["boxes"], output["scores"], output["masks"]
58
+ sorted_indices = torch.argsort(scores, descending=True)
59
+ boxes = boxes[sorted_indices][:100, :]
60
+ scores = scores[sorted_indices][:100]
61
+ masks = masks[sorted_indices][:100]
62
+ # If the highest confidence score is greater than 0.5, filter with 0.3 threshold
63
+ if len(scores) > 0 and scores[0] > 0.75:
64
+ conf_threshold = 0.3
65
+
66
+ else:
67
+ conf_threshold = 0.05
68
+ mask = scores > conf_threshold
69
+ boxes = boxes[mask]
70
+ scores = scores[mask]
71
+ masks = masks[mask]
72
+ # Keep boxes with score > 0.8 in a separate list
73
+ high_conf_mask = scores > 0.8
74
+ high_conf_boxes = boxes[high_conf_mask]
75
+
76
+ print("========boxes========\n", boxes.tolist())
77
+ print("========scores========\n", scores.tolist())
78
+ print("========high_conf_boxes (>0.8)========\n", high_conf_boxes.tolist())
79
+
80
+ output = {
81
+ "boxes": boxes,
82
+ "scores": scores,
83
+ "masks": masks,
84
+ }
85
+ return boxes.tolist(), scores.tolist(), high_conf_boxes.tolist(), masks.tolist(), output
86
+
87
+
88
+ def multimodal_model(image, bboxes, scores, text):
89
+ if len(bboxes) == 0:
90
+ return None, {}, []
91
+
92
+ if '<image>' in text:
93
+ print(text)
94
+ parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()]
95
+ print(parts)
96
+ content = []
97
+ for part in parts:
98
+ if part == '<image>':
99
+ content.append({"type": "image_url", "image_url": {"url": image}})
100
+ else:
101
+ content.append({"type": "text", "text": part})
102
+ else:
103
+ content = [{
104
+ "type": "image_url",
105
+ "image_url": {
106
+ "url": image
107
+ }
108
+ }, {
109
+ "type": "text",
110
+ "text": text
111
+ }]
112
+
113
+ messages = [
114
+ {
115
+ "role": "user",
116
+ "content": content,
117
+ "bbox_list": bboxes
118
+ }
119
+ ]
120
+ generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages,
121
+ max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False, image_size=1024)
122
+ with torch.inference_mode():
123
+ output_ids = model.generate(**generation_kwargs)
124
+ outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
125
+ print("========output========\n", outputs)
126
+
127
+ if '<ground>' in outputs:
128
+ prediction_dict = extract_predictions_to_indexes(outputs)
129
+ else:
130
+ match_pattern = r"<region(\d+)>"
131
+ matches = re.findall(match_pattern, outputs)
132
+ prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
133
+
134
+ ans_bbox_json = []
135
+ ans_bbox_list = []
136
+ for k, v in prediction_dict.items():
137
+ for box_index in v:
138
+ box_index = int(box_index)
139
+ if box_index < len(bboxes):
140
+ current_bbox = bboxes[box_index]
141
+ current_score = scores[box_index]
142
+ ans_bbox_json.append({
143
+ "region_index": f"<region{box_index}>",
144
+ "xmin": current_bbox[0],
145
+ "ymin": current_bbox[1],
146
+ "xmax": current_bbox[2],
147
+ "ymax": current_bbox[3],
148
+ "label": k,
149
+ "score": current_score
150
+ })
151
+ ans_bbox_list.append(current_bbox)
152
+
153
+ return outputs, ans_bbox_json, ans_bbox_list
154
+
155
+
156
+ def draw_bboxes(img, results):
157
+ fig, ax = plt.subplots(figsize=(12, 8))
158
+ # fig.subplots_adjust(0, 0, 1, 1)
159
+ ax.imshow(img)
160
+ nb_objects = len(results["scores"])
161
+ print(f"found {nb_objects} object(s)")
162
+ for i in range(nb_objects):
163
+ color = COLORS[i % len(COLORS)]
164
+ plot_mask(results["masks"][i].squeeze(0).cpu(), color=color)
165
+ w, h = img.size
166
+ prob = results["scores"][i].item()
167
+ plot_bbox(
168
+ h,
169
+ w,
170
+ results["boxes"][i].cpu(),
171
+ text=f"(id={i}, {prob=:.2f})",
172
+ box_format="XYXY",
173
+ color=color,
174
+ relative_coords=False,
175
+ )
176
+ ax.axis("off")
177
+ fig.tight_layout(pad=0)
178
+
179
+ # Convert matplotlib figure to PIL Image
180
+ fig.canvas.draw()
181
+ buf = fig.canvas.buffer_rgba()
182
+ pil_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf)
183
+ plt.close(fig)
184
+
185
+ return pil_img
186
+
187
+
188
+ @spaces.GPU
189
+ def process(image, prompt, threshold=0):
190
+ if image is None:
191
+ error_msg = "Error: Please upload an image or select a valid example."
192
+ print(f"Error: image is None, original input type: {type(image)}")
193
+ return None, None, error_msg, []
194
+
195
+ try:
196
+ image = image.convert('RGB')
197
+ except Exception as e:
198
+ error_msg = f"Error: Cannot process image - {str(e)}"
199
+ return None, None, error_msg, []
200
+
201
+ bboxes, scores, high_conf_bboxes, masks, output = detect_model(image, prompt, threshold)
202
+
203
+ fo1_prompt = OD_Counting_template.format(prompt)
204
+ ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, scores, fo1_prompt)
205
+
206
+ detection_image = draw_bboxes(image, output)
207
+
208
+ annotated_bboxes = []
209
+ if len(ans_bbox_json) > 0:
210
+ img_width, img_height = image.size
211
+ for item in ans_bbox_json:
212
+ xmin = max(0, min(img_width, int(item['xmin'])))
213
+ ymin = max(0, min(img_height, int(item['ymin'])))
214
+ xmax = max(0, min(img_width, int(item['xmax'])))
215
+ ymax = max(0, min(img_height, int(item['ymax'])))
216
+ annotated_bboxes.append(
217
+ ((xmin, ymin, xmax, ymax), item['label'])
218
+ )
219
+ annotated_image = (image, annotated_bboxes)
220
+
221
+ return annotated_image, detection_image, ans_bbox_json
222
+
223
+
224
+ def update_btn(is_processing):
225
+ if is_processing:
226
+ return gr.update(value="Processing...", interactive=False)
227
+ else:
228
+ return gr.update(value="Submit", interactive=True)
229
+
230
+
231
+ def launch_demo():
232
+ with gr.Blocks() as demo:
233
+ gr.Markdown("# 🚀 VLM-FO1 + SAM3 Demo")
234
+ gr.Markdown("""
235
+ ### 📋 Instructions
236
+ Combine the SAM3 detection results with the VLM-FO1 model to enchance its dectection and segmentation performance on complex label tasks.
237
+
238
+ **How it works**
239
+ 1. Upload or pick an example image.
240
+ 2. Describe the target object in natural language.
241
+ 3. Hit **Submit** to run SAM3 + VLM-FO1.
242
+
243
+ **Outputs**
244
+ - `SAM3 Result`: raw detections with masks/bboxes generated by SAM3.
245
+ - `VLM-FO1 Result`: filtered detections plus labels generated by VLM-FO1.
246
+
247
+ **Tips**
248
+ - One prompt at a time is currently supported. Multiple label prompts will be supported soon.
249
+ - Use the examples below to quickly explore the pipeline.
250
+ """)
251
+
252
+ gr.Markdown("""
253
+ ### 🔗 References
254
+ - [SAM3](https://github.com/facebookresearch/sam3)
255
+ - [VLM-FO1](https://github.com/om-ai-lab/VLM-FO1)
256
+ """)
257
+
258
+ with gr.Row():
259
+ with gr.Column():
260
+ img_input_draw = gr.Image(
261
+ label="Image Input",
262
+ type="pil",
263
+ sources=['upload'],
264
+ )
265
+
266
+ gr.Markdown("### Prompt")
267
+
268
+ prompt_input = gr.Textbox(
269
+ label="Label Prompt",
270
+ lines=2,
271
+ )
272
+
273
+ submit_btn = gr.Button("Submit", variant="primary")
274
+
275
+
276
+ examples = gr.Examples(
277
+ examples=EXAMPLES,
278
+ inputs=[img_input_draw, prompt_input],
279
+ label="Click to load example",
280
+ examples_per_page=5
281
+ )
282
+
283
+ with gr.Column():
284
+ with gr.Accordion("SAM3 Result", open=True):
285
+ image_output_detection = gr.Image(label="SAM3 Result", height=400)
286
+
287
+ image_output = gr.AnnotatedImage(label="VLM-FO1 Result", height=400)
288
+
289
+ ans_bbox_json = gr.JSON(label="Extracted Detection Output")
290
+
291
+ submit_btn.click(
292
+ update_btn,
293
+ inputs=[gr.State(True)],
294
+ outputs=[submit_btn],
295
+ queue=False
296
+ ).then(
297
+ process,
298
+ inputs=[img_input_draw, prompt_input],
299
+ outputs=[image_output, image_output_detection, ans_bbox_json],
300
+ queue=True
301
+ ).then(
302
+ update_btn,
303
+ inputs=[gr.State(False)],
304
+ outputs=[submit_btn],
305
+ queue=False
306
+ )
307
+
308
+ return demo
309
+
310
+ if __name__ == "__main__":
311
+ # model_path = './resources/VLM-FO1_Qwen2.5-VL-3B-v01'
312
+ # sam3_model_path = './resources/sam3/sam3.pt'
313
+
314
+ model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01'
315
+ tokenizer, model, image_processors = load_pretrained_model(
316
+ model_path=model_path,
317
+ device="cuda:0",
318
+ )
319
+ sam3_model = build_sam3_image_model(device="cuda:0")
320
+ sam3_processor = Sam3Processor(sam3_model, confidence_threshold=0.0, device="cuda:0")
321
+
322
+ demo = launch_demo()
323
+ demo.launch()
demo/sam3_examples/init.py ADDED
File without changes
detect_tools/sam3/.gitignore ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .nox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Django stuff:
56
+ *.log
57
+ local_settings.py
58
+ db.sqlite3
59
+
60
+ # Flask stuff:
61
+ instance/
62
+ .webassets-cache
63
+
64
+ # Scrapy stuff:
65
+ .scrapy
66
+
67
+ # Sphinx documentation
68
+ docs/_build/
69
+
70
+ # PyBuilder
71
+ target/
72
+
73
+ # Jupyter Notebook
74
+ .ipynb_checkpoints
75
+ *-Copy*.ipynb
76
+
77
+ # IPython
78
+ profile_default/
79
+ ipython_config.py
80
+
81
+ # pyenv
82
+ .python-version
83
+
84
+ # celery beat schedule file
85
+ celerybeat-schedule
86
+
87
+ # SageMath parsed files
88
+ *.sage.py
89
+
90
+ # Environments
91
+ .env
92
+ .venv
93
+ env/
94
+ venv/
95
+ ENV/
96
+ env.bak/
97
+ venv.bak/
98
+
99
+ # Spyder project settings
100
+ .spyderproject
101
+ .spyproject
102
+
103
+ # Rope project settings
104
+ .ropeproject
105
+
106
+ # mkdocs documentation
107
+ /site
108
+
109
+ # mypy
110
+ .mypy_cache/
111
+ .dmypy.json
112
+ dmypy.json
113
+
114
+ # Pyre type checker
115
+ .pyre/
116
+
117
+ # PyCharm
118
+ .idea/
119
+
120
+ # VS Code
121
+ .vscode/
122
+ *.code-workspace
123
+
124
+ # Model weights and checkpoints
125
+ *.pth
126
+ *.pt
127
+ *.bin
128
+ *.ckpt
129
+ *.safetensors
130
+ weights/
131
+ checkpoints/
132
+ sam3_logs/
133
+
134
+ # Data files
135
+ *.h5
136
+ *.hdf5
137
+ *.pkl
138
+ *.pickle
139
+ *.npy
140
+ *.npz
141
+
142
+ # Logs
143
+ logs/
144
+ runs/
145
+ tensorboard/
146
+
147
+ # OS specific
148
+ .DS_Store
149
+ Thumbs.db
150
+
151
+ # BPE vocabulary files
152
+ *.bpe
153
+ *.vocab
detect_tools/sam3/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
detect_tools/sam3/CONTRIBUTING.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to sam3
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Make sure your code lints.
12
+ 5. If you haven't already, complete the Contributor License Agreement ("CLA").
13
+
14
+ ## Contributor License Agreement ("CLA")
15
+ In order to accept your pull request, we need you to submit a CLA. You only need
16
+ to do this once to work on any of Facebook's open source projects.
17
+
18
+ Complete your CLA here: <https://code.facebook.com/cla>
19
+
20
+ ## Issues
21
+ We use GitHub issues to track public bugs. Please ensure your description is
22
+ clear and has sufficient instructions to be able to reproduce the issue.
23
+
24
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
25
+ disclosure of security bugs. In those cases, please go through the process
26
+ outlined on that page and do not file a public issue.
27
+
28
+ ## License
29
+ By contributing to sam3, you agree that your contributions will be licensed
30
+ under the LICENSE file in the root directory of this source tree.
detect_tools/sam3/LICENSE ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SAM License
2
+ Last Updated: November 19, 2025
3
+
4
+ “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
5
+
6
+
7
+ “SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
8
+
9
+ “Documentation” means the specifications, manuals and documentation accompanying
10
+ SAM Materials distributed by Meta.
11
+
12
+
13
+ “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
14
+
15
+
16
+ “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
17
+
18
+
19
+ “Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
20
+
21
+
22
+ “Trade Controls” means any of the following: Sanctions and applicable export and import controls.
23
+
24
+ By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
25
+
26
+
27
+ 1. License Rights and Redistribution.
28
+
29
+
30
+ a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
31
+
32
+ b. Redistribution and Use.
33
+ i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
34
+
35
+
36
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
37
+
38
+
39
+ iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
40
+ iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
41
+ v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
42
+ 2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
43
+
44
+
45
+ 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
46
+
47
+ 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
48
+
49
+ 5. Intellectual Property.
50
+
51
+
52
+ a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
53
+
54
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
55
+
56
+ 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
57
+
58
+ 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
59
+
60
+
61
+ 8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
detect_tools/sam3/MANIFEST.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ include LICENSE
2
+ include README.md
3
+ recursive-include examples *.py
4
+ recursive-include examples *.ipynb
5
+ recursive-include examples *.md
6
+ recursive-include tests *.py
detect_tools/sam3/README.md ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SAM 3: Segment Anything with Concepts
2
+
3
+ Meta Superintelligence Labs
4
+
5
+ [Nicolas Carion](https://www.nicolascarion.com/)\*,
6
+ [Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en)\*,
7
+ [Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en)\*,
8
+ [Shoubhik Debnath](https://scholar.google.com/citations?user=fb6FOfsAAAAJ&hl=en)\*,
9
+ [Ronghang Hu](https://ronghanghu.com/)\*,
10
+ [Didac Suris](https://www.didacsuris.com/)\*,
11
+ [Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en)\*,
12
+ [Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en)\*,
13
+ [Haitham Khedr](https://hkhedr.com/)\*, Andrew Huang,
14
+ [Jie Lei](https://jayleicn.github.io/),
15
+ [Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en),
16
+ [Baishan Guo](https://scholar.google.com/citations?user=BC5wDu8AAAAJ&hl=en),
17
+ Arpit Kalla, [Markus Marks](https://damaggu.github.io/),
18
+ [Joseph Greer](https://scholar.google.com/citations?user=guL96CkAAAAJ&hl=en),
19
+ Meng Wang, [Peize Sun](https://peizesun.github.io/),
20
+ [Roman Rädle](https://scholar.google.com/citations?user=Tpt57v0AAAAJ&hl=en),
21
+ [Triantafyllos Afouras](https://www.robots.ox.ac.uk/~afourast/),
22
+ [Effrosyni Mavroudi](https://scholar.google.com/citations?user=vYRzGGEAAAAJ&hl=en),
23
+ [Katherine Xu](https://k8xu.github.io/)°,
24
+ [Tsung-Han Wu](https://patrickthwu.com/)°,
25
+ [Yu Zhou](https://yu-bryan-zhou.github.io/)°,
26
+ [Liliane Momeni](https://scholar.google.com/citations?user=Lb-KgVYAAAAJ&hl=en)°,
27
+ [Rishi Hazra](https://rishihazra.github.io/)°,
28
+ [Shuangrui Ding](https://mark12ding.github.io/)°,
29
+ [Sagar Vaze](https://sgvaze.github.io/)°,
30
+ [Francois Porcher](https://scholar.google.com/citations?user=LgHZ8hUAAAAJ&hl=en)°,
31
+ [Feng Li](https://fengli-ust.github.io/)°,
32
+ [Siyuan Li](https://siyuanliii.github.io/)°,
33
+ [Aishwarya Kamath](https://ashkamath.github.io/)°,
34
+ [Ho Kei Cheng](https://hkchengrex.com/)°,
35
+ [Piotr Dollar](https://pdollar.github.io/)†,
36
+ [Nikhila Ravi](https://nikhilaravi.com/)†,
37
+ [Kate Saenko](https://ai.bu.edu/ksaenko.html)†,
38
+ [Pengchuan Zhang](https://pzzhang.github.io/pzzhang/)†,
39
+ [Christoph Feichtenhofer](https://feichtenhofer.github.io/)†
40
+
41
+ \* core contributor, ° intern, † project lead, order is random within groups
42
+
43
+ [[`Paper`](https://ai.meta.com/research/publications/sam-3-segment-anything-with-concepts/)]
44
+ [[`Project`](https://ai.meta.com/sam3)]
45
+ [[`Demo`](https://segment-anything.com/)]
46
+ [[`Blog`](https://ai.meta.com/blog/segment-anything-model-3/)]
47
+ <!-- [[`BibTeX`](#citing-sam-3)] -->
48
+
49
+ ![SAM 3 architecture](assets/model_diagram.png?raw=true) SAM 3 is a unified foundation model for promptable segmentation in images and videos. It can detect, segment, and track objects using text or visual prompts such as points, boxes, and masks. Compared to its predecessor [SAM 2](https://github.com/facebookresearch/sam2), SAM 3 introduces the ability to exhaustively segment all instances of an open-vocabulary concept specified by a short text phrase or exemplars. Unlike prior work, SAM 3 can handle a vastly larger set of open-vocabulary prompts. It achieves 75-80% of human performance on our new [SA-CO benchmark](https://github.com/facebookresearch/sam3/edit/main_readme/README.md#sa-co-dataset) which contains 270K unique concepts, over 50 times more than existing benchmarks.
50
+
51
+ This breakthrough is driven by an innovative data engine that has automatically annotated over 4 million unique concepts, creating the largest high-quality open-vocabulary segmentation dataset to date. In addition, SAM 3 introduces a new model architecture featuring a presence token that improves discrimination between closely related text prompts (e.g., “a player in white” vs. “a player in red”), as well as a decoupled detector–tracker design that minimizes task interference and scales efficiently with data.
52
+
53
+ <p align="center">
54
+ <img src="assets/dog.gif" width=380 />
55
+ <img src="assets/player.gif" width=380 />
56
+ </p>
57
+
58
+ ## Installation
59
+
60
+ ### Prerequisites
61
+
62
+ - Python 3.12 or higher
63
+ - PyTorch 2.7 or higher
64
+ - CUDA-compatible GPU with CUDA 12.6 or higher
65
+
66
+ 1. **Create a new Conda environment:**
67
+
68
+ ```bash
69
+ conda create -n sam3 python=3.12
70
+ conda deactivate
71
+ conda activate sam3
72
+ ```
73
+
74
+ 2. **Install PyTorch with CUDA support:**
75
+
76
+ ```bash
77
+ pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
78
+ ```
79
+
80
+ 3. **Clone the repository and install the package:**
81
+
82
+ ```bash
83
+ git clone https://github.com/facebookresearch/sam3.git
84
+ cd sam3
85
+ pip install -e .
86
+ ```
87
+
88
+ 4. **Install additional dependencies for example notebooks or development:**
89
+
90
+ ```bash
91
+ # For running example notebooks
92
+ pip install -e ".[notebooks]"
93
+
94
+ # For development
95
+ pip install -e ".[train,dev]"
96
+ ```
97
+
98
+ ## Getting Started
99
+
100
+ ⚠️ Before using SAM 3, please request access to the checkpoints on the SAM 3
101
+ Hugging Face [repo](https://huggingface.co/facebook/sam3). Once accepted, you
102
+ need to be authenticated to download the checkpoints. You can do this by running
103
+ the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
104
+ (e.g. `hf auth login` after generating an access token.)
105
+
106
+ ### Basic Usage
107
+
108
+ ```python
109
+ import torch
110
+ #################################### For Image ####################################
111
+ from PIL import Image
112
+ from sam3.model_builder import build_sam3_image_model
113
+ from sam3.model.sam3_image_processor import Sam3Processor
114
+ # Load the model
115
+ model = build_sam3_image_model()
116
+ processor = Sam3Processor(model)
117
+ # Load an image
118
+ image = Image.open("<YOUR_IMAGE_PATH.jpg>")
119
+ inference_state = processor.set_image(image)
120
+ # Prompt the model with text
121
+ output = processor.set_text_prompt(state=inference_state, prompt="<YOUR_TEXT_PROMPT>")
122
+
123
+ # Get the masks, bounding boxes, and scores
124
+ masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
125
+
126
+ #################################### For Video ####################################
127
+
128
+ from sam3.model_builder import build_sam3_video_predictor
129
+
130
+ video_predictor = build_sam3_video_predictor()
131
+ video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
132
+ # Start a session
133
+ response = video_predictor.handle_request(
134
+ request=dict(
135
+ type="start_session",
136
+ resource_path=video_path,
137
+ )
138
+ )
139
+ response = video_predictor.handle_request(
140
+ request=dict(
141
+ type="add_prompt",
142
+ session_id=response["session_id"],
143
+ frame_index=0, # Arbitrary frame index
144
+ text="<YOUR_TEXT_PROMPT>",
145
+ )
146
+ )
147
+ output = response["outputs"]
148
+ ```
149
+
150
+ ## Examples
151
+
152
+ The `examples` directory contains notebooks demonstrating how to use SAM3 with
153
+ various types of prompts:
154
+
155
+ - [`sam3_image_predictor_example.ipynb`](examples/sam3_image_predictor_example.ipynb)
156
+ : Demonstrates how to prompt SAM 3 with text and visual box prompts on images.
157
+ - [`sam3_video_predictor_example.ipynb`](examples/sam3_video_predictor_example.ipynb)
158
+ : Demonstrates how to prompt SAM 3 with text prompts on videos, and doing
159
+ further interactive refinements with points.
160
+ - [`sam3_image_batched_inference.ipynb`](examples/sam3_image_batched_inference.ipynb)
161
+ : Demonstrates how to run batched inference with SAM 3 on images.
162
+ - [`sam3_agent.ipynb`](examples/sam3_agent.ipynb): Demonsterates the use of SAM
163
+ 3 Agent to segment complex text prompt on images.
164
+ - [`saco_gold_silver_vis_example.ipynb`](examples/saco_gold_silver_vis_example.ipynb)
165
+ : Shows a few examples from SA-Co image evaluation set.
166
+ - [`saco_veval_vis_example.ipynb`](examples/saco_veval_vis_example.ipynb) :
167
+ Shows a few examples from SA-Co video evaluation set.
168
+
169
+ There are additional notebooks in the examples directory that demonstrate how to
170
+ use SAM 3 for interactive instance segmentation in images and videos (SAM 1/2
171
+ tasks), or as a tool for an MLLM, and how to run evaluations on the SA-Co
172
+ dataset.
173
+
174
+ To run the Jupyter notebook examples:
175
+
176
+ ```bash
177
+ # Make sure you have the notebooks dependencies installed
178
+ pip install -e ".[notebooks]"
179
+
180
+ # Start Jupyter notebook
181
+ jupyter notebook examples/sam3_image_predictor_example.ipynb
182
+ ```
183
+
184
+ ## Model
185
+
186
+ SAM 3 consists of a detector and a tracker that share a vision encoder. It has 848M parameters. The
187
+ detector is a DETR-based model conditioned on text, geometry, and image
188
+ exemplars. The tracker inherits the SAM 2 transformer encoder-decoder
189
+ architecture, supporting video segmentation and interactive refinement.
190
+
191
+ ## Image Results
192
+
193
+ <div align="center">
194
+ <table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
195
+ <thead>
196
+ <tr>
197
+ <th rowspan="3" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
198
+ <th colspan="3" style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">Instance Segmentation</th>
199
+ <th colspan="5" style="text-align: center; padding: 12px 20px">Box Detection</th>
200
+ </tr>
201
+ <tr>
202
+ <th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
203
+ <th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">SA-Co/Gold</th>
204
+ <th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
205
+ <th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">COCO</th>
206
+ <th style="text-align: center; padding: 12px 20px">SA-Co/Gold</th>
207
+ </tr>
208
+ <tr>
209
+ <th style="text-align: center; padding: 12px 20px">cgF1</th>
210
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
211
+ <th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">cgF1</th>
212
+ <th style="text-align: center; padding: 12px 20px">cgF1</th>
213
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
214
+ <th style="text-align: center; padding: 12px 20px">AP</th>
215
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP<sub>o</sub>
216
+ </th>
217
+ <th style="text-align: center; padding: 12px 20px">cgF1</th>
218
+ </tr>
219
+ </thead>
220
+ <tbody>
221
+ <tr>
222
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
223
+ <td style="text-align: center; padding: 10px 20px">-</td>
224
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
225
+ <td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">72.8</td>
226
+ <td style="text-align: center; padding: 10px 20px">-</td>
227
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
228
+ <td style="text-align: center; padding: 10px 20px">-</td>
229
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
230
+ <td style="text-align: center; padding: 10px 20px">74.0</td>
231
+ </tr>
232
+ <tr>
233
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">OWLv2*</td>
234
+ <td style="text-align: center; padding: 10px 20px; color: #999">29.3</td>
235
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">43.4</td>
236
+ <td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">24.6</td>
237
+ <td style="text-align: center; padding: 10px 20px; color: #999">30.2</td>
238
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">45.5</td>
239
+ <td style="text-align: center; padding: 10px 20px">46.1</td>
240
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">23.9</td>
241
+ <td style="text-align: center; padding: 10px 20px">24.5</td>
242
+ </tr>
243
+ <tr>
244
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">DINO-X</td>
245
+ <td style="text-align: center; padding: 10px 20px">-</td>
246
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">38.5</td>
247
+ <td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">21.3</td>
248
+ <td style="text-align: center; padding: 10px 20px">-</td>
249
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">52.4</td>
250
+ <td style="text-align: center; padding: 10px 20px">56.0</td>
251
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
252
+ <td style="text-align: center; padding: 10px 20px">22.5</td>
253
+ </tr>
254
+ <tr>
255
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">Gemini 2.5</td>
256
+ <td style="text-align: center; padding: 10px 20px">13.4</td>
257
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
258
+ <td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">13.0</td>
259
+ <td style="text-align: center; padding: 10px 20px">16.1</td>
260
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
261
+ <td style="text-align: center; padding: 10px 20px">-</td>
262
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
263
+ <td style="text-align: center; padding: 10px 20px">14.4</td>
264
+ </tr>
265
+ <tr style="border-top: 2px solid #b19c9cff">
266
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
267
+ <td style="text-align: center; padding: 10px 20px">37.2</td>
268
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">48.5</td>
269
+ <td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">54.1</td>
270
+ <td style="text-align: center; padding: 10px 20px">40.6</td>
271
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">53.6</td>
272
+ <td style="text-align: center; padding: 10px 20px">56.4</td>
273
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">55.7</td>
274
+ <td style="text-align: center; padding: 10px 20px">55.7</td>
275
+ </tr>
276
+ </tbody>
277
+ </table>
278
+
279
+ <p style="text-align: center; margin-top: 10px; font-size: 0.9em; color: #ddd;">* Partially trained on LVIS, AP<sub>o</sub> refers to COCO-O accuracy</p>
280
+
281
+ </div>
282
+
283
+ ## Video Results
284
+
285
+ <div align="center">
286
+ <table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
287
+ <thead>
288
+ <tr>
289
+ <th rowspan="2" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
290
+ <th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SA-V test</th>
291
+ <th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">YT-Temporal-1B test</th>
292
+ <th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SmartGlasses test</th>
293
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVVIS test</th>
294
+ <th style="text-align: center; padding: 12px 20px">BURST test</th>
295
+ </tr>
296
+ <tr>
297
+ <th style="text-align: center; padding: 12px 20px">cgF1</th>
298
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
299
+ <th style="text-align: center; padding: 12px 20px">cgF1</th>
300
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
301
+ <th style="text-align: center; padding: 12px 20px">cgF1</th>
302
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
303
+ <th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">mAP</th>
304
+ <th style="text-align: center; padding: 12px 20px">HOTA</th>
305
+ </tr>
306
+ </thead>
307
+ <tbody>
308
+ <tr>
309
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
310
+ <td style="text-align: center; padding: 10px 20px">53.1</td>
311
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">70.5</td>
312
+ <td style="text-align: center; padding: 10px 20px">71.2</td>
313
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">78.4</td>
314
+ <td style="text-align: center; padding: 10px 20px">58.5</td>
315
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">72.3</td>
316
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
317
+ <td style="text-align: center; padding: 10px 20px">-</td>
318
+ </tr>
319
+ <tr style="border-top: 2px solid #b19c9cff">
320
+ <td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
321
+ <td style="text-align: center; padding: 10px 20px">30.3</td>
322
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">58.0</td>
323
+ <td style="text-align: center; padding: 10px 20px">50.8</td>
324
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">69.9</td>
325
+ <td style="text-align: center; padding: 10px 20px">36.4</td>
326
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">63.6</td>
327
+ <td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">36.3</td>
328
+ <td style="text-align: center; padding: 10px 20px">44.5</td>
329
+ </tr>
330
+ </tbody>
331
+ </table>
332
+ </div>
333
+
334
+ ## SA-Co Dataset
335
+
336
+ We release 2 image benchmarks, [SA-Co/Gold](scripts/eval/gold/README.md) and
337
+ [SA-Co/Silver](scripts/eval/silver/README.md), and a video benchmark
338
+ [SA-Co/VEval](scripts/eval/veval/README.md). The datasets contain images (or videos) with annotated noun phrases. Each image/video and noun phrase pair is annotated with instance masks and unique IDs of each object matching the phrase. Phrases that have no matching objects (negative prompts) have no masks, shown in red font in the figure. See the linked READMEs for more details on how to download and run evaluations on the datasets.
339
+
340
+ * HuggingFace host: [SA-Co/Gold](https://huggingface.co/datasets/facebook/SACo-Gold), [SA-Co/Silver](https://huggingface.co/datasets/facebook/SACo-Silver) and [SA-Co/VEval](https://huggingface.co/datasets/facebook/SACo-VEval)
341
+ * Roboflow host: [SA-Co/Gold](https://universe.roboflow.com/sa-co-gold), [SA-Co/Silver](https://universe.roboflow.com/sa-co-silver) and [SA-Co/VEval](https://universe.roboflow.com/sa-co-veval)
342
+
343
+ ![SA-Co dataset](assets/sa_co_dataset.jpg?raw=true)
344
+
345
+ ## Development
346
+
347
+ To set up the development environment:
348
+
349
+ ```bash
350
+ pip install -e ".[dev,train]"
351
+ ```
352
+
353
+ To format the code:
354
+
355
+ ```bash
356
+ ufmt format .
357
+ ```
358
+
359
+ ## Contributing
360
+
361
+ See [contributing](CONTRIBUTING.md) and the
362
+ [code of conduct](CODE_OF_CONDUCT.md).
363
+
364
+ ## License
365
+
366
+ This project is licensed under the SAM License - see the [LICENSE](LICENSE) file
367
+ for details.
368
+
369
+ ## Acknowledgements
370
+
371
+ We would like to thank the following people for their contributions to the SAM 3 project: Alex He, Alexander Kirillov,
372
+ Alyssa Newcomb, Ana Paula Kirschner Mofarrej, Andrea Madotto, Andrew Westbury, Ashley Gabriel, Azita Shokpour,
373
+ Ben Samples, Bernie Huang, Carleigh Wood, Ching-Feng Yeh, Christian Puhrsch, Claudette Ward, Daniel Bolya,
374
+ Daniel Li, Facundo Figueroa, Fazila Vhora, George Orlin, Hanzi Mao, Helen Klein, Hu Xu, Ida Cheng, Jake Kinney,
375
+ Jiale Zhi, Jo Sampaio, Joel Schlosser, Justin Johnson, Kai Brown, Karen Bergan, Karla Martucci, Kenny Lehmann,
376
+ Maddie Mintz, Mallika Malhotra, Matt Ward, Michelle Chan, Michelle Restrepo, Miranda Hartley, Muhammad Maaz,
377
+ Nisha Deo, Peter Park, Phillip Thomas, Raghu Nayani, Rene Martinez Doehner, Robbie Adkins, Ross Girshik, Sasha
378
+ Mitts, Shashank Jain, Spencer Whitehead, Ty Toledano, Valentin Gabeur, Vincent Cho, Vivian Lee, William Ngan,
379
+ Xuehai He, Yael Yungster, Ziqi Pang, Ziyi Dou, Zoe Quake.
380
+
381
+ <!-- ## Citing SAM 3
382
+
383
+ If you use SAM 3 or the SA-Co dataset in your research, please use the following BibTeX entry.
384
+
385
+ ```bibtex
386
+ TODO
387
+ ``` -->
detect_tools/sam3/README_TRAIN.md ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training
2
+
3
+ This repository supports finetuning SAM3 models on custom datasets in multi-node setup or local execution. The training script is located at `sam3/train.py` and uses Hydra configuration management to handle complex training setups.
4
+
5
+
6
+ ## Installation
7
+
8
+ ```bash
9
+ cd sam3
10
+ pip install -e ".[train]"
11
+ ```
12
+
13
+ ### Training Script Usage
14
+
15
+ The main training script is located at `sam3/train.py`. It uses Hydra configuration management to handle complex training setups.
16
+
17
+ #### Basic Usage
18
+
19
+ ```bash
20
+ # Example: Train on Roboflow dataset
21
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml
22
+ # Example: Train on ODinW13 dataset
23
+ python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
24
+ ```
25
+ Follow [`Roboflow 100-VL`](https://github.com/roboflow/rf100-vl/) to download the roboflow 100-vl datasets. Follow [`GLIP`](https://github.com/microsoft/GLIP) to download the ODinW datasets. The data folder should be organized as follows, and put your roboflow_vl_100_root and odinw_data_root in the job configs.
26
+ ```
27
+ roboflow_vl_100_root:
28
+ 13-lkc01
29
+ train
30
+ valid
31
+ test
32
+ 2024-frc
33
+ actions
34
+ ...
35
+ odinw_data_root:
36
+ AerialMaritimeDrone
37
+ large
38
+ train
39
+ valid
40
+ test
41
+ Aquarium
42
+ ...
43
+ ```
44
+
45
+ #### Command Line Arguments
46
+
47
+ The training script supports several command line arguments:
48
+
49
+ ```bash
50
+ python sam3/train/train.py \
51
+ -c CONFIG_NAME \
52
+ [--use-cluster 0|1] \
53
+ [--partition PARTITION_NAME] \
54
+ [--account ACCOUNT_NAME] \
55
+ [--qos QOS_NAME] \
56
+ [--num-gpus NUM_GPUS] \
57
+ [--num-nodes NUM_NODES]
58
+ ```
59
+
60
+ **Arguments:**
61
+ - `-c, --config`: **Required.** Path to the configuration file (e.g., `sam3/train/configs/roboflow_v100_full_ft_100_images.yaml`)
62
+ - `--use-cluster`: Whether to launch on a cluster (0: local, 1: cluster). Default: uses config setting
63
+ - `--partition`: SLURM partition name for cluster execution
64
+ - `--account`: SLURM account name for cluster execution
65
+ - `--qos`: SLURM QOS (Quality of Service) setting
66
+ - `--num-gpus`: Number of GPUs per node. Default: uses config setting
67
+ - `--num-nodes`: Number of nodes for distributed training. Default: uses config setting
68
+
69
+ #### Local Training Examples
70
+
71
+ ```bash
72
+ # Single GPU training
73
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 1
74
+
75
+ # Multi-GPU training on a single node
76
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 4
77
+
78
+ # Force local execution even if config specifies GPUs
79
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0
80
+ ```
81
+
82
+ #### Cluster Training Examples
83
+
84
+ ```bash
85
+ # Basic cluster training with default settings from config
86
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 1
87
+
88
+ # Cluster training with specific SLURM settings
89
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
90
+ --use-cluster 1 \
91
+ --partition gpu_partition \
92
+ --account my_account \
93
+ --qos high_priority \
94
+ --num-gpus 8 \
95
+ --num-nodes 2
96
+ ```
97
+
98
+ ### Configuration Files
99
+
100
+ Training configurations are stored in `sam3/train/configs/`. The configuration files use Hydra's YAML format and support:
101
+
102
+ - **Dataset Configuration**: Data paths, transforms, and loading parameters
103
+ - **Model Configuration**: Architecture settings, checkpoint paths, and model parameters
104
+ - **Training Configuration**: Batch sizes, learning rates, optimization settings
105
+ - **Launcher Configuration**: Distributed training and cluster settings
106
+ - **Logging Configuration**: TensorBoard, experiment tracking, and output directories
107
+
108
+ #### Key Configuration Sections
109
+
110
+ ```yaml
111
+ # Paths to datasets and checkpoints
112
+ paths:
113
+ bpe_path: /path/to/bpe/file
114
+ dataset_root: /path/to/dataset
115
+ experiment_log_dir: /path/to/logs
116
+
117
+ # Launcher settings for local/cluster execution
118
+ launcher:
119
+ num_nodes: 1
120
+ gpus_per_node: 2
121
+ experiment_log_dir: ${paths.experiment_log_dir}
122
+
123
+ # Cluster execution settings
124
+ submitit:
125
+ use_cluster: True
126
+ timeout_hour: 72
127
+ cpus_per_task: 10
128
+ partition: null
129
+ account: null
130
+ ```
131
+
132
+ ### Monitoring Training
133
+
134
+ The training script automatically sets up logging and saves outputs to the experiment directory:
135
+
136
+ ```bash
137
+ # Logs are saved to the experiment_log_dir specified in config
138
+ experiment_log_dir/
139
+ ├── config.yaml # Original configuration
140
+ ├── config_resolved.yaml # Resolved configuration with all variables expanded
141
+ ├── checkpoints/ # Model checkpoints (if skip_checkpointing=False)
142
+ ├── tensorboard/ # TensorBoard logs
143
+ ├── logs/ # Text logs
144
+ └── submitit_logs/ # Cluster job logs (if using cluster)
145
+ ```
146
+
147
+ You can monitor training progress using TensorBoard:
148
+
149
+ ```bash
150
+ tensorboard --logdir /path/to/experiment_log_dir/tensorboard
151
+ ```
152
+
153
+ ### Job Arrays for Dataset Sweeps
154
+
155
+ The Roboflow and ODinW configuration supports job arrays for training multiple models on different datasets:
156
+
157
+ This feature is specifically enabled via,
158
+ ```yaml
159
+ submitit:
160
+ job_array:
161
+ num_tasks: 100
162
+ task_index: 0
163
+ ```
164
+
165
+ The configuration includes a complete list of 100 Roboflow supercategories, and the `submitit.job_array.task_index` automatically selects which dataset to use based on the array job index.
166
+
167
+ ```bash
168
+ # Submit job array to train on different Roboflow datasets
169
+ # The job array index selects which dataset from all_roboflow_supercategories
170
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
171
+ --use-cluster 1
172
+ ```
173
+
174
+ ### Reproduce ODinW13 10-shot results
175
+ Running the following job will give the results on the ODinW13 seed 300, see `odinw_train.train_file: fewshot_train_shot10_seed300` in the config file.
176
+ ```bash
177
+ # Example: Train on ODinW13 dataset
178
+ python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
179
+ ```
180
+ Change `odinw_train.train_file` to `fewshot_train_shot10_seed30` and `fewshot_train_shot10_seed3` to get the results for the other two seeds. Final results are aggregated from the three seeds. Notice that a small number of jobs may diverge during training, in which case we just use the last checkpoint's result before it diverges.
181
+
182
+
183
+ ### Eval Script Usage
184
+ With a similar setup as the training config, the training script `sam3/train.py` can also be used for evaluation, too, when setting `trainer.mode = val` in the job config. Run the following job will give the results on the zero-shot results on RF100-VL and ODinW13 datasets.
185
+ ```bash
186
+ # Example: Evaluate on Roboflow dataset
187
+ python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_eval.yaml
188
+ # Example: Evaluate on ODinW13 dataset
189
+ python sam3/train/train.py -c configs/odinw13/odinw_text_only.yaml
190
+ ```
detect_tools/sam3/assets/init.py ADDED
File without changes
detect_tools/sam3/pyproject.toml ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sam3"
7
+ dynamic = ["version"]
8
+ description = "SAM3 (Segment Anything Model 3) implementation"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {file = "LICENSE"}
12
+ authors = [
13
+ {name = "Meta AI Research"}
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 4 - Beta",
17
+ "Intended Audience :: Science/Research",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.8",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ ]
27
+ dependencies = [
28
+ "timm>=1.0.17",
29
+ "numpy==1.26",
30
+ "tqdm",
31
+ "ftfy==6.1.1",
32
+ "regex",
33
+ "iopath>=0.1.10",
34
+ "typing_extensions",
35
+ "huggingface_hub",
36
+ ]
37
+
38
+ [project.optional-dependencies]
39
+ dev = [
40
+ "pytest",
41
+ "pytest-cov",
42
+ "black==24.2.0",
43
+ "ufmt==2.8.0",
44
+ "ruff-api==0.1.0",
45
+ "usort==1.0.2",
46
+ "gitpython==3.1.31",
47
+ "yt-dlp",
48
+ "pandas",
49
+ "opencv-python",
50
+ "pycocotools",
51
+ "numba",
52
+ "python-rapidjson",
53
+ ]
54
+ notebooks = [
55
+ "matplotlib",
56
+ "jupyter",
57
+ "notebook",
58
+ "ipywidgets",
59
+ "ipycanvas",
60
+ "ipympl",
61
+ "pycocotools",
62
+ "decord",
63
+ "opencv-python",
64
+ "einops",
65
+ "scikit-image",
66
+ "scikit-learn",
67
+ ]
68
+ train = [
69
+ "hydra-core",
70
+ "submitit",
71
+ "tensorboard",
72
+ "zstandard",
73
+ "scipy",
74
+ "torchmetrics",
75
+ "fvcore",
76
+ "fairscale",
77
+ "scikit-image",
78
+ "scikit-learn",
79
+ ]
80
+
81
+ [project.urls]
82
+ "Homepage" = "https://github.com/facebookresearch/sam3"
83
+ "Bug Tracker" = "https://github.com/facebookresearch/sam3/issues"
84
+
85
+ [tool.setuptools]
86
+ packages = ["sam3", "sam3.model"]
87
+
88
+ [tool.setuptools.dynamic]
89
+ version = {attr = "sam3.__version__"}
90
+
91
+ [tool.black]
92
+ line-length = 88
93
+ target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
94
+ include = '\.pyi?$'
95
+
96
+ [tool.isort]
97
+ profile = "black"
98
+ multi_line_output = 3
99
+
100
+ [tool.usort]
101
+ first_party_detection = false
102
+
103
+ [tool.ufmt]
104
+ formatter = "ruff-api"
105
+
106
+ [tool.mypy]
107
+ python_version = "3.12"
108
+ warn_return_any = true
109
+ warn_unused_configs = true
110
+ disallow_untyped_defs = true
111
+ disallow_incomplete_defs = true
112
+
113
+ [[tool.mypy.overrides]]
114
+ module = [
115
+ "torch.*",
116
+ "torchvision.*",
117
+ "timm.*",
118
+ "numpy.*",
119
+ "PIL.*",
120
+ "tqdm.*",
121
+ "ftfy.*",
122
+ "regex.*",
123
+ "iopath.*",
124
+ ]
125
+ ignore_missing_imports = true
126
+
127
+ [tool.pytest.ini_options]
128
+ testpaths = ["tests"]
129
+ python_files = "test_*.py"
130
+ python_classes = "Test*"
131
+ python_functions = "test_*"
detect_tools/sam3/sam3/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from .model_builder import build_sam3_image_model
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ __all__ = ["build_sam3_image_model"]
detect_tools/sam3/sam3/logger.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ import logging
3
+ import os
4
+
5
+ LOG_LEVELS = {
6
+ "DEBUG": logging.DEBUG,
7
+ "INFO": logging.INFO,
8
+ "WARNING": logging.WARNING,
9
+ "ERROR": logging.ERROR,
10
+ "CRITICAL": logging.CRITICAL,
11
+ }
12
+
13
+
14
+ class ColoredFormatter(logging.Formatter):
15
+ """A command line formatter with different colors for each level."""
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+ reset = "\033[0m"
20
+ colors = {
21
+ logging.DEBUG: f"{reset}\033[36m", # cyan,
22
+ logging.INFO: f"{reset}\033[32m", # green
23
+ logging.WARNING: f"{reset}\033[33m", # yellow
24
+ logging.ERROR: f"{reset}\033[31m", # red
25
+ logging.CRITICAL: f"{reset}\033[35m", # magenta
26
+ }
27
+ fmt_str = "{color}%(levelname)s %(asctime)s %(process)d %(filename)s:%(lineno)4d:{reset} %(message)s"
28
+ self.formatters = {
29
+ level: logging.Formatter(fmt_str.format(color=color, reset=reset))
30
+ for level, color in colors.items()
31
+ }
32
+ self.default_formatter = self.formatters[logging.INFO]
33
+
34
+ def format(self, record):
35
+ formatter = self.formatters.get(record.levelno, self.default_formatter)
36
+ return formatter.format(record)
37
+
38
+
39
+ def get_logger(name, level=logging.INFO):
40
+ """A command line logger."""
41
+ if "LOG_LEVEL" in os.environ:
42
+ level = os.environ["LOG_LEVEL"].upper()
43
+ assert (
44
+ level in LOG_LEVELS
45
+ ), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
46
+ level = LOG_LEVELS[level]
47
+ logger = logging.getLogger(name)
48
+ logger.setLevel(level)
49
+ logger.propagate = False
50
+ ch = logging.StreamHandler()
51
+ ch.setLevel(level)
52
+ ch.setFormatter(ColoredFormatter())
53
+ logger.addHandler(ch)
54
+ return logger
detect_tools/sam3/sam3/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
detect_tools/sam3/sam3/model/act_ckpt_utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import inspect
4
+ from functools import wraps
5
+ from typing import Callable, TypeVar, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint as checkpoint
10
+ from torch.utils._pytree import tree_map_only
11
+
12
+ # Type variables for better type hinting
13
+ T = TypeVar("T")
14
+ Module = TypeVar("Module", bound=nn.Module)
15
+
16
+
17
+ def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
18
+ """
19
+ Wraps a given module to enable or disable activation checkpointing.
20
+
21
+ Activation checkpointing (gradient checkpointing) trades compute for memory by
22
+ recomputing intermediate activations during the backward pass instead of storing
23
+ them in memory during the forward pass.
24
+
25
+ When activation checkpointing is enabled, the wrapper expects only keyword arguments,
26
+ and it maps these to positional arguments based on the module's signature.
27
+
28
+ Args:
29
+ module: The module or function to wrap with activation checkpointing
30
+
31
+ Returns:
32
+ A wrapped callable that supports activation checkpointing
33
+
34
+ Usage:
35
+ The returned wrapper function can be called with the same arguments as the
36
+ original module, with an additional `act_ckpt_enable` keyword argument to control
37
+ activation checkpointing and optional `use_reentrant` parameter.
38
+
39
+ Example:
40
+ ```python
41
+ wrapped_module = activation_ckpt_wrapper(my_module)
42
+ output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
43
+ ```
44
+ """
45
+
46
+ @wraps(module)
47
+ def act_ckpt_wrapper(
48
+ *args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
49
+ ):
50
+ if act_ckpt_enable:
51
+ if len(args) > 0:
52
+ raise ValueError(
53
+ "This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
54
+ )
55
+ # Get the signature of the target function/module
56
+ callable_fn = module.forward if isinstance(module, nn.Module) else module
57
+ sig = inspect.signature(callable_fn)
58
+ # Create a mapping of parameter names to their default values
59
+ param_defaults = {
60
+ name: param.default for name, param in sig.parameters.items()
61
+ }
62
+ args = []
63
+ for p_name in param_defaults.keys():
64
+ if p_name in kwargs:
65
+ args.append(kwargs.pop(p_name))
66
+ elif param_defaults[p_name] is not inspect.Parameter.empty:
67
+ # Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
68
+ args.append(param_defaults[p_name])
69
+ elif (
70
+ sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
71
+ ): # Skip **kwargs parameter
72
+ raise ValueError(f"Missing positional argument: {p_name}")
73
+
74
+ # Scan remaining kwargs for torch.Tensor
75
+ remaining_keys = list(kwargs.keys())
76
+ for key in remaining_keys:
77
+ if isinstance(kwargs[key], torch.Tensor):
78
+ # Remove the tensor from kwargs, assuming it's not required by the module.
79
+ # If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
80
+ kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
81
+
82
+ ret = checkpoint.checkpoint(
83
+ module, *args, use_reentrant=use_reentrant, **kwargs
84
+ )
85
+ else:
86
+ ret = module(*args, **kwargs)
87
+
88
+ return ret
89
+
90
+ return act_ckpt_wrapper
91
+
92
+
93
+ def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
94
+ """
95
+ Clone the CUDA output tensors of a function to avoid in-place operations.
96
+
97
+ This wrapper is useful when working with torch.compile to prevent errors
98
+ related to in-place operations on tensors.
99
+
100
+ Args:
101
+ f: The function whose CUDA tensor outputs should be cloned
102
+
103
+ Returns:
104
+ A wrapped function that clones any CUDA tensor outputs
105
+ """
106
+
107
+ @wraps(f)
108
+ def wrapped(*args, **kwargs):
109
+ outputs = f(*args, **kwargs)
110
+ return tree_map_only(
111
+ torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
112
+ )
113
+
114
+ return wrapped
detect_tools/sam3/sam3/model/box_ops.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ """
3
+ Utilities for bounding box manipulation and GIoU.
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+
10
+
11
+ def box_cxcywh_to_xyxy(x):
12
+ x_c, y_c, w, h = x.unbind(-1)
13
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
14
+ return torch.stack(b, dim=-1)
15
+
16
+
17
+ def box_cxcywh_to_xywh(x):
18
+ x_c, y_c, w, h = x.unbind(-1)
19
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (w), (h)]
20
+ return torch.stack(b, dim=-1)
21
+
22
+
23
+ def box_xywh_to_xyxy(x):
24
+ x, y, w, h = x.unbind(-1)
25
+ b = [(x), (y), (x + w), (y + h)]
26
+ return torch.stack(b, dim=-1)
27
+
28
+
29
+ def box_xywh_to_cxcywh(x):
30
+ x, y, w, h = x.unbind(-1)
31
+ b = [(x + 0.5 * w), (y + 0.5 * h), (w), (h)]
32
+ return torch.stack(b, dim=-1)
33
+
34
+
35
+ def box_xyxy_to_xywh(x):
36
+ x, y, X, Y = x.unbind(-1)
37
+ b = [(x), (y), (X - x), (Y - y)]
38
+ return torch.stack(b, dim=-1)
39
+
40
+
41
+ def box_xyxy_to_cxcywh(x):
42
+ x0, y0, x1, y1 = x.unbind(-1)
43
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
44
+ return torch.stack(b, dim=-1)
45
+
46
+
47
+ def box_area(boxes):
48
+ """
49
+ Batched version of box area. Boxes should be in [x0, y0, x1, y1] format.
50
+
51
+ Inputs:
52
+ - boxes: Tensor of shape (..., 4)
53
+
54
+ Returns:
55
+ - areas: Tensor of shape (...,)
56
+ """
57
+ x0, y0, x1, y1 = boxes.unbind(-1)
58
+ return (x1 - x0) * (y1 - y0)
59
+
60
+
61
+ def masks_to_boxes(masks):
62
+ """Compute the bounding boxes around the provided masks
63
+
64
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
65
+
66
+ Returns a [N, 4] tensors, with the boxes in xyxy format
67
+ """
68
+ if masks.numel() == 0:
69
+ return torch.zeros((0, 4), device=masks.device)
70
+
71
+ h, w = masks.shape[-2:]
72
+
73
+ y = torch.arange(0, h, dtype=torch.float, device=masks.device)
74
+ x = torch.arange(0, w, dtype=torch.float, device=masks.device)
75
+ y, x = torch.meshgrid(y, x)
76
+
77
+ x_mask = masks * x.unsqueeze(0)
78
+ x_max = x_mask.flatten(1).max(-1)[0] + 1
79
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
80
+
81
+ y_mask = masks * y.unsqueeze(0)
82
+ y_max = y_mask.flatten(1).max(-1)[0] + 1
83
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
84
+
85
+ boxes = torch.stack([x_min, y_min, x_max, y_max], 1)
86
+ # Invalidate boxes corresponding to empty masks.
87
+ boxes = boxes * masks.flatten(-2).any(-1)
88
+ return boxes
89
+
90
+
91
+ def box_iou(boxes1, boxes2):
92
+ """
93
+ Batched version of box_iou. Boxes should be in [x0, y0, x1, y1] format.
94
+
95
+ Inputs:
96
+ - boxes1: Tensor of shape (..., N, 4)
97
+ - boxes2: Tensor of shape (..., M, 4)
98
+
99
+ Returns:
100
+ - iou, union: Tensors of shape (..., N, M)
101
+ """
102
+ area1 = box_area(boxes1)
103
+ area2 = box_area(boxes2)
104
+
105
+ # boxes1: (..., N, 4) -> (..., N, 1, 2)
106
+ # boxes2: (..., M, 4) -> (..., 1, M, 2)
107
+ lt = torch.max(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
108
+ rb = torch.min(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
109
+
110
+ wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
111
+ inter = wh[..., 0] * wh[..., 1] # (..., N, M)
112
+
113
+ union = area1[..., None] + area2[..., None, :] - inter
114
+
115
+ iou = inter / union
116
+ return iou, union
117
+
118
+
119
+ def generalized_box_iou(boxes1, boxes2):
120
+ """
121
+ Batched version of Generalized IoU from https://giou.stanford.edu/
122
+
123
+ Boxes should be in [x0, y0, x1, y1] format
124
+
125
+ Inputs:
126
+ - boxes1: Tensor of shape (..., N, 4)
127
+ - boxes2: Tensor of shape (..., M, 4)
128
+
129
+ Returns:
130
+ - giou: Tensor of shape (..., N, M)
131
+ """
132
+ iou, union = box_iou(boxes1, boxes2)
133
+
134
+ # boxes1: (..., N, 4) -> (..., N, 1, 2)
135
+ # boxes2: (..., M, 4) -> (..., 1, M, 2)
136
+ lt = torch.min(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
137
+ rb = torch.max(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
138
+
139
+ wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
140
+ area = wh[..., 0] * wh[..., 1] # (..., N, M)
141
+
142
+ return iou - (area - union) / area
143
+
144
+
145
+ @torch.jit.script
146
+ def fast_diag_generalized_box_iou(boxes1, boxes2):
147
+ assert len(boxes1) == len(boxes2)
148
+ box1_xy = boxes1[:, 2:]
149
+ box1_XY = boxes1[:, :2]
150
+ box2_xy = boxes2[:, 2:]
151
+ box2_XY = boxes2[:, :2]
152
+ # assert (box1_xy >= box1_XY).all()
153
+ # assert (box2_xy >= box2_XY).all()
154
+ area1 = (box1_xy - box1_XY).prod(-1)
155
+ area2 = (box2_xy - box2_XY).prod(-1)
156
+
157
+ lt = torch.max(box1_XY, box2_XY) # [N,2]
158
+ lt2 = torch.min(box1_XY, box2_XY)
159
+ rb = torch.min(box1_xy, box2_xy) # [N,2]
160
+ rb2 = torch.max(box1_xy, box2_xy)
161
+
162
+ inter = (rb - lt).clamp(min=0).prod(-1)
163
+ tot_area = (rb2 - lt2).clamp(min=0).prod(-1)
164
+
165
+ union = area1 + area2 - inter
166
+
167
+ iou = inter / union
168
+
169
+ return iou - (tot_area - union) / tot_area
170
+
171
+
172
+ @torch.jit.script
173
+ def fast_diag_box_iou(boxes1, boxes2):
174
+ assert len(boxes1) == len(boxes2)
175
+ box1_xy = boxes1[:, 2:]
176
+ box1_XY = boxes1[:, :2]
177
+ box2_xy = boxes2[:, 2:]
178
+ box2_XY = boxes2[:, :2]
179
+ # assert (box1_xy >= box1_XY).all()
180
+ # assert (box2_xy >= box2_XY).all()
181
+ area1 = (box1_xy - box1_XY).prod(-1)
182
+ area2 = (box2_xy - box2_XY).prod(-1)
183
+
184
+ lt = torch.max(box1_XY, box2_XY) # [N,2]
185
+ rb = torch.min(box1_xy, box2_xy) # [N,2]
186
+
187
+ inter = (rb - lt).clamp(min=0).prod(-1)
188
+
189
+ union = area1 + area2 - inter
190
+
191
+ iou = inter / union
192
+
193
+ return iou
194
+
195
+
196
+ def box_xywh_inter_union(
197
+ boxes1: torch.Tensor, boxes2: torch.Tensor
198
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
199
+ # Asuumes boxes in xywh format
200
+ assert boxes1.size(-1) == 4 and boxes2.size(-1) == 4
201
+ boxes1 = box_xywh_to_xyxy(boxes1)
202
+ boxes2 = box_xywh_to_xyxy(boxes2)
203
+ box1_tl_xy = boxes1[..., :2]
204
+ box1_br_xy = boxes1[..., 2:]
205
+ box2_tl_xy = boxes2[..., :2]
206
+ box2_br_xy = boxes2[..., 2:]
207
+ area1 = (box1_br_xy - box1_tl_xy).prod(-1)
208
+ area2 = (box2_br_xy - box2_tl_xy).prod(-1)
209
+
210
+ assert (area1 >= 0).all() and (area2 >= 0).all()
211
+ tl = torch.max(box1_tl_xy, box2_tl_xy)
212
+ br = torch.min(box1_br_xy, box2_br_xy)
213
+
214
+ inter = (br - tl).clamp(min=0).prod(-1)
215
+ union = area1 + area2 - inter
216
+
217
+ return inter, union
detect_tools/sam3/sam3/model/data_misc.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+ """
5
+
6
+ import collections
7
+ import re
8
+
9
+ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
10
+ from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
11
+
12
+ import torch
13
+
14
+
15
+ MyTensor = Union[torch.Tensor, List[Any]]
16
+
17
+
18
+ def interpolate(
19
+ input, size=None, scale_factor=None, mode="nearest", align_corners=None
20
+ ):
21
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
22
+ """
23
+ Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
24
+ """
25
+ if input.numel() > 0:
26
+ return torch.nn.functional.interpolate(
27
+ input, size, scale_factor, mode, align_corners
28
+ )
29
+
30
+ assert (
31
+ input.shape[0] != 0 or input.shape[1] != 0
32
+ ), "At least one of the two first dimensions must be non zero"
33
+
34
+ if input.shape[1] == 0:
35
+ # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
36
+ return torch.nn.functional.interpolate(
37
+ input.transpose(0, 1), size, scale_factor, mode, align_corners
38
+ ).transpose(0, 1)
39
+
40
+ # empty batch dimension is now supported in pytorch
41
+ return torch.nn.functional.interpolate(
42
+ input, size, scale_factor, mode, align_corners
43
+ )
44
+
45
+
46
+ @dataclass
47
+ class BatchedPointer:
48
+ stage_ids: MyTensor
49
+ stage_ids__type = torch.long
50
+ query_ids: MyTensor
51
+ query_ids__type = torch.long
52
+ object_ids: MyTensor
53
+ object_ids__type = torch.long
54
+ ptr_mask: MyTensor
55
+ ptr_mask__type = torch.bool
56
+ ptr_types: MyTensor
57
+ ptr_types__type = torch.long
58
+
59
+
60
+ @dataclass
61
+ class FindStage:
62
+ img_ids: MyTensor
63
+ img_ids__type = torch.long
64
+ text_ids: MyTensor
65
+ text_ids__type = torch.long
66
+
67
+ input_boxes: MyTensor
68
+ input_boxes__type = torch.float
69
+ input_boxes_mask: MyTensor
70
+ input_boxes_mask__type = torch.bool
71
+ input_boxes_label: MyTensor
72
+ input_boxes_label__type = torch.long
73
+
74
+ input_points: MyTensor
75
+ input_points__type = torch.float
76
+ input_points_mask: MyTensor
77
+ input_points_mask__type = torch.bool
78
+
79
+ # We track the object ids referred to by this query.
80
+ # This is beneficial for tracking in videos without the need for pointers.
81
+ object_ids: Optional[List[List]] = None # List of objects per query
82
+
83
+
84
+ @dataclass
85
+ class BatchedFindTarget:
86
+ # The number of boxes in each find query
87
+ num_boxes: MyTensor
88
+ num_boxes__type = torch.long
89
+
90
+ # Target boxes in normalized CxCywh format
91
+ boxes: MyTensor
92
+ boxes__type = torch.float
93
+ # Target boxes in normalized CxCywh format but in padded representation
94
+ # as used in BinaryHungarianMatcherV2 (unlike the packed ones in `boxes`)
95
+ boxes_padded: MyTensor
96
+ boxes_padded__type = torch.float
97
+
98
+ # For hybrid matching, we repeat the boxes
99
+ repeated_boxes: MyTensor
100
+ repeated_boxes__type = torch.float
101
+
102
+ # Target Segmentation masks
103
+ segments: Optional[MyTensor]
104
+ segments__type = torch.bool
105
+
106
+ # Target Semantic Segmentation masks
107
+ semantic_segments: Optional[MyTensor]
108
+ semantic_segments__type = torch.bool
109
+
110
+ is_valid_segment: Optional[MyTensor]
111
+ is_valid_segment__type = torch.bool
112
+
113
+ # Whether annotations are exhaustive for each query
114
+ is_exhaustive: MyTensor
115
+ is_exhaustive__type = torch.bool
116
+
117
+ # The object id for each ground-truth box, in both packed and padded representations
118
+ object_ids: MyTensor
119
+ object_ids__type = torch.long
120
+ object_ids_padded: MyTensor
121
+ object_ids_padded__type = torch.long
122
+
123
+
124
+ @dataclass
125
+ class BatchedInferenceMetadata:
126
+ """All metadata required to post-process a find stage"""
127
+
128
+ # Coco id that corresponds to the "image" for evaluation by the coco evaluator
129
+ coco_image_id: MyTensor
130
+ coco_image_id__type = torch.long
131
+
132
+ # id in the original dataset, such that we can use the original evaluator
133
+ original_image_id: MyTensor
134
+ original_image_id__type = torch.long
135
+
136
+ # Original category id (if we want to use the original evaluator)
137
+ original_category_id: MyTensor
138
+ original_category_id__type = torch.int
139
+
140
+ # Size of the raw image (height, width)
141
+ original_size: MyTensor
142
+ original_size__type = torch.long
143
+
144
+ # id of the object in the media (track_id for a video)
145
+ object_id: MyTensor
146
+ object_id__type = torch.long
147
+
148
+ # index of the frame in the media (0 in the case of a single-frame media)
149
+ frame_index: MyTensor
150
+ frame_index__type = torch.long
151
+
152
+ # Adding for relations inference
153
+ # get_text_input: List[Optional[str]]
154
+
155
+ # Adding for TA conditional inference
156
+ is_conditioning_only: List[Optional[bool]]
157
+
158
+
159
+ @dataclass
160
+ class BatchedDatapoint:
161
+ img_batch: torch.Tensor
162
+ find_text_batch: List[str]
163
+ find_inputs: List[FindStage]
164
+ find_targets: List[BatchedFindTarget]
165
+ find_metadatas: List[BatchedInferenceMetadata]
166
+ raw_images: Optional[List[Any]] = None
167
+
168
+
169
+ def convert_my_tensors(obj):
170
+ def is_optional_field(field) -> bool:
171
+ return get_origin(field) is Union and type(None) in get_args(field)
172
+
173
+ for field in fields(obj):
174
+ if is_dataclass(getattr(obj, field.name)):
175
+ convert_my_tensors(getattr(obj, field.name))
176
+ continue
177
+
178
+ field_type = field.type
179
+ if is_optional_field(field.type):
180
+ field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
181
+
182
+ if field_type != MyTensor or getattr(obj, field.name) is None:
183
+ continue
184
+
185
+ elif len(getattr(obj, field.name)) and isinstance(
186
+ getattr(obj, field.name)[0], torch.Tensor
187
+ ):
188
+ stack_dim = 0
189
+ if field.name in [
190
+ "input_boxes",
191
+ "input_boxes_label",
192
+ ]:
193
+ stack_dim = 1
194
+ setattr(
195
+ obj,
196
+ field.name,
197
+ torch.stack(getattr(obj, field.name), dim=stack_dim).to(
198
+ getattr(obj, field.name + "__type")
199
+ ),
200
+ )
201
+ else:
202
+ setattr(
203
+ obj,
204
+ field.name,
205
+ torch.as_tensor(
206
+ getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
207
+ ),
208
+ )
209
+ return obj
detect_tools/sam3/sam3/model/decoder.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ """
3
+ Transformer decoder.
4
+ Inspired from Pytorch's version, adds the pre-norm variant
5
+ """
6
+
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+ from sam3.sam.transformer import RoPEAttention
14
+
15
+ from torch import nn, Tensor
16
+ from torchvision.ops.roi_align import RoIAlign
17
+
18
+ from .act_ckpt_utils import activation_ckpt_wrapper
19
+
20
+ from .box_ops import box_cxcywh_to_xyxy
21
+
22
+ from .model_misc import (
23
+ gen_sineembed_for_position,
24
+ get_activation_fn,
25
+ get_clones,
26
+ inverse_sigmoid,
27
+ MLP,
28
+ )
29
+
30
+
31
+ class TransformerDecoderLayer(nn.Module):
32
+ def __init__(
33
+ self,
34
+ activation: str,
35
+ d_model: int,
36
+ dim_feedforward: int,
37
+ dropout: float,
38
+ cross_attention: nn.Module,
39
+ n_heads: int,
40
+ use_text_cross_attention: bool = False,
41
+ ):
42
+ super().__init__()
43
+
44
+ # cross attention
45
+ self.cross_attn = cross_attention
46
+ self.dropout1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
47
+ self.norm1 = nn.LayerNorm(d_model)
48
+
49
+ # cross attention text
50
+ self.use_text_cross_attention = use_text_cross_attention
51
+ if use_text_cross_attention:
52
+ self.ca_text = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
53
+ self.catext_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
54
+ self.catext_norm = nn.LayerNorm(d_model)
55
+
56
+ # self attention
57
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
58
+ self.dropout2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
59
+ self.norm2 = nn.LayerNorm(d_model)
60
+
61
+ # ffn
62
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
63
+ self.activation = get_activation_fn(activation)
64
+ self.dropout3 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
65
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
66
+ self.dropout4 = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
67
+ self.norm3 = nn.LayerNorm(d_model)
68
+
69
+ @staticmethod
70
+ def with_pos_embed(tensor, pos):
71
+ return tensor if pos is None else tensor + pos
72
+
73
+ def forward_ffn(self, tgt):
74
+ with torch.amp.autocast(device_type="cuda", enabled=False):
75
+ tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
76
+ tgt = tgt + self.dropout4(tgt2)
77
+ tgt = self.norm3(tgt)
78
+ return tgt
79
+
80
+ def forward(
81
+ self,
82
+ # for tgt
83
+ tgt: Optional[Tensor], # nq, bs, d_model
84
+ tgt_query_pos: Optional[Tensor] = None, # pos for query. MLP(Sine(pos))
85
+ tgt_query_sine_embed: Optional[Tensor] = None, # pos for query. Sine(pos)
86
+ tgt_key_padding_mask: Optional[Tensor] = None,
87
+ tgt_reference_points: Optional[Tensor] = None, # nq, bs, 4
88
+ memory_text: Optional[Tensor] = None, # num_token, bs, d_model
89
+ text_attention_mask: Optional[Tensor] = None, # bs, num_token
90
+ # for memory
91
+ memory: Optional[Tensor] = None, # hw, bs, d_model
92
+ memory_key_padding_mask: Optional[Tensor] = None,
93
+ memory_level_start_index: Optional[Tensor] = None, # num_levels
94
+ memory_spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
95
+ memory_pos: Optional[Tensor] = None, # pos for memory
96
+ # sa
97
+ self_attn_mask: Optional[Tensor] = None, # mask used for self-attention
98
+ cross_attn_mask: Optional[Tensor] = None, # mask used for cross-attention
99
+ # dac
100
+ dac=False,
101
+ dac_use_selfatt_ln=True,
102
+ presence_token=None,
103
+ # skip inside deformable attn
104
+ identity=0.0,
105
+ **kwargs, # additional kwargs for compatibility
106
+ ):
107
+ """
108
+ Input:
109
+ - tgt/tgt_query_pos: nq, bs, d_model
110
+ -
111
+ """
112
+ # self attention
113
+ if self.self_attn is not None:
114
+ if dac:
115
+ # we only apply self attention to the first half of the queries
116
+ assert tgt.shape[0] % 2 == 0
117
+ num_o2o_queries = tgt.shape[0] // 2
118
+ tgt_o2o = tgt[:num_o2o_queries]
119
+ tgt_query_pos_o2o = tgt_query_pos[:num_o2o_queries]
120
+ tgt_o2m = tgt[num_o2o_queries:]
121
+ else:
122
+ tgt_o2o = tgt
123
+ tgt_query_pos_o2o = tgt_query_pos
124
+
125
+ if presence_token is not None:
126
+ tgt_o2o = torch.cat([presence_token, tgt_o2o], dim=0)
127
+ tgt_query_pos_o2o = torch.cat(
128
+ [torch.zeros_like(presence_token), tgt_query_pos_o2o], dim=0
129
+ )
130
+ tgt_query_pos = torch.cat(
131
+ [torch.zeros_like(presence_token), tgt_query_pos], dim=0
132
+ )
133
+
134
+ q = k = self.with_pos_embed(tgt_o2o, tgt_query_pos_o2o)
135
+ tgt2 = self.self_attn(q, k, tgt_o2o, attn_mask=self_attn_mask)[0]
136
+ tgt_o2o = tgt_o2o + self.dropout2(tgt2)
137
+ if dac:
138
+ if not dac_use_selfatt_ln:
139
+ tgt_o2o = self.norm2(tgt_o2o)
140
+ tgt = torch.cat((tgt_o2o, tgt_o2m), dim=0) # Recombine
141
+ if dac_use_selfatt_ln:
142
+ tgt = self.norm2(tgt)
143
+ else:
144
+ tgt = tgt_o2o
145
+ tgt = self.norm2(tgt)
146
+
147
+ if self.use_text_cross_attention:
148
+ tgt2 = self.ca_text(
149
+ self.with_pos_embed(tgt, tgt_query_pos),
150
+ memory_text,
151
+ memory_text,
152
+ key_padding_mask=text_attention_mask,
153
+ )[0]
154
+ tgt = tgt + self.catext_dropout(tgt2)
155
+ tgt = self.catext_norm(tgt)
156
+
157
+ if presence_token is not None:
158
+ presence_token_mask = torch.zeros_like(cross_attn_mask[:, :1, :])
159
+ cross_attn_mask = torch.cat(
160
+ [presence_token_mask, cross_attn_mask], dim=1
161
+ ) # (bs*nheads, 1+nq, hw)
162
+
163
+ # Cross attention to image
164
+ tgt2 = self.cross_attn(
165
+ query=self.with_pos_embed(tgt, tgt_query_pos),
166
+ key=self.with_pos_embed(memory, memory_pos),
167
+ value=memory,
168
+ attn_mask=cross_attn_mask,
169
+ key_padding_mask=(
170
+ memory_key_padding_mask.transpose(0, 1)
171
+ if memory_key_padding_mask is not None
172
+ else None
173
+ ),
174
+ )[0]
175
+
176
+ tgt = tgt + self.dropout1(tgt2)
177
+ tgt = self.norm1(tgt)
178
+
179
+ # ffn
180
+ tgt = self.forward_ffn(tgt)
181
+
182
+ presence_token_out = None
183
+ if presence_token is not None:
184
+ presence_token_out = tgt[:1]
185
+ tgt = tgt[1:]
186
+
187
+ return tgt, presence_token_out
188
+
189
+
190
+ class TransformerDecoder(nn.Module):
191
+ def __init__(
192
+ self,
193
+ d_model: int,
194
+ frozen: bool,
195
+ interaction_layer,
196
+ layer,
197
+ num_layers: int,
198
+ num_queries: int,
199
+ return_intermediate: bool,
200
+ box_refine: bool = False,
201
+ num_o2m_queries: int = 0,
202
+ dac: bool = False,
203
+ boxRPB: str = "none",
204
+ # Experimental: An object query for SAM 2 tasks
205
+ instance_query: bool = False,
206
+ # Defines the number of additional instance queries,
207
+ # 1 or 4 are the most likely for single vs multi mask support
208
+ num_instances: int = 1, # Irrelevant if instance_query is False
209
+ dac_use_selfatt_ln: bool = True,
210
+ use_act_checkpoint: bool = False,
211
+ compile_mode=None,
212
+ presence_token: bool = False,
213
+ clamp_presence_logits: bool = True,
214
+ clamp_presence_logit_max_val: float = 10.0,
215
+ use_normed_output_consistently: bool = True,
216
+ separate_box_head_instance: bool = False,
217
+ separate_norm_instance: bool = False,
218
+ resolution: Optional[int] = None,
219
+ stride: Optional[int] = None,
220
+ ):
221
+ super().__init__()
222
+ self.d_model = d_model
223
+ self.layers = get_clones(layer, num_layers)
224
+ self.fine_layers = (
225
+ get_clones(interaction_layer, num_layers)
226
+ if interaction_layer is not None
227
+ else [None] * num_layers
228
+ )
229
+ self.num_layers = num_layers
230
+ self.num_queries = num_queries
231
+ self.dac = dac
232
+ if dac:
233
+ self.num_o2m_queries = num_queries
234
+ tot_num_queries = num_queries
235
+ else:
236
+ self.num_o2m_queries = num_o2m_queries
237
+ tot_num_queries = num_queries + num_o2m_queries
238
+ self.norm = nn.LayerNorm(d_model)
239
+ self.return_intermediate = return_intermediate
240
+ self.bbox_embed = MLP(d_model, d_model, 4, 3)
241
+ self.query_embed = nn.Embedding(tot_num_queries, d_model)
242
+ self.instance_query_embed = None
243
+ self.instance_query_reference_points = None
244
+ self.use_instance_query = instance_query
245
+ self.num_instances = num_instances
246
+ self.use_normed_output_consistently = use_normed_output_consistently
247
+
248
+ self.instance_norm = nn.LayerNorm(d_model) if separate_norm_instance else None
249
+ self.instance_bbox_embed = None
250
+ if separate_box_head_instance:
251
+ self.instance_bbox_embed = MLP(d_model, d_model, 4, 3)
252
+ if instance_query:
253
+ self.instance_query_embed = nn.Embedding(num_instances, d_model)
254
+ self.box_refine = box_refine
255
+ if box_refine:
256
+ nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
257
+ nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
258
+
259
+ self.reference_points = nn.Embedding(num_queries, 4)
260
+ if instance_query:
261
+ self.instance_reference_points = nn.Embedding(num_instances, 4)
262
+
263
+ assert boxRPB in ["none", "log", "linear", "both"]
264
+ self.boxRPB = boxRPB
265
+ if boxRPB != "none":
266
+ try:
267
+ nheads = self.layers[0].cross_attn_image.num_heads
268
+ except AttributeError:
269
+ nheads = self.layers[0].cross_attn.num_heads
270
+
271
+ n_input = 4 if boxRPB == "both" else 2
272
+ self.boxRPB_embed_x = MLP(n_input, d_model, nheads, 2)
273
+ self.boxRPB_embed_y = MLP(n_input, d_model, nheads, 2)
274
+ self.compilable_cord_cache = None
275
+ self.compilable_stored_size = None
276
+ self.coord_cache = {}
277
+
278
+ if resolution is not None and stride is not None:
279
+ feat_size = resolution // stride
280
+ coords_h, coords_w = self._get_coords(
281
+ feat_size, feat_size, device="cuda"
282
+ )
283
+ self.compilable_cord_cache = (coords_h, coords_w)
284
+ self.compilable_stored_size = (feat_size, feat_size)
285
+
286
+ self.roi_pooler = (
287
+ RoIAlign(output_size=7, spatial_scale=1, sampling_ratio=-1, aligned=True)
288
+ if interaction_layer is not None
289
+ else None
290
+ )
291
+ if frozen:
292
+ for p in self.parameters():
293
+ p.requires_grad_(False)
294
+
295
+ self.presence_token = None
296
+ self.clamp_presence_logits = clamp_presence_logits
297
+ self.clamp_presence_logit_max_val = clamp_presence_logit_max_val
298
+ if presence_token:
299
+ self.presence_token = nn.Embedding(1, d_model)
300
+ self.presence_token_head = MLP(d_model, d_model, 1, 3)
301
+ self.presence_token_out_norm = nn.LayerNorm(d_model)
302
+
303
+ self.ref_point_head = MLP(2 * self.d_model, self.d_model, self.d_model, 2)
304
+ self.dac_use_selfatt_ln = dac_use_selfatt_ln
305
+ self.use_act_checkpoint = use_act_checkpoint
306
+
307
+ nn.init.normal_(self.query_embed.weight.data)
308
+ if self.instance_query_embed is not None:
309
+ nn.init.normal_(self.instance_query_embed.weight.data)
310
+
311
+ assert self.roi_pooler is None
312
+ assert self.return_intermediate, "support return_intermediate only"
313
+ assert self.box_refine, "support box refine only"
314
+
315
+ self.compile_mode = compile_mode
316
+ self.compiled = False
317
+ # We defer compilation till after the first forward, to first warm-up the boxRPB cache
318
+
319
+ # assign layer index to each layer so that some layers can decide what to do
320
+ # based on which layer index they are (e.g. cross attention to memory bank only
321
+ # in selected layers)
322
+ for layer_idx, layer in enumerate(self.layers):
323
+ layer.layer_idx = layer_idx
324
+
325
+ @staticmethod
326
+ def _get_coords(H, W, device):
327
+ coords_h = torch.arange(0, H, device=device, dtype=torch.float32) / H
328
+ coords_w = torch.arange(0, W, device=device, dtype=torch.float32) / W
329
+ return coords_h, coords_w
330
+
331
+ def _get_rpb_matrix(self, reference_boxes, feat_size):
332
+ H, W = feat_size
333
+ boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes).transpose(0, 1)
334
+ bs, num_queries, _ = boxes_xyxy.shape
335
+ if self.compilable_cord_cache is None:
336
+ self.compilable_cord_cache = self._get_coords(H, W, reference_boxes.device)
337
+ self.compilable_stored_size = (H, W)
338
+
339
+ if torch.compiler.is_dynamo_compiling() or self.compilable_stored_size == (
340
+ H,
341
+ W,
342
+ ):
343
+ # good, hitting the cache, will be compilable
344
+ coords_h, coords_w = self.compilable_cord_cache
345
+ else:
346
+ # cache miss, will create compilation issue
347
+ # In case we're not compiling, we'll still rely on the dict-based cache
348
+ if feat_size not in self.coord_cache:
349
+ self.coord_cache[feat_size] = self._get_coords(
350
+ H, W, reference_boxes.device
351
+ )
352
+ coords_h, coords_w = self.coord_cache[feat_size]
353
+
354
+ assert coords_h.shape == (H,)
355
+ assert coords_w.shape == (W,)
356
+
357
+ deltas_y = coords_h.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 1:4:2]
358
+ deltas_y = deltas_y.view(bs, num_queries, -1, 2)
359
+ deltas_x = coords_w.view(1, -1, 1) - boxes_xyxy.reshape(-1, 1, 4)[:, :, 0:3:2]
360
+ deltas_x = deltas_x.view(bs, num_queries, -1, 2)
361
+
362
+ if self.boxRPB in ["log", "both"]:
363
+ deltas_x_log = deltas_x * 8 # normalize to -8, 8
364
+ deltas_x_log = (
365
+ torch.sign(deltas_x_log)
366
+ * torch.log2(torch.abs(deltas_x_log) + 1.0)
367
+ / np.log2(8)
368
+ )
369
+
370
+ deltas_y_log = deltas_y * 8 # normalize to -8, 8
371
+ deltas_y_log = (
372
+ torch.sign(deltas_y_log)
373
+ * torch.log2(torch.abs(deltas_y_log) + 1.0)
374
+ / np.log2(8)
375
+ )
376
+ if self.boxRPB == "log":
377
+ deltas_x = deltas_x_log
378
+ deltas_y = deltas_y_log
379
+ else:
380
+ deltas_x = torch.cat([deltas_x, deltas_x_log], dim=-1)
381
+ deltas_y = torch.cat([deltas_y, deltas_y_log], dim=-1)
382
+
383
+ if self.training:
384
+ assert self.use_act_checkpoint, "activation ckpt not enabled in decoder"
385
+ deltas_x = activation_ckpt_wrapper(self.boxRPB_embed_x)(
386
+ x=deltas_x,
387
+ act_ckpt_enable=self.training and self.use_act_checkpoint,
388
+ ) # bs, num_queries, W, n_heads
389
+ deltas_y = activation_ckpt_wrapper(self.boxRPB_embed_y)(
390
+ x=deltas_y,
391
+ act_ckpt_enable=self.training and self.use_act_checkpoint,
392
+ ) # bs, num_queries, H, n_heads
393
+
394
+ if not torch.compiler.is_dynamo_compiling():
395
+ assert deltas_x.shape[:3] == (bs, num_queries, W)
396
+ assert deltas_y.shape[:3] == (bs, num_queries, H)
397
+
398
+ B = deltas_y.unsqueeze(3) + deltas_x.unsqueeze(
399
+ 2
400
+ ) # bs, num_queries, H, W, n_heads
401
+ if not torch.compiler.is_dynamo_compiling():
402
+ assert B.shape[:4] == (bs, num_queries, H, W)
403
+ B = B.flatten(2, 3) # bs, num_queries, H*W, n_heads
404
+ B = B.permute(0, 3, 1, 2) # bs, n_heads, num_queries, H*W
405
+ B = B.contiguous() # memeff attn likes ordered strides
406
+ if not torch.compiler.is_dynamo_compiling():
407
+ assert B.shape[2:] == (num_queries, H * W)
408
+ return B
409
+
410
+ def forward(
411
+ self,
412
+ tgt,
413
+ memory,
414
+ tgt_mask: Optional[Tensor] = None,
415
+ memory_mask: Optional[Tensor] = None,
416
+ tgt_key_padding_mask: Optional[Tensor] = None,
417
+ memory_key_padding_mask: Optional[Tensor] = None,
418
+ pos: Optional[Tensor] = None,
419
+ reference_boxes: Optional[Tensor] = None, # num_queries, bs, 4
420
+ # for memory
421
+ level_start_index: Optional[Tensor] = None, # num_levels
422
+ spatial_shapes: Optional[Tensor] = None, # bs, num_levels, 2
423
+ valid_ratios: Optional[Tensor] = None,
424
+ # for text
425
+ memory_text: Optional[Tensor] = None,
426
+ text_attention_mask: Optional[Tensor] = None,
427
+ # if `apply_dac` is None, it will default to `self.dac`
428
+ apply_dac: Optional[bool] = None,
429
+ is_instance_prompt=False,
430
+ decoder_extra_kwargs: Optional[Dict] = None,
431
+ # ROI memory bank
432
+ obj_roi_memory_feat=None,
433
+ obj_roi_memory_mask=None,
434
+ box_head_trk=None,
435
+ ):
436
+ """
437
+ Input:
438
+ - tgt: nq, bs, d_model
439
+ - memory: \\sum{hw}, bs, d_model
440
+ - pos: \\sum{hw}, bs, d_model
441
+ - reference_boxes: nq, bs, 4 (after sigmoid)
442
+ - valid_ratios/spatial_shapes: bs, nlevel, 2
443
+ """
444
+ if memory_mask is not None:
445
+ assert (
446
+ self.boxRPB == "none"
447
+ ), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
448
+
449
+ apply_dac = apply_dac if apply_dac is not None else self.dac
450
+ if apply_dac:
451
+ assert (tgt.shape[0] == self.num_queries) or (
452
+ self.use_instance_query
453
+ and (tgt.shape[0] == self.instance_query_embed.num_embeddings)
454
+ )
455
+
456
+ tgt = tgt.repeat(2, 1, 1)
457
+ # note that we don't tile tgt_mask, since DAC doesn't
458
+ # use self-attention in o2m queries
459
+ if reference_boxes is not None:
460
+ assert (reference_boxes.shape[0] == self.num_queries) or (
461
+ self.use_instance_query
462
+ and (
463
+ reference_boxes.shape[0]
464
+ == self.instance_query_embed.num_embeddings
465
+ )
466
+ )
467
+ reference_boxes = reference_boxes.repeat(2, 1, 1)
468
+
469
+ bs = tgt.shape[1]
470
+ intermediate = []
471
+ intermediate_presence_logits = []
472
+ presence_feats = None
473
+
474
+ if self.box_refine:
475
+ if reference_boxes is None:
476
+ # In this case, we're in a one-stage model, so we generate the reference boxes
477
+ reference_boxes = self.reference_points.weight.unsqueeze(1)
478
+ reference_boxes = (
479
+ reference_boxes.repeat(2, bs, 1)
480
+ if apply_dac
481
+ else reference_boxes.repeat(1, bs, 1)
482
+ )
483
+ reference_boxes = reference_boxes.sigmoid()
484
+ intermediate_ref_boxes = [reference_boxes]
485
+ else:
486
+ reference_boxes = None
487
+ intermediate_ref_boxes = None
488
+
489
+ output = tgt
490
+ presence_out = None
491
+ if self.presence_token is not None and is_instance_prompt is False:
492
+ # expand to batch dim
493
+ presence_out = self.presence_token.weight[None].expand(1, bs, -1)
494
+
495
+ box_head = self.bbox_embed
496
+ if is_instance_prompt and self.instance_bbox_embed is not None:
497
+ box_head = self.instance_bbox_embed
498
+
499
+ out_norm = self.norm
500
+ if is_instance_prompt and self.instance_norm is not None:
501
+ out_norm = self.instance_norm
502
+
503
+ for layer_idx, layer in enumerate(self.layers):
504
+ reference_points_input = (
505
+ reference_boxes[:, :, None]
506
+ * torch.cat([valid_ratios, valid_ratios], -1)[None, :]
507
+ ) # nq, bs, nlevel, 4
508
+
509
+ query_sine_embed = gen_sineembed_for_position(
510
+ reference_points_input[:, :, 0, :], self.d_model
511
+ ) # nq, bs, d_model*2
512
+
513
+ # conditional query
514
+ query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
515
+
516
+ if self.boxRPB != "none" and reference_boxes is not None:
517
+ assert (
518
+ spatial_shapes.shape[0] == 1
519
+ ), "only single scale support implemented"
520
+ memory_mask = self._get_rpb_matrix(
521
+ reference_boxes,
522
+ (spatial_shapes[0, 0], spatial_shapes[0, 1]),
523
+ )
524
+ memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
525
+ if self.training:
526
+ assert (
527
+ self.use_act_checkpoint
528
+ ), "Activation checkpointing not enabled in the decoder"
529
+ output, presence_out = activation_ckpt_wrapper(layer)(
530
+ tgt=output,
531
+ tgt_query_pos=query_pos,
532
+ tgt_query_sine_embed=query_sine_embed,
533
+ tgt_key_padding_mask=tgt_key_padding_mask,
534
+ tgt_reference_points=reference_points_input,
535
+ memory_text=memory_text,
536
+ text_attention_mask=text_attention_mask,
537
+ memory=memory,
538
+ memory_key_padding_mask=memory_key_padding_mask,
539
+ memory_level_start_index=level_start_index,
540
+ memory_spatial_shapes=spatial_shapes,
541
+ memory_pos=pos,
542
+ self_attn_mask=tgt_mask,
543
+ cross_attn_mask=memory_mask,
544
+ dac=apply_dac,
545
+ dac_use_selfatt_ln=self.dac_use_selfatt_ln,
546
+ presence_token=presence_out,
547
+ **(decoder_extra_kwargs or {}),
548
+ act_ckpt_enable=self.training and self.use_act_checkpoint,
549
+ # ROI memory bank
550
+ obj_roi_memory_feat=obj_roi_memory_feat,
551
+ obj_roi_memory_mask=obj_roi_memory_mask,
552
+ )
553
+
554
+ # iter update
555
+ if self.box_refine:
556
+ reference_before_sigmoid = inverse_sigmoid(reference_boxes)
557
+ if box_head_trk is None:
558
+ # delta_unsig = self.bbox_embed(output)
559
+ if not self.use_normed_output_consistently:
560
+ delta_unsig = box_head(output)
561
+ else:
562
+ delta_unsig = box_head(out_norm(output))
563
+ else:
564
+ # box_head_trk use a separate box head for tracking queries
565
+ Q_det = decoder_extra_kwargs["Q_det"]
566
+ assert output.size(0) >= Q_det
567
+ delta_unsig_det = self.bbox_embed(output[:Q_det])
568
+ delta_unsig_trk = box_head_trk(output[Q_det:])
569
+ delta_unsig = torch.cat([delta_unsig_det, delta_unsig_trk], dim=0)
570
+ outputs_unsig = delta_unsig + reference_before_sigmoid
571
+ new_reference_points = outputs_unsig.sigmoid()
572
+
573
+ reference_boxes = new_reference_points.detach()
574
+ if layer_idx != self.num_layers - 1:
575
+ intermediate_ref_boxes.append(new_reference_points)
576
+ else:
577
+ raise NotImplementedError("not implemented yet")
578
+
579
+ intermediate.append(out_norm(output))
580
+ if self.presence_token is not None and is_instance_prompt is False:
581
+ # norm, mlp head
582
+ intermediate_layer_presence_logits = self.presence_token_head(
583
+ self.presence_token_out_norm(presence_out)
584
+ ).squeeze(-1)
585
+
586
+ # clamp to mitigate numerical issues
587
+ if self.clamp_presence_logits:
588
+ intermediate_layer_presence_logits.clamp(
589
+ min=-self.clamp_presence_logit_max_val,
590
+ max=self.clamp_presence_logit_max_val,
591
+ )
592
+
593
+ intermediate_presence_logits.append(intermediate_layer_presence_logits)
594
+ presence_feats = presence_out.clone()
595
+
596
+ if not self.compiled and self.compile_mode is not None:
597
+ self.forward = torch.compile(
598
+ self.forward, mode=self.compile_mode, fullgraph=True
599
+ )
600
+ self.compiled = True
601
+
602
+ return (
603
+ torch.stack(intermediate),
604
+ torch.stack(intermediate_ref_boxes),
605
+ (
606
+ torch.stack(intermediate_presence_logits)
607
+ if self.presence_token is not None and is_instance_prompt is False
608
+ else None
609
+ ),
610
+ presence_feats,
611
+ )
612
+
613
+
614
+ class TransformerEncoderCrossAttention(nn.Module):
615
+ def __init__(
616
+ self,
617
+ d_model: int,
618
+ frozen: bool,
619
+ pos_enc_at_input: bool,
620
+ layer,
621
+ num_layers: int,
622
+ use_act_checkpoint: bool = False,
623
+ batch_first: bool = False, # Do layers expect batch first input?
624
+ # which layers to exclude cross attention? default: None, means all
625
+ # layers use cross attention
626
+ remove_cross_attention_layers: Optional[list] = None,
627
+ ):
628
+ super().__init__()
629
+ self.d_model = d_model
630
+ self.layers = get_clones(layer, num_layers)
631
+ self.num_layers = num_layers
632
+ self.norm = nn.LayerNorm(d_model)
633
+ self.pos_enc_at_input = pos_enc_at_input
634
+ self.use_act_checkpoint = use_act_checkpoint
635
+
636
+ if frozen:
637
+ for p in self.parameters():
638
+ p.requires_grad_(False)
639
+
640
+ self.batch_first = batch_first
641
+
642
+ # remove cross attention layers if specified
643
+ self.remove_cross_attention_layers = [False] * self.num_layers
644
+ if remove_cross_attention_layers is not None:
645
+ for i in remove_cross_attention_layers:
646
+ self.remove_cross_attention_layers[i] = True
647
+ assert len(self.remove_cross_attention_layers) == len(self.layers)
648
+
649
+ for i, remove_cross_attention in enumerate(self.remove_cross_attention_layers):
650
+ if remove_cross_attention:
651
+ self.layers[i].cross_attn_image = None
652
+ self.layers[i].norm2 = None
653
+ self.layers[i].dropout2 = None
654
+
655
+ def forward(
656
+ self,
657
+ src, # self-attention inputs
658
+ prompt, # cross-attention inputs
659
+ src_mask: Optional[Tensor] = None, # att.mask for self-attention inputs
660
+ prompt_mask: Optional[Tensor] = None, # att.mask for cross-attention inputs
661
+ src_key_padding_mask: Optional[Tensor] = None,
662
+ prompt_key_padding_mask: Optional[Tensor] = None,
663
+ src_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
664
+ prompt_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
665
+ feat_sizes: Optional[list] = None,
666
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
667
+ ):
668
+ if isinstance(src, list):
669
+ assert isinstance(src_key_padding_mask, list) and isinstance(src_pos, list)
670
+ assert len(src) == len(src_key_padding_mask) == len(src_pos) == 1
671
+ src, src_key_padding_mask, src_pos = (
672
+ src[0],
673
+ src_key_padding_mask[0],
674
+ src_pos[0],
675
+ )
676
+
677
+ assert (
678
+ src.shape[1] == prompt.shape[1]
679
+ ), "Batch size must be the same for src and prompt"
680
+
681
+ output = src
682
+
683
+ if self.pos_enc_at_input and src_pos is not None:
684
+ output = output + 0.1 * src_pos
685
+
686
+ if self.batch_first:
687
+ # Convert to batch first
688
+ output = output.transpose(0, 1)
689
+ src_pos = src_pos.transpose(0, 1)
690
+ prompt = prompt.transpose(0, 1)
691
+ prompt_pos = prompt_pos.transpose(0, 1)
692
+
693
+ for layer in self.layers:
694
+ kwds = {}
695
+ if isinstance(layer.cross_attn_image, RoPEAttention):
696
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
697
+
698
+ output = activation_ckpt_wrapper(layer)(
699
+ tgt=output,
700
+ memory=prompt,
701
+ tgt_mask=src_mask,
702
+ memory_mask=prompt_mask,
703
+ tgt_key_padding_mask=src_key_padding_mask,
704
+ memory_key_padding_mask=prompt_key_padding_mask,
705
+ pos=prompt_pos,
706
+ query_pos=src_pos,
707
+ dac=False,
708
+ attn_bias=None,
709
+ act_ckpt_enable=self.training and self.use_act_checkpoint,
710
+ **kwds,
711
+ )
712
+ normed_output = self.norm(output)
713
+
714
+ if self.batch_first:
715
+ # Convert back to seq first
716
+ normed_output = normed_output.transpose(0, 1)
717
+ src_pos = src_pos.transpose(0, 1)
718
+
719
+ return {
720
+ "memory": normed_output,
721
+ "pos_embed": src_pos,
722
+ "padding_mask": src_key_padding_mask,
723
+ }
724
+
725
+
726
+ class TransformerDecoderLayerv1(nn.Module):
727
+ def __init__(
728
+ self,
729
+ activation: str,
730
+ cross_attention: nn.Module,
731
+ d_model: int,
732
+ dim_feedforward: int,
733
+ dropout: float,
734
+ pos_enc_at_attn: bool,
735
+ pos_enc_at_cross_attn_keys: bool,
736
+ pos_enc_at_cross_attn_queries: bool,
737
+ pre_norm: bool,
738
+ self_attention: nn.Module,
739
+ ):
740
+ super().__init__()
741
+ self.d_model = d_model
742
+ self.dim_feedforward = dim_feedforward
743
+ self.dropout_value = dropout
744
+ self.self_attn = self_attention
745
+ self.cross_attn_image = cross_attention
746
+
747
+ # Implementation of Feedforward model
748
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
749
+ self.dropout = nn.Dropout(dropout)
750
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
751
+
752
+ self.norm1 = nn.LayerNorm(d_model)
753
+ self.norm2 = nn.LayerNorm(d_model)
754
+ self.norm3 = nn.LayerNorm(d_model)
755
+ self.dropout1 = nn.Dropout(dropout)
756
+ self.dropout2 = nn.Dropout(dropout)
757
+ self.dropout3 = nn.Dropout(dropout)
758
+
759
+ self.activation_str = activation
760
+ self.activation = get_activation_fn(activation)
761
+ self.pre_norm = pre_norm
762
+
763
+ self.pos_enc_at_attn = pos_enc_at_attn
764
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
765
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
766
+
767
+ def forward_post(
768
+ self,
769
+ tgt,
770
+ memory,
771
+ tgt_mask: Optional[Tensor] = None,
772
+ memory_mask: Optional[Tensor] = None,
773
+ tgt_key_padding_mask: Optional[Tensor] = None,
774
+ memory_key_padding_mask: Optional[Tensor] = None,
775
+ pos: Optional[Tensor] = None,
776
+ query_pos: Optional[Tensor] = None,
777
+ **kwargs,
778
+ ):
779
+ q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
780
+
781
+ # Self attention
782
+ tgt2 = self.self_attn(
783
+ q,
784
+ k,
785
+ value=tgt,
786
+ attn_mask=tgt_mask,
787
+ key_padding_mask=tgt_key_padding_mask,
788
+ )[0]
789
+ tgt = tgt + self.dropout1(tgt2)
790
+ tgt = self.norm1(tgt)
791
+
792
+ # Cross attention to image
793
+ tgt2 = self.cross_attn_image(
794
+ query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
795
+ key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
796
+ value=memory,
797
+ attn_mask=memory_mask,
798
+ key_padding_mask=memory_key_padding_mask,
799
+ )[0]
800
+ tgt = tgt + self.dropout2(tgt2)
801
+ tgt = self.norm2(tgt)
802
+
803
+ # FFN
804
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
805
+ tgt = tgt + self.dropout3(tgt2)
806
+ tgt = self.norm3(tgt)
807
+ return tgt
808
+
809
+ def forward_pre(
810
+ self,
811
+ tgt,
812
+ memory,
813
+ dac: bool = False,
814
+ tgt_mask: Optional[Tensor] = None,
815
+ memory_mask: Optional[Tensor] = None,
816
+ tgt_key_padding_mask: Optional[Tensor] = None,
817
+ memory_key_padding_mask: Optional[Tensor] = None,
818
+ pos: Optional[Tensor] = None,
819
+ query_pos: Optional[Tensor] = None,
820
+ attn_bias: Optional[Tensor] = None,
821
+ **kwargs,
822
+ ):
823
+ if dac:
824
+ # we only apply self attention to the first half of the queries
825
+ assert tgt.shape[0] % 2 == 0
826
+ other_tgt = tgt[tgt.shape[0] // 2 :]
827
+ tgt = tgt[: tgt.shape[0] // 2]
828
+ tgt2 = self.norm1(tgt)
829
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
830
+ tgt2 = self.self_attn(
831
+ q,
832
+ k,
833
+ value=tgt2,
834
+ attn_mask=tgt_mask,
835
+ key_padding_mask=tgt_key_padding_mask,
836
+ )[0]
837
+ tgt = tgt + self.dropout1(tgt2)
838
+ if dac:
839
+ # Recombine
840
+ tgt = torch.cat((tgt, other_tgt), dim=0)
841
+ tgt2 = self.norm2(tgt)
842
+ tgt2 = self.cross_attn_image(
843
+ query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
844
+ key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
845
+ value=memory,
846
+ attn_mask=memory_mask,
847
+ key_padding_mask=memory_key_padding_mask,
848
+ attn_bias=attn_bias,
849
+ )[0]
850
+ tgt = tgt + self.dropout2(tgt2)
851
+ tgt2 = self.norm3(tgt)
852
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
853
+ tgt = tgt + self.dropout3(tgt2)
854
+ return tgt
855
+
856
+ def forward(
857
+ self,
858
+ tgt,
859
+ memory,
860
+ dac: bool = False,
861
+ tgt_mask: Optional[Tensor] = None,
862
+ memory_mask: Optional[Tensor] = None,
863
+ tgt_key_padding_mask: Optional[Tensor] = None,
864
+ memory_key_padding_mask: Optional[Tensor] = None,
865
+ pos: Optional[Tensor] = None,
866
+ query_pos: Optional[Tensor] = None,
867
+ attn_bias: Optional[Tensor] = None,
868
+ **kwds: Any,
869
+ ) -> torch.Tensor:
870
+ fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
871
+ return fwd_fn(
872
+ tgt,
873
+ memory,
874
+ dac=dac,
875
+ tgt_mask=tgt_mask,
876
+ memory_mask=memory_mask,
877
+ tgt_key_padding_mask=tgt_key_padding_mask,
878
+ memory_key_padding_mask=memory_key_padding_mask,
879
+ pos=pos,
880
+ query_pos=query_pos,
881
+ attn_bias=attn_bias,
882
+ **kwds,
883
+ )
884
+
885
+
886
+ class TransformerDecoderLayerv2(TransformerDecoderLayerv1):
887
+ def __init__(self, cross_attention_first=False, *args: Any, **kwds: Any):
888
+ super().__init__(*args, **kwds)
889
+ self.cross_attention_first = cross_attention_first
890
+
891
+ def _forward_sa(self, tgt, query_pos):
892
+ # Self-Attention
893
+ tgt2 = self.norm1(tgt)
894
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
895
+ tgt2 = self.self_attn(q, k, v=tgt2)
896
+ tgt = tgt + self.dropout1(tgt2)
897
+ return tgt
898
+
899
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
900
+ if self.cross_attn_image is None:
901
+ return tgt
902
+
903
+ kwds = {}
904
+ if num_k_exclude_rope > 0:
905
+ assert isinstance(self.cross_attn_image, RoPEAttention)
906
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
907
+
908
+ # Cross-Attention
909
+ tgt2 = self.norm2(tgt)
910
+ tgt2 = self.cross_attn_image(
911
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
912
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
913
+ v=memory,
914
+ **kwds,
915
+ )
916
+ tgt = tgt + self.dropout2(tgt2)
917
+ return tgt
918
+
919
+ def forward_pre(
920
+ self,
921
+ tgt,
922
+ memory,
923
+ dac: bool,
924
+ tgt_mask: Optional[Tensor] = None,
925
+ memory_mask: Optional[Tensor] = None,
926
+ tgt_key_padding_mask: Optional[Tensor] = None,
927
+ memory_key_padding_mask: Optional[Tensor] = None,
928
+ pos: Optional[Tensor] = None,
929
+ query_pos: Optional[Tensor] = None,
930
+ attn_bias: Optional[Tensor] = None,
931
+ num_k_exclude_rope: int = 0,
932
+ ):
933
+ assert dac is False
934
+ assert tgt_mask is None
935
+ assert memory_mask is None
936
+ assert tgt_key_padding_mask is None
937
+ assert memory_key_padding_mask is None
938
+ assert attn_bias is None
939
+
940
+ if self.cross_attention_first:
941
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
942
+ tgt = self._forward_sa(tgt, query_pos)
943
+ else:
944
+ tgt = self._forward_sa(tgt, query_pos)
945
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
946
+
947
+ # MLP
948
+ tgt2 = self.norm3(tgt)
949
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
950
+ tgt = tgt + self.dropout3(tgt2)
951
+ return tgt
952
+
953
+ def forward(self, *args: Any, **kwds: Any) -> torch.Tensor:
954
+ if self.pre_norm:
955
+ return self.forward_pre(*args, **kwds)
956
+ raise NotImplementedError
detect_tools/sam3/sam3/model/edt.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """Triton kernel for euclidean distance transform (EDT)"""
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ """
10
+ Disclaimer: This implementation is not meant to be extremely efficient. A CUDA kernel would likely be more efficient.
11
+ Even in Triton, there may be more suitable algorithms.
12
+
13
+ The goal of this kernel is to mimic cv2.distanceTransform(input, cv2.DIST_L2, 0).
14
+ Recall that the euclidean distance transform (EDT) calculates the L2 distance to the closest zero pixel for each pixel of the source image.
15
+
16
+ For images of size NxN, the naive algorithm would be to compute pairwise distances between every pair of points, leading to a O(N^4) algorithm, which is obviously impractical.
17
+ One can do better using the following approach:
18
+ - First, compute the distance to the closest point in the same row. We can write it as Row_EDT[i,j] = min_k (sqrt((k-j)^2) if input[i,k]==0 else +infinity). With a naive implementation, this step has a O(N^3) complexity
19
+ - Then, because of triangular inequality, we notice that the EDT for a given location [i,j] is the min of the row EDTs in the same column. EDT[i,j] = min_k Row_EDT[k, j]. This is also O(N^3)
20
+
21
+ Overall, this algorithm is quite amenable to parallelization, and has a complexity O(N^3). Can we do better?
22
+
23
+ It turns out that we can leverage the structure of the L2 distance (nice and convex) to find the minimum in a more efficient way.
24
+ We follow the algorithm from "Distance Transforms of Sampled Functions" (https://cs.brown.edu/people/pfelzens/papers/dt-final.pdf), which is also what's implemented in opencv
25
+
26
+ For a single dimension EDT, we can compute the EDT of an arbitrary function F, that we discretize over the grid. Note that for the binary EDT that we're interested in, we can set F(i,j) = 0 if input[i,j]==0 else +infinity
27
+ For now, we'll compute the EDT squared, and will take the sqrt only at the very end.
28
+ The basic idea is that each point at location i spawns a parabola around itself, with a bias equal to F(i). So specifically, we're looking at the parabola (x - i)^2 + F(i)
29
+ When we're looking for the row EDT at location j, we're effectively looking for min_i (x-i)^2 + F(i). In other word we want to find the lowest parabola at location j.
30
+
31
+ To do this efficiently, we need to maintain the lower envelope of the union of parabolas. This can be constructed on the fly using a sort of stack approach:
32
+ - every time we want to add a new parabola, we check if it may be covering the current right-most parabola. If so, then that parabola was useless, so we can pop it from the stack
33
+ - repeat until we can't find any more parabola to pop. Then push the new one.
34
+
35
+ This algorithm runs in O(N) for a single row, so overall O(N^2) when applied to all rows
36
+ Similarly as before, we notice that we can decompose the algorithm for rows and columns, leading to an overall run-time of O(N^2)
37
+
38
+ This algorithm is less suited for to GPUs, since the one-dimensional EDT computation is quite sequential in nature. However, we can parallelize over batch and row dimensions.
39
+ In Triton, things are particularly bad at the moment, since there is no support for reading/writing to the local memory at a specific index (a local gather is coming soon, see https://github.com/triton-lang/triton/issues/974, but no mention of writing, ie scatter)
40
+ One could emulate these operations with masking, but in initial tests, it proved to be worst than naively reading and writing to the global memory. My guess is that the cache is compensating somewhat for the repeated single-point accesses.
41
+
42
+
43
+ The timing obtained on a H100 for a random batch of masks of dimension 256 x 1024 x 1024 are as follows:
44
+ - OpenCV: 1780ms (including round-trip to cpu, but discounting the fact that it introduces a synchronization point)
45
+ - triton, O(N^3) algo: 627ms
46
+ - triton, O(N^2) algo: 322ms
47
+
48
+ Overall, despite being quite naive, this implementation is roughly 5.5x faster than the openCV cpu implem
49
+
50
+ """
51
+
52
+
53
+ @triton.jit
54
+ def edt_kernel(inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr):
55
+ # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above
56
+ # It can be applied horizontally or vertically depending if we're doing the first or second stage.
57
+ # It's parallelized across batch+row (or batch+col if horizontal=False)
58
+ # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton
59
+ batch_id = tl.program_id(axis=0)
60
+ if horizontal:
61
+ row_id = tl.program_id(axis=1)
62
+ block_start = (batch_id * height * width) + row_id * width
63
+ length = width
64
+ stride = 1
65
+ else:
66
+ col_id = tl.program_id(axis=1)
67
+ block_start = (batch_id * height * width) + col_id
68
+ length = height
69
+ stride = width
70
+
71
+ # This will be the index of the right most parabola in the envelope ("the top of the stack")
72
+ k = 0
73
+ for q in range(1, length):
74
+ # Read the function value at the current location. Note that we're doing a singular read, not very efficient
75
+ cur_input = tl.load(inputs_ptr + block_start + (q * stride))
76
+ # location of the parabola on top of the stack
77
+ r = tl.load(v + block_start + (k * stride))
78
+ # associated boundary
79
+ z_k = tl.load(z + block_start + (k * stride))
80
+ # value of the function at the parabola location
81
+ previous_input = tl.load(inputs_ptr + block_start + (r * stride))
82
+ # intersection between the two parabolas
83
+ s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
84
+
85
+ # we'll pop as many parabolas as required
86
+ while s <= z_k and k - 1 >= 0:
87
+ k = k - 1
88
+ r = tl.load(v + block_start + (k * stride))
89
+ z_k = tl.load(z + block_start + (k * stride))
90
+ previous_input = tl.load(inputs_ptr + block_start + (r * stride))
91
+ s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
92
+
93
+ # Store the new one
94
+ k = k + 1
95
+ tl.store(v + block_start + (k * stride), q)
96
+ tl.store(z + block_start + (k * stride), s)
97
+ if k + 1 < length:
98
+ tl.store(z + block_start + ((k + 1) * stride), 1e9)
99
+
100
+ # Last step, we read the envelope to find the min in every location
101
+ k = 0
102
+ for q in range(length):
103
+ while (
104
+ k + 1 < length
105
+ and tl.load(
106
+ z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q
107
+ )
108
+ < q
109
+ ):
110
+ k += 1
111
+ r = tl.load(v + block_start + (k * stride))
112
+ d = q - r
113
+ old_value = tl.load(inputs_ptr + block_start + (r * stride))
114
+ tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d)
115
+
116
+
117
+ def edt_triton(data: torch.Tensor):
118
+ """
119
+ Computes the Euclidean Distance Transform (EDT) of a batch of binary images.
120
+
121
+ Args:
122
+ data: A tensor of shape (B, H, W) representing a batch of binary images.
123
+
124
+ Returns:
125
+ A tensor of the same shape as data containing the EDT.
126
+ It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0)
127
+ """
128
+ assert data.dim() == 3
129
+ assert data.is_cuda
130
+ B, H, W = data.shape
131
+ data = data.contiguous()
132
+
133
+ # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity
134
+ output = torch.where(data, 1e18, 0.0)
135
+ assert output.is_contiguous()
136
+
137
+ # Scratch tensors for the parabola stacks
138
+ parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device)
139
+ parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device)
140
+ parabola_inter[:, :, 0] = -1e18
141
+ parabola_inter[:, :, 1] = 1e18
142
+
143
+ # Grid size (number of blocks)
144
+ grid = (B, H)
145
+
146
+ # Launch initialization kernel
147
+ edt_kernel[grid](
148
+ output.clone(),
149
+ output,
150
+ parabola_loc,
151
+ parabola_inter,
152
+ H,
153
+ W,
154
+ horizontal=True,
155
+ )
156
+
157
+ # reset the parabola stacks
158
+ parabola_loc.zero_()
159
+ parabola_inter[:, :, 0] = -1e18
160
+ parabola_inter[:, :, 1] = 1e18
161
+
162
+ grid = (B, W)
163
+ edt_kernel[grid](
164
+ output.clone(),
165
+ output,
166
+ parabola_loc,
167
+ parabola_inter,
168
+ H,
169
+ W,
170
+ horizontal=False,
171
+ )
172
+ # don't forget to take sqrt at the end
173
+ return output.sqrt()
detect_tools/sam3/sam3/model/encoder.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ # Based on https://github.com/IDEA-Research/GroundingDINO
3
+
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn, Tensor
8
+
9
+ from .act_ckpt_utils import activation_ckpt_wrapper
10
+ from .model_misc import get_activation_fn, get_clones, get_valid_ratio
11
+
12
+
13
+ class TransformerEncoderLayer(nn.Module):
14
+ """
15
+ Transformer encoder layer that performs self-attention followed by cross-attention.
16
+
17
+ This layer was previously called TransformerDecoderLayer but was renamed to better
18
+ reflect its role in the architecture. It processes input sequences through self-attention
19
+ and then cross-attention with another input (typically image features).
20
+
21
+ The layer supports both pre-norm and post-norm configurations, as well as
22
+ positional encoding at different stages of the attention mechanism.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ activation: str,
28
+ cross_attention: nn.Module,
29
+ d_model: int,
30
+ dim_feedforward: int,
31
+ dropout: float,
32
+ pos_enc_at_attn: bool,
33
+ pos_enc_at_cross_attn_keys: bool,
34
+ pos_enc_at_cross_attn_queries: bool,
35
+ pre_norm: bool,
36
+ self_attention: nn.Module,
37
+ ):
38
+ """
39
+ Initialize a transformer encoder layer.
40
+
41
+ Args:
42
+ activation: Activation function to use in the feedforward network
43
+ cross_attention: Cross-attention module for attending to image features
44
+ d_model: Model dimension/hidden size
45
+ dim_feedforward: Dimension of the feedforward network
46
+ dropout: Dropout probability
47
+ pos_enc_at_attn: Whether to add positional encodings at self-attention
48
+ pos_enc_at_cross_attn_keys: Whether to add positional encodings to keys in cross-attention
49
+ pos_enc_at_cross_attn_queries: Whether to add positional encodings to queries in cross-attention
50
+ pre_norm: Whether to use pre-norm (True) or post-norm (False) architecture
51
+ self_attention: Self-attention module
52
+ """
53
+ super().__init__()
54
+ self.d_model = d_model
55
+ self.dim_feedforward = dim_feedforward
56
+ self.dropout_value = dropout
57
+ self.self_attn = self_attention
58
+ self.cross_attn_image = cross_attention
59
+
60
+ # Implementation of Feedforward model
61
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
62
+ self.dropout = nn.Dropout(dropout)
63
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
64
+
65
+ self.norm1 = nn.LayerNorm(d_model)
66
+ self.norm2 = nn.LayerNorm(d_model)
67
+ self.norm3 = nn.LayerNorm(d_model)
68
+ self.dropout1 = nn.Dropout(dropout)
69
+ self.dropout2 = nn.Dropout(dropout)
70
+ self.dropout3 = nn.Dropout(dropout)
71
+
72
+ self.activation_str = activation
73
+ self.activation = get_activation_fn(activation)
74
+ self.pre_norm = pre_norm
75
+
76
+ self.pos_enc_at_attn = pos_enc_at_attn
77
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
78
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
79
+
80
+ self.layer_idx = None
81
+
82
+ def forward_post(
83
+ self,
84
+ tgt: Tensor,
85
+ memory: Tensor,
86
+ tgt_mask: Optional[Tensor] = None,
87
+ memory_mask: Optional[Tensor] = None,
88
+ tgt_key_padding_mask: Optional[Tensor] = None,
89
+ memory_key_padding_mask: Optional[Tensor] = None,
90
+ pos: Optional[Tensor] = None,
91
+ query_pos: Optional[Tensor] = None,
92
+ **kwargs,
93
+ ) -> Tensor:
94
+ """
95
+ Forward pass for post-norm architecture.
96
+
97
+ In post-norm architecture, normalization is applied after attention and feedforward operations.
98
+
99
+ Args:
100
+ tgt: Input tensor to be processed
101
+ memory: Memory tensor for cross-attention
102
+ tgt_mask: Mask for self-attention
103
+ memory_mask: Mask for cross-attention
104
+ tgt_key_padding_mask: Key padding mask for self-attention
105
+ memory_key_padding_mask: Key padding mask for cross-attention
106
+ pos: Positional encoding for memory
107
+ query_pos: Positional encoding for query
108
+ **kwargs: Additional keyword arguments
109
+
110
+ Returns:
111
+ Processed tensor
112
+ """
113
+ q = k = tgt + query_pos if self.pos_enc_at_attn else tgt
114
+
115
+ # Self attention
116
+ tgt2 = self.self_attn(
117
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
118
+ )[0]
119
+ tgt = tgt + self.dropout1(tgt2)
120
+ tgt = self.norm1(tgt)
121
+
122
+ # Cross attention to image
123
+ tgt2 = self.cross_attn_image(
124
+ query=tgt + query_pos if self.pos_enc_at_cross_attn_queries else tgt,
125
+ key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
126
+ value=memory,
127
+ attn_mask=memory_mask,
128
+ key_padding_mask=memory_key_padding_mask,
129
+ )[0]
130
+ tgt = tgt + self.dropout2(tgt2)
131
+ tgt = self.norm2(tgt)
132
+
133
+ # FFN
134
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
135
+ tgt = tgt + self.dropout3(tgt2)
136
+ tgt = self.norm3(tgt)
137
+ return tgt
138
+
139
+ def forward_pre(
140
+ self,
141
+ tgt: Tensor,
142
+ memory: Tensor,
143
+ dac: bool = False,
144
+ tgt_mask: Optional[Tensor] = None,
145
+ memory_mask: Optional[Tensor] = None,
146
+ tgt_key_padding_mask: Optional[Tensor] = None,
147
+ memory_key_padding_mask: Optional[Tensor] = None,
148
+ pos: Optional[Tensor] = None,
149
+ query_pos: Optional[Tensor] = None,
150
+ # attn_bias: Optional[Tensor] = None,
151
+ # **kwargs,
152
+ ) -> Tensor:
153
+ """
154
+ Forward pass for pre-norm architecture.
155
+
156
+ In pre-norm architecture, normalization is applied before attention and feedforward operations.
157
+
158
+ Args:
159
+ tgt: Input tensor to be processed
160
+ memory: Memory tensor for cross-attention
161
+ dac: Whether to use Divide-and-Conquer attention
162
+ tgt_mask: Mask for self-attention
163
+ memory_mask: Mask for cross-attention
164
+ tgt_key_padding_mask: Key padding mask for self-attention
165
+ memory_key_padding_mask: Key padding mask for cross-attention
166
+ pos: Positional encoding for memory
167
+ query_pos: Positional encoding for query
168
+ attn_bias: Optional attention bias tensor
169
+ **kwargs: Additional keyword arguments
170
+
171
+ Returns:
172
+ Processed tensor
173
+ """
174
+ if dac:
175
+ # we only apply self attention to the first half of the queries
176
+ assert tgt.shape[0] % 2 == 0
177
+ other_tgt = tgt[tgt.shape[0] // 2 :]
178
+ tgt = tgt[: tgt.shape[0] // 2]
179
+ tgt2 = self.norm1(tgt)
180
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
181
+ tgt2 = self.self_attn(
182
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
183
+ )[0]
184
+ tgt = tgt + self.dropout1(tgt2)
185
+ if dac:
186
+ # Recombine
187
+ tgt = torch.cat((tgt, other_tgt), dim=0)
188
+ tgt2 = self.norm2(tgt)
189
+ tgt2 = self.cross_attn_image(
190
+ query=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
191
+ key=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
192
+ value=memory,
193
+ attn_mask=memory_mask,
194
+ key_padding_mask=memory_key_padding_mask,
195
+ # attn_bias=attn_bias,
196
+ )[0]
197
+ tgt = tgt + self.dropout2(tgt2)
198
+ tgt2 = self.norm3(tgt)
199
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
200
+ tgt = tgt + self.dropout3(tgt2)
201
+ return tgt
202
+
203
+ def forward(
204
+ self,
205
+ tgt: Tensor,
206
+ memory: Tensor,
207
+ dac: bool = False,
208
+ tgt_mask: Optional[Tensor] = None,
209
+ memory_mask: Optional[Tensor] = None,
210
+ tgt_key_padding_mask: Optional[Tensor] = None,
211
+ memory_key_padding_mask: Optional[Tensor] = None,
212
+ pos: Optional[Tensor] = None,
213
+ query_pos: Optional[Tensor] = None,
214
+ # attn_bias: Optional[Tensor] = None,
215
+ # **kwds: Any,
216
+ ) -> torch.Tensor:
217
+ """
218
+ Forward pass for the transformer encoder layer.
219
+
220
+ Args:
221
+ tgt: Input tensor to be processed
222
+ memory: Memory tensor (e.g., image features) for cross-attention
223
+ dac: Whether to use Divide-and-Conquer attention (only apply self-attention to first half)
224
+ tgt_mask: Mask for self-attention
225
+ memory_mask: Mask for cross-attention
226
+ tgt_key_padding_mask: Key padding mask for self-attention
227
+ memory_key_padding_mask: Key padding mask for cross-attention
228
+ pos: Positional encoding for memory
229
+ query_pos: Positional encoding for query
230
+ attn_bias: Optional attention bias tensor
231
+ **kwds: Additional keyword arguments
232
+
233
+ Returns:
234
+ Processed tensor after self-attention, cross-attention, and feedforward network
235
+ """
236
+ fwd_fn = self.forward_pre if self.pre_norm else self.forward_post
237
+ return fwd_fn(
238
+ tgt,
239
+ memory,
240
+ dac=dac,
241
+ tgt_mask=tgt_mask,
242
+ memory_mask=memory_mask,
243
+ tgt_key_padding_mask=tgt_key_padding_mask,
244
+ memory_key_padding_mask=memory_key_padding_mask,
245
+ pos=pos,
246
+ query_pos=query_pos,
247
+ # attn_bias=attn_bias,
248
+ # **kwds,
249
+ )
250
+
251
+
252
+ class TransformerEncoder(nn.Module):
253
+ """
254
+ Transformer encoder that processes multi-level features.
255
+
256
+ This encoder takes multi-level features (e.g., from a backbone network) and processes
257
+ them through a stack of transformer encoder layers. It supports features from multiple
258
+ levels (e.g., different resolutions) and can apply activation checkpointing for memory
259
+ efficiency during training.
260
+
261
+ Args:
262
+ layer: The encoder layer to be stacked multiple times
263
+ num_layers: Number of encoder layers to stack
264
+ d_model: Model dimension/hidden size
265
+ num_feature_levels: Number of feature levels to process
266
+ frozen: Whether to freeze the parameters of this module
267
+ use_act_checkpoint: Whether to use activation checkpointing during training
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ layer: nn.Module,
273
+ num_layers: int,
274
+ d_model: int,
275
+ num_feature_levels: int,
276
+ frozen: bool = False,
277
+ use_act_checkpoint: bool = False,
278
+ ):
279
+ super().__init__()
280
+ self.layers = get_clones(layer, num_layers)
281
+ self.num_layers = num_layers
282
+
283
+ self.num_feature_levels = num_feature_levels
284
+ self.level_embed = None
285
+ if num_feature_levels > 1:
286
+ self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
287
+
288
+ if frozen:
289
+ for p in self.parameters():
290
+ p.requires_grad_(False)
291
+
292
+ self.use_act_checkpoint = use_act_checkpoint
293
+
294
+ # assign layer index to each layer so that some layers can decide what to do
295
+ # based on which layer index they are (e.g. cross attention to memory bank only
296
+ # in selected layers)
297
+ for layer_idx, layer in enumerate(self.layers):
298
+ layer.layer_idx = layer_idx
299
+
300
+ @staticmethod
301
+ def get_reference_points(spatial_shapes, valid_ratios, device):
302
+ with torch.no_grad():
303
+ reference_points_list = []
304
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
305
+ ref_y, ref_x = torch.meshgrid(
306
+ torch.linspace(
307
+ 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device
308
+ ),
309
+ torch.linspace(
310
+ 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device
311
+ ),
312
+ )
313
+ ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
314
+ ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
315
+ ref = torch.stack((ref_x, ref_y), -1)
316
+ reference_points_list.append(ref)
317
+ reference_points = torch.cat(reference_points_list, 1)
318
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
319
+
320
+ return reference_points
321
+
322
+ def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
323
+ assert (
324
+ len(srcs) == self.num_feature_levels
325
+ ), "mismatch between expected and received # of feature levels"
326
+
327
+ src_flatten = []
328
+ mask_flatten = []
329
+ lvl_pos_embed_flatten = []
330
+ spatial_shapes = []
331
+ has_mask = masks is not None and masks[0] is not None
332
+ for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
333
+ bs, c, h, w = src.shape
334
+ spatial_shape = (h, w)
335
+ spatial_shapes.append(spatial_shape)
336
+
337
+ src = src.flatten(2).transpose(1, 2) # bs, hw, c
338
+ if has_mask:
339
+ mask = mask.flatten(1)
340
+ pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
341
+ if self.level_embed is not None:
342
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
343
+ else:
344
+ lvl_pos_embed = pos_embed
345
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
346
+ src_flatten.append(src)
347
+ if has_mask:
348
+ mask_flatten.append(mask)
349
+ src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c
350
+ mask_flatten = torch.cat(mask_flatten, 1) if has_mask else None # bs, \sum{hxw}
351
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
352
+ spatial_shapes = torch.tensor(
353
+ spatial_shapes, dtype=torch.long, device=src_flatten.device
354
+ )
355
+ level_start_index = torch.cat(
356
+ (
357
+ spatial_shapes.new_zeros((1,)),
358
+ spatial_shapes.prod(1).cumsum(0)[:-1],
359
+ )
360
+ )
361
+ if has_mask:
362
+ valid_ratios = torch.stack([get_valid_ratio(m) for m in masks], 1)
363
+ else:
364
+ valid_ratios = torch.ones(
365
+ (src_flatten.shape[0], self.num_feature_levels, 2),
366
+ device=src_flatten.device,
367
+ )
368
+
369
+ return (
370
+ src_flatten,
371
+ mask_flatten,
372
+ lvl_pos_embed_flatten,
373
+ level_start_index,
374
+ valid_ratios,
375
+ spatial_shapes,
376
+ )
377
+
378
+ def forward(
379
+ self,
380
+ src: List[Tensor],
381
+ src_key_padding_masks: Optional[List[Tensor]] = None,
382
+ pos: Optional[List[Tensor]] = None,
383
+ prompt: Optional[Tensor] = None,
384
+ prompt_key_padding_mask: Optional[Tensor] = None,
385
+ encoder_extra_kwargs: Optional[Dict] = None,
386
+ ) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor, Tensor]:
387
+ """
388
+ Process multi-level features through the transformer encoder.
389
+
390
+ Args:
391
+ src: List of multi-level features, each with shape (batch_size, channels, height, width)
392
+ src_key_padding_masks: List of padding masks for each feature level, each with shape (batch_size, height, width)
393
+ pos: List of positional embeddings for each feature level, each with shape (batch_size, channels, height, width)
394
+ prompt: Optional text/prompt features to attend to, with shape (seq_len, batch_size, d_model)
395
+ prompt_key_padding_mask: Optional padding mask for prompt, with shape (batch_size, seq_len)
396
+ encoder_extra_kwargs: Optional additional arguments to pass to each encoder layer
397
+
398
+ Returns:
399
+ A tuple containing:
400
+ - output: Processed features with shape (seq_len, batch_size, d_model)
401
+ - key_padding_masks_flatten: Flattened padding masks
402
+ - lvl_pos_embed_flatten: Flattened positional embeddings
403
+ - level_start_index: Starting indices for each feature level
404
+ - spatial_shapes: Spatial dimensions of each feature level
405
+ - valid_ratios: Valid ratios for each feature level
406
+ """
407
+ assert (
408
+ len(src) == self.num_feature_levels
409
+ ), "must be equal to num_feature_levels"
410
+ if src_key_padding_masks is not None:
411
+ assert len(src_key_padding_masks) == self.num_feature_levels
412
+ if pos is not None:
413
+ assert len(pos) == self.num_feature_levels
414
+ # Flatten multilevel feats and add level pos embeds
415
+ (
416
+ src_flatten,
417
+ key_padding_masks_flatten,
418
+ lvl_pos_embed_flatten,
419
+ level_start_index,
420
+ valid_ratios,
421
+ spatial_shapes,
422
+ ) = self._prepare_multilevel_features(src, src_key_padding_masks, pos)
423
+
424
+ reference_points = self.get_reference_points(
425
+ spatial_shapes, valid_ratios, device=src_flatten.device
426
+ )
427
+
428
+ output = src_flatten
429
+ for layer in self.layers:
430
+ layer_kwargs = {}
431
+
432
+ assert isinstance(layer, TransformerEncoderLayer)
433
+ layer_kwargs["memory"] = prompt
434
+ layer_kwargs["memory_key_padding_mask"] = prompt_key_padding_mask
435
+ layer_kwargs["query_pos"] = lvl_pos_embed_flatten
436
+ layer_kwargs["tgt"] = output
437
+ layer_kwargs["tgt_key_padding_mask"] = key_padding_masks_flatten
438
+
439
+ if self.training:
440
+ assert self.use_act_checkpoint, "activation ckpt not enabled in encoder"
441
+ if encoder_extra_kwargs is not None:
442
+ layer_kwargs.update(encoder_extra_kwargs)
443
+ output = activation_ckpt_wrapper(layer)(
444
+ **layer_kwargs,
445
+ act_ckpt_enable=self.training and self.use_act_checkpoint,
446
+ )
447
+ # return as seq first
448
+ return (
449
+ output.transpose(0, 1),
450
+ (
451
+ key_padding_masks_flatten.transpose(0, 1)
452
+ if key_padding_masks_flatten is not None
453
+ else None
454
+ ),
455
+ lvl_pos_embed_flatten.transpose(0, 1),
456
+ level_start_index,
457
+ spatial_shapes,
458
+ valid_ratios,
459
+ )
460
+
461
+
462
+ class TransformerEncoderFusion(TransformerEncoder):
463
+ """
464
+ Transformer encoder that fuses text and image features.
465
+
466
+ This encoder extends TransformerEncoder to handle both text and image features,
467
+ with the ability to add pooled text features to image features for better
468
+ cross-modal fusion. It supports torch.compile for performance optimization.
469
+
470
+ Args:
471
+ layer: The encoder layer to be stacked multiple times
472
+ num_layers: Number of encoder layers to stack
473
+ d_model: Model dimension/hidden size
474
+ num_feature_levels: Number of feature levels to process
475
+ add_pooled_text_to_img_feat: Whether to add pooled text features to image features
476
+ pool_text_with_mask: Whether to use the mask when pooling text features
477
+ compile_mode: Mode for torch.compile, or None to disable compilation
478
+ **kwargs: Additional arguments to pass to the parent class
479
+ """
480
+
481
+ def __init__(
482
+ self,
483
+ layer: nn.Module,
484
+ num_layers: int,
485
+ d_model: int,
486
+ num_feature_levels: int,
487
+ add_pooled_text_to_img_feat: bool = True,
488
+ pool_text_with_mask: bool = False,
489
+ compile_mode: Optional[str] = None,
490
+ **kwargs,
491
+ ):
492
+ super().__init__(
493
+ layer,
494
+ num_layers,
495
+ d_model,
496
+ num_feature_levels,
497
+ **kwargs,
498
+ )
499
+ self.add_pooled_text_to_img_feat = add_pooled_text_to_img_feat
500
+ if self.add_pooled_text_to_img_feat:
501
+ self.text_pooling_proj = nn.Linear(d_model, d_model)
502
+ self.pool_text_with_mask = pool_text_with_mask
503
+ if compile_mode is not None:
504
+ self.forward = torch.compile(
505
+ self.forward, mode=compile_mode, fullgraph=True
506
+ )
507
+
508
+ @staticmethod
509
+ def get_reference_points(spatial_shapes, valid_ratios, device):
510
+ # Not needed here
511
+ return None
512
+
513
+ def forward(
514
+ self,
515
+ src: List[Tensor],
516
+ prompt: Tensor,
517
+ src_key_padding_mask: Optional[List[Tensor]] = None,
518
+ src_pos: Optional[List[Tensor]] = None,
519
+ prompt_key_padding_mask: Optional[Tensor] = None,
520
+ prompt_pos: Optional[Tensor] = None,
521
+ feat_sizes: Optional[List[int]] = None,
522
+ encoder_extra_kwargs: Optional[Dict] = None,
523
+ ):
524
+ # Restore spatial shapes of vision
525
+ bs = src[0].shape[1] # seq first
526
+ if feat_sizes is not None:
527
+ assert len(feat_sizes) == len(src)
528
+ if src_key_padding_mask is None:
529
+ src_key_padding_mask = [None] * len(src)
530
+ for i, (h, w) in enumerate(feat_sizes):
531
+ src[i] = src[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
532
+ src_pos[i] = src_pos[i].reshape(h, w, bs, -1).permute(2, 3, 0, 1)
533
+ src_key_padding_mask[i] = (
534
+ src_key_padding_mask[i].reshape(h, w, bs).permute(2, 0, 1)
535
+ if src_key_padding_mask[i] is not None
536
+ else None
537
+ )
538
+ else:
539
+ assert all(
540
+ x.dim == 4 for x in src
541
+ ), "expected list of (bs, c, h, w) tensors"
542
+
543
+ if self.add_pooled_text_to_img_feat:
544
+ # Fusion: Add mean pooled text to image features
545
+ pooled_text = pool_text_feat(
546
+ prompt, prompt_key_padding_mask, self.pool_text_with_mask
547
+ )
548
+ pooled_text = self.text_pooling_proj(pooled_text)[
549
+ ..., None, None
550
+ ] # prompt is seq first
551
+ src = [x.add_(pooled_text) for x in src]
552
+
553
+ (
554
+ out,
555
+ key_padding_masks_flatten,
556
+ lvl_pos_embed_flatten,
557
+ level_start_index,
558
+ spatial_shapes,
559
+ valid_ratios,
560
+ ) = super().forward(
561
+ src,
562
+ src_key_padding_masks=src_key_padding_mask,
563
+ pos=src_pos,
564
+ prompt=prompt.transpose(0, 1),
565
+ prompt_key_padding_mask=prompt_key_padding_mask,
566
+ encoder_extra_kwargs=encoder_extra_kwargs,
567
+ )
568
+
569
+ return {
570
+ "memory": out,
571
+ "padding_mask": key_padding_masks_flatten,
572
+ "pos_embed": lvl_pos_embed_flatten,
573
+ "memory_text": prompt,
574
+ "level_start_index": level_start_index,
575
+ "spatial_shapes": spatial_shapes,
576
+ "valid_ratios": valid_ratios,
577
+ }
578
+
579
+
580
+ def pool_text_feat(prompt, prompt_mask, pool_with_mask):
581
+ # prompt has shape (seq, bs, dim)
582
+ if not pool_with_mask:
583
+ return prompt.mean(dim=0)
584
+
585
+ # prompt_mask has shape (bs, seq), where False is valid and True is padding
586
+ assert prompt_mask.dim() == 2
587
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
588
+ is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
589
+ # num_valid has shape (bs, 1)
590
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
591
+
592
+ # mean pool over all the valid tokens
593
+ pooled_text = (prompt * is_valid).sum(dim=0) / num_valid
594
+ return pooled_text
detect_tools/sam3/sam3/model/geometry_encoders.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision
8
+ from typing_extensions import override
9
+
10
+ from .act_ckpt_utils import activation_ckpt_wrapper
11
+ from .box_ops import box_cxcywh_to_xyxy
12
+
13
+ from .model_misc import get_clones
14
+
15
+
16
+ def is_right_padded(mask):
17
+ """Given a padding mask (following pytorch convention, 1s for padded values),
18
+ returns whether the padding is on the right or not."""
19
+ return (mask.long() == torch.sort(mask.long(), dim=-1)[0]).all()
20
+
21
+
22
+ def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False):
23
+ """
24
+ Concatenates two right-padded sequences, such that the resulting sequence
25
+ is contiguous and also right-padded.
26
+
27
+ Following pytorch's convention, tensors are sequence first, and the mask are
28
+ batch first, with 1s for padded values.
29
+
30
+ :param seq1: A tensor of shape (seq1_length, batch_size, hidden_size).
31
+ :param mask1: A tensor of shape (batch_size, seq1_length).
32
+ :param seq2: A tensor of shape (seq2_length, batch_size, hidden_size).
33
+ :param mask2: A tensor of shape (batch_size, seq2_length).
34
+ :param return_index: If True, also returns the index of the ids of the element of seq2
35
+ in the concatenated sequence. This can be used to retrieve the elements of seq2
36
+ :return: A tuple (concatenated_sequence, concatenated_mask) if return_index is False,
37
+ otherwise (concatenated_sequence, concatenated_mask, index).
38
+ """
39
+ seq1_length, batch_size, hidden_size = seq1.shape
40
+ seq2_length, batch_size, hidden_size = seq2.shape
41
+
42
+ assert batch_size == seq1.size(1) == seq2.size(1) == mask1.size(0) == mask2.size(0)
43
+ assert hidden_size == seq1.size(2) == seq2.size(2)
44
+ assert seq1_length == mask1.size(1)
45
+ assert seq2_length == mask2.size(1)
46
+
47
+ torch._assert_async(is_right_padded(mask1))
48
+ torch._assert_async(is_right_padded(mask2))
49
+
50
+ actual_seq1_lengths = (~mask1).sum(dim=-1)
51
+ actual_seq2_lengths = (~mask2).sum(dim=-1)
52
+
53
+ final_lengths = actual_seq1_lengths + actual_seq2_lengths
54
+ max_length = seq1_length + seq2_length
55
+ concatenated_mask = (
56
+ torch.arange(max_length, device=seq2.device)[None].repeat(batch_size, 1)
57
+ >= final_lengths[:, None]
58
+ )
59
+
60
+ # (max_len, batch_size, hidden_size)
61
+ concatenated_sequence = torch.zeros(
62
+ (max_length, batch_size, hidden_size), device=seq2.device, dtype=seq2.dtype
63
+ )
64
+ concatenated_sequence[:seq1_length, :, :] = seq1
65
+
66
+ # At this point, the element of seq1 are in the right place
67
+ # We just need to shift the elements of seq2
68
+
69
+ index = torch.arange(seq2_length, device=seq2.device)[:, None].repeat(1, batch_size)
70
+ index = index + actual_seq1_lengths[None]
71
+
72
+ concatenated_sequence = concatenated_sequence.scatter(
73
+ 0, index[:, :, None].expand(-1, -1, hidden_size), seq2
74
+ )
75
+
76
+ if return_index:
77
+ return concatenated_sequence, concatenated_mask, index
78
+
79
+ return concatenated_sequence, concatenated_mask
80
+
81
+
82
+ class Prompt:
83
+ """Utility class to manipulate geometric prompts.
84
+
85
+ We expect the sequences in pytorch convention, that is sequence first, batch second
86
+ The dimensions are expected as follows:
87
+ box_embeddings shape: N_boxes x B x C_box
88
+ box_mask shape: B x N_boxes. Can be None if nothing is masked out
89
+ point_embeddings shape: N_points x B x C_point
90
+ point_mask shape: B x N_points. Can be None if nothing is masked out
91
+ mask_embeddings shape: N_masks x B x 1 x H_mask x W_mask
92
+ mask_mask shape: B x N_masks. Can be None if nothing is masked out
93
+
94
+ We also store positive/negative labels. These tensors are also stored batch-first
95
+ If they are None, we'll assume positive labels everywhere
96
+ box_labels: long tensor of shape N_boxes x B
97
+ point_labels: long tensor of shape N_points x B
98
+ mask_labels: long tensor of shape N_masks x B
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ box_embeddings=None,
104
+ box_mask=None,
105
+ point_embeddings=None,
106
+ point_mask=None,
107
+ box_labels=None,
108
+ point_labels=None,
109
+ mask_embeddings=None,
110
+ mask_mask=None, # Attention mask for mask prompt
111
+ mask_labels=None,
112
+ ):
113
+ # Check for null prompt
114
+ if (
115
+ box_embeddings is None
116
+ and point_embeddings is None
117
+ and mask_embeddings is None
118
+ ):
119
+ self.box_embeddings = None
120
+ self.box_labels = None
121
+ self.box_mask = None
122
+ self.point_embeddings = None
123
+ self.point_labels = None
124
+ self.point_mask = None
125
+ self.mask_embeddings = None
126
+ self.mask_mask = None
127
+ # Masks are assumed positive only for now.
128
+ self.mask_labels = None
129
+ return
130
+ # Get sequence lengths and device
131
+ box_seq_len, point_seq_len, mask_seq_len, bs, device = (
132
+ self._init_seq_len_and_device(
133
+ box_embeddings, point_embeddings, mask_embeddings
134
+ )
135
+ )
136
+
137
+ # Initialize embeds, labels, attention masks.
138
+ box_embeddings, box_labels, box_mask = self._init_box(
139
+ box_embeddings, box_labels, box_mask, box_seq_len, bs, device
140
+ )
141
+ point_embeddings, point_labels, point_mask = self._init_point(
142
+ point_embeddings, point_labels, point_mask, point_seq_len, bs, device
143
+ )
144
+ mask_embeddings, mask_labels, mask_mask = self._init_mask(
145
+ mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
146
+ )
147
+
148
+ # Dimension checks
149
+ assert (
150
+ box_embeddings is not None
151
+ and list(box_embeddings.shape[:2])
152
+ == [
153
+ box_seq_len,
154
+ bs,
155
+ ]
156
+ ), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
157
+ assert (
158
+ box_mask is not None
159
+ and list(box_mask.shape)
160
+ == [
161
+ bs,
162
+ box_seq_len,
163
+ ]
164
+ ), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
165
+ assert (
166
+ point_embeddings is not None
167
+ and list(point_embeddings.shape[:2])
168
+ == [
169
+ point_seq_len,
170
+ bs,
171
+ ]
172
+ ), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
173
+ assert (
174
+ point_mask is not None
175
+ and list(point_mask.shape)
176
+ == [
177
+ bs,
178
+ point_seq_len,
179
+ ]
180
+ ), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
181
+ assert (
182
+ box_labels is not None
183
+ and list(box_labels.shape)
184
+ == [
185
+ box_seq_len,
186
+ bs,
187
+ ]
188
+ ), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
189
+ assert (
190
+ point_labels is not None
191
+ and list(point_labels.shape)
192
+ == [
193
+ point_seq_len,
194
+ bs,
195
+ ]
196
+ ), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
197
+ assert (
198
+ # Allowed to be None, we leave it to the encoder to check for validity before encoding.
199
+ mask_embeddings is None
200
+ or list(mask_embeddings.shape[:2])
201
+ == [
202
+ mask_seq_len,
203
+ bs,
204
+ ]
205
+ ), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
206
+ assert (
207
+ mask_mask is None
208
+ or list(mask_mask.shape)
209
+ == [
210
+ bs,
211
+ mask_seq_len,
212
+ ]
213
+ ), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
214
+
215
+ # Device checks
216
+ assert (
217
+ box_embeddings is not None and box_embeddings.device == device
218
+ ), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
219
+ assert (
220
+ box_mask is not None and box_mask.device == device
221
+ ), f"Expected box mask to be on device {device}, got {box_mask.device}"
222
+ assert (
223
+ box_labels is not None and box_labels.device == device
224
+ ), f"Expected box labels to be on device {device}, got {box_labels.device}"
225
+ assert (
226
+ point_embeddings is not None and point_embeddings.device == device
227
+ ), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
228
+ assert (
229
+ point_mask is not None and point_mask.device == device
230
+ ), f"Expected point mask to be on device {device}, got {point_mask.device}"
231
+ assert (
232
+ point_labels is not None and point_labels.device == device
233
+ ), f"Expected point labels to be on device {device}, got {point_labels.device}"
234
+ assert (
235
+ mask_embeddings is None or mask_embeddings.device == device
236
+ ), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
237
+ assert (
238
+ mask_mask is None or mask_mask.device == device
239
+ ), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
240
+
241
+ self.box_embeddings = box_embeddings
242
+ self.point_embeddings = point_embeddings
243
+ self.box_mask = box_mask
244
+ self.point_mask = point_mask
245
+ self.box_labels = box_labels
246
+ self.point_labels = point_labels
247
+ self.mask_embeddings = mask_embeddings
248
+ self.mask_labels = mask_labels
249
+ self.mask_mask = mask_mask
250
+
251
+ def _init_seq_len_and_device(
252
+ self, box_embeddings, point_embeddings, mask_embeddings
253
+ ):
254
+ box_seq_len = point_seq_len = mask_seq_len = 0
255
+ bs = None
256
+ device = None
257
+ if box_embeddings is not None:
258
+ bs = box_embeddings.shape[1]
259
+ box_seq_len = box_embeddings.shape[0]
260
+ device = box_embeddings.device
261
+
262
+ if point_embeddings is not None:
263
+ point_seq_len = point_embeddings.shape[0]
264
+ if bs is not None:
265
+ assert (
266
+ bs == point_embeddings.shape[1]
267
+ ), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
268
+ else:
269
+ bs = point_embeddings.shape[1]
270
+ if device is not None:
271
+ assert (
272
+ device == point_embeddings.device
273
+ ), "Device mismatch between box and point embeddings"
274
+ else:
275
+ device = point_embeddings.device
276
+
277
+ if mask_embeddings is not None:
278
+ mask_seq_len = mask_embeddings.shape[0]
279
+ if bs is not None:
280
+ assert (
281
+ bs == mask_embeddings.shape[1]
282
+ ), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
283
+ else:
284
+ bs = mask_embeddings.shape[1]
285
+ if device is not None:
286
+ assert (
287
+ device == mask_embeddings.device
288
+ ), "Device mismatch between box/point and mask embeddings."
289
+ else:
290
+ device = mask_embeddings.device
291
+
292
+ return box_seq_len, point_seq_len, mask_seq_len, bs, device
293
+
294
+ def _init_box(self, box_embeddings, box_labels, box_mask, box_seq_len, bs, device):
295
+ if box_embeddings is None:
296
+ box_embeddings = torch.zeros(box_seq_len, bs, 4, device=device)
297
+ if box_labels is None:
298
+ box_labels = torch.ones(box_seq_len, bs, device=device, dtype=torch.long)
299
+ if box_mask is None:
300
+ box_mask = torch.zeros(bs, box_seq_len, device=device, dtype=torch.bool)
301
+ return box_embeddings, box_labels, box_mask
302
+
303
+ def _init_point(
304
+ self, point_embeddings, point_labels, point_mask, point_seq_len, bs, device
305
+ ):
306
+ """
307
+ Identical to _init_box. Except that C=2 for points (vs. 4 for boxes).
308
+ """
309
+ if point_embeddings is None:
310
+ point_embeddings = torch.zeros(point_seq_len, bs, 2, device=device)
311
+ if point_labels is None:
312
+ point_labels = torch.ones(
313
+ point_seq_len, bs, device=device, dtype=torch.long
314
+ )
315
+ if point_mask is None:
316
+ point_mask = torch.zeros(bs, point_seq_len, device=device, dtype=torch.bool)
317
+ return point_embeddings, point_labels, point_mask
318
+
319
+ def _init_mask(
320
+ self, mask_embeddings, mask_labels, mask_mask, mask_seq_len, bs, device
321
+ ):
322
+ # NOTE: Mask embeddings can be of arbitrary resolution, so we don't initialize it here.
323
+ # In case we append new mask, we check that its resolution matches exisiting ones (if any).
324
+ # In case mask_embeddings is None, we should never encode it.
325
+ if mask_labels is None:
326
+ mask_labels = torch.ones(mask_seq_len, bs, device=device, dtype=torch.long)
327
+ if mask_mask is None:
328
+ mask_mask = torch.zeros(bs, mask_seq_len, device=device, dtype=torch.bool)
329
+ return mask_embeddings, mask_labels, mask_mask
330
+
331
+ def append_boxes(self, boxes, labels, mask=None):
332
+ if self.box_embeddings is None:
333
+ self.box_embeddings = boxes
334
+ self.box_labels = labels
335
+ self.box_mask = mask
336
+ return
337
+
338
+ bs = self.box_embeddings.shape[1]
339
+ assert boxes.shape[1] == labels.shape[1] == bs
340
+ assert list(boxes.shape[:2]) == list(labels.shape[:2])
341
+ if mask is None:
342
+ mask = torch.zeros(
343
+ bs, boxes.shape[0], dtype=torch.bool, device=boxes.device
344
+ )
345
+
346
+ self.box_labels, _ = concat_padded_sequences(
347
+ self.box_labels.unsqueeze(-1), self.box_mask, labels.unsqueeze(-1), mask
348
+ )
349
+ self.box_labels = self.box_labels.squeeze(-1)
350
+ self.box_embeddings, self.box_mask = concat_padded_sequences(
351
+ self.box_embeddings, self.box_mask, boxes, mask
352
+ )
353
+
354
+ def append_points(self, points, labels, mask=None):
355
+ if self.point_embeddings is None:
356
+ self.point_embeddings = points
357
+ self.point_labels = labels
358
+ self.point_mask = mask
359
+ return
360
+
361
+ bs = self.point_embeddings.shape[1]
362
+ assert points.shape[1] == labels.shape[1] == bs
363
+ assert list(points.shape[:2]) == list(labels.shape[:2])
364
+ if mask is None:
365
+ mask = torch.zeros(
366
+ bs, points.shape[0], dtype=torch.bool, device=points.device
367
+ )
368
+
369
+ self.point_labels, _ = concat_padded_sequences(
370
+ self.point_labels.unsqueeze(-1), self.point_mask, labels.unsqueeze(-1), mask
371
+ )
372
+ self.point_labels = self.point_labels.squeeze(-1)
373
+ self.point_embeddings, self.point_mask = concat_padded_sequences(
374
+ self.point_embeddings, self.point_mask, points, mask
375
+ )
376
+
377
+ def append_masks(self, masks, labels=None, attn_mask=None):
378
+ if labels is not None:
379
+ assert list(masks.shape[:2]) == list(labels.shape[:2])
380
+ if self.mask_embeddings is None:
381
+ self.mask_embeddings = masks
382
+ mask_seq_len, bs = masks.shape[:2]
383
+ if labels is None:
384
+ self.mask_labels = torch.ones(
385
+ mask_seq_len, bs, device=masks.device, dtype=torch.long
386
+ )
387
+ else:
388
+ self.mask_labels = labels
389
+ if attn_mask is None:
390
+ self.mask_mask = torch.zeros(
391
+ bs, mask_seq_len, device=masks.device, dtype=torch.bool
392
+ )
393
+ else:
394
+ self.mask_mask = attn_mask
395
+ else:
396
+ raise NotImplementedError("Only one mask per prompt is supported.")
397
+
398
+ def clone(self):
399
+ return Prompt(
400
+ box_embeddings=(
401
+ None if self.box_embeddings is None else self.box_embeddings.clone()
402
+ ),
403
+ box_mask=None if self.box_mask is None else self.box_mask.clone(),
404
+ point_embeddings=(
405
+ None if self.point_embeddings is None else self.point_embeddings.clone()
406
+ ),
407
+ point_mask=None if self.point_mask is None else self.point_mask.clone(),
408
+ box_labels=None if self.box_labels is None else self.box_labels.clone(),
409
+ point_labels=(
410
+ None if self.point_labels is None else self.point_labels.clone()
411
+ ),
412
+ )
413
+
414
+
415
+ class MaskEncoder(nn.Module):
416
+ """
417
+ Base class for mask encoders.
418
+ """
419
+
420
+ def __init__(
421
+ self,
422
+ mask_downsampler: nn.Module,
423
+ position_encoding: nn.Module,
424
+ ):
425
+ super().__init__()
426
+ self.mask_downsampler = mask_downsampler
427
+ self.position_encoding = position_encoding
428
+
429
+ def forward(self, masks, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
430
+ masks = self.mask_downsampler(masks)
431
+ masks_pos = self.position_encoding(masks).to(masks.dtype)
432
+
433
+ return masks, masks_pos
434
+
435
+
436
+ class FusedMaskEncoder(MaskEncoder):
437
+ """
438
+ Identical to memory.SimpleMaskEncoder but follows the interface of geometry_encoders.MaskEncoder.
439
+ We also remove the `skip_mask_sigmoid` option (to be handled outside the MaskEncoder).
440
+ Fuses backbone image features with mask features.
441
+ """
442
+
443
+ def __init__(
444
+ self,
445
+ mask_downsampler: nn.Module,
446
+ position_encoding: nn.Module,
447
+ fuser: nn.Module,
448
+ in_dim: int = 256,
449
+ out_dim: int = 256,
450
+ ):
451
+ super().__init__(mask_downsampler, position_encoding)
452
+ self.fuser = fuser
453
+ self.out_proj = nn.Identity()
454
+ if out_dim != in_dim:
455
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
456
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
457
+
458
+ @override
459
+ def forward(
460
+ self,
461
+ masks: torch.Tensor,
462
+ pix_feat: torch.Tensor,
463
+ **kwargs,
464
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
465
+ masks = self.mask_downsampler(masks)
466
+
467
+ ## Fuse pix_feats and downsampled masks
468
+ # in case the visual features are on CPU, cast them to CUDA
469
+ pix_feat = pix_feat.to(masks.device)
470
+
471
+ x = self.pix_feat_proj(pix_feat)
472
+ x = x + masks
473
+ x = self.fuser(x)
474
+ x = self.out_proj(x)
475
+
476
+ pos = self.position_encoding(x).to(x.dtype)
477
+
478
+ return x, pos
479
+
480
+
481
+ class SequenceGeometryEncoder(nn.Module):
482
+ """
483
+ This a fully fledged encoder for geometric prompts.
484
+ It assumes boxes are passed in the "normalized CxCyWH" format, and points in normalized xy
485
+ This allows flexibility in how to encode the features (eg do pooling)
486
+
487
+ Points and boxes can be encoded with any of the three possibilities:
488
+ - direct projection: we just compute a linear from coordinate space to d_model
489
+ - pooling: pool features from the backbone in the requested location.
490
+ For boxes, it's a roi align
491
+ For points it's a grid sample
492
+ - pos encoder: Take the position encoding of the point or box center
493
+
494
+ These three options are mutually compatible. If several are selected, we'll take a simple addition
495
+
496
+ As an alternative, we offer the possibility to encode points only.
497
+ In that case, the boxes are converted to two points for the top left and bottom right corners (with appropriate labels)
498
+
499
+ On top of these encodings, we offer the possibility to further encode the prompt sequence with a transformer.
500
+ """
501
+
502
+ def __init__(
503
+ self,
504
+ encode_boxes_as_points: bool,
505
+ points_direct_project: bool,
506
+ points_pool: bool,
507
+ points_pos_enc: bool,
508
+ boxes_direct_project: bool,
509
+ boxes_pool: bool,
510
+ boxes_pos_enc: bool,
511
+ d_model: int,
512
+ pos_enc,
513
+ num_layers: int,
514
+ layer: nn.Module,
515
+ roi_size: int = 7, # for boxes pool
516
+ add_cls: bool = True,
517
+ add_post_encode_proj: bool = True,
518
+ mask_encoder: MaskEncoder = None,
519
+ add_mask_label: bool = False,
520
+ use_act_ckpt: bool = False,
521
+ ):
522
+ super().__init__()
523
+
524
+ self.d_model = d_model
525
+ self.pos_enc = pos_enc
526
+ self.encode_boxes_as_points = encode_boxes_as_points
527
+ self.roi_size = roi_size
528
+ # There usually are two labels: positive and negatives.
529
+ # If we encode boxes as points, we have 3 types of points: regular, top left, bottom right
530
+ # These 3 types can be positives or negatives, hence 2*3 = 6 labels
531
+ num_labels = 6 if self.encode_boxes_as_points else 2
532
+ self.label_embed = torch.nn.Embedding(num_labels, self.d_model)
533
+
534
+ # This is a cls token, can be used for pooling if need be.
535
+ # It also ensures that the encoded sequences are always non-empty
536
+ self.cls_embed = None
537
+ if add_cls:
538
+ self.cls_embed = torch.nn.Embedding(1, self.d_model)
539
+
540
+ assert (
541
+ points_direct_project or points_pos_enc or points_pool
542
+ ), "Error: need at least one way to encode points"
543
+ assert (
544
+ encode_boxes_as_points
545
+ or boxes_direct_project
546
+ or boxes_pos_enc
547
+ or boxes_pool
548
+ ), "Error: need at least one way to encode boxes"
549
+
550
+ self.points_direct_project = None
551
+ if points_direct_project:
552
+ self.points_direct_project = nn.Linear(2, self.d_model)
553
+ self.points_pool_project = None
554
+ if points_pool:
555
+ self.points_pool_project = nn.Linear(self.d_model, self.d_model)
556
+ self.points_pos_enc_project = None
557
+ if points_pos_enc:
558
+ self.points_pos_enc_project = nn.Linear(self.d_model, self.d_model)
559
+
560
+ self.boxes_direct_project = None
561
+ self.boxes_pool_project = None
562
+ self.boxes_pos_enc_project = None
563
+ if not encode_boxes_as_points:
564
+ if boxes_direct_project:
565
+ self.boxes_direct_project = nn.Linear(4, self.d_model)
566
+ if boxes_pool:
567
+ self.boxes_pool_project = nn.Conv2d(
568
+ self.d_model, self.d_model, self.roi_size
569
+ )
570
+ if boxes_pos_enc:
571
+ self.boxes_pos_enc_project = nn.Linear(self.d_model + 2, self.d_model)
572
+
573
+ self.final_proj = None
574
+ if add_post_encode_proj:
575
+ self.final_proj = nn.Linear(self.d_model, self.d_model)
576
+ self.norm = nn.LayerNorm(self.d_model)
577
+
578
+ self.img_pre_norm = nn.Identity()
579
+ if self.points_pool_project is not None or self.boxes_pool_project is not None:
580
+ self.img_pre_norm = nn.LayerNorm(self.d_model)
581
+
582
+ self.encode = None
583
+ if num_layers > 0:
584
+ assert (
585
+ add_cls
586
+ ), "It's currently highly recommended to add a CLS when using a transformer"
587
+ self.encode = get_clones(layer, num_layers)
588
+ self.encode_norm = nn.LayerNorm(self.d_model)
589
+
590
+ if mask_encoder is not None:
591
+ assert isinstance(
592
+ mask_encoder, MaskEncoder
593
+ ), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
594
+ if add_mask_label:
595
+ self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
596
+ self.add_mask_label = add_mask_label
597
+ self.mask_encoder = mask_encoder
598
+ self.use_act_ckpt = use_act_ckpt
599
+
600
+ def _encode_points(self, points, points_mask, points_labels, img_feats):
601
+ points_embed = None
602
+ n_points, bs = points.shape[:2]
603
+
604
+ if self.points_direct_project is not None:
605
+ proj = self.points_direct_project(points)
606
+ assert points_embed is None
607
+ points_embed = proj
608
+
609
+ if self.points_pool_project is not None:
610
+ # points are [Num_points, bs, 2], normalized in [0, 1]
611
+ # the grid needs to be [Bs, H_out, W_out, 2] normalized in [-1,1]
612
+ # Will take H_out = num_points, w_out = 1
613
+ grid = points.transpose(0, 1).unsqueeze(2)
614
+ # re normalize to [-1, 1]
615
+ grid = (grid * 2) - 1
616
+ sampled = torch.nn.functional.grid_sample(
617
+ img_feats, grid, align_corners=False
618
+ )
619
+ assert list(sampled.shape) == [bs, self.d_model, n_points, 1]
620
+ sampled = sampled.squeeze(-1).permute(2, 0, 1)
621
+ proj = self.points_pool_project(sampled)
622
+ if points_embed is None:
623
+ points_embed = proj
624
+ else:
625
+ points_embed = points_embed + proj
626
+
627
+ if self.points_pos_enc_project is not None:
628
+ x, y = points.unbind(-1)
629
+ enc_x, enc_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
630
+ enc_x = enc_x.view(n_points, bs, enc_x.shape[-1])
631
+ enc_y = enc_y.view(n_points, bs, enc_y.shape[-1])
632
+ enc = torch.cat([enc_x, enc_y], -1)
633
+
634
+ proj = self.points_pos_enc_project(enc)
635
+ if points_embed is None:
636
+ points_embed = proj
637
+ else:
638
+ points_embed = points_embed + proj
639
+
640
+ type_embed = self.label_embed(points_labels.long())
641
+ return type_embed + points_embed, points_mask
642
+
643
+ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
644
+ boxes_embed = None
645
+ n_boxes, bs = boxes.shape[:2]
646
+
647
+ if self.boxes_direct_project is not None:
648
+ proj = self.boxes_direct_project(boxes)
649
+ assert boxes_embed is None
650
+ boxes_embed = proj
651
+
652
+ if self.boxes_pool_project is not None:
653
+ H, W = img_feats.shape[-2:]
654
+
655
+ # boxes are [Num_boxes, bs, 4], normalized in [0, 1]
656
+ # We need to denormalize, and convert to [x, y, x, y]
657
+ boxes_xyxy = box_cxcywh_to_xyxy(boxes)
658
+ scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
659
+ scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
660
+ scale = scale.view(1, 1, 4)
661
+ boxes_xyxy = boxes_xyxy * scale
662
+ sampled = torchvision.ops.roi_align(
663
+ img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
664
+ )
665
+ assert list(sampled.shape) == [
666
+ bs * n_boxes,
667
+ self.d_model,
668
+ self.roi_size,
669
+ self.roi_size,
670
+ ]
671
+ proj = self.boxes_pool_project(sampled)
672
+ proj = proj.view(bs, n_boxes, self.d_model).transpose(0, 1)
673
+ if boxes_embed is None:
674
+ boxes_embed = proj
675
+ else:
676
+ boxes_embed = boxes_embed + proj
677
+
678
+ if self.boxes_pos_enc_project is not None:
679
+ cx, cy, w, h = boxes.unbind(-1)
680
+ enc = self.pos_enc.encode_boxes(
681
+ cx.flatten(), cy.flatten(), w.flatten(), h.flatten()
682
+ )
683
+ enc = enc.view(boxes.shape[0], boxes.shape[1], enc.shape[-1])
684
+
685
+ proj = self.boxes_pos_enc_project(enc)
686
+ if boxes_embed is None:
687
+ boxes_embed = proj
688
+ else:
689
+ boxes_embed = boxes_embed + proj
690
+
691
+ type_embed = self.label_embed(boxes_labels.long())
692
+ return type_embed + boxes_embed, boxes_mask
693
+
694
+ def _encode_masks(
695
+ self,
696
+ masks: torch.Tensor,
697
+ attn_mask: torch.Tensor,
698
+ mask_labels: torch.Tensor,
699
+ img_feats: torch.Tensor = None,
700
+ ):
701
+ n_masks, bs = masks.shape[:2]
702
+ assert (
703
+ n_masks == 1
704
+ ), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
705
+ assert (
706
+ list(attn_mask.shape)
707
+ == [
708
+ bs,
709
+ n_masks,
710
+ ]
711
+ ), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
712
+ masks, pos = self.mask_encoder(
713
+ masks=masks.flatten(0, 1).float(),
714
+ pix_feat=img_feats,
715
+ )
716
+ H, W = masks.shape[-2:]
717
+ n_tokens_per_mask = H * W
718
+ # NOTE: We directly add pos enc here as we usually don't keep track of pos encoding for the concatenated prompt (text, other geometric prompts). Might need to do some refactoring for more flexibility.
719
+ masks = masks + pos
720
+ masks = masks.view(n_masks, bs, *masks.shape[1:]).flatten(
721
+ -2
722
+ ) # n_masks x bs x C x H*W
723
+ masks = masks.permute(0, 3, 1, 2).flatten(0, 1) # n_masks * H*W x bs x C
724
+ attn_mask = attn_mask.repeat_interleave(n_tokens_per_mask, dim=1)
725
+ if self.add_mask_label:
726
+ masks = masks + self.mask_label_embed(mask_labels.long())
727
+ return masks, attn_mask
728
+
729
+ def forward(self, geo_prompt: Prompt, img_feats, img_sizes, img_pos_embeds=None):
730
+ points = geo_prompt.point_embeddings
731
+ points_mask = geo_prompt.point_mask
732
+ points_labels = geo_prompt.point_labels
733
+ boxes = geo_prompt.box_embeddings
734
+ boxes_mask = geo_prompt.box_mask
735
+ boxes_labels = geo_prompt.box_labels
736
+ masks = geo_prompt.mask_embeddings
737
+ masks_mask = geo_prompt.mask_mask
738
+ masks_labels = geo_prompt.mask_labels
739
+ seq_first_img_feats = img_feats[-1] # [H*W, B, C]
740
+ seq_first_img_pos_embeds = (
741
+ img_pos_embeds[-1]
742
+ if img_pos_embeds is not None
743
+ else torch.zeros_like(seq_first_img_feats)
744
+ )
745
+
746
+ if self.points_pool_project or self.boxes_pool_project:
747
+ assert len(img_feats) == len(img_sizes)
748
+ cur_img_feat = img_feats[-1]
749
+ cur_img_feat = self.img_pre_norm(cur_img_feat)
750
+ H, W = img_sizes[-1]
751
+ assert cur_img_feat.shape[0] == H * W
752
+ N, C = cur_img_feat.shape[-2:]
753
+ # Put back in NxCxHxW
754
+ cur_img_feat = cur_img_feat.permute(1, 2, 0)
755
+ cur_img_feat = cur_img_feat.view(N, C, H, W)
756
+ img_feats = cur_img_feat
757
+
758
+ if self.encode_boxes_as_points:
759
+ assert boxes is not None
760
+ assert geo_prompt.box_mask is not None
761
+ assert geo_prompt.box_labels is not None
762
+ assert boxes.shape[-1] == 4
763
+
764
+ boxes_xyxy = box_cxcywh_to_xyxy(boxes)
765
+ top_left, bottom_right = boxes_xyxy.split(split_size=2, dim=-1)
766
+
767
+ labels_tl = geo_prompt.box_labels + 2
768
+ labels_br = geo_prompt.box_labels + 4
769
+
770
+ # Append to the existing points
771
+ points, _ = concat_padded_sequences(
772
+ points, points_mask, top_left, boxes_mask
773
+ )
774
+ points_labels, points_mask = concat_padded_sequences(
775
+ points_labels.unsqueeze(-1),
776
+ points_mask,
777
+ labels_tl.unsqueeze(-1),
778
+ boxes_mask,
779
+ )
780
+ points_labels = points_labels.squeeze(-1)
781
+
782
+ points, _ = concat_padded_sequences(
783
+ points, points_mask, bottom_right, boxes_mask
784
+ )
785
+ points_labels, points_mask = concat_padded_sequences(
786
+ points_labels.unsqueeze(-1),
787
+ points_mask,
788
+ labels_br.unsqueeze(-1),
789
+ boxes_mask,
790
+ )
791
+ points_labels = points_labels.squeeze(-1)
792
+
793
+ final_embeds, final_mask = self._encode_points(
794
+ points=points,
795
+ points_mask=points_mask,
796
+ points_labels=points_labels,
797
+ img_feats=img_feats,
798
+ )
799
+
800
+ if not self.encode_boxes_as_points:
801
+ boxes_embeds, boxes_mask = self._encode_boxes(
802
+ boxes=boxes,
803
+ boxes_mask=boxes_mask,
804
+ boxes_labels=boxes_labels,
805
+ img_feats=img_feats,
806
+ )
807
+
808
+ final_embeds, final_mask = concat_padded_sequences(
809
+ final_embeds, final_mask, boxes_embeds, boxes_mask
810
+ )
811
+
812
+ if masks is not None and self.mask_encoder is not None:
813
+ masks_embed, masks_mask = self._encode_masks(
814
+ masks=masks,
815
+ attn_mask=masks_mask,
816
+ mask_labels=masks_labels,
817
+ img_feats=img_feats,
818
+ )
819
+ if points.size(0) == boxes.size(0) == 0:
820
+ return masks_embed, masks_mask
821
+ bs = final_embeds.shape[1]
822
+ assert final_mask.shape[0] == bs
823
+ if self.cls_embed is not None:
824
+ cls = self.cls_embed.weight.view(1, 1, self.d_model).repeat(1, bs, 1)
825
+ cls_mask = torch.zeros(
826
+ bs, 1, dtype=final_mask.dtype, device=final_mask.device
827
+ )
828
+ final_embeds, final_mask = concat_padded_sequences(
829
+ final_embeds, final_mask, cls, cls_mask
830
+ )
831
+
832
+ if self.final_proj is not None:
833
+ final_embeds = self.norm(self.final_proj(final_embeds))
834
+
835
+ if self.encode is not None:
836
+ for lay in self.encode:
837
+ final_embeds = activation_ckpt_wrapper(lay)(
838
+ tgt=final_embeds,
839
+ memory=seq_first_img_feats,
840
+ tgt_key_padding_mask=final_mask,
841
+ pos=seq_first_img_pos_embeds,
842
+ act_ckpt_enable=self.training and self.use_act_ckpt,
843
+ )
844
+ final_embeds = self.encode_norm(final_embeds)
845
+ # Finally, concat mask embeddings if any
846
+ if masks is not None and self.mask_encoder is not None:
847
+ final_embeds, final_mask = concat_padded_sequences(
848
+ final_embeds, final_mask, masks_embed, masks_mask
849
+ )
850
+ return final_embeds, final_mask
detect_tools/sam3/sam3/model/io_utils.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import contextlib
4
+ import os
5
+ import queue
6
+ import re
7
+ import time
8
+ from threading import Condition, get_ident, Lock, Thread
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torchvision.transforms.functional as TF
14
+
15
+ from PIL import Image
16
+
17
+ from sam3.logger import get_logger
18
+ from tqdm import tqdm
19
+
20
+ logger = get_logger(__name__)
21
+
22
+ IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1"
23
+ RANK = int(os.getenv("RANK", "0"))
24
+
25
+ IMAGE_EXTS = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]
26
+ VIDEO_EXTS = [".mp4", ".mov", ".avi", ".mkv", ".webm"]
27
+
28
+
29
+ def load_resource_as_video_frames(
30
+ resource_path,
31
+ image_size,
32
+ offload_video_to_cpu,
33
+ img_mean=(0.5, 0.5, 0.5),
34
+ img_std=(0.5, 0.5, 0.5),
35
+ async_loading_frames=False,
36
+ video_loader_type="cv2",
37
+ ):
38
+ """
39
+ Load video frames from either a video or an image (as a single-frame video).
40
+ Alternatively, if input is a list of PIL images, convert its format
41
+ """
42
+ if isinstance(resource_path, list):
43
+ img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
44
+ img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
45
+ assert all(isinstance(img_pil, Image.Image) for img_pil in resource_path)
46
+ assert len(resource_path) is not None
47
+ orig_height, orig_width = resource_path[0].size
48
+ orig_height, orig_width = (
49
+ orig_width,
50
+ orig_height,
51
+ ) # For some reason, this method returns these swapped
52
+ images = []
53
+ for img_pil in resource_path:
54
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
55
+ assert img_np.dtype == np.uint8, "np.uint8 is expected for JPEG images"
56
+ img_np = img_np / 255.0
57
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
58
+ # float16 precision should be sufficient for image tensor storage
59
+ img = img.to(dtype=torch.float16)
60
+ # normalize by mean and std
61
+ img -= img_mean
62
+ img /= img_std
63
+ images.append(img)
64
+ images = torch.stack(images)
65
+ if not offload_video_to_cpu:
66
+ images = images.cuda()
67
+ return images, orig_height, orig_width
68
+
69
+ is_image = (
70
+ isinstance(resource_path, str)
71
+ and os.path.splitext(resource_path)[-1].lower() in IMAGE_EXTS
72
+ )
73
+ if is_image:
74
+ return load_image_as_single_frame_video(
75
+ image_path=resource_path,
76
+ image_size=image_size,
77
+ offload_video_to_cpu=offload_video_to_cpu,
78
+ img_mean=img_mean,
79
+ img_std=img_std,
80
+ )
81
+ else:
82
+ return load_video_frames(
83
+ video_path=resource_path,
84
+ image_size=image_size,
85
+ offload_video_to_cpu=offload_video_to_cpu,
86
+ img_mean=img_mean,
87
+ img_std=img_std,
88
+ async_loading_frames=async_loading_frames,
89
+ video_loader_type=video_loader_type,
90
+ )
91
+
92
+
93
+ def load_image_as_single_frame_video(
94
+ image_path,
95
+ image_size,
96
+ offload_video_to_cpu,
97
+ img_mean=(0.5, 0.5, 0.5),
98
+ img_std=(0.5, 0.5, 0.5),
99
+ ):
100
+ """Load an image as a single-frame video."""
101
+ images, image_height, image_width = _load_img_as_tensor(image_path, image_size)
102
+ images = images.unsqueeze(0).half()
103
+
104
+ img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
105
+ img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
106
+ if not offload_video_to_cpu:
107
+ images = images.cuda()
108
+ img_mean = img_mean.cuda()
109
+ img_std = img_std.cuda()
110
+ # normalize by mean and std
111
+ images -= img_mean
112
+ images /= img_std
113
+ return images, image_height, image_width
114
+
115
+
116
+ def load_video_frames(
117
+ video_path,
118
+ image_size,
119
+ offload_video_to_cpu,
120
+ img_mean=(0.5, 0.5, 0.5),
121
+ img_std=(0.5, 0.5, 0.5),
122
+ async_loading_frames=False,
123
+ video_loader_type="cv2",
124
+ ):
125
+ """
126
+ Load the video frames from video_path. The frames are resized to image_size as in
127
+ the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
128
+ """
129
+ assert isinstance(video_path, str)
130
+ if video_path.startswith("<load-dummy-video"):
131
+ # Check for pattern <load-dummy-video-N> where N is an integer
132
+ match = re.match(r"<load-dummy-video-(\d+)>", video_path)
133
+ num_frames = int(match.group(1)) if match else 60
134
+ return load_dummy_video(image_size, offload_video_to_cpu, num_frames=num_frames)
135
+ elif os.path.isdir(video_path):
136
+ return load_video_frames_from_image_folder(
137
+ image_folder=video_path,
138
+ image_size=image_size,
139
+ offload_video_to_cpu=offload_video_to_cpu,
140
+ img_mean=img_mean,
141
+ img_std=img_std,
142
+ async_loading_frames=async_loading_frames,
143
+ )
144
+ elif os.path.splitext(video_path)[-1].lower() in VIDEO_EXTS:
145
+ return load_video_frames_from_video_file(
146
+ video_path=video_path,
147
+ image_size=image_size,
148
+ offload_video_to_cpu=offload_video_to_cpu,
149
+ img_mean=img_mean,
150
+ img_std=img_std,
151
+ async_loading_frames=async_loading_frames,
152
+ video_loader_type=video_loader_type,
153
+ )
154
+ else:
155
+ raise NotImplementedError("Only video files and image folders are supported")
156
+
157
+
158
+ def load_video_frames_from_image_folder(
159
+ image_folder,
160
+ image_size,
161
+ offload_video_to_cpu,
162
+ img_mean,
163
+ img_std,
164
+ async_loading_frames,
165
+ ):
166
+ """
167
+ Load the video frames from a directory of image files ("<frame_index>.<img_ext>" format)
168
+ """
169
+ frame_names = [
170
+ p
171
+ for p in os.listdir(image_folder)
172
+ if os.path.splitext(p)[-1].lower() in IMAGE_EXTS
173
+ ]
174
+ try:
175
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
176
+ except ValueError:
177
+ # fallback to lexicographic sort if the format is not "<frame_index>.<img_ext>"
178
+ logger.warning(
179
+ f'frame names are not in "<frame_index>.<img_ext>" format: {frame_names[:5]=}, '
180
+ f"falling back to lexicographic sort."
181
+ )
182
+ frame_names.sort()
183
+ num_frames = len(frame_names)
184
+ if num_frames == 0:
185
+ raise RuntimeError(f"no images found in {image_folder}")
186
+ img_paths = [os.path.join(image_folder, frame_name) for frame_name in frame_names]
187
+ img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
188
+ img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
189
+
190
+ if async_loading_frames:
191
+ lazy_images = AsyncImageFrameLoader(
192
+ img_paths, image_size, offload_video_to_cpu, img_mean, img_std
193
+ )
194
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
195
+
196
+ # float16 precision should be sufficient for image tensor storage
197
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float16)
198
+ video_height, video_width = None, None
199
+ for n, img_path in enumerate(
200
+ tqdm(img_paths, desc=f"frame loading (image folder) [rank={RANK}]")
201
+ ):
202
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
203
+ if not offload_video_to_cpu:
204
+ images = images.cuda()
205
+ img_mean = img_mean.cuda()
206
+ img_std = img_std.cuda()
207
+ # normalize by mean and std
208
+ images -= img_mean
209
+ images /= img_std
210
+ return images, video_height, video_width
211
+
212
+
213
+ def load_video_frames_from_video_file(
214
+ video_path,
215
+ image_size,
216
+ offload_video_to_cpu,
217
+ img_mean,
218
+ img_std,
219
+ async_loading_frames,
220
+ gpu_acceleration=False,
221
+ gpu_device=None,
222
+ video_loader_type="cv2",
223
+ ):
224
+ """Load the video frames from a video file."""
225
+ if video_loader_type == "cv2":
226
+ return load_video_frames_from_video_file_using_cv2(
227
+ video_path=video_path,
228
+ image_size=image_size,
229
+ img_mean=img_mean,
230
+ img_std=img_std,
231
+ offload_video_to_cpu=offload_video_to_cpu,
232
+ )
233
+ elif video_loader_type == "torchcodec":
234
+ logger.info("Using torchcodec to load video file")
235
+ lazy_images = AsyncVideoFileLoaderWithTorchCodec(
236
+ video_path=video_path,
237
+ image_size=image_size,
238
+ offload_video_to_cpu=offload_video_to_cpu,
239
+ img_mean=img_mean,
240
+ img_std=img_std,
241
+ gpu_acceleration=gpu_acceleration,
242
+ gpu_device=gpu_device,
243
+ )
244
+ # The `AsyncVideoFileLoaderWithTorchCodec` class always loads the videos asynchronously,
245
+ # so we just wait for its loading thread to finish if async_loading_frames=False.
246
+ if not async_loading_frames:
247
+ async_thread = lazy_images.thread
248
+ if async_thread is not None:
249
+ async_thread.join()
250
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
251
+ else:
252
+ raise RuntimeError("video_loader_type must be either 'cv2' or 'torchcodec'")
253
+
254
+
255
+ def load_video_frames_from_video_file_using_cv2(
256
+ video_path: str,
257
+ image_size: int,
258
+ img_mean: tuple = (0.5, 0.5, 0.5),
259
+ img_std: tuple = (0.5, 0.5, 0.5),
260
+ offload_video_to_cpu: bool = False,
261
+ ) -> torch.Tensor:
262
+ """
263
+ Load video from path, convert to normalized tensor with specified preprocessing
264
+
265
+ Args:
266
+ video_path: Path to video file
267
+ image_size: Target size for square frames (height and width)
268
+ img_mean: Normalization mean (RGB)
269
+ img_std: Normalization standard deviation (RGB)
270
+
271
+ Returns:
272
+ torch.Tensor: Preprocessed video tensor in shape (T, C, H, W) with float16 dtype
273
+ """
274
+ import cv2 # delay OpenCV import to avoid unnecessary dependency
275
+
276
+ # Initialize video capture
277
+ cap = cv2.VideoCapture(video_path)
278
+ if not cap.isOpened():
279
+ raise ValueError(f"Could not open video: {video_path}")
280
+
281
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
282
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
283
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
284
+ num_frames = num_frames if num_frames > 0 else None
285
+
286
+ frames = []
287
+ pbar = tqdm(desc=f"frame loading (OpenCV) [rank={RANK}]", total=num_frames)
288
+ while True:
289
+ ret, frame = cap.read()
290
+ if not ret:
291
+ break
292
+
293
+ # Convert BGR to RGB and resize
294
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
295
+ frame_resized = cv2.resize(
296
+ frame_rgb, (image_size, image_size), interpolation=cv2.INTER_CUBIC
297
+ )
298
+ frames.append(frame_resized)
299
+ pbar.update(1)
300
+ cap.release()
301
+ pbar.close()
302
+
303
+ # Convert to tensor
304
+ frames_np = np.stack(frames, axis=0).astype(np.float32) # (T, H, W, C)
305
+ video_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2) # (T, C, H, W)
306
+
307
+ img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1)
308
+ img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1)
309
+ if not offload_video_to_cpu:
310
+ video_tensor = video_tensor.cuda()
311
+ img_mean = img_mean.cuda()
312
+ img_std = img_std.cuda()
313
+ # normalize by mean and std
314
+ video_tensor -= img_mean
315
+ video_tensor /= img_std
316
+ return video_tensor, original_height, original_width
317
+
318
+
319
+ def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60):
320
+ """
321
+ Load a dummy video with random frames for testing and compilation warmup purposes.
322
+ """
323
+ video_height, video_width = 480, 640 # dummy original video sizes
324
+ images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16)
325
+ if not offload_video_to_cpu:
326
+ images = images.cuda()
327
+ return images, video_height, video_width
328
+
329
+
330
+ def _load_img_as_tensor(img_path, image_size):
331
+ """Load and resize an image and convert it into a PyTorch tensor."""
332
+ img = Image.open(img_path).convert("RGB")
333
+ orig_width, orig_height = img.width, img.height
334
+ img = TF.resize(img, size=(image_size, image_size))
335
+ img = TF.to_tensor(img)
336
+ return img, orig_height, orig_width
337
+
338
+
339
+ class AsyncImageFrameLoader:
340
+ """
341
+ A list of video frames to be load asynchronously without blocking session start.
342
+ """
343
+
344
+ def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
345
+ self.img_paths = img_paths
346
+ self.image_size = image_size
347
+ self.offload_video_to_cpu = offload_video_to_cpu
348
+ self.img_mean = img_mean
349
+ self.img_std = img_std
350
+ # items in `self._images` will be loaded asynchronously
351
+ self.images = [None] * len(img_paths)
352
+ # catch and raise any exceptions in the async loading thread
353
+ self.exception = None
354
+ # video_height and video_width be filled when loading the first image
355
+ self.video_height = None
356
+ self.video_width = None
357
+
358
+ # load the first frame to fill video_height and video_width and also
359
+ # to cache it (since it's most likely where the user will click)
360
+ self.__getitem__(0)
361
+
362
+ # load the rest of frames asynchronously without blocking the session start
363
+ def _load_frames():
364
+ try:
365
+ for n in tqdm(
366
+ range(len(self.images)),
367
+ desc=f"frame loading (image folder) [rank={RANK}]",
368
+ ):
369
+ self.__getitem__(n)
370
+ except Exception as e:
371
+ self.exception = e
372
+
373
+ self.thread = Thread(target=_load_frames, daemon=True)
374
+ self.thread.start()
375
+
376
+ def __getitem__(self, index):
377
+ if self.exception is not None:
378
+ raise RuntimeError("Failure in frame loading thread") from self.exception
379
+
380
+ img = self.images[index]
381
+ if img is not None:
382
+ return img
383
+
384
+ img, video_height, video_width = _load_img_as_tensor(
385
+ self.img_paths[index], self.image_size
386
+ )
387
+ self.video_height = video_height
388
+ self.video_width = video_width
389
+ # float16 precision should be sufficient for image tensor storage
390
+ img = img.to(dtype=torch.float16)
391
+ # normalize by mean and std
392
+ img -= self.img_mean
393
+ img /= self.img_std
394
+ if not self.offload_video_to_cpu:
395
+ img = img.cuda()
396
+ self.images[index] = img
397
+ return img
398
+
399
+ def __len__(self):
400
+ return len(self.images)
401
+
402
+
403
+ class TorchCodecDecoder:
404
+ """
405
+ A wrapper to support GPU device and num_threads in TorchCodec decoder,
406
+ which are not supported by `torchcodec.decoders.SimpleVideoDecoder` yet.
407
+ """
408
+
409
+ def __init__(self, source, dimension_order="NCHW", device="cpu", num_threads=1):
410
+ from torchcodec import _core as core
411
+
412
+ self._source = source # hold a reference to the source to prevent it from GC
413
+ if isinstance(source, str):
414
+ self._decoder = core.create_from_file(source, "exact")
415
+ elif isinstance(source, bytes):
416
+ self._decoder = core.create_from_bytes(source, "exact")
417
+ else:
418
+ raise TypeError(f"Unknown source type: {type(source)}.")
419
+ assert dimension_order in ("NCHW", "NHWC")
420
+
421
+ device_string = str(device)
422
+ core.scan_all_streams_to_update_metadata(self._decoder)
423
+ core.add_video_stream(
424
+ self._decoder,
425
+ dimension_order=dimension_order,
426
+ device=device_string,
427
+ num_threads=(1 if "cuda" in device_string else num_threads),
428
+ )
429
+ video_metadata = core.get_container_metadata(self._decoder)
430
+ best_stream_index = video_metadata.best_video_stream_index
431
+ assert best_stream_index is not None
432
+ self.metadata = video_metadata.streams[best_stream_index]
433
+ assert self.metadata.num_frames_from_content is not None
434
+ self._num_frames = self.metadata.num_frames_from_content
435
+
436
+ def __len__(self) -> int:
437
+ return self._num_frames
438
+
439
+ def __getitem__(self, key: int):
440
+ from torchcodec import _core as core
441
+
442
+ if key < 0:
443
+ key += self._num_frames
444
+ if key >= self._num_frames or key < 0:
445
+ raise IndexError(
446
+ f"Index {key} is out of bounds; length is {self._num_frames}"
447
+ )
448
+ frame_data, *_ = core.get_frame_at_index(
449
+ self._decoder,
450
+ frame_index=key,
451
+ )
452
+ return frame_data
453
+
454
+
455
+ class FIFOLock:
456
+ """A lock that ensures FIFO ordering of lock acquisitions."""
457
+
458
+ def __init__(self):
459
+ self._lock = Lock()
460
+ self._waiters = queue.Queue()
461
+ self._condition = Condition()
462
+
463
+ def acquire(self):
464
+ ident = get_ident()
465
+ with self._condition:
466
+ self._waiters.put(ident)
467
+ while self._waiters.queue[0] != ident or not self._lock.acquire(
468
+ blocking=False
469
+ ):
470
+ self._condition.wait()
471
+ # got the lock and it's our turn
472
+
473
+ def release(self):
474
+ with self._condition:
475
+ self._lock.release()
476
+ self._waiters.get()
477
+ self._condition.notify_all()
478
+
479
+ def __enter__(self):
480
+ self.acquire()
481
+
482
+ def __exit__(self, t, v, tb):
483
+ self.release()
484
+
485
+
486
+ class AsyncVideoFileLoaderWithTorchCodec:
487
+ """
488
+ Loading frames from video files asynchronously without blocking session start.
489
+
490
+ Unlike `AsyncVideoFileLoader`, this class uses PyTorch's offical TorchCodec library
491
+ for video decoding, which is more efficient and supports more video formats.
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ video_path,
497
+ image_size,
498
+ offload_video_to_cpu,
499
+ img_mean,
500
+ img_std,
501
+ gpu_acceleration=True,
502
+ gpu_device=None,
503
+ use_rand_seek_in_loading=False,
504
+ ):
505
+ # Check and possibly infer the output device (and also get its GPU id when applicable)
506
+ assert gpu_device is None or gpu_device.type == "cuda"
507
+ gpu_id = (
508
+ gpu_device.index
509
+ if gpu_device is not None and gpu_device.index is not None
510
+ else torch.cuda.current_device()
511
+ )
512
+ if offload_video_to_cpu:
513
+ out_device = torch.device("cpu")
514
+ else:
515
+ out_device = torch.device("cuda") if gpu_device is None else gpu_device
516
+ self.out_device = out_device
517
+ self.gpu_acceleration = gpu_acceleration
518
+ self.gpu_id = gpu_id
519
+ self.image_size = image_size
520
+ self.offload_video_to_cpu = offload_video_to_cpu
521
+ if not isinstance(img_mean, torch.Tensor):
522
+ img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
523
+ self.img_mean = img_mean
524
+ if not isinstance(img_std, torch.Tensor):
525
+ img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
526
+ self.img_std = img_std
527
+
528
+ if gpu_acceleration:
529
+ self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}")
530
+ self.img_std = self.img_std.to(f"cuda:{self.gpu_id}")
531
+ decoder_option = {"device": f"cuda:{self.gpu_id}"}
532
+ else:
533
+ self.img_mean = self.img_mean.cpu()
534
+ self.img_std = self.img_std.cpu()
535
+ decoder_option = {"num_threads": 1} # use a single thread to save memory
536
+
537
+ self.rank = int(os.environ.get("RANK", "0"))
538
+ self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
539
+ self.async_reader = TorchCodecDecoder(video_path, **decoder_option)
540
+
541
+ # `num_frames_from_content` is the true number of frames in the video content
542
+ # from the scan operation (rather than from the metadata, which could be wrong)
543
+ self.num_frames = self.async_reader.metadata.num_frames_from_content
544
+ self.video_height = self.async_reader.metadata.height
545
+ self.video_width = self.async_reader.metadata.width
546
+
547
+ # items in `self._images` will be loaded asynchronously
548
+ self.images_loaded = [False] * self.num_frames
549
+ self.images = torch.zeros(
550
+ self.num_frames,
551
+ 3,
552
+ self.image_size,
553
+ self.image_size,
554
+ dtype=torch.float16,
555
+ device=self.out_device,
556
+ )
557
+ # catch and raise any exceptions in the async loading thread
558
+ self.exception = None
559
+ self.use_rand_seek_in_loading = use_rand_seek_in_loading
560
+ self.rand_seek_idx_queue = queue.Queue()
561
+ # use a lock to avoid race condition between concurrent access to torchcodec
562
+ # libs (which are not thread-safe); the lock is replaced with a nullcontext
563
+ # when the video is fully loaded
564
+ self.torchcodec_access_lock = FIFOLock()
565
+ self._start_video_loading()
566
+
567
+ def _load_one_frame(self, idx):
568
+ frame_resized = self._transform_frame(self.async_reader[idx])
569
+ return frame_resized
570
+
571
+ @torch.inference_mode()
572
+ def _start_video_loading(self):
573
+ desc = f"frame loading (TorchCodec w/ {'GPU' if self.gpu_acceleration else 'CPU'}) [rank={RANK}]"
574
+ pbar = tqdm(desc=desc, total=self.num_frames)
575
+ self.num_loaded_frames = 0
576
+ # load the first frame synchronously to cache it before the session is opened
577
+ idx = self.num_loaded_frames
578
+ self.images[idx] = self._load_one_frame(idx)
579
+ self.images_loaded[idx] = True
580
+ self.num_loaded_frames += 1
581
+ pbar.update(n=1)
582
+ self.all_frames_loaded = self.num_loaded_frames == self.num_frames
583
+
584
+ # load the frames asynchronously without blocking the session start
585
+ def _load_frames():
586
+ finished = self.all_frames_loaded
587
+ chunk_size = 16
588
+ while not finished:
589
+ # asynchronously load `chunk_size` frames each time we acquire the lock
590
+ with self.torchcodec_access_lock, torch.inference_mode():
591
+ for _ in range(chunk_size):
592
+ try:
593
+ idx = self.num_loaded_frames
594
+ self.images[idx] = self._load_one_frame(idx)
595
+ self.images_loaded[idx] = True
596
+ self.num_loaded_frames += 1
597
+ pbar.update(n=1)
598
+ if self.num_loaded_frames >= self.num_frames:
599
+ finished = True
600
+ break
601
+ except Exception as e:
602
+ self.exception = e
603
+ raise
604
+
605
+ # also read the frame that is being randomly seeked to
606
+ while True:
607
+ try:
608
+ idx = self.rand_seek_idx_queue.get_nowait()
609
+ if not self.images_loaded[idx]:
610
+ self.images[idx] = self._load_one_frame(idx)
611
+ self.images_loaded[idx] = True
612
+ except queue.Empty:
613
+ break
614
+ except Exception as e:
615
+ self.exception = e
616
+ raise
617
+
618
+ # finished -- check whether we have loaded the total number of frames
619
+ if self.num_loaded_frames != self.num_frames:
620
+ raise RuntimeError(
621
+ f"There are {self.num_frames} frames in the video, but only "
622
+ f"{self.num_loaded_frames} frames can be loaded successfully."
623
+ )
624
+ else:
625
+ self.all_frames_loaded = True
626
+ pbar.close()
627
+ with self.torchcodec_access_lock:
628
+ import gc
629
+
630
+ # all frames have been loaded, so we can release the readers and free their memory
631
+ # also remove pbar and thread (which shouldn't be a part of session saving)
632
+ reader = self.async_reader
633
+ if reader is not None:
634
+ reader._source = None
635
+ self.async_reader = None
636
+ self.pbar = None
637
+ self.thread = None
638
+ self.rand_seek_idx_queue = None
639
+ gc.collect()
640
+ # remove the lock (replace it with nullcontext) when the video is fully loaded
641
+ self.torchcodec_access_lock = contextlib.nullcontext()
642
+
643
+ self.thread = Thread(target=_load_frames, daemon=True)
644
+ self.thread.start()
645
+
646
+ def _transform_frame(self, frame):
647
+ frame = frame.clone() # make a copy to avoid modifying the original frame bytes
648
+ frame = frame.float() # convert to float32 before interpolation
649
+ frame_resized = F.interpolate(
650
+ frame[None, :],
651
+ size=(self.image_size, self.image_size),
652
+ mode="bicubic",
653
+ align_corners=False,
654
+ )[0]
655
+ # float16 precision should be sufficient for image tensor storage
656
+ frame_resized = frame_resized.half() # uint8 -> float16
657
+ frame_resized /= 255
658
+ frame_resized -= self.img_mean
659
+ frame_resized /= self.img_std
660
+ if self.offload_video_to_cpu:
661
+ frame_resized = frame_resized.cpu()
662
+ elif frame_resized.device != self.out_device:
663
+ frame_resized = frame_resized.to(device=self.out_device, non_blocking=True)
664
+ return frame_resized
665
+
666
+ def __getitem__(self, index):
667
+ if self.exception is not None:
668
+ raise RuntimeError("Failure in frame loading thread") from self.exception
669
+
670
+ max_tries = 1200
671
+ for _ in range(max_tries):
672
+ # use a lock to avoid race condition between concurrent access to torchcodec
673
+ # libs (which are not thread-safe); the lock is replaced with a nullcontext
674
+ # when the video is fully loaded
675
+ with self.torchcodec_access_lock:
676
+ if self.images_loaded[index]:
677
+ return self.images[index]
678
+
679
+ if self.use_rand_seek_in_loading:
680
+ # async loading hasn't reached this frame yet, so we load this frame individually
681
+ # (it will be loaded by in _load_frames thread and added to self.images[index])
682
+ self.rand_seek_idx_queue.put(index)
683
+
684
+ time.sleep(0.1)
685
+
686
+ raise RuntimeError(f"Failed to load frame {index} after {max_tries} tries")
687
+
688
+ def __len__(self):
689
+ return len(self.images)
690
+
691
+ def __getstate__(self):
692
+ """
693
+ Remove a few attributes during pickling, so that this async video loader can be
694
+ saved and loaded as a part of the model session.
695
+ """
696
+ # wait for async video loading to finish before pickling
697
+ async_thread = self.thread
698
+ if async_thread is not None:
699
+ async_thread.join()
700
+ # release a few objects that cannot be pickled
701
+ reader = self.async_reader
702
+ if reader is not None:
703
+ reader._source = None
704
+ self.async_reader = None
705
+ self.pbar = None
706
+ self.thread = None
707
+ self.rand_seek_idx_queue = None
708
+ self.torchcodec_access_lock = contextlib.nullcontext()
709
+ return self.__dict__.copy()
detect_tools/sam3/sam3/model/maskformer_segmentation.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import math
4
+ from typing import Dict, List, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as checkpoint
10
+
11
+ from .model_misc import MLP
12
+
13
+
14
+ class LinearPresenceHead(nn.Sequential):
15
+ def __init__(self, d_model):
16
+ # a hack to make `LinearPresenceHead` compatible with old checkpoints
17
+ super().__init__(nn.Identity(), nn.Identity(), nn.Linear(d_model, 1))
18
+
19
+ def forward(self, hs, prompt, prompt_mask):
20
+ return super().forward(hs)
21
+
22
+
23
+ class MaskPredictor(nn.Module):
24
+ def __init__(self, hidden_dim, mask_dim):
25
+ super().__init__()
26
+ self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
27
+
28
+ def forward(self, obj_queries, pixel_embed):
29
+ if len(obj_queries.shape) == 3:
30
+ if pixel_embed.ndim == 3:
31
+ # batch size was omitted
32
+ mask_preds = torch.einsum(
33
+ "bqc,chw->bqhw", self.mask_embed(obj_queries), pixel_embed
34
+ )
35
+ else:
36
+ mask_preds = torch.einsum(
37
+ "bqc,bchw->bqhw", self.mask_embed(obj_queries), pixel_embed
38
+ )
39
+ else:
40
+ # Assumed to have aux masks
41
+ if pixel_embed.ndim == 3:
42
+ # batch size was omitted
43
+ mask_preds = torch.einsum(
44
+ "lbqc,chw->lbqhw", self.mask_embed(obj_queries), pixel_embed
45
+ )
46
+ else:
47
+ mask_preds = torch.einsum(
48
+ "lbqc,bchw->lbqhw", self.mask_embed(obj_queries), pixel_embed
49
+ )
50
+
51
+ return mask_preds
52
+
53
+
54
+ class SegmentationHead(nn.Module):
55
+ def __init__(
56
+ self,
57
+ hidden_dim,
58
+ upsampling_stages,
59
+ use_encoder_inputs=False,
60
+ aux_masks=False,
61
+ no_dec=False,
62
+ pixel_decoder=None,
63
+ act_ckpt=False,
64
+ shared_conv=False,
65
+ compile_mode_pixel_decoder=None,
66
+ ):
67
+ super().__init__()
68
+ self.use_encoder_inputs = use_encoder_inputs
69
+ self.aux_masks = aux_masks
70
+ if pixel_decoder is not None:
71
+ self.pixel_decoder = pixel_decoder
72
+ else:
73
+ self.pixel_decoder = PixelDecoder(
74
+ hidden_dim,
75
+ upsampling_stages,
76
+ shared_conv=shared_conv,
77
+ compile_mode=compile_mode_pixel_decoder,
78
+ )
79
+ self.no_dec = no_dec
80
+ if no_dec:
81
+ self.mask_predictor = nn.Conv2d(
82
+ hidden_dim, 1, kernel_size=3, stride=1, padding=1
83
+ )
84
+ else:
85
+ self.mask_predictor = MaskPredictor(hidden_dim, mask_dim=hidden_dim)
86
+
87
+ self.act_ckpt = act_ckpt
88
+
89
+ # used to update the output dictionary
90
+ self.instance_keys = ["pred_masks"]
91
+
92
+ @property
93
+ def device(self):
94
+ self._device = getattr(self, "_device", None) or next(self.parameters()).device
95
+ return self._device
96
+
97
+ def to(self, *args, **kwargs):
98
+ # clear cached _device in case the model is moved to a different device
99
+ self._device = None
100
+ return super().to(*args, **kwargs)
101
+
102
+ def _embed_pixels(
103
+ self,
104
+ backbone_feats: List[torch.Tensor],
105
+ image_ids,
106
+ encoder_hidden_states,
107
+ ) -> torch.Tensor:
108
+ feature_device = backbone_feats[0].device # features could be on CPU
109
+ model_device = self.device
110
+ image_ids_ = image_ids.to(feature_device)
111
+ if self.use_encoder_inputs:
112
+ if backbone_feats[0].shape[0] > 1:
113
+ # For bs > 1, we construct the per query backbone features
114
+ backbone_visual_feats = []
115
+ for feat in backbone_feats:
116
+ # Copy the img features per query (pixel decoder won't share img feats)
117
+ backbone_visual_feats.append(feat[image_ids_, ...].to(model_device))
118
+ else:
119
+ # Bs=1, we rely on broadcasting for query-based processing
120
+ backbone_visual_feats = [bb_feat.clone() for bb_feat in backbone_feats]
121
+ # Extract visual embeddings
122
+ encoder_hidden_states = encoder_hidden_states.permute(1, 2, 0)
123
+ spatial_dim = math.prod(backbone_feats[-1].shape[-2:])
124
+ encoder_visual_embed = encoder_hidden_states[..., :spatial_dim].reshape(
125
+ -1, *backbone_feats[-1].shape[1:]
126
+ )
127
+
128
+ backbone_visual_feats[-1] = encoder_visual_embed
129
+ if self.act_ckpt:
130
+ pixel_embed = checkpoint.checkpoint(
131
+ self.pixel_decoder, backbone_visual_feats, use_reentrant=False
132
+ )
133
+ else:
134
+ pixel_embed = self.pixel_decoder(backbone_visual_feats)
135
+ else:
136
+ backbone_feats = [x.to(model_device) for x in backbone_feats]
137
+ pixel_embed = self.pixel_decoder(backbone_feats)
138
+ if pixel_embed.shape[0] == 1:
139
+ # For batch_size=1 training, we can avoid the indexing to save memory
140
+ pixel_embed = pixel_embed.squeeze(0)
141
+ else:
142
+ pixel_embed = pixel_embed[image_ids, ...]
143
+ return pixel_embed
144
+
145
+ def forward(
146
+ self,
147
+ backbone_feats: List[torch.Tensor],
148
+ obj_queries: torch.Tensor,
149
+ image_ids,
150
+ encoder_hidden_states: Optional[torch.Tensor] = None,
151
+ **kwargs,
152
+ ) -> Dict[str, torch.Tensor]:
153
+ if self.use_encoder_inputs:
154
+ assert encoder_hidden_states is not None
155
+
156
+ pixel_embed = self._embed_pixels(
157
+ backbone_feats=backbone_feats,
158
+ image_ids=image_ids,
159
+ encoder_hidden_states=encoder_hidden_states,
160
+ )
161
+
162
+ if self.no_dec:
163
+ mask_pred = self.mask_predictor(pixel_embed)
164
+ elif self.aux_masks:
165
+ mask_pred = self.mask_predictor(obj_queries, pixel_embed)
166
+ else:
167
+ mask_pred = self.mask_predictor(obj_queries[-1], pixel_embed)
168
+
169
+ return {"pred_masks": mask_pred}
170
+
171
+
172
+ class PixelDecoder(nn.Module):
173
+ def __init__(
174
+ self,
175
+ hidden_dim,
176
+ num_upsampling_stages,
177
+ interpolation_mode="nearest",
178
+ shared_conv=False,
179
+ compile_mode=None,
180
+ ):
181
+ super().__init__()
182
+ self.hidden_dim = hidden_dim
183
+ self.num_upsampling_stages = num_upsampling_stages
184
+ self.interpolation_mode = interpolation_mode
185
+ conv_layers = []
186
+ norms = []
187
+ num_convs = 1 if shared_conv else num_upsampling_stages
188
+ for _ in range(num_convs):
189
+ conv_layers.append(nn.Conv2d(self.hidden_dim, self.hidden_dim, 3, 1, 1))
190
+ norms.append(nn.GroupNorm(8, self.hidden_dim))
191
+
192
+ self.conv_layers = nn.ModuleList(conv_layers)
193
+ self.norms = nn.ModuleList(norms)
194
+ self.shared_conv = shared_conv
195
+ self.out_dim = self.conv_layers[-1].out_channels
196
+ if compile_mode is not None:
197
+ self.forward = torch.compile(
198
+ self.forward, mode=compile_mode, dynamic=True, fullgraph=True
199
+ )
200
+ # Needed to make checkpointing happy. But we don't know if the module is checkpointed, so we disable it by default.
201
+ torch._dynamo.config.optimize_ddp = False
202
+
203
+ def forward(self, backbone_feats: List[torch.Tensor]):
204
+ # Assumes backbone features are already projected (C == hidden dim)
205
+
206
+ prev_fpn = backbone_feats[-1]
207
+ fpn_feats = backbone_feats[:-1]
208
+ for layer_idx, bb_feat in enumerate(fpn_feats[::-1]):
209
+ curr_fpn = bb_feat
210
+ prev_fpn = curr_fpn + F.interpolate(
211
+ prev_fpn, size=curr_fpn.shape[-2:], mode=self.interpolation_mode
212
+ )
213
+ if self.shared_conv:
214
+ # only one conv layer
215
+ layer_idx = 0
216
+ prev_fpn = self.conv_layers[layer_idx](prev_fpn)
217
+ prev_fpn = F.relu(self.norms[layer_idx](prev_fpn))
218
+
219
+ return prev_fpn
220
+
221
+
222
+ class UniversalSegmentationHead(SegmentationHead):
223
+ """This module handles semantic+instance segmentation"""
224
+
225
+ def __init__(
226
+ self,
227
+ hidden_dim,
228
+ upsampling_stages,
229
+ pixel_decoder,
230
+ aux_masks=False,
231
+ no_dec=False,
232
+ act_ckpt=False,
233
+ presence_head: bool = False,
234
+ dot_product_scorer=None,
235
+ cross_attend_prompt=None,
236
+ ):
237
+ super().__init__(
238
+ hidden_dim=hidden_dim,
239
+ upsampling_stages=upsampling_stages,
240
+ use_encoder_inputs=True,
241
+ aux_masks=aux_masks,
242
+ no_dec=no_dec,
243
+ pixel_decoder=pixel_decoder,
244
+ act_ckpt=act_ckpt,
245
+ )
246
+ self.d_model = hidden_dim
247
+
248
+ if dot_product_scorer is not None:
249
+ assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
250
+
251
+ self.presence_head = None
252
+ if presence_head:
253
+ self.presence_head = (
254
+ dot_product_scorer
255
+ if dot_product_scorer is not None
256
+ else LinearPresenceHead(self.d_model)
257
+ )
258
+
259
+ self.cross_attend_prompt = cross_attend_prompt
260
+ if self.cross_attend_prompt is not None:
261
+ self.cross_attn_norm = nn.LayerNorm(self.d_model)
262
+
263
+ self.semantic_seg_head = nn.Conv2d(self.pixel_decoder.out_dim, 1, kernel_size=1)
264
+ self.instance_seg_head = nn.Conv2d(
265
+ self.pixel_decoder.out_dim, self.d_model, kernel_size=1
266
+ )
267
+
268
+ def forward(
269
+ self,
270
+ backbone_feats: List[torch.Tensor],
271
+ obj_queries: torch.Tensor,
272
+ image_ids,
273
+ encoder_hidden_states: Optional[torch.Tensor] = None,
274
+ prompt: Optional[torch.Tensor] = None,
275
+ prompt_mask: Optional[torch.Tensor] = None,
276
+ **kwargs,
277
+ ) -> Dict[str, Optional[torch.Tensor]]:
278
+ assert encoder_hidden_states is not None
279
+ bs = encoder_hidden_states.shape[1]
280
+
281
+ if self.cross_attend_prompt is not None:
282
+ tgt2 = self.cross_attn_norm(encoder_hidden_states)
283
+ tgt2 = self.cross_attend_prompt(
284
+ query=tgt2,
285
+ key=prompt,
286
+ value=prompt,
287
+ key_padding_mask=prompt_mask,
288
+ )[0]
289
+ encoder_hidden_states = tgt2 + encoder_hidden_states
290
+
291
+ presence_logit = None
292
+ if self.presence_head is not None:
293
+ pooled_enc = encoder_hidden_states.mean(0)
294
+ presence_logit = (
295
+ self.presence_head(
296
+ pooled_enc.view(1, bs, 1, self.d_model),
297
+ prompt=prompt,
298
+ prompt_mask=prompt_mask,
299
+ )
300
+ .squeeze(0)
301
+ .squeeze(1)
302
+ )
303
+
304
+ pixel_embed = self._embed_pixels(
305
+ backbone_feats=backbone_feats,
306
+ image_ids=image_ids,
307
+ encoder_hidden_states=encoder_hidden_states,
308
+ )
309
+
310
+ instance_embeds = self.instance_seg_head(pixel_embed)
311
+
312
+ if self.no_dec:
313
+ mask_pred = self.mask_predictor(instance_embeds)
314
+ elif self.aux_masks:
315
+ mask_pred = self.mask_predictor(obj_queries, instance_embeds)
316
+ else:
317
+ mask_pred = self.mask_predictor(obj_queries[-1], instance_embeds)
318
+
319
+ return {
320
+ "pred_masks": mask_pred,
321
+ "semantic_seg": self.semantic_seg_head(pixel_embed),
322
+ "presence_logit": presence_logit,
323
+ }
detect_tools/sam3/sam3/model/memory.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import math
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ try:
11
+ from timm.layers import DropPath
12
+ except ModuleNotFoundError:
13
+ # compatibility for older timm versions
14
+ from timm.models.layers import DropPath
15
+
16
+ from .model_misc import get_clones, LayerNorm2d
17
+
18
+
19
+ class SimpleMaskDownSampler(nn.Module):
20
+ """
21
+ Progressively downsample a mask by total_stride, each time by stride.
22
+ Note that LayerNorm is applied per *token*, like in ViT.
23
+
24
+ With each downsample (by a factor stride**2), channel capacity increases by the same factor.
25
+ In the end, we linearly project to embed_dim channels.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ embed_dim=256,
31
+ kernel_size=4,
32
+ stride=4,
33
+ padding=0,
34
+ total_stride=16,
35
+ activation=nn.GELU,
36
+ # Option to interpolate the input mask first before downsampling using convs. In that case, the total_stride is assumed to be after interpolation.
37
+ # If set to input resolution or None, we don't interpolate. We default to None to be safe (for older configs or if not explicitly set)
38
+ interpol_size=None,
39
+ ):
40
+ super().__init__()
41
+ num_layers = int(math.log2(total_stride) // math.log2(stride))
42
+ assert stride**num_layers == total_stride
43
+ self.encoder = nn.Sequential()
44
+ mask_in_chans, mask_out_chans = 1, 1
45
+ for _ in range(num_layers):
46
+ mask_out_chans = mask_in_chans * (stride**2)
47
+ self.encoder.append(
48
+ nn.Conv2d(
49
+ mask_in_chans,
50
+ mask_out_chans,
51
+ kernel_size=kernel_size,
52
+ stride=stride,
53
+ padding=padding,
54
+ )
55
+ )
56
+ self.encoder.append(LayerNorm2d(mask_out_chans))
57
+ self.encoder.append(activation())
58
+ mask_in_chans = mask_out_chans
59
+
60
+ self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
61
+ self.interpol_size = interpol_size
62
+ if self.interpol_size is not None:
63
+ assert isinstance(
64
+ self.interpol_size, (list, tuple)
65
+ ), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
66
+ self.interpol_size = list(interpol_size)
67
+ assert len(self.interpol_size) == 2
68
+
69
+ def forward(self, x: torch.Tensor):
70
+ if self.interpol_size is not None and self.interpol_size != list(x.shape[-2:]):
71
+ x = F.interpolate(
72
+ x.float(),
73
+ size=self.interpol_size,
74
+ align_corners=False,
75
+ mode="bilinear",
76
+ antialias=True,
77
+ )
78
+ return self.encoder(x)
79
+
80
+
81
+ # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
82
+ class CXBlock(nn.Module):
83
+ r"""ConvNeXt Block. There are two equivalent implementations:
84
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
85
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
86
+ We use (2) as we find it slightly faster in PyTorch
87
+
88
+ Args:
89
+ dim (int): Number of input channels.
90
+ drop_path (float): Stochastic depth rate. Default: 0.0
91
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ dim,
97
+ kernel_size=7,
98
+ padding=3,
99
+ drop_path=0.0,
100
+ layer_scale_init_value=1e-6,
101
+ use_dwconv=True,
102
+ ):
103
+ super().__init__()
104
+ self.dwconv = nn.Conv2d(
105
+ dim,
106
+ dim,
107
+ kernel_size=kernel_size,
108
+ padding=padding,
109
+ groups=dim if use_dwconv else 1,
110
+ ) # depthwise conv
111
+ self.norm = LayerNorm2d(dim, eps=1e-6)
112
+ self.pwconv1 = nn.Linear(
113
+ dim, 4 * dim
114
+ ) # pointwise/1x1 convs, implemented with linear layers
115
+ self.act = nn.GELU()
116
+ self.pwconv2 = nn.Linear(4 * dim, dim)
117
+ self.gamma = (
118
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
119
+ if layer_scale_init_value > 0
120
+ else None
121
+ )
122
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
123
+
124
+ def forward(self, x):
125
+ input = x
126
+ x = self.dwconv(x)
127
+ x = self.norm(x)
128
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
129
+ x = self.pwconv1(x)
130
+ x = self.act(x)
131
+ x = self.pwconv2(x)
132
+ if self.gamma is not None:
133
+ x = self.gamma * x
134
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
135
+
136
+ x = input + self.drop_path(x)
137
+ return x
138
+
139
+
140
+ class SimpleFuser(nn.Module):
141
+ def __init__(self, layer, num_layers, dim=None, input_projection=False):
142
+ super().__init__()
143
+ self.proj = nn.Identity()
144
+ self.layers = get_clones(layer, num_layers)
145
+
146
+ if input_projection:
147
+ assert dim is not None
148
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
149
+
150
+ def forward(self, x):
151
+ # normally x: (N, C, H, W)
152
+ x = self.proj(x)
153
+ for layer in self.layers:
154
+ x = layer(x)
155
+ return x
156
+
157
+
158
+ class SimpleMaskEncoder(nn.Module):
159
+ def __init__(
160
+ self,
161
+ out_dim,
162
+ mask_downsampler,
163
+ fuser,
164
+ position_encoding,
165
+ in_dim=256, # in_dim of pix_feats
166
+ ):
167
+ super().__init__()
168
+
169
+ self.mask_downsampler = mask_downsampler
170
+
171
+ self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
172
+ self.fuser = fuser
173
+ self.position_encoding = position_encoding
174
+ self.out_proj = nn.Identity()
175
+ if out_dim != in_dim:
176
+ self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
177
+
178
+ def forward(
179
+ self,
180
+ pix_feat: torch.Tensor,
181
+ masks: torch.Tensor,
182
+ skip_mask_sigmoid: bool = False,
183
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
184
+ ## Process masks
185
+ # sigmoid, so that less domain shift from gt masks which are bool
186
+ if not skip_mask_sigmoid:
187
+ masks = F.sigmoid(masks)
188
+ masks = self.mask_downsampler(masks)
189
+
190
+ ## Fuse pix_feats and downsampled masks
191
+ # in case the visual features are on CPU, cast them to CUDA
192
+ pix_feat = pix_feat.to(masks.device)
193
+
194
+ x = self.pix_feat_proj(pix_feat)
195
+ x = x + masks
196
+ x = self.fuser(x)
197
+ x = self.out_proj(x)
198
+
199
+ pos = self.position_encoding(x).to(x.dtype)
200
+
201
+ return {"vision_features": x, "vision_pos_enc": [pos]}
detect_tools/sam3/sam3/model/model_misc.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """Various utility models"""
4
+
5
+ import copy
6
+ import math
7
+ import weakref
8
+ from collections.abc import Iterator
9
+ from contextlib import AbstractContextManager
10
+ from enum import auto, Enum
11
+ from typing import Dict, List, Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn, Tensor
17
+ from typing_extensions import override
18
+
19
+
20
+ def inverse_sigmoid(x, eps=1e-3):
21
+ """
22
+ The inverse function for sigmoid activation function.
23
+ Note: It might face numberical issues with fp16 small eps.
24
+ """
25
+ x = x.clamp(min=0, max=1)
26
+ x1 = x.clamp(min=eps)
27
+ x2 = (1 - x).clamp(min=eps)
28
+ return torch.log(x1 / x2)
29
+
30
+
31
+ class MultiheadAttentionWrapper(nn.MultiheadAttention):
32
+ def forward(self, *args, **kwargs):
33
+ kwargs["need_weights"] = False
34
+ return super().forward(*args, **kwargs)
35
+
36
+
37
+ class DotProductScoring(torch.nn.Module):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_proj,
42
+ prompt_mlp=None,
43
+ clamp_logits=True,
44
+ clamp_max_val=12.0,
45
+ ):
46
+ super().__init__()
47
+ self.d_proj = d_proj
48
+ assert isinstance(prompt_mlp, torch.nn.Module) or prompt_mlp is None
49
+ self.prompt_mlp = prompt_mlp # an optional MLP projection for prompt
50
+ self.prompt_proj = torch.nn.Linear(d_model, d_proj)
51
+ self.hs_proj = torch.nn.Linear(d_model, d_proj)
52
+ self.scale = float(1.0 / np.sqrt(d_proj))
53
+ self.clamp_logits = clamp_logits
54
+ if self.clamp_logits:
55
+ self.clamp_max_val = clamp_max_val
56
+
57
+ def mean_pool_text(self, prompt, prompt_mask):
58
+ # is_valid has shape (seq, bs, 1), where 1 is valid and 0 is padding
59
+ is_valid = (~prompt_mask).float().permute(1, 0)[..., None]
60
+ # num_valid has shape (bs, 1)
61
+ num_valid = torch.clamp(torch.sum(is_valid, dim=0), min=1.0)
62
+ # mean pool over all the valid tokens -- pooled_prompt has shape (bs, proj_dim)
63
+ pooled_prompt = (prompt * is_valid).sum(dim=0) / num_valid
64
+ return pooled_prompt
65
+
66
+ def forward(self, hs, prompt, prompt_mask):
67
+ # hs has shape (num_layer, bs, num_query, d_model)
68
+ # prompt has shape (seq, bs, d_model)
69
+ # prompt_mask has shape (bs, seq), where 1 is valid and 0 is padding
70
+ assert hs.dim() == 4 and prompt.dim() == 3 and prompt_mask.dim() == 2
71
+
72
+ # apply MLP on prompt if specified
73
+ if self.prompt_mlp is not None:
74
+ prompt = self.prompt_mlp(prompt)
75
+
76
+ # first, get the mean-pooled version of the prompt
77
+ pooled_prompt = self.mean_pool_text(prompt, prompt_mask)
78
+
79
+ # then, project pooled_prompt and hs to d_proj dimensions
80
+ proj_pooled_prompt = self.prompt_proj(pooled_prompt) # (bs, d_proj)
81
+ proj_hs = self.hs_proj(hs) # (num_layer, bs, num_query, d_proj)
82
+
83
+ # finally, get dot-product scores of shape (num_layer, bs, num_query, 1)
84
+ scores = torch.matmul(proj_hs, proj_pooled_prompt.unsqueeze(-1))
85
+ scores *= self.scale
86
+
87
+ # clamp scores to a max value to avoid numerical issues in loss or matcher
88
+ if self.clamp_logits:
89
+ scores.clamp_(min=-self.clamp_max_val, max=self.clamp_max_val)
90
+
91
+ return scores
92
+
93
+
94
+ class LayerScale(nn.Module):
95
+ def __init__(
96
+ self,
97
+ dim: int,
98
+ init_values: Union[float, Tensor] = 1e-5,
99
+ inplace: bool = False,
100
+ ) -> None:
101
+ super().__init__()
102
+ self.inplace = inplace
103
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
104
+
105
+ def forward(self, x: Tensor) -> Tensor:
106
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
107
+
108
+
109
+ class LayerNorm2d(nn.Module):
110
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
111
+ super().__init__()
112
+ self.weight = nn.Parameter(torch.ones(num_channels))
113
+ self.bias = nn.Parameter(torch.zeros(num_channels))
114
+ self.eps = eps
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ u = x.mean(1, keepdim=True)
118
+ s = (x - u).pow(2).mean(1, keepdim=True)
119
+ x = (x - u) / torch.sqrt(s + self.eps)
120
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
121
+ return x
122
+
123
+
124
+ class TransformerWrapper(nn.Module):
125
+ def __init__(
126
+ self,
127
+ encoder,
128
+ decoder,
129
+ d_model: int,
130
+ two_stage_type="none", # ["none"] only for now
131
+ pos_enc_at_input_dec=True,
132
+ ):
133
+ super().__init__()
134
+
135
+ self.encoder = encoder
136
+ self.decoder = decoder
137
+ self.num_queries = decoder.num_queries if decoder is not None else None
138
+ self.pos_enc_at_input_dec = pos_enc_at_input_dec
139
+
140
+ # for two stage
141
+ assert two_stage_type in ["none"], "unknown param {} of two_stage_type".format(
142
+ two_stage_type
143
+ )
144
+ self.two_stage_type = two_stage_type
145
+
146
+ self._reset_parameters()
147
+ self.d_model = d_model
148
+
149
+ def _reset_parameters(self):
150
+ for n, p in self.named_parameters():
151
+ if p.dim() > 1:
152
+ if (
153
+ "box_embed" not in n
154
+ and "query_embed" not in n
155
+ and "reference_points" not in n
156
+ ):
157
+ nn.init.xavier_uniform_(p)
158
+
159
+
160
+ class MLP(nn.Module):
161
+ """Very simple multi-layer perceptron (also called FFN)"""
162
+
163
+ def __init__(
164
+ self,
165
+ input_dim: int,
166
+ hidden_dim: int,
167
+ output_dim: int,
168
+ num_layers: int,
169
+ dropout: float = 0.0,
170
+ residual: bool = False,
171
+ out_norm: Optional[nn.Module] = None,
172
+ ):
173
+ super().__init__()
174
+ self.num_layers = num_layers
175
+ h = [hidden_dim] * (num_layers - 1)
176
+ self.layers = nn.ModuleList(
177
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
178
+ )
179
+ self.drop = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
180
+ # whether to add the output as a residual connection to the input
181
+ if residual and input_dim != output_dim:
182
+ raise ValueError("residual is only supported if input_dim == output_dim")
183
+ self.residual = residual
184
+ # whether to apply a normalization layer to the output
185
+ assert isinstance(out_norm, nn.Module) or out_norm is None
186
+ self.out_norm = out_norm or nn.Identity()
187
+
188
+ def forward(self, x):
189
+ orig_x = x
190
+ for i, layer in enumerate(self.layers):
191
+ x = self.drop(F.relu(layer(x))) if i < self.num_layers - 1 else layer(x)
192
+ if self.residual:
193
+ x = x + orig_x
194
+ x = self.out_norm(x)
195
+ return x
196
+
197
+
198
+ def get_clones(module, N):
199
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
200
+
201
+
202
+ def get_clones_seq(module, N):
203
+ return nn.Sequential(*[copy.deepcopy(module) for i in range(N)])
204
+
205
+
206
+ def get_activation_fn(activation):
207
+ """Return an activation function given a string"""
208
+ if activation == "relu":
209
+ return F.relu
210
+ if activation == "gelu":
211
+ return F.gelu
212
+ if activation == "glu":
213
+ return F.glu
214
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
215
+
216
+
217
+ def get_activation_module(activation):
218
+ """Return an activation function given a string"""
219
+ if activation == "relu":
220
+ return nn.ReLU
221
+ if activation == "gelu":
222
+ return nn.GELU
223
+ if activation == "glu":
224
+ return nn.GLU
225
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
226
+
227
+
228
+ def get_valid_ratio(mask):
229
+ _, H, W = mask.shape
230
+ valid_H = torch.sum(~mask[:, :, 0], 1)
231
+ valid_W = torch.sum(~mask[:, 0, :], 1)
232
+ valid_ratio_h = valid_H.float() / H
233
+ valid_ratio_w = valid_W.float() / W
234
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
235
+ return valid_ratio
236
+
237
+
238
+ def gen_sineembed_for_position(pos_tensor, num_feats=256):
239
+ assert num_feats % 2 == 0
240
+ num_feats = num_feats // 2
241
+ # n_query, bs, _ = pos_tensor.size()
242
+ # sineembed_tensor = torch.zeros(n_query, bs, 256)
243
+ scale = 2 * math.pi
244
+ dim_t = torch.arange(num_feats, dtype=torch.float32, device=pos_tensor.device)
245
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode="floor")) / num_feats)
246
+ x_embed = pos_tensor[:, :, 0] * scale
247
+ y_embed = pos_tensor[:, :, 1] * scale
248
+ pos_x = x_embed[:, :, None] / dim_t
249
+ pos_y = y_embed[:, :, None] / dim_t
250
+ pos_x = torch.stack(
251
+ (pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3
252
+ ).flatten(2)
253
+ pos_y = torch.stack(
254
+ (pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3
255
+ ).flatten(2)
256
+ if pos_tensor.size(-1) == 2:
257
+ pos = torch.cat((pos_y, pos_x), dim=2)
258
+ elif pos_tensor.size(-1) == 4:
259
+ w_embed = pos_tensor[:, :, 2] * scale
260
+ pos_w = w_embed[:, :, None] / dim_t
261
+ pos_w = torch.stack(
262
+ (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3
263
+ ).flatten(2)
264
+
265
+ h_embed = pos_tensor[:, :, 3] * scale
266
+ pos_h = h_embed[:, :, None] / dim_t
267
+ pos_h = torch.stack(
268
+ (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3
269
+ ).flatten(2)
270
+
271
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
272
+ else:
273
+ raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
274
+ return pos
275
+
276
+
277
+ class SAM3Output(list):
278
+ """
279
+ A class representing the output of a SAM3 model.
280
+ It provides an iterable interface that supports different iteration modes, including iterating over all steps per stage,
281
+ last step per stage, and flattened output.
282
+ Attributes:
283
+ output: The output of the SAM3 model, represented as a list of lists.
284
+ iter_mode: The current iteration mode.
285
+ Example:
286
+ >>> output = [[1, 2], [3, 4], [5, 6]]
287
+ >>> sam3_output = SAM3Output(output)
288
+ >>> for step in sam3_output:
289
+ ... print(step)
290
+ [1, 2]
291
+ [3, 4]
292
+ [5, 6]
293
+ >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.LAST_STEP_PER_STAGE) as sam3_last_step_out:
294
+ ... for step in sam3_last_step_out:
295
+ ... print(step)
296
+ [2]
297
+ [4]
298
+ [6]
299
+ >>> with SAM3Output.iteration_mode(SAM3Output.IterMode.FLATTENED) as sam3_flattened_out:
300
+ ... for step in sam3_flattened_out:
301
+ ... print(step)
302
+ 1
303
+ 2
304
+ 3
305
+ 4
306
+ 5
307
+ 6
308
+ """
309
+
310
+ class IterMode(Enum):
311
+ # Defines the type of iterator over ouptuts.
312
+ ALL_STEPS_PER_STAGE = auto()
313
+ LAST_STEP_PER_STAGE = auto()
314
+ FLATTENED = auto() # Returns each interactivity step as if it is a separate stage (this is used in SAM3Image model)
315
+
316
+ def __init__(
317
+ self,
318
+ output: List[List[Dict]] = None,
319
+ iter_mode: IterMode = IterMode.ALL_STEPS_PER_STAGE,
320
+ loss_stages: Optional[List[int]] = None,
321
+ ):
322
+ if output is not None:
323
+ assert (
324
+ isinstance(output, list)
325
+ and len(output) > 0
326
+ and isinstance(output[0], list)
327
+ ), "Expected output to be a list of lists"
328
+ self.output = output
329
+ else:
330
+ self.output = []
331
+ assert isinstance(
332
+ iter_mode, SAM3Output.IterMode
333
+ ), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
334
+
335
+ self.iter_mode = iter_mode
336
+ # We create a weak reference to self to be used in the lambda functions.
337
+ # This is to avoid cyclic references and let SAM3Output be garabge collected.
338
+ self_ref = weakref.ref(self)
339
+ self._mode2iter = {
340
+ SAM3Output.IterMode.ALL_STEPS_PER_STAGE: lambda: iter(self_ref().output),
341
+ SAM3Output.IterMode.LAST_STEP_PER_STAGE: lambda: (
342
+ inner_list[-1] for inner_list in self_ref().output
343
+ ),
344
+ SAM3Output.IterMode.FLATTENED: lambda: (
345
+ element for inner_list in self_ref().output for element in inner_list
346
+ ),
347
+ }
348
+ self.loss_stages = loss_stages
349
+
350
+ @override
351
+ def __iter__(self) -> Iterator:
352
+ return self._mode2iter[self.iter_mode]()
353
+
354
+ def __getitem__(self, index):
355
+ """
356
+ Returns the item at the specified index.
357
+ Args:
358
+ index (int): The index of the item to return.
359
+ Returns:
360
+ list or element: The item at the specified index.
361
+ """
362
+ assert isinstance(index, int), f"index should be an integer. Got {type(index)}"
363
+ if self.iter_mode == SAM3Output.IterMode.ALL_STEPS_PER_STAGE:
364
+ return self.output[index]
365
+ elif self.iter_mode == SAM3Output.IterMode.LAST_STEP_PER_STAGE:
366
+ return self.output[index][-1]
367
+ elif self.iter_mode == SAM3Output.IterMode.FLATTENED:
368
+ if index == -1:
369
+ return self.self.output[-1][-1]
370
+ else:
371
+ flattened_output = sum(self.output, [])
372
+ return flattened_output[index]
373
+
374
+ class _IterationMode(AbstractContextManager):
375
+ """
376
+ A context manager that temporarily changes the iteration mode of a SAM3Output object.
377
+ This class is used internally by the SAM3Output.iteration_mode method.
378
+ """
379
+
380
+ def __init__(
381
+ self, model_output: "SAM3Output", iter_mode: "SAM3Output.IterMode"
382
+ ):
383
+ self._model_output = model_output
384
+ self._orig_iter_mode = model_output.iter_mode
385
+ self._new_iter_mode = iter_mode
386
+
387
+ @override
388
+ def __enter__(self) -> "SAM3Output":
389
+ self._model_output.iter_mode = self._new_iter_mode
390
+ return self._model_output
391
+
392
+ @override
393
+ def __exit__(self, exc_type, exc_value, traceback):
394
+ self._model_output.iter_mode = self._orig_iter_mode
395
+ return super().__exit__(exc_type, exc_value, traceback)
396
+
397
+ @staticmethod
398
+ def iteration_mode(
399
+ model_output: "SAM3Output", iter_mode: IterMode
400
+ ) -> _IterationMode:
401
+ """
402
+ Returns a context manager that allows you to temporarily change the iteration mode of the SAM3Output object.
403
+ Args:
404
+ model_output: The SAM3Output object.
405
+ iter_mode: The new iteration mode.
406
+ Returns:
407
+ SAM3Output._IterationMode: A context manager that changes the iteration mode of the SAM3Output object.
408
+ """
409
+ return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
410
+
411
+ def append(self, item: list):
412
+ assert isinstance(
413
+ item, list
414
+ ), f"Only list items are supported. Got {type(item)}"
415
+ self.output.append(item)
416
+
417
+ def __repr__(self):
418
+ return self.output.__repr__()
419
+
420
+ def __len__(self):
421
+ if self.iter_mode in [
422
+ SAM3Output.IterMode.ALL_STEPS_PER_STAGE,
423
+ SAM3Output.IterMode.LAST_STEP_PER_STAGE,
424
+ ]:
425
+ return len(self.output)
426
+ elif self.iter_mode == SAM3Output.IterMode.FLATTENED:
427
+ flattened_output = sum(self.output, [])
428
+ return len(flattened_output)
detect_tools/sam3/sam3/model/necks.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """Necks are the interface between a vision backbone and the rest of the detection model"""
4
+
5
+ from copy import deepcopy
6
+ from typing import List, Optional, Tuple
7
+
8
+ import torch
9
+
10
+ import torch.nn as nn
11
+
12
+
13
+ class Sam3DualViTDetNeck(nn.Module):
14
+ def __init__(
15
+ self,
16
+ trunk: nn.Module,
17
+ position_encoding: nn.Module,
18
+ d_model: int,
19
+ scale_factors=(4.0, 2.0, 1.0, 0.5),
20
+ add_sam2_neck: bool = False,
21
+ ):
22
+ """
23
+ SimpleFPN neck a la ViTDet
24
+ (From detectron2, very lightly adapted)
25
+ It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights
26
+
27
+ :param trunk: the backbone
28
+ :param position_encoding: the positional encoding to use
29
+ :param d_model: the dimension of the model
30
+ """
31
+ super().__init__()
32
+ self.trunk = trunk
33
+ self.position_encoding = position_encoding
34
+ self.convs = nn.ModuleList()
35
+
36
+ self.scale_factors = scale_factors
37
+ use_bias = True
38
+ dim: int = self.trunk.channel_list[-1]
39
+
40
+ for _, scale in enumerate(scale_factors):
41
+ current = nn.Sequential()
42
+
43
+ if scale == 4.0:
44
+ current.add_module(
45
+ "dconv_2x2_0",
46
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
47
+ )
48
+ current.add_module(
49
+ "gelu",
50
+ nn.GELU(),
51
+ )
52
+ current.add_module(
53
+ "dconv_2x2_1",
54
+ nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
55
+ )
56
+ out_dim = dim // 4
57
+ elif scale == 2.0:
58
+ current.add_module(
59
+ "dconv_2x2",
60
+ nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
61
+ )
62
+ out_dim = dim // 2
63
+ elif scale == 1.0:
64
+ out_dim = dim
65
+ elif scale == 0.5:
66
+ current.add_module(
67
+ "maxpool_2x2",
68
+ nn.MaxPool2d(kernel_size=2, stride=2),
69
+ )
70
+ out_dim = dim
71
+ else:
72
+ raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
73
+
74
+ current.add_module(
75
+ "conv_1x1",
76
+ nn.Conv2d(
77
+ in_channels=out_dim,
78
+ out_channels=d_model,
79
+ kernel_size=1,
80
+ bias=use_bias,
81
+ ),
82
+ )
83
+ current.add_module(
84
+ "conv_3x3",
85
+ nn.Conv2d(
86
+ in_channels=d_model,
87
+ out_channels=d_model,
88
+ kernel_size=3,
89
+ padding=1,
90
+ bias=use_bias,
91
+ ),
92
+ )
93
+ self.convs.append(current)
94
+
95
+ self.sam2_convs = None
96
+ if add_sam2_neck:
97
+ # Assumes sam2 neck is just a clone of the original neck
98
+ self.sam2_convs = deepcopy(self.convs)
99
+
100
+ def forward(
101
+ self, tensor_list: List[torch.Tensor]
102
+ ) -> Tuple[
103
+ List[torch.Tensor],
104
+ List[torch.Tensor],
105
+ Optional[List[torch.Tensor]],
106
+ Optional[List[torch.Tensor]],
107
+ ]:
108
+ xs = self.trunk(tensor_list)
109
+ sam3_out, sam3_pos = [], []
110
+ sam2_out, sam2_pos = None, None
111
+ if self.sam2_convs is not None:
112
+ sam2_out, sam2_pos = [], []
113
+ x = xs[-1] # simpleFPN
114
+ for i in range(len(self.convs)):
115
+ sam3_x_out = self.convs[i](x)
116
+ sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
117
+ sam3_out.append(sam3_x_out)
118
+ sam3_pos.append(sam3_pos_out)
119
+
120
+ if self.sam2_convs is not None:
121
+ sam2_x_out = self.sam2_convs[i](x)
122
+ sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
123
+ sam2_out.append(sam2_x_out)
124
+ sam2_pos.append(sam2_pos_out)
125
+ return sam3_out, sam3_pos, sam2_out, sam2_pos
detect_tools/sam3/sam3/model/position_encoding.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class PositionEmbeddingSine(nn.Module):
11
+ """
12
+ This is a more standard version of the position embedding, very similar to the one
13
+ used by the Attention is all you need paper, generalized to work on images.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ num_pos_feats,
19
+ temperature: int = 10000,
20
+ normalize: bool = True,
21
+ scale: Optional[float] = None,
22
+ precompute_resolution: Optional[int] = None,
23
+ ):
24
+ super().__init__()
25
+ assert num_pos_feats % 2 == 0, "Expecting even model width"
26
+ self.num_pos_feats = num_pos_feats // 2
27
+ self.temperature = temperature
28
+ self.normalize = normalize
29
+ if scale is not None and normalize is False:
30
+ raise ValueError("normalize should be True if scale is passed")
31
+ if scale is None:
32
+ scale = 2 * math.pi
33
+ self.scale = scale
34
+
35
+ self.cache = {}
36
+ # Precompute positional encodings under `precompute_resolution` to fill the cache
37
+ # and avoid symbolic shape tracing errors in torch.compile in PyTorch 2.4 nightly.
38
+ if precompute_resolution is not None:
39
+ # We precompute pos enc for stride 4, 8, 16 and 32 to fill `self.cache`.
40
+ precompute_sizes = [
41
+ (precompute_resolution // 4, precompute_resolution // 4),
42
+ (precompute_resolution // 8, precompute_resolution // 8),
43
+ (precompute_resolution // 16, precompute_resolution // 16),
44
+ (precompute_resolution // 32, precompute_resolution // 32),
45
+ ]
46
+ for size in precompute_sizes:
47
+ tensors = torch.zeros((1, 1) + size, device="cuda")
48
+ self.forward(tensors)
49
+ # further clone and detach it in the cache (just to be safe)
50
+ self.cache[size] = self.cache[size].clone().detach()
51
+
52
+ def _encode_xy(self, x, y):
53
+ # The positions are expected to be normalized
54
+ assert len(x) == len(y) and x.ndim == y.ndim == 1
55
+ x_embed = x * self.scale
56
+ y_embed = y * self.scale
57
+
58
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
59
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
60
+
61
+ pos_x = x_embed[:, None] / dim_t
62
+ pos_y = y_embed[:, None] / dim_t
63
+ pos_x = torch.stack(
64
+ (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
65
+ ).flatten(1)
66
+ pos_y = torch.stack(
67
+ (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
68
+ ).flatten(1)
69
+ return pos_x, pos_y
70
+
71
+ @torch.no_grad()
72
+ def encode_boxes(self, x, y, w, h):
73
+ pos_x, pos_y = self._encode_xy(x, y)
74
+ pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
75
+ return pos
76
+
77
+ encode = encode_boxes # Backwards compatibility
78
+
79
+ @torch.no_grad()
80
+ def encode_points(self, x, y, labels):
81
+ (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
82
+ assert bx == by and nx == ny and bx == bl and nx == nl
83
+ pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
84
+ pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
85
+ pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
86
+ return pos
87
+
88
+ @torch.no_grad()
89
+ def forward(self, x):
90
+ cache_key = None
91
+ cache_key = (x.shape[-2], x.shape[-1])
92
+ if cache_key in self.cache:
93
+ return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
94
+ y_embed = (
95
+ torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
96
+ .view(1, -1, 1)
97
+ .repeat(x.shape[0], 1, x.shape[-1])
98
+ )
99
+ x_embed = (
100
+ torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
101
+ .view(1, 1, -1)
102
+ .repeat(x.shape[0], x.shape[-2], 1)
103
+ )
104
+
105
+ if self.normalize:
106
+ eps = 1e-6
107
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
108
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
109
+
110
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
111
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
112
+
113
+ pos_x = x_embed[:, :, :, None] / dim_t
114
+ pos_y = y_embed[:, :, :, None] / dim_t
115
+ pos_x = torch.stack(
116
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
117
+ ).flatten(3)
118
+ pos_y = torch.stack(
119
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
120
+ ).flatten(3)
121
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
122
+ if cache_key is not None:
123
+ self.cache[cache_key] = pos[0]
124
+ return pos
detect_tools/sam3/sam3/model/sam1_task_predictor.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ import torch.nn as nn
15
+ from PIL.Image import Image
16
+
17
+ from sam3.model.sam3_tracker_base import Sam3TrackerBase
18
+ from sam3.model.utils.sam1_utils import SAM2Transforms
19
+
20
+
21
+ # Adapted from https://github.com/facebookresearch/sam2/blob/main/sam2/sam2_image_predictor.py
22
+ class SAM3InteractiveImagePredictor(nn.Module):
23
+ def __init__(
24
+ self,
25
+ sam_model: Sam3TrackerBase,
26
+ mask_threshold=0.0,
27
+ max_hole_area=256.0,
28
+ max_sprinkle_area=0.0,
29
+ **kwargs,
30
+ ) -> None:
31
+ """
32
+ Uses SAM-3 to calculate the image embedding for an image, and then
33
+ allow repeated, efficient mask prediction given prompts.
34
+
35
+ Arguments:
36
+ sam_model : The model to use for mask prediction.
37
+ mask_threshold (float): The threshold to use when converting mask logits
38
+ to binary masks. Masks are thresholded at 0 by default.
39
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
40
+ the maximum area of max_hole_area in low_res_masks.
41
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
42
+ the maximum area of max_sprinkle_area in low_res_masks.
43
+ """
44
+ super().__init__()
45
+ self.model = sam_model
46
+ self._transforms = SAM2Transforms(
47
+ resolution=self.model.image_size,
48
+ mask_threshold=mask_threshold,
49
+ max_hole_area=max_hole_area,
50
+ max_sprinkle_area=max_sprinkle_area,
51
+ )
52
+
53
+ # Predictor state
54
+ self._is_image_set = False
55
+ self._features = None
56
+ self._orig_hw = None
57
+ # Whether the predictor is set for single image or a batch of images
58
+ self._is_batch = False
59
+
60
+ # Predictor config
61
+ self.mask_threshold = mask_threshold
62
+
63
+ # Spatial dim for backbone feature maps
64
+ self._bb_feat_sizes = [
65
+ (288, 288),
66
+ (144, 144),
67
+ (72, 72),
68
+ ]
69
+
70
+ @torch.no_grad()
71
+ def set_image(
72
+ self,
73
+ image: Union[np.ndarray, Image],
74
+ ) -> None:
75
+ """
76
+ Calculates the image embeddings for the provided image, allowing
77
+ masks to be predicted with the 'predict' method.
78
+
79
+ Arguments:
80
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
81
+ with pixel values in [0, 255].
82
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
83
+ """
84
+ self.reset_predictor()
85
+ # Transform the image to the form expected by the model
86
+ if isinstance(image, np.ndarray):
87
+ logging.info("For numpy array image, we assume (HxWxC) format")
88
+ self._orig_hw = [image.shape[:2]]
89
+ elif isinstance(image, Image):
90
+ w, h = image.size
91
+ self._orig_hw = [(h, w)]
92
+ else:
93
+ raise NotImplementedError("Image format not supported")
94
+
95
+ input_image = self._transforms(image)
96
+ input_image = input_image[None, ...].to(self.device)
97
+
98
+ assert (
99
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
100
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
101
+ logging.info("Computing image embeddings for the provided image...")
102
+ backbone_out = self.model.forward_image(input_image)
103
+ (
104
+ _,
105
+ vision_feats,
106
+ _,
107
+ _,
108
+ ) = self.model._prepare_backbone_features(backbone_out)
109
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
110
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
111
+
112
+ feats = [
113
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
114
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
115
+ ][::-1]
116
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
117
+ self._is_image_set = True
118
+ logging.info("Image embeddings computed.")
119
+
120
+ @torch.no_grad()
121
+ def set_image_batch(
122
+ self,
123
+ image_list: List[Union[np.ndarray]],
124
+ ) -> None:
125
+ """
126
+ Calculates the image embeddings for the provided image batch, allowing
127
+ masks to be predicted with the 'predict_batch' method.
128
+
129
+ Arguments:
130
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
131
+ with pixel values in [0, 255].
132
+ """
133
+ self.reset_predictor()
134
+ assert isinstance(image_list, list)
135
+ self._orig_hw = []
136
+ for image in image_list:
137
+ assert isinstance(
138
+ image, np.ndarray
139
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
140
+ self._orig_hw.append(image.shape[:2])
141
+ # Transform the image to the form expected by the model
142
+ img_batch = self._transforms.forward_batch(image_list)
143
+ img_batch = img_batch.to(self.device)
144
+ batch_size = img_batch.shape[0]
145
+ assert (
146
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
147
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
148
+ logging.info("Computing image embeddings for the provided images...")
149
+ backbone_out = self.model.forward_image(img_batch)
150
+ (
151
+ _,
152
+ vision_feats,
153
+ _,
154
+ _,
155
+ ) = self.model._prepare_backbone_features(backbone_out)
156
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
157
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
158
+
159
+ feats = [
160
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
161
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
162
+ ][::-1]
163
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
164
+ self._is_image_set = True
165
+ self._is_batch = True
166
+ logging.info("Image embeddings computed.")
167
+
168
+ def predict_batch(
169
+ self,
170
+ point_coords_batch: List[np.ndarray] = None,
171
+ point_labels_batch: List[np.ndarray] = None,
172
+ box_batch: List[np.ndarray] = None,
173
+ mask_input_batch: List[np.ndarray] = None,
174
+ multimask_output: bool = True,
175
+ return_logits: bool = False,
176
+ normalize_coords=True,
177
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
178
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
179
+ It returns a tuple of lists of masks, ious, and low_res_masks_logits.
180
+ """
181
+ assert self._is_batch, "This function should only be used when in batched mode"
182
+ if not self._is_image_set:
183
+ raise RuntimeError(
184
+ "An image must be set with .set_image_batch(...) before mask prediction."
185
+ )
186
+ num_images = len(self._features["image_embed"])
187
+ all_masks = []
188
+ all_ious = []
189
+ all_low_res_masks = []
190
+ for img_idx in range(num_images):
191
+ # Transform input prompts
192
+ point_coords = (
193
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
194
+ )
195
+ point_labels = (
196
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
197
+ )
198
+ box = box_batch[img_idx] if box_batch is not None else None
199
+ mask_input = (
200
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
201
+ )
202
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
203
+ point_coords,
204
+ point_labels,
205
+ box,
206
+ mask_input,
207
+ normalize_coords,
208
+ img_idx=img_idx,
209
+ )
210
+ masks, iou_predictions, low_res_masks = self._predict(
211
+ unnorm_coords,
212
+ labels,
213
+ unnorm_box,
214
+ mask_input,
215
+ multimask_output,
216
+ return_logits=return_logits,
217
+ img_idx=img_idx,
218
+ )
219
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
220
+ iou_predictions_np = (
221
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
222
+ )
223
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
224
+ all_masks.append(masks_np)
225
+ all_ious.append(iou_predictions_np)
226
+ all_low_res_masks.append(low_res_masks_np)
227
+
228
+ return all_masks, all_ious, all_low_res_masks
229
+
230
+ def predict(
231
+ self,
232
+ point_coords: Optional[np.ndarray] = None,
233
+ point_labels: Optional[np.ndarray] = None,
234
+ box: Optional[np.ndarray] = None,
235
+ mask_input: Optional[np.ndarray] = None,
236
+ multimask_output: bool = True,
237
+ return_logits: bool = False,
238
+ normalize_coords=True,
239
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
240
+ """
241
+ Predict masks for the given input prompts, using the currently set image.
242
+
243
+ Arguments:
244
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
245
+ model. Each point is in (X,Y) in pixels.
246
+ point_labels (np.ndarray or None): A length N array of labels for the
247
+ point prompts. 1 indicates a foreground point and 0 indicates a
248
+ background point.
249
+ box (np.ndarray or None): A length 4 array given a box prompt to the
250
+ model, in XYXY format.
251
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
252
+ coming from a previous prediction iteration. Has form 1xHxW, where
253
+ for SAM, H=W=256.
254
+ multimask_output (bool): If true, the model will return three masks.
255
+ For ambiguous input prompts (such as a single click), this will often
256
+ produce better masks than a single prediction. If only a single
257
+ mask is needed, the model's predicted quality score can be used
258
+ to select the best mask. For non-ambiguous prompts, such as multiple
259
+ input prompts, multimask_output=False can give better results.
260
+ return_logits (bool): If true, returns un-thresholded masks logits
261
+ instead of a binary mask.
262
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
263
+
264
+ Returns:
265
+ (np.ndarray): The output masks in CxHxW format, where C is the
266
+ number of masks, and (H, W) is the original image size.
267
+ (np.ndarray): An array of length C containing the model's
268
+ predictions for the quality of each mask.
269
+ (np.ndarray): An array of shape CxHxW, where C is the number
270
+ of masks and H=W=256. These low resolution logits can be passed to
271
+ a subsequent iteration as mask input.
272
+ """
273
+ if not self._is_image_set:
274
+ raise RuntimeError(
275
+ "An image must be set with .set_image(...) before mask prediction."
276
+ )
277
+
278
+ # Transform input prompts
279
+
280
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
281
+ point_coords, point_labels, box, mask_input, normalize_coords
282
+ )
283
+
284
+ masks, iou_predictions, low_res_masks = self._predict(
285
+ unnorm_coords,
286
+ labels,
287
+ unnorm_box,
288
+ mask_input,
289
+ multimask_output,
290
+ return_logits=return_logits,
291
+ )
292
+
293
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
294
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
295
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
296
+ return masks_np, iou_predictions_np, low_res_masks_np
297
+
298
+ def _prep_prompts(
299
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
300
+ ):
301
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
302
+ if point_coords is not None:
303
+ assert (
304
+ point_labels is not None
305
+ ), "point_labels must be supplied if point_coords is supplied."
306
+ point_coords = torch.as_tensor(
307
+ point_coords, dtype=torch.float, device=self.device
308
+ )
309
+ unnorm_coords = self._transforms.transform_coords(
310
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
311
+ )
312
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
313
+ if len(unnorm_coords.shape) == 2:
314
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
315
+ if box is not None:
316
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
317
+ unnorm_box = self._transforms.transform_boxes(
318
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
319
+ ) # Bx2x2
320
+ if mask_logits is not None:
321
+ mask_input = torch.as_tensor(
322
+ mask_logits, dtype=torch.float, device=self.device
323
+ )
324
+ if len(mask_input.shape) == 3:
325
+ mask_input = mask_input[None, :, :, :]
326
+ return mask_input, unnorm_coords, labels, unnorm_box
327
+
328
+ @torch.no_grad()
329
+ def _predict(
330
+ self,
331
+ point_coords: Optional[torch.Tensor],
332
+ point_labels: Optional[torch.Tensor],
333
+ boxes: Optional[torch.Tensor] = None,
334
+ mask_input: Optional[torch.Tensor] = None,
335
+ multimask_output: bool = True,
336
+ return_logits: bool = False,
337
+ img_idx: int = -1,
338
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
339
+ """
340
+ Predict masks for the given input prompts, using the currently set image.
341
+ Input prompts are batched torch tensors and are expected to already be
342
+ transformed to the input frame using SAM2Transforms.
343
+
344
+ Arguments:
345
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
346
+ model. Each point is in (X,Y) in pixels.
347
+ point_labels (torch.Tensor or None): A BxN array of labels for the
348
+ point prompts. 1 indicates a foreground point and 0 indicates a
349
+ background point.
350
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
351
+ model, in XYXY format.
352
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
353
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
354
+ for SAM, H=W=256. Masks returned by a previous iteration of the
355
+ predict method do not need further transformation.
356
+ multimask_output (bool): If true, the model will return three masks.
357
+ For ambiguous input prompts (such as a single click), this will often
358
+ produce better masks than a single prediction. If only a single
359
+ mask is needed, the model's predicted quality score can be used
360
+ to select the best mask. For non-ambiguous prompts, such as multiple
361
+ input prompts, multimask_output=False can give better results.
362
+ return_logits (bool): If true, returns un-thresholded masks logits
363
+ instead of a binary mask.
364
+
365
+ Returns:
366
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
367
+ number of masks, and (H, W) is the original image size.
368
+ (torch.Tensor): An array of shape BxC containing the model's
369
+ predictions for the quality of each mask.
370
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
371
+ of masks and H=W=256. These low res logits can be passed to
372
+ a subsequent iteration as mask input.
373
+ """
374
+ if not self._is_image_set:
375
+ raise RuntimeError(
376
+ "An image must be set with .set_image(...) before mask prediction."
377
+ )
378
+
379
+ if point_coords is not None:
380
+ concat_points = (point_coords, point_labels)
381
+ else:
382
+ concat_points = None
383
+
384
+ # Embed prompts
385
+ if boxes is not None:
386
+ box_coords = boxes.reshape(-1, 2, 2)
387
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
388
+ box_labels = box_labels.repeat(boxes.size(0), 1)
389
+ # we merge "boxes" and "points" into a single "concat_points" input (where
390
+ # boxes are added at the beginning) to sam_prompt_encoder
391
+ if concat_points is not None:
392
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
393
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
394
+ concat_points = (concat_coords, concat_labels)
395
+ else:
396
+ concat_points = (box_coords, box_labels)
397
+
398
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
399
+ points=concat_points,
400
+ boxes=None,
401
+ masks=mask_input,
402
+ )
403
+
404
+ # Predict masks
405
+ batched_mode = (
406
+ concat_points is not None and concat_points[0].shape[0] > 1
407
+ ) # multi object prediction
408
+ high_res_features = [
409
+ feat_level[img_idx].unsqueeze(0)
410
+ for feat_level in self._features["high_res_feats"]
411
+ ]
412
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
413
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
414
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
415
+ sparse_prompt_embeddings=sparse_embeddings,
416
+ dense_prompt_embeddings=dense_embeddings,
417
+ multimask_output=multimask_output,
418
+ repeat_image=batched_mode,
419
+ high_res_features=high_res_features,
420
+ )
421
+
422
+ # Upscale the masks to the original image resolution
423
+ masks = self._transforms.postprocess_masks(
424
+ low_res_masks, self._orig_hw[img_idx]
425
+ )
426
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
427
+ if not return_logits:
428
+ masks = masks > self.mask_threshold
429
+
430
+ return masks, iou_predictions, low_res_masks
431
+
432
+ def get_image_embedding(self) -> torch.Tensor:
433
+ """
434
+ Returns the image embeddings for the currently set image, with
435
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
436
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
437
+ """
438
+ if not self._is_image_set:
439
+ raise RuntimeError(
440
+ "An image must be set with .set_image(...) to generate an embedding."
441
+ )
442
+ assert (
443
+ self._features is not None
444
+ ), "Features must exist if an image has been set."
445
+ return self._features["image_embed"]
446
+
447
+ @property
448
+ def device(self) -> torch.device:
449
+ return self.model.device
450
+
451
+ def reset_predictor(self) -> None:
452
+ """
453
+ Resets the image embeddings and other state variables.
454
+ """
455
+ self._is_image_set = False
456
+ self._features = None
457
+ self._orig_hw = None
458
+ self._is_batch = False
detect_tools/sam3/sam3/model/sam3_image.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import os
4
+ from copy import deepcopy
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from sam3.model.model_misc import SAM3Output
11
+
12
+ from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
13
+ from sam3.model.vl_combiner import SAM3VLBackbone
14
+ from sam3.perflib.nms import nms_masks
15
+
16
+ from sam3.train.data.collator import BatchedDatapoint
17
+
18
+ from .act_ckpt_utils import activation_ckpt_wrapper
19
+
20
+ from .box_ops import box_cxcywh_to_xyxy
21
+
22
+ from .geometry_encoders import Prompt
23
+ from .model_misc import inverse_sigmoid
24
+
25
+
26
+ def _update_out(out, out_name, out_value, auxiliary=True, update_aux=True):
27
+ out[out_name] = out_value[-1] if auxiliary else out_value
28
+ if auxiliary and update_aux:
29
+ if "aux_outputs" not in out:
30
+ out["aux_outputs"] = [{} for _ in range(len(out_value) - 1)]
31
+ assert len(out["aux_outputs"]) == len(out_value) - 1
32
+ for aux_output, aux_value in zip(out["aux_outputs"], out_value[:-1]):
33
+ aux_output[out_name] = aux_value
34
+
35
+
36
+ class Sam3Image(torch.nn.Module):
37
+ TEXT_ID_FOR_TEXT = 0
38
+ TEXT_ID_FOR_VISUAL = 1
39
+ TEXT_ID_FOR_GEOMETRIC = 2
40
+
41
+ def __init__(
42
+ self,
43
+ backbone: SAM3VLBackbone,
44
+ transformer,
45
+ input_geometry_encoder,
46
+ segmentation_head=None,
47
+ num_feature_levels=1,
48
+ o2m_mask_predict=True,
49
+ dot_prod_scoring=None,
50
+ use_instance_query: bool = True,
51
+ multimask_output: bool = True,
52
+ use_act_checkpoint_seg_head: bool = True,
53
+ interactivity_in_encoder: bool = True,
54
+ matcher=None,
55
+ use_dot_prod_scoring=True,
56
+ supervise_joint_box_scores: bool = False, # only relevant if using presence token/score
57
+ detach_presence_in_joint_score: bool = False, # only relevant if using presence token/score
58
+ separate_scorer_for_instance: bool = False,
59
+ num_interactive_steps_val: int = 0,
60
+ inst_interactive_predictor: SAM3InteractiveImagePredictor = None,
61
+ **kwargs,
62
+ ):
63
+ super().__init__()
64
+ self.backbone = backbone
65
+ self.geometry_encoder = input_geometry_encoder
66
+ self.transformer = transformer
67
+ self.hidden_dim = transformer.d_model
68
+ self.num_feature_levels = num_feature_levels
69
+ self.segmentation_head = segmentation_head
70
+
71
+ self.o2m_mask_predict = o2m_mask_predict
72
+
73
+ self.dot_prod_scoring = dot_prod_scoring
74
+ self.use_act_checkpoint_seg_head = use_act_checkpoint_seg_head
75
+ self.interactivity_in_encoder = interactivity_in_encoder
76
+ self.matcher = matcher
77
+
78
+ self.num_interactive_steps_val = num_interactive_steps_val
79
+ self.use_dot_prod_scoring = use_dot_prod_scoring
80
+
81
+ if self.use_dot_prod_scoring:
82
+ assert dot_prod_scoring is not None
83
+ self.dot_prod_scoring = dot_prod_scoring
84
+ self.instance_dot_prod_scoring = None
85
+ if separate_scorer_for_instance:
86
+ self.instance_dot_prod_scoring = deepcopy(dot_prod_scoring)
87
+ else:
88
+ self.class_embed = torch.nn.Linear(self.hidden_dim, 1)
89
+ self.instance_class_embed = None
90
+ if separate_scorer_for_instance:
91
+ self.instance_class_embed = deepcopy(self.class_embed)
92
+
93
+ self.supervise_joint_box_scores = supervise_joint_box_scores
94
+ self.detach_presence_in_joint_score = detach_presence_in_joint_score
95
+
96
+ # verify the number of queries for O2O and O2M
97
+ num_o2o_static = self.transformer.decoder.num_queries
98
+ num_o2m_static = self.transformer.decoder.num_o2m_queries
99
+ assert num_o2m_static == (num_o2o_static if self.transformer.decoder.dac else 0)
100
+ self.dac = self.transformer.decoder.dac
101
+
102
+ self.use_instance_query = use_instance_query
103
+ self.multimask_output = multimask_output
104
+
105
+ self.inst_interactive_predictor = inst_interactive_predictor
106
+
107
+ @property
108
+ def device(self):
109
+ self._device = getattr(self, "_device", None) or next(self.parameters()).device
110
+ return self._device
111
+
112
+ def to(self, *args, **kwargs):
113
+ # clear cached _device in case the model is moved to a different device
114
+ self._device = None
115
+ return super().to(*args, **kwargs)
116
+
117
+ def _get_img_feats(self, backbone_out, img_ids):
118
+ """Retrieve correct image features from backbone output."""
119
+ if "backbone_fpn" in backbone_out:
120
+ if "id_mapping" in backbone_out and backbone_out["id_mapping"] is not None:
121
+ img_ids = backbone_out["id_mapping"][img_ids]
122
+ # If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?)
123
+ # We currently don't expect this to happen. We could technically trigger a recompute here,
124
+ # but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf
125
+ torch._assert_async((img_ids >= 0).all())
126
+
127
+ vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :]
128
+ vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
129
+ vis_feat_sizes = [x.shape[-2:] for x in vis_pos_enc] # (H, W) shapes
130
+ # index and flatten visual features NxCxHxW => HWxNxC (batch-first => seq-first)
131
+ img_feats = [x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_feats]
132
+ img_pos_embeds = [
133
+ x[img_ids].flatten(2).permute(2, 0, 1) for x in vis_pos_enc
134
+ ]
135
+ return backbone_out, img_feats, img_pos_embeds, vis_feat_sizes
136
+
137
+ # Image features not available in backbone output, so we compute them on the fly
138
+ # This case likely occurs for video. In that case, we want to forward only the current frame
139
+ img_batch = backbone_out["img_batch_all_stages"]
140
+ if img_ids.numel() > 1:
141
+ # Only forward backbone on unique image ids to avoid repetitive computation
142
+ unique_ids, _ = torch.unique(img_ids, return_inverse=True)
143
+ else:
144
+ unique_ids, _ = img_ids, slice(None)
145
+ # Compute the image features on those unique image ids
146
+ # note: we allow using a list (or other indexable types) of tensors as img_batch
147
+ # (e.g. for async frame loading in demo). In this case we index img_batch.tensors directly
148
+ if isinstance(img_batch, torch.Tensor):
149
+ image = img_batch[unique_ids]
150
+ elif unique_ids.numel() == 1:
151
+ image = img_batch[unique_ids.item()].unsqueeze(0)
152
+ else:
153
+ image = torch.stack([img_batch[i] for i in unique_ids.tolist()])
154
+ # `img_batch` might be fp16 and offloaded to CPU
155
+ image = image.to(dtype=torch.float32, device=self.device)
156
+ # Next time we call this function, we want to remember which indices we computed
157
+ id_mapping = torch.full(
158
+ (len(img_batch),), -1, dtype=torch.long, device=self.device
159
+ )
160
+ id_mapping[unique_ids] = torch.arange(len(unique_ids), device=self.device)
161
+ backbone_out = {
162
+ **backbone_out,
163
+ **self.backbone.forward_image(image),
164
+ "id_mapping": id_mapping,
165
+ }
166
+ assert "backbone_fpn" in backbone_out
167
+ return self._get_img_feats(backbone_out, img_ids=img_ids)
168
+
169
+ def _encode_prompt(
170
+ self,
171
+ backbone_out,
172
+ find_input,
173
+ geometric_prompt,
174
+ visual_prompt_embed=None,
175
+ visual_prompt_mask=None,
176
+ encode_text=True,
177
+ prev_mask_pred=None,
178
+ ):
179
+ # index text features (note that regardless of early or late fusion, the batch size of
180
+ # `txt_feats` is always the number of *prompts* in the encoder)
181
+ txt_ids = find_input.text_ids
182
+ txt_feats = backbone_out["language_features"][:, txt_ids]
183
+ txt_masks = backbone_out["language_mask"][txt_ids]
184
+
185
+ feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
186
+ backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
187
+
188
+ if prev_mask_pred is not None:
189
+ img_feats = [img_feats[-1] + prev_mask_pred]
190
+ # Encode geometry
191
+ geo_feats, geo_masks = self.geometry_encoder(
192
+ geo_prompt=geometric_prompt,
193
+ img_feats=img_feats,
194
+ img_sizes=vis_feat_sizes,
195
+ img_pos_embeds=img_pos_embeds,
196
+ )
197
+ if visual_prompt_embed is None:
198
+ visual_prompt_embed = torch.zeros(
199
+ (0, *geo_feats.shape[1:]), device=geo_feats.device
200
+ )
201
+ visual_prompt_mask = torch.zeros(
202
+ (*geo_masks.shape[:-1], 0),
203
+ device=geo_masks.device,
204
+ dtype=geo_masks.dtype,
205
+ )
206
+ if encode_text:
207
+ prompt = torch.cat([txt_feats, geo_feats, visual_prompt_embed], dim=0)
208
+ prompt_mask = torch.cat([txt_masks, geo_masks, visual_prompt_mask], dim=1)
209
+ else:
210
+ prompt = torch.cat([geo_feats, visual_prompt_embed], dim=0)
211
+ prompt_mask = torch.cat([geo_masks, visual_prompt_mask], dim=1)
212
+ return prompt, prompt_mask, backbone_out
213
+
214
+ def _run_encoder(
215
+ self,
216
+ backbone_out,
217
+ find_input,
218
+ prompt,
219
+ prompt_mask,
220
+ encoder_extra_kwargs: Optional[Dict] = None,
221
+ ):
222
+ feat_tuple = self._get_img_feats(backbone_out, find_input.img_ids)
223
+ backbone_out, img_feats, img_pos_embeds, vis_feat_sizes = feat_tuple
224
+
225
+ # Run the encoder
226
+ prompt_pos_embed = torch.zeros_like(prompt)
227
+ # make a copy of the image feature lists since the encoder may modify these lists in-place
228
+ memory = self.transformer.encoder(
229
+ src=img_feats.copy(),
230
+ src_key_padding_mask=None,
231
+ src_pos=img_pos_embeds.copy(),
232
+ prompt=prompt,
233
+ prompt_pos=prompt_pos_embed,
234
+ prompt_key_padding_mask=prompt_mask,
235
+ feat_sizes=vis_feat_sizes,
236
+ encoder_extra_kwargs=encoder_extra_kwargs,
237
+ )
238
+ encoder_out = {
239
+ # encoded image features
240
+ "encoder_hidden_states": memory["memory"],
241
+ "pos_embed": memory["pos_embed"],
242
+ "padding_mask": memory["padding_mask"],
243
+ "level_start_index": memory["level_start_index"],
244
+ "spatial_shapes": memory["spatial_shapes"],
245
+ "valid_ratios": memory["valid_ratios"],
246
+ "vis_feat_sizes": vis_feat_sizes,
247
+ # encoded text features (or other prompts)
248
+ "prompt_before_enc": prompt,
249
+ "prompt_after_enc": memory.get("memory_text", prompt),
250
+ "prompt_mask": prompt_mask,
251
+ }
252
+ return backbone_out, encoder_out, feat_tuple
253
+
254
+ def _run_decoder(
255
+ self,
256
+ pos_embed,
257
+ memory,
258
+ src_mask,
259
+ out,
260
+ prompt,
261
+ prompt_mask,
262
+ encoder_out,
263
+ ):
264
+ bs = memory.shape[1]
265
+ query_embed = self.transformer.decoder.query_embed.weight
266
+ tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
267
+
268
+ apply_dac = self.transformer.decoder.dac and self.training
269
+ hs, reference_boxes, dec_presence_out, dec_presence_feats = (
270
+ self.transformer.decoder(
271
+ tgt=tgt,
272
+ memory=memory,
273
+ memory_key_padding_mask=src_mask,
274
+ pos=pos_embed,
275
+ reference_boxes=None,
276
+ level_start_index=encoder_out["level_start_index"],
277
+ spatial_shapes=encoder_out["spatial_shapes"],
278
+ valid_ratios=encoder_out["valid_ratios"],
279
+ tgt_mask=None,
280
+ memory_text=prompt,
281
+ text_attention_mask=prompt_mask,
282
+ apply_dac=apply_dac,
283
+ )
284
+ )
285
+ hs = hs.transpose(1, 2) # seq-first to batch-first
286
+ reference_boxes = reference_boxes.transpose(1, 2) # seq-first to batch-first
287
+ if dec_presence_out is not None:
288
+ # seq-first to batch-first
289
+ dec_presence_out = dec_presence_out.transpose(1, 2)
290
+
291
+ out["presence_feats"] = dec_presence_feats
292
+ self._update_scores_and_boxes(
293
+ out,
294
+ hs,
295
+ reference_boxes,
296
+ prompt,
297
+ prompt_mask,
298
+ dec_presence_out=dec_presence_out,
299
+ )
300
+ return out, hs
301
+
302
+ def _update_scores_and_boxes(
303
+ self,
304
+ out,
305
+ hs,
306
+ reference_boxes,
307
+ prompt,
308
+ prompt_mask,
309
+ dec_presence_out=None,
310
+ is_instance_prompt=False,
311
+ ):
312
+ apply_dac = self.transformer.decoder.dac and self.training
313
+ num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
314
+ num_o2m = hs.size(2) - num_o2o
315
+ assert num_o2m == (num_o2o if apply_dac else 0)
316
+ out["queries"] = hs[-1][:, :num_o2o] # remove o2m queries if there are any
317
+ # score prediction
318
+ if self.use_dot_prod_scoring:
319
+ dot_prod_scoring_head = self.dot_prod_scoring
320
+ if is_instance_prompt and self.instance_dot_prod_scoring is not None:
321
+ dot_prod_scoring_head = self.instance_dot_prod_scoring
322
+ outputs_class = dot_prod_scoring_head(hs, prompt, prompt_mask)
323
+ else:
324
+ class_embed_head = self.class_embed
325
+ if is_instance_prompt and self.instance_class_embed is not None:
326
+ class_embed_head = self.instance_class_embed
327
+ outputs_class = class_embed_head(hs)
328
+
329
+ # box prediction
330
+ box_head = self.transformer.decoder.bbox_embed
331
+ if (
332
+ is_instance_prompt
333
+ and self.transformer.decoder.instance_bbox_embed is not None
334
+ ):
335
+ box_head = self.transformer.decoder.instance_bbox_embed
336
+ anchor_box_offsets = box_head(hs)
337
+ reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
338
+ outputs_coord = (reference_boxes_inv_sig + anchor_box_offsets).sigmoid()
339
+ outputs_boxes_xyxy = box_cxcywh_to_xyxy(outputs_coord)
340
+
341
+ if dec_presence_out is not None:
342
+ _update_out(
343
+ out, "presence_logit_dec", dec_presence_out, update_aux=self.training
344
+ )
345
+
346
+ if self.supervise_joint_box_scores:
347
+ assert dec_presence_out is not None
348
+ prob_dec_presence_out = dec_presence_out.clone().sigmoid()
349
+ if self.detach_presence_in_joint_score:
350
+ prob_dec_presence_out = prob_dec_presence_out.detach()
351
+
352
+ outputs_class = inverse_sigmoid(
353
+ outputs_class.sigmoid() * prob_dec_presence_out.unsqueeze(2)
354
+ ).clamp(min=-10.0, max=10.0)
355
+
356
+ _update_out(
357
+ out, "pred_logits", outputs_class[:, :, :num_o2o], update_aux=self.training
358
+ )
359
+ _update_out(
360
+ out, "pred_boxes", outputs_coord[:, :, :num_o2o], update_aux=self.training
361
+ )
362
+ _update_out(
363
+ out,
364
+ "pred_boxes_xyxy",
365
+ outputs_boxes_xyxy[:, :, :num_o2o],
366
+ update_aux=self.training,
367
+ )
368
+ if num_o2m > 0 and self.training:
369
+ _update_out(
370
+ out,
371
+ "pred_logits_o2m",
372
+ outputs_class[:, :, num_o2o:],
373
+ update_aux=self.training,
374
+ )
375
+ _update_out(
376
+ out,
377
+ "pred_boxes_o2m",
378
+ outputs_coord[:, :, num_o2o:],
379
+ update_aux=self.training,
380
+ )
381
+ _update_out(
382
+ out,
383
+ "pred_boxes_xyxy_o2m",
384
+ outputs_boxes_xyxy[:, :, num_o2o:],
385
+ update_aux=self.training,
386
+ )
387
+
388
+ def _run_segmentation_heads(
389
+ self,
390
+ out,
391
+ backbone_out,
392
+ img_ids,
393
+ vis_feat_sizes,
394
+ encoder_hidden_states,
395
+ prompt,
396
+ prompt_mask,
397
+ hs,
398
+ ):
399
+ apply_dac = self.transformer.decoder.dac and self.training
400
+ if self.segmentation_head is not None:
401
+ num_o2o = (hs.size(2) // 2) if apply_dac else hs.size(2)
402
+ num_o2m = hs.size(2) - num_o2o
403
+ obj_queries = hs if self.o2m_mask_predict else hs[:, :, :num_o2o]
404
+ seg_head_outputs = activation_ckpt_wrapper(self.segmentation_head)(
405
+ backbone_feats=backbone_out["backbone_fpn"],
406
+ obj_queries=obj_queries,
407
+ image_ids=img_ids,
408
+ encoder_hidden_states=encoder_hidden_states,
409
+ act_ckpt_enable=self.training and self.use_act_checkpoint_seg_head,
410
+ prompt=prompt,
411
+ prompt_mask=prompt_mask,
412
+ )
413
+ aux_masks = False # self.aux_loss and self.segmentation_head.aux_masks
414
+ for k, v in seg_head_outputs.items():
415
+ if k in self.segmentation_head.instance_keys:
416
+ _update_out(out, k, v[:, :num_o2o], auxiliary=aux_masks)
417
+ if (
418
+ self.o2m_mask_predict and num_o2m > 0
419
+ ): # handle o2m mask prediction
420
+ _update_out(
421
+ out, f"{k}_o2m", v[:, num_o2o:], auxiliary=aux_masks
422
+ )
423
+ else:
424
+ out[k] = v
425
+ else:
426
+ backbone_out.pop("backbone_fpn", None)
427
+
428
+ def _get_best_mask(self, out):
429
+ prev_mask_idx = out["pred_logits"].argmax(dim=1).squeeze(1)
430
+ batch_idx = torch.arange(
431
+ out["pred_logits"].shape[0], device=prev_mask_idx.device
432
+ )
433
+ prev_mask_pred = out["pred_masks"][batch_idx, prev_mask_idx][:, None]
434
+ # Downsample mask to match image resolution.
435
+ prev_mask_pred = self.geometry_encoder.mask_encoder.mask_downsampler(
436
+ prev_mask_pred
437
+ )
438
+ prev_mask_pred = prev_mask_pred.flatten(-2).permute(2, 0, 1)
439
+
440
+ return prev_mask_pred
441
+
442
+ def forward_grounding(
443
+ self,
444
+ backbone_out,
445
+ find_input,
446
+ find_target,
447
+ geometric_prompt: Prompt,
448
+ ):
449
+ with torch.profiler.record_function("SAM3Image._encode_prompt"):
450
+ prompt, prompt_mask, backbone_out = self._encode_prompt(
451
+ backbone_out, find_input, geometric_prompt
452
+ )
453
+ # Run the encoder
454
+ with torch.profiler.record_function("SAM3Image._run_encoder"):
455
+ backbone_out, encoder_out, _ = self._run_encoder(
456
+ backbone_out, find_input, prompt, prompt_mask
457
+ )
458
+ out = {
459
+ "encoder_hidden_states": encoder_out["encoder_hidden_states"],
460
+ "prev_encoder_out": {
461
+ "encoder_out": encoder_out,
462
+ "backbone_out": backbone_out,
463
+ },
464
+ }
465
+
466
+ # Run the decoder
467
+ with torch.profiler.record_function("SAM3Image._run_decoder"):
468
+ out, hs = self._run_decoder(
469
+ memory=out["encoder_hidden_states"],
470
+ pos_embed=encoder_out["pos_embed"],
471
+ src_mask=encoder_out["padding_mask"],
472
+ out=out,
473
+ prompt=prompt,
474
+ prompt_mask=prompt_mask,
475
+ encoder_out=encoder_out,
476
+ )
477
+
478
+ # Run segmentation heads
479
+ with torch.profiler.record_function("SAM3Image._run_segmentation_heads"):
480
+ self._run_segmentation_heads(
481
+ out=out,
482
+ backbone_out=backbone_out,
483
+ img_ids=find_input.img_ids,
484
+ vis_feat_sizes=encoder_out["vis_feat_sizes"],
485
+ encoder_hidden_states=out["encoder_hidden_states"],
486
+ prompt=prompt,
487
+ prompt_mask=prompt_mask,
488
+ hs=hs,
489
+ )
490
+
491
+ if self.training or self.num_interactive_steps_val > 0:
492
+ self._compute_matching(out, self.back_convert(find_target))
493
+ return out
494
+
495
+ def _postprocess_out(self, out: Dict, multimask_output: bool = False):
496
+ # For multimask output, during eval we return the single best mask with the dict keys expected by the evaluators, but also return the multimasks output with new keys.
497
+ num_mask_boxes = out["pred_boxes"].size(1)
498
+ if not self.training and multimask_output and num_mask_boxes > 1:
499
+ out["multi_pred_logits"] = out["pred_logits"]
500
+ if "pred_masks" in out:
501
+ out["multi_pred_masks"] = out["pred_masks"]
502
+ out["multi_pred_boxes"] = out["pred_boxes"]
503
+ out["multi_pred_boxes_xyxy"] = out["pred_boxes_xyxy"]
504
+
505
+ best_mask_idx = out["pred_logits"].argmax(1).squeeze(1)
506
+ batch_idx = torch.arange(len(best_mask_idx), device=best_mask_idx.device)
507
+
508
+ out["pred_logits"] = out["pred_logits"][batch_idx, best_mask_idx].unsqueeze(
509
+ 1
510
+ )
511
+ if "pred_masks" in out:
512
+ out["pred_masks"] = out["pred_masks"][
513
+ batch_idx, best_mask_idx
514
+ ].unsqueeze(1)
515
+ out["pred_boxes"] = out["pred_boxes"][batch_idx, best_mask_idx].unsqueeze(1)
516
+ out["pred_boxes_xyxy"] = out["pred_boxes_xyxy"][
517
+ batch_idx, best_mask_idx
518
+ ].unsqueeze(1)
519
+
520
+ return out
521
+
522
+ def _get_dummy_prompt(self, num_prompts=1):
523
+ device = self.device
524
+ geometric_prompt = Prompt(
525
+ box_embeddings=torch.zeros(0, num_prompts, 4, device=device),
526
+ box_mask=torch.zeros(num_prompts, 0, device=device, dtype=torch.bool),
527
+ )
528
+ return geometric_prompt
529
+
530
+ def forward(self, input: BatchedDatapoint):
531
+ device = self.device
532
+ backbone_out = {"img_batch_all_stages": input.img_batch}
533
+ backbone_out.update(self.backbone.forward_image(input.img_batch))
534
+ num_frames = len(input.find_inputs)
535
+ assert num_frames == 1
536
+
537
+ text_outputs = self.backbone.forward_text(input.find_text_batch, device=device)
538
+ backbone_out.update(text_outputs)
539
+
540
+ previous_stages_out = SAM3Output(
541
+ iter_mode=SAM3Output.IterMode.LAST_STEP_PER_STAGE
542
+ )
543
+
544
+ find_input = input.find_inputs[0]
545
+ find_target = input.find_targets[0]
546
+
547
+ if find_input.input_points is not None and find_input.input_points.numel() > 0:
548
+ print("Warning: Point prompts are ignored in PCS.")
549
+
550
+ num_interactive_steps = 0 if self.training else self.num_interactive_steps_val
551
+ geometric_prompt = Prompt(
552
+ box_embeddings=find_input.input_boxes,
553
+ box_mask=find_input.input_boxes_mask,
554
+ box_labels=find_input.input_boxes_label,
555
+ )
556
+
557
+ # Init vars that are shared across the loop.
558
+ stage_outs = []
559
+ for cur_step in range(num_interactive_steps + 1):
560
+ if cur_step > 0:
561
+ # We sample interactive geometric prompts (boxes, points)
562
+ geometric_prompt, _ = self.interactive_prompt_sampler.sample(
563
+ geo_prompt=geometric_prompt,
564
+ find_target=find_target,
565
+ previous_out=stage_outs[-1],
566
+ )
567
+ out = self.forward_grounding(
568
+ backbone_out=backbone_out,
569
+ find_input=find_input,
570
+ find_target=find_target,
571
+ geometric_prompt=geometric_prompt.clone(),
572
+ )
573
+ stage_outs.append(out)
574
+
575
+ previous_stages_out.append(stage_outs)
576
+ return previous_stages_out
577
+
578
+ def _compute_matching(self, out, targets):
579
+ out["indices"] = self.matcher(out, targets)
580
+ for aux_out in out.get("aux_outputs", []):
581
+ aux_out["indices"] = self.matcher(aux_out, targets)
582
+
583
+ def back_convert(self, targets):
584
+ batched_targets = {
585
+ "boxes": targets.boxes.view(-1, 4),
586
+ "boxes_xyxy": box_cxcywh_to_xyxy(targets.boxes.view(-1, 4)),
587
+ "boxes_padded": targets.boxes_padded,
588
+ "positive_map": targets.boxes.new_ones(len(targets.boxes), 1),
589
+ "num_boxes": targets.num_boxes,
590
+ "masks": targets.segments,
591
+ "semantic_masks": targets.semantic_segments,
592
+ "is_valid_mask": targets.is_valid_segment,
593
+ "is_exhaustive": targets.is_exhaustive,
594
+ "object_ids_packed": targets.object_ids,
595
+ "object_ids_padded": targets.object_ids_padded,
596
+ }
597
+ return batched_targets
598
+
599
+ def predict_inst(
600
+ self,
601
+ inference_state,
602
+ **kwargs,
603
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
604
+ orig_h, orig_w = (
605
+ inference_state["original_height"],
606
+ inference_state["original_width"],
607
+ )
608
+ backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
609
+ (
610
+ _,
611
+ vision_feats,
612
+ _,
613
+ _,
614
+ ) = self.inst_interactive_predictor.model._prepare_backbone_features(
615
+ backbone_out
616
+ )
617
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
618
+ vision_feats[-1] = (
619
+ vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
620
+ )
621
+ feats = [
622
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
623
+ for feat, feat_size in zip(
624
+ vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
625
+ )
626
+ ][::-1]
627
+ self.inst_interactive_predictor._features = {
628
+ "image_embed": feats[-1],
629
+ "high_res_feats": feats[:-1],
630
+ }
631
+ self.inst_interactive_predictor._is_image_set = True
632
+ self.inst_interactive_predictor._orig_hw = [(orig_h, orig_w)]
633
+ res = self.inst_interactive_predictor.predict(**kwargs)
634
+ self.inst_interactive_predictor._features = None
635
+ self.inst_interactive_predictor._is_image_set = False
636
+ return res
637
+
638
+ def predict_inst_batch(
639
+ self,
640
+ inference_state,
641
+ *args,
642
+ **kwargs,
643
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
644
+ backbone_out = inference_state["backbone_out"]["sam2_backbone_out"]
645
+ (
646
+ _,
647
+ vision_feats,
648
+ _,
649
+ _,
650
+ ) = self.inst_interactive_predictor.model._prepare_backbone_features(
651
+ backbone_out
652
+ )
653
+ # Add no_mem_embed, which is added to the lowest res feat. map during training on videos
654
+ vision_feats[-1] = (
655
+ vision_feats[-1] + self.inst_interactive_predictor.model.no_mem_embed
656
+ )
657
+ batch_size = vision_feats[-1].shape[1]
658
+ orig_heights, orig_widths = (
659
+ inference_state["original_heights"],
660
+ inference_state["original_widths"],
661
+ )
662
+ assert (
663
+ batch_size == len(orig_heights) == len(orig_widths)
664
+ ), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
665
+ feats = [
666
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
667
+ for feat, feat_size in zip(
668
+ vision_feats[::-1], self.inst_interactive_predictor._bb_feat_sizes[::-1]
669
+ )
670
+ ][::-1]
671
+ self.inst_interactive_predictor._features = {
672
+ "image_embed": feats[-1],
673
+ "high_res_feats": feats[:-1],
674
+ }
675
+ self.inst_interactive_predictor._is_image_set = True
676
+ self.inst_interactive_predictor._is_batch = True
677
+ self.inst_interactive_predictor._orig_hw = [
678
+ (orig_h, orig_w) for orig_h, orig_w in zip(orig_heights, orig_widths)
679
+ ]
680
+ res = self.inst_interactive_predictor.predict_batch(*args, **kwargs)
681
+ self.inst_interactive_predictor._features = None
682
+ self.inst_interactive_predictor._is_image_set = False
683
+ self.inst_interactive_predictor._is_batch = False
684
+ return res
685
+
686
+
687
+ class Sam3ImageOnVideoMultiGPU(Sam3Image):
688
+ def __init__(
689
+ self, *args, async_all_gather=True, gather_backbone_out=None, **kwargs
690
+ ):
691
+ super().__init__(*args, **kwargs)
692
+ self.rank = int(os.getenv("RANK", "0"))
693
+ self.world_size = int(os.getenv("WORLD_SIZE", "1"))
694
+ self.async_all_gather = async_all_gather
695
+
696
+ # if gather_backbone is not set, default to gathering only for `SAM3VLBackbone`
697
+ if gather_backbone_out is None:
698
+ gather_backbone_out = isinstance(self.backbone, SAM3VLBackbone)
699
+ self.gather_backbone_out = gather_backbone_out
700
+
701
+ def forward_video_grounding_multigpu(
702
+ self,
703
+ backbone_out,
704
+ find_inputs,
705
+ geometric_prompt: Prompt,
706
+ frame_idx,
707
+ num_frames,
708
+ # `multigpu_buffer` is a dict to cache detector's outputs in a chunk between different calls
709
+ multigpu_buffer,
710
+ track_in_reverse=False,
711
+ # whether to also return the SAM2 backbone features
712
+ return_sam2_backbone_feats=False,
713
+ # whether to perform NMS and suppress the scores of those detections removed by NMS
714
+ run_nms=False,
715
+ nms_prob_thresh=None,
716
+ nms_iou_thresh=None,
717
+ **kwargs,
718
+ ):
719
+ """
720
+ Compute the detector's detection outputs in a distributed manner, where all GPUs process
721
+ a chunk of frames (equal to the number of GPUs) at once and store them in cache.
722
+ """
723
+ # Step 1: fetch the detector outputs in the current chunk from buffer
724
+ frame_idx_curr_b = frame_idx - frame_idx % self.world_size
725
+ frame_idx_curr_e = min(frame_idx_curr_b + self.world_size, num_frames)
726
+ # in case the current frame's detection results are not in the buffer yet, build the current chunk
727
+ # (this should only happen on the first chunk, since we are also building the next chunk below)
728
+ if frame_idx not in multigpu_buffer:
729
+ with torch.profiler.record_function("build_multigpu_buffer_next_chunk1"):
730
+ self._build_multigpu_buffer_next_chunk(
731
+ backbone_out=backbone_out,
732
+ find_inputs=find_inputs,
733
+ geometric_prompt=geometric_prompt,
734
+ frame_idx_begin=frame_idx_curr_b,
735
+ frame_idx_end=frame_idx_curr_e,
736
+ num_frames=num_frames,
737
+ multigpu_buffer=multigpu_buffer,
738
+ run_nms=run_nms,
739
+ nms_prob_thresh=nms_prob_thresh,
740
+ nms_iou_thresh=nms_iou_thresh,
741
+ )
742
+
743
+ # read out the current frame's results from `multigpu_buffer`
744
+ out = {}
745
+ for k, (v, handle) in multigpu_buffer[frame_idx].items():
746
+ if k.startswith("sam2_backbone_") and not return_sam2_backbone_feats:
747
+ continue
748
+ if handle is not None:
749
+ handle.wait() # wait for async all-gather to finish
750
+ out[k] = v
751
+
752
+ # Step 2: remove detection outputs of the previous chunk from cache to save GPU memory
753
+ if not track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
754
+ frame_idx_prev_e = frame_idx_curr_b
755
+ frame_idx_prev_b = frame_idx_curr_b - self.world_size
756
+ elif track_in_reverse and frame_idx_curr_e < num_frames:
757
+ frame_idx_prev_b = frame_idx_curr_e
758
+ frame_idx_prev_e = min(frame_idx_prev_b + self.world_size, num_frames)
759
+ else:
760
+ frame_idx_prev_b = frame_idx_prev_e = None
761
+ if frame_idx_prev_b is not None:
762
+ for frame_idx_rm in range(frame_idx_prev_b, frame_idx_prev_e):
763
+ multigpu_buffer.pop(frame_idx_rm, None)
764
+
765
+ # Step 3: compute and cache detection outputs of the next chunk ahead of time
766
+ # (so that we can overlap computation with all-gather transfer)
767
+ if not track_in_reverse and frame_idx_curr_e < num_frames:
768
+ frame_idx_next_b = frame_idx_curr_e
769
+ frame_idx_next_e = min(frame_idx_next_b + self.world_size, num_frames)
770
+ elif track_in_reverse and frame_idx_curr_b - self.world_size >= 0:
771
+ frame_idx_next_e = frame_idx_curr_b
772
+ frame_idx_next_b = frame_idx_curr_b - self.world_size
773
+ else:
774
+ frame_idx_next_b = frame_idx_next_e = None
775
+ if frame_idx_next_b is not None and frame_idx_next_b not in multigpu_buffer:
776
+ with torch.profiler.record_function("build_multigpu_buffer_next_chunk2"):
777
+ self._build_multigpu_buffer_next_chunk(
778
+ backbone_out=backbone_out,
779
+ find_inputs=find_inputs,
780
+ geometric_prompt=geometric_prompt,
781
+ frame_idx_begin=frame_idx_next_b,
782
+ frame_idx_end=frame_idx_next_e,
783
+ num_frames=num_frames,
784
+ multigpu_buffer=multigpu_buffer,
785
+ run_nms=run_nms,
786
+ nms_prob_thresh=nms_prob_thresh,
787
+ nms_iou_thresh=nms_iou_thresh,
788
+ )
789
+
790
+ return out, backbone_out
791
+
792
+ def _build_multigpu_buffer_next_chunk(
793
+ self,
794
+ backbone_out,
795
+ find_inputs,
796
+ geometric_prompt: Prompt,
797
+ frame_idx_begin,
798
+ frame_idx_end,
799
+ num_frames,
800
+ multigpu_buffer,
801
+ run_nms=False,
802
+ nms_prob_thresh=None,
803
+ nms_iou_thresh=None,
804
+ ):
805
+ """Compute detection outputs on a chunk of frames and store their results in multigpu_buffer."""
806
+ # each GPU computes detections on one frame in the chunk (in a round-robin manner)
807
+ frame_idx_local_gpu = min(frame_idx_begin + self.rank, frame_idx_end - 1)
808
+ # `forward_grounding` (from base class `Sam3ImageOnVideo`) runs the detector on a single frame
809
+ with torch.profiler.record_function("forward_grounding"):
810
+ out_local = self.forward_grounding(
811
+ backbone_out=backbone_out,
812
+ find_input=find_inputs[frame_idx_local_gpu],
813
+ find_target=None,
814
+ geometric_prompt=geometric_prompt,
815
+ )
816
+ if run_nms:
817
+ with torch.profiler.record_function("nms_masks"):
818
+ # run NMS as a post-processing step on top of the detection outputs
819
+ assert nms_prob_thresh is not None and nms_iou_thresh is not None
820
+ pred_probs = out_local["pred_logits"].squeeze(-1).sigmoid()
821
+ pred_masks = out_local["pred_masks"]
822
+ # loop over text prompts (not an overhead for demo where there's only 1 prompt)
823
+ for prompt_idx in range(pred_probs.size(0)):
824
+ keep = nms_masks(
825
+ pred_probs=pred_probs[prompt_idx],
826
+ pred_masks=pred_masks[prompt_idx],
827
+ prob_threshold=nms_prob_thresh,
828
+ iou_threshold=nms_iou_thresh,
829
+ )
830
+ # set a very low threshold for those detections removed by NMS
831
+ out_local["pred_logits"][prompt_idx, :, 0] -= 1e4 * (~keep).float()
832
+
833
+ if self.gather_backbone_out:
834
+ # gather the SAM 2 backbone features across GPUs
835
+ feats = out_local["prev_encoder_out"]["backbone_out"]["sam2_backbone_out"]
836
+ assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels
837
+ # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually
838
+ # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP)
839
+ backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]]
840
+ fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0])
841
+ fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1])
842
+ fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2])
843
+ # vision_pos_enc is the same on all frames, so no need to all-gather them
844
+ vision_pos_enc = feats["vision_pos_enc"]
845
+
846
+ # trim the detector output to only include the necessary keys
847
+ out_local = {
848
+ "pred_logits": out_local["pred_logits"],
849
+ "pred_boxes": out_local["pred_boxes"],
850
+ "pred_boxes_xyxy": out_local["pred_boxes_xyxy"],
851
+ "pred_masks": out_local["pred_masks"],
852
+ }
853
+
854
+ # gather the results: after this step, each GPU will receive detector outputs on
855
+ # all frames in the chunk and store them in `multigpu_buffer`
856
+ out_gathered = {k: self._gather_tensor(v) for k, v in out_local.items()}
857
+ for rank in range(self.world_size):
858
+ frame_idx_to_save = frame_idx_begin + rank
859
+ if frame_idx_to_save >= num_frames:
860
+ continue
861
+ frame_buffer = {
862
+ k: (v[rank], handle) for k, (v, handle) in out_gathered.items()
863
+ }
864
+ if self.gather_backbone_out:
865
+ # also add gathered SAM 2 backbone features to frame_buffer
866
+ frame_buffer["tracker_backbone_fpn_0"] = (fpn0[rank], fpn_handle0)
867
+ frame_buffer["tracker_backbone_fpn_1"] = (fpn1[rank], fpn_handle1)
868
+ frame_buffer["tracker_backbone_fpn_2"] = (fpn2[rank], fpn_handle2)
869
+ frame_buffer["tracker_backbone_pos_enc"] = (vision_pos_enc, None)
870
+
871
+ multigpu_buffer[frame_idx_to_save] = frame_buffer
872
+
873
+ def _gather_tensor(self, x):
874
+ if self.world_size == 1:
875
+ return [x], None
876
+
877
+ async_op = self.async_all_gather
878
+ # here `.contiguous()` is required -- otherwise NCCL all_gather
879
+ # sometimes gives wrong results
880
+ x = x.contiguous() # ensure contiguous memory for NCCL
881
+ output_list = [torch.empty_like(x) for _ in range(self.world_size)]
882
+ handle = torch.distributed.all_gather(output_list, x, async_op=async_op)
883
+ return output_list, handle
detect_tools/sam3/sam3/model/sam3_image_processor.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ from typing import Dict, List
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+
8
+ from sam3.model import box_ops
9
+
10
+ from sam3.model.data_misc import FindStage, interpolate
11
+ from torchvision.transforms import v2
12
+
13
+
14
+ class Sam3Processor:
15
+ """ """
16
+
17
+ def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
18
+ self.model = model
19
+ self.resolution = resolution
20
+ self.device = device
21
+ self.transform = v2.Compose(
22
+ [
23
+ v2.ToDtype(torch.uint8, scale=True),
24
+ v2.Resize(size=(resolution, resolution)),
25
+ v2.ToDtype(torch.float32, scale=True),
26
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
27
+ ]
28
+ )
29
+ self.confidence_threshold = confidence_threshold
30
+
31
+ self.find_stage = FindStage(
32
+ img_ids=torch.tensor([0], device=device, dtype=torch.long),
33
+ text_ids=torch.tensor([0], device=device, dtype=torch.long),
34
+ input_boxes=None,
35
+ input_boxes_mask=None,
36
+ input_boxes_label=None,
37
+ input_points=None,
38
+ input_points_mask=None,
39
+ )
40
+
41
+ @torch.inference_mode()
42
+ def set_image(self, image, state=None):
43
+ """Sets the image on which we want to do predictions."""
44
+ if state is None:
45
+ state = {}
46
+
47
+ if isinstance(image, PIL.Image.Image):
48
+ width, height = image.size
49
+ elif isinstance(image, (torch.Tensor, np.ndarray)):
50
+ height, width = image.shape[-2:]
51
+ else:
52
+ raise ValueError("Image must be a PIL image or a tensor")
53
+
54
+ image = v2.functional.to_image(image).to(self.device)
55
+ image = self.transform(image).unsqueeze(0)
56
+
57
+ state["original_height"] = height
58
+ state["original_width"] = width
59
+ state["backbone_out"] = self.model.backbone.forward_image(image)
60
+ inst_interactivity_en = self.model.inst_interactive_predictor is not None
61
+ if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
62
+ sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
63
+ sam2_backbone_out["backbone_fpn"][0] = (
64
+ self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
65
+ sam2_backbone_out["backbone_fpn"][0]
66
+ )
67
+ )
68
+ sam2_backbone_out["backbone_fpn"][1] = (
69
+ self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
70
+ sam2_backbone_out["backbone_fpn"][1]
71
+ )
72
+ )
73
+ return state
74
+
75
+ @torch.inference_mode()
76
+ def set_image_batch(self, images: List[np.ndarray], state=None):
77
+ """Sets the image batch on which we want to do predictions."""
78
+ if state is None:
79
+ state = {}
80
+
81
+ if not isinstance(images, list):
82
+ raise ValueError("Images must be a list of PIL images or tensors")
83
+ assert len(images) > 0, "Images list must not be empty"
84
+ assert isinstance(
85
+ images[0], PIL.Image.Image
86
+ ), "Images must be a list of PIL images"
87
+
88
+ state["original_heights"] = [image.height for image in images]
89
+ state["original_widths"] = [image.width for image in images]
90
+
91
+ images = [
92
+ self.transform(v2.functional.to_image(image).to(self.device))
93
+ for image in images
94
+ ]
95
+ images = torch.stack(images, dim=0)
96
+ state["backbone_out"] = self.model.backbone.forward_image(images)
97
+ inst_interactivity_en = self.model.inst_interactive_predictor is not None
98
+ if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
99
+ sam2_backbone_out = state["backbone_out"]["sam2_backbone_out"]
100
+ sam2_backbone_out["backbone_fpn"][0] = (
101
+ self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s0(
102
+ sam2_backbone_out["backbone_fpn"][0]
103
+ )
104
+ )
105
+ sam2_backbone_out["backbone_fpn"][1] = (
106
+ self.model.inst_interactive_predictor.model.sam_mask_decoder.conv_s1(
107
+ sam2_backbone_out["backbone_fpn"][1]
108
+ )
109
+ )
110
+ return state
111
+
112
+ @torch.inference_mode()
113
+ def set_text_prompt(self, prompt: str, state: Dict):
114
+ """Sets the text prompt and run the inference"""
115
+
116
+ if "backbone_out" not in state:
117
+ raise ValueError("You must call set_image before set_text_prompt")
118
+
119
+ text_outputs = self.model.backbone.forward_text([prompt], device=self.device)
120
+ # will erase the previous text prompt if any
121
+ state["backbone_out"].update(text_outputs)
122
+ if "geometric_prompt" not in state:
123
+ state["geometric_prompt"] = self.model._get_dummy_prompt()
124
+
125
+ return self._forward_grounding(state)
126
+
127
+ @torch.inference_mode()
128
+ def add_geometric_prompt(self, box: List, label: bool, state: Dict):
129
+ """Adds a box prompt and run the inference.
130
+ The image needs to be set, but not necessarily the text prompt.
131
+ The box is assumed to be in [center_x, center_y, width, height] format and normalized in [0, 1] range.
132
+ The label is True for a positive box, False for a negative box.
133
+ """
134
+ if "backbone_out" not in state:
135
+ raise ValueError("You must call set_image before set_text_prompt")
136
+
137
+ if "language_features" not in state["backbone_out"]:
138
+ # Looks like we don't have a text prompt yet. This is allowed, but we need to set the text prompt to "visual" for the model to rely only on the geometric prompt
139
+ dummy_text_outputs = self.model.backbone.forward_text(
140
+ ["visual"], device=self.device
141
+ )
142
+ state["backbone_out"].update(dummy_text_outputs)
143
+
144
+ if "geometric_prompt" not in state:
145
+ state["geometric_prompt"] = self.model._get_dummy_prompt()
146
+
147
+ # adding a batch and sequence dimension
148
+ boxes = torch.tensor(box, device=self.device, dtype=torch.float32).view(1, 1, 4)
149
+ labels = torch.tensor([label], device=self.device, dtype=torch.bool).view(1, 1)
150
+ state["geometric_prompt"].append_boxes(boxes, labels)
151
+
152
+ return self._forward_grounding(state)
153
+
154
+ def reset_all_prompts(self, state: Dict):
155
+ """Removes all the prompts and results"""
156
+ if "backbone_out" in state:
157
+ backbone_keys_to_del = [
158
+ "language_features",
159
+ "language_mask",
160
+ "language_embeds",
161
+ ]
162
+ for key in backbone_keys_to_del:
163
+ if key in state["backbone_out"]:
164
+ del state["backbone_out"][key]
165
+
166
+ keys_to_del = ["geometric_prompt", "boxes", "masks", "masks_logits", "scores"]
167
+ for key in keys_to_del:
168
+ if key in state:
169
+ del state[key]
170
+
171
+ @torch.inference_mode()
172
+ def set_confidence_threshold(self, threshold: float, state=None):
173
+ """Sets the confidence threshold for the masks"""
174
+ self.confidence_threshold = threshold
175
+ if state is not None and "boxes" in state:
176
+ # we need to filter the boxes again
177
+ # In principle we could do this more efficiently since we would only need
178
+ # to rerun the heads. But this is simpler and not too inefficient
179
+ return self._forward_grounding(state)
180
+ return state
181
+
182
+ @torch.inference_mode()
183
+ def _forward_grounding(self, state: Dict):
184
+ outputs = self.model.forward_grounding(
185
+ backbone_out=state["backbone_out"],
186
+ find_input=self.find_stage,
187
+ geometric_prompt=state["geometric_prompt"],
188
+ find_target=None,
189
+ )
190
+
191
+ out_bbox = outputs["pred_boxes"]
192
+ out_logits = outputs["pred_logits"]
193
+ out_masks = outputs["pred_masks"]
194
+ out_probs = out_logits.sigmoid()
195
+ presence_score = outputs["presence_logit_dec"].sigmoid().unsqueeze(1)
196
+ out_probs = (out_probs * presence_score).squeeze(-1)
197
+
198
+ keep = out_probs > self.confidence_threshold
199
+ out_probs = out_probs[keep]
200
+ out_masks = out_masks[keep]
201
+ out_bbox = out_bbox[keep]
202
+
203
+ # convert to [x0, y0, x1, y1] format
204
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
205
+
206
+ img_h = state["original_height"]
207
+ img_w = state["original_width"]
208
+ scale_fct = torch.tensor([img_w, img_h, img_w, img_h]).to(self.device)
209
+ boxes = boxes * scale_fct[None, :]
210
+
211
+ out_masks = interpolate(
212
+ out_masks.unsqueeze(1),
213
+ (img_h, img_w),
214
+ mode="bilinear",
215
+ align_corners=False,
216
+ ).sigmoid()
217
+
218
+ state["masks_logits"] = out_masks
219
+ state["masks"] = out_masks > 0.5
220
+ state["boxes"] = boxes
221
+ state["scores"] = out_probs
222
+ return state
detect_tools/sam3/sam3/model/sam3_tracker_base.py ADDED
@@ -0,0 +1,1188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import logging
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from sam3.model.memory import SimpleMaskEncoder
9
+
10
+ from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
11
+
12
+ from sam3.sam.mask_decoder import MaskDecoder, MLP
13
+ from sam3.sam.prompt_encoder import PromptEncoder
14
+ from sam3.sam.transformer import TwoWayTransformer
15
+ from sam3.train.data.collator import BatchedDatapoint
16
+
17
+ try:
18
+ from timm.layers import trunc_normal_
19
+ except ModuleNotFoundError:
20
+ # compatibility for older timm versions
21
+ from timm.models.layers import trunc_normal_
22
+
23
+ # a large negative value as a placeholder score for missing objects
24
+ NO_OBJ_SCORE = -1024.0
25
+
26
+
27
+ class Sam3TrackerBase(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ backbone,
31
+ transformer,
32
+ maskmem_backbone,
33
+ num_maskmem=7, # default 1 input frame + 6 previous frames as in CAE
34
+ image_size=1008,
35
+ backbone_stride=14, # stride of the image backbone output
36
+ # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
37
+ # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
38
+ # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
39
+ max_cond_frames_in_attn=-1,
40
+ # Whether to always keep the first conditioning frame in case we exceed the maximum number of conditioning frames allowed
41
+ keep_first_cond_frame=False,
42
+ # whether to output multiple (3) masks for the first click on initial conditioning frames
43
+ multimask_output_in_sam=False,
44
+ # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
45
+ # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
46
+ multimask_min_pt_num=1,
47
+ multimask_max_pt_num=1,
48
+ # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
49
+ multimask_output_for_tracking=False,
50
+ # whether to forward image features per frame (as it's being tracked) during evaluation, instead of forwarding image features
51
+ # of all frames at once. This avoids backbone OOM errors on very long videos in evaluation, but could be slightly slower.
52
+ forward_backbone_per_frame_for_eval=False,
53
+ # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
54
+ # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
55
+ # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
56
+ memory_temporal_stride_for_eval=1,
57
+ # whether to offload outputs to CPU memory during evaluation, to avoid GPU OOM on very long videos or very large resolutions or too many objects
58
+ # (it's recommended to use `forward_backbone_per_frame_for_eval=True` first before setting this option to True)
59
+ offload_output_to_cpu_for_eval=False,
60
+ # whether to trim the output of past non-conditioning frames (num_maskmem frames before the current frame) during evaluation
61
+ # (this helps save GPU or CPU memory on very long videos for semi-supervised VOS eval, where only the first frame receives prompts)
62
+ trim_past_non_cond_mem_for_eval=False,
63
+ # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
64
+ non_overlap_masks_for_mem_enc=False,
65
+ # the maximum number of object pointers from other frames in encoder cross attention
66
+ max_obj_ptrs_in_encoder=16,
67
+ # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
68
+ sam_mask_decoder_extra_args=None,
69
+ # whether to compile all the model compoents
70
+ compile_all_components=False,
71
+ # select the frame with object existence
72
+ use_memory_selection=False,
73
+ # when using memory selection, the threshold to determine if the frame is good
74
+ mf_threshold=0.01,
75
+ ):
76
+ super().__init__()
77
+
78
+ # Part 1: the image backbone
79
+ self.backbone = backbone
80
+ self.num_feature_levels = 3
81
+ self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
82
+ # A conv layer to downsample the GT mask prompt to stride 4 (the same stride as
83
+ # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
84
+ # so that it can be fed into the SAM mask decoder to generate a pointer.
85
+ self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
86
+
87
+ # Part 2: encoder-only transformer to fuse current frame's visual features
88
+ # with memories from past frames
89
+ assert transformer.decoder is None, "transformer should be encoder-only"
90
+ self.transformer = transformer
91
+ self.hidden_dim = transformer.d_model
92
+
93
+ # Part 3: memory encoder for the previous frame's outputs
94
+ self.maskmem_backbone = maskmem_backbone
95
+ self.mem_dim = self.hidden_dim
96
+ if hasattr(self.maskmem_backbone, "out_proj") and hasattr(
97
+ self.maskmem_backbone.out_proj, "weight"
98
+ ):
99
+ # if there is compression of memories along channel dim
100
+ self.mem_dim = self.maskmem_backbone.out_proj.weight.shape[0]
101
+ self.num_maskmem = num_maskmem # Number of memories accessible
102
+
103
+ # Temporal encoding of the memories
104
+ self.maskmem_tpos_enc = torch.nn.Parameter(
105
+ torch.zeros(num_maskmem, 1, 1, self.mem_dim)
106
+ )
107
+ trunc_normal_(self.maskmem_tpos_enc, std=0.02)
108
+
109
+ # a single token to indicate no memory embedding from previous frames
110
+ self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
111
+ self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
112
+ trunc_normal_(self.no_mem_embed, std=0.02)
113
+ trunc_normal_(self.no_mem_pos_enc, std=0.02)
114
+ # Apply sigmoid to the output raw mask logits (to turn them from
115
+ # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
116
+ self.sigmoid_scale_for_mem_enc = 20.0
117
+ self.sigmoid_bias_for_mem_enc = -10.0
118
+ self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
119
+ self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
120
+ # On frames with mask input, whether to directly output the input mask without
121
+ # using a SAM prompt encoder + mask decoder
122
+ self.multimask_output_in_sam = multimask_output_in_sam
123
+ self.multimask_min_pt_num = multimask_min_pt_num
124
+ self.multimask_max_pt_num = multimask_max_pt_num
125
+ self.multimask_output_for_tracking = multimask_output_for_tracking
126
+
127
+ # Part 4: SAM-style prompt encoder (for both mask and point inputs)
128
+ # and SAM-style mask decoder for the final mask output
129
+ self.image_size = image_size
130
+ self.backbone_stride = backbone_stride
131
+ self.low_res_mask_size = self.image_size // self.backbone_stride * 4
132
+ # we resize the mask if it doesn't match `self.input_mask_size` (which is always 4x
133
+ # the low-res mask size, regardless of the actual input image size); this is because
134
+ # `_use_mask_as_output` always downsamples the input masks by 4x
135
+ self.input_mask_size = self.low_res_mask_size * 4
136
+ self.forward_backbone_per_frame_for_eval = forward_backbone_per_frame_for_eval
137
+ self.offload_output_to_cpu_for_eval = offload_output_to_cpu_for_eval
138
+ self.trim_past_non_cond_mem_for_eval = trim_past_non_cond_mem_for_eval
139
+ self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
140
+ self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
141
+ trunc_normal_(self.no_obj_ptr, std=0.02)
142
+ self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
143
+ trunc_normal_(self.no_obj_embed_spatial, std=0.02)
144
+
145
+ self._build_sam_heads()
146
+ self.max_cond_frames_in_attn = max_cond_frames_in_attn
147
+ self.keep_first_cond_frame = keep_first_cond_frame
148
+
149
+ # Use frame filtering according to SAM2Long
150
+ self.use_memory_selection = use_memory_selection
151
+ self.mf_threshold = mf_threshold
152
+
153
+ # Compile all components of the model
154
+ self.compile_all_components = compile_all_components
155
+ if self.compile_all_components:
156
+ self._compile_all_components()
157
+
158
+ @property
159
+ def device(self):
160
+ return next(self.parameters()).device
161
+
162
+ def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False):
163
+ if dummy:
164
+ return torch.zeros(len(rel_pos_list), self.mem_dim, device=device)
165
+
166
+ t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1
167
+ pos_enc = (
168
+ torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True)
169
+ / t_diff_max
170
+ )
171
+ tpos_dim = self.hidden_dim
172
+ pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim)
173
+ pos_enc = self.obj_ptr_tpos_proj(pos_enc)
174
+
175
+ return pos_enc
176
+
177
+ def _build_sam_heads(self):
178
+ """Build SAM-style prompt encoder and mask decoder."""
179
+ self.sam_prompt_embed_dim = self.hidden_dim
180
+ self.sam_image_embedding_size = self.image_size // self.backbone_stride
181
+
182
+ # build PromptEncoder and MaskDecoder from SAM
183
+ # (their hyperparameters like `mask_in_chans=16` are from SAM code)
184
+ self.sam_prompt_encoder = PromptEncoder(
185
+ embed_dim=self.sam_prompt_embed_dim,
186
+ image_embedding_size=(
187
+ self.sam_image_embedding_size,
188
+ self.sam_image_embedding_size,
189
+ ),
190
+ input_image_size=(self.image_size, self.image_size),
191
+ mask_in_chans=16,
192
+ )
193
+ self.sam_mask_decoder = MaskDecoder(
194
+ num_multimask_outputs=3,
195
+ transformer=TwoWayTransformer(
196
+ depth=2,
197
+ embedding_dim=self.sam_prompt_embed_dim,
198
+ mlp_dim=2048,
199
+ num_heads=8,
200
+ ),
201
+ transformer_dim=self.sam_prompt_embed_dim,
202
+ iou_head_depth=3,
203
+ iou_head_hidden_dim=256,
204
+ use_high_res_features=True,
205
+ iou_prediction_use_sigmoid=True,
206
+ pred_obj_scores=True,
207
+ pred_obj_scores_mlp=True,
208
+ use_multimask_token_for_obj_ptr=True,
209
+ **(self.sam_mask_decoder_extra_args or {}),
210
+ )
211
+ # a linear projection on SAM output tokens to turn them into object pointers
212
+ self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
213
+ self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
214
+ # a linear projection on temporal positional encoding in object pointers to
215
+ # avoid potential interference with spatial positional encoding
216
+ self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
217
+
218
+ def _forward_sam_heads(
219
+ self,
220
+ backbone_features,
221
+ point_inputs=None,
222
+ mask_inputs=None,
223
+ high_res_features=None,
224
+ multimask_output=False,
225
+ gt_masks=None,
226
+ ):
227
+ """
228
+ Forward SAM prompt encoders and mask heads.
229
+
230
+ Inputs:
231
+ - backbone_features: image features of [B, C, H, W] shape
232
+ - point_inputs: a dictionary with "point_coords" and "point_labels", where
233
+ 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
234
+ absolute pixel-unit coordinate in (x, y) format of the P input points
235
+ 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
236
+ positive clicks, 0 means negative clicks, and -1 means padding
237
+ - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
238
+ same spatial size as the image.
239
+ - high_res_features: either 1) None or 2) or a list of length 2 containing
240
+ two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
241
+ which will be used as high-resolution feature maps for SAM decoder.
242
+ - multimask_output: if it's True, we output 3 candidate masks and their 3
243
+ corresponding IoU estimates, and if it's False, we output only 1 mask and
244
+ its corresponding IoU estimate.
245
+
246
+ Outputs:
247
+ - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
248
+ `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
249
+ output mask logits (before sigmoid) for the low-resolution masks, with 4x
250
+ the resolution (1/4 stride) of the input backbone_features.
251
+ - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
252
+ if `multimask_output=True` and M = 1 if `multimask_output=False`),
253
+ upsampled from the low-resolution masks, with shape size as the image
254
+ (stride is 1 pixel).
255
+ - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
256
+ if `multimask_output=False`), the estimated IoU of each output mask.
257
+ - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
258
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
259
+ If `multimask_output=False`, it's the same as `low_res_multimasks`.
260
+ - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
261
+ If `multimask_output=True`, it's the mask with the highest IoU estimate.
262
+ If `multimask_output=False`, it's the same as `high_res_multimasks`.
263
+ - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
264
+ based on the output token from the SAM mask decoder.
265
+ """
266
+ B = backbone_features.size(0)
267
+ device = backbone_features.device
268
+ assert backbone_features.size(1) == self.sam_prompt_embed_dim
269
+ assert backbone_features.size(2) == self.sam_image_embedding_size
270
+ assert backbone_features.size(3) == self.sam_image_embedding_size
271
+
272
+ # a) Handle point prompts
273
+ if point_inputs is not None:
274
+ sam_point_coords = point_inputs["point_coords"]
275
+ sam_point_labels = point_inputs["point_labels"]
276
+ assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
277
+ else:
278
+ # If no points are provide, pad with an empty point (with label -1)
279
+ sam_point_coords = torch.zeros(B, 1, 2, device=device)
280
+ sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
281
+
282
+ # b) Handle mask prompts
283
+ if mask_inputs is not None:
284
+ # If mask_inputs is provided, downsize it into low-res mask input if needed
285
+ # and feed it as a dense mask prompt into the SAM mask encoder
286
+ assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
287
+ if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
288
+ sam_mask_prompt = F.interpolate(
289
+ mask_inputs.float(),
290
+ size=self.sam_prompt_encoder.mask_input_size,
291
+ align_corners=False,
292
+ mode="bilinear",
293
+ antialias=True, # use antialias for downsampling
294
+ )
295
+ else:
296
+ sam_mask_prompt = mask_inputs
297
+ else:
298
+ # Otherwise, simply feed None (and SAM's prompt encoder will add
299
+ # a learned `no_mask_embed` to indicate no mask input in this case).
300
+ sam_mask_prompt = None
301
+
302
+ sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
303
+ points=(sam_point_coords, sam_point_labels),
304
+ boxes=None,
305
+ masks=sam_mask_prompt,
306
+ )
307
+ # Clone image_pe and the outputs of sam_prompt_encoder
308
+ # to enable compilation
309
+ sparse_embeddings = self._maybe_clone(sparse_embeddings)
310
+ dense_embeddings = self._maybe_clone(dense_embeddings)
311
+ image_pe = self._maybe_clone(self.sam_prompt_encoder.get_dense_pe())
312
+ with torch.profiler.record_function("sam_mask_decoder"):
313
+ (
314
+ low_res_multimasks,
315
+ ious,
316
+ sam_output_tokens,
317
+ object_score_logits,
318
+ ) = self.sam_mask_decoder(
319
+ image_embeddings=backbone_features,
320
+ image_pe=image_pe,
321
+ sparse_prompt_embeddings=sparse_embeddings,
322
+ dense_prompt_embeddings=dense_embeddings,
323
+ multimask_output=multimask_output,
324
+ repeat_image=False, # the image is already batched
325
+ high_res_features=high_res_features,
326
+ )
327
+ # Clone the output of sam_mask_decoder
328
+ # to enable compilation
329
+ low_res_multimasks = self._maybe_clone(low_res_multimasks)
330
+ ious = self._maybe_clone(ious)
331
+ sam_output_tokens = self._maybe_clone(sam_output_tokens)
332
+ object_score_logits = self._maybe_clone(object_score_logits)
333
+
334
+ if self.training and self.teacher_force_obj_scores_for_mem:
335
+ # we use gt to detect if there is an object or not to
336
+ # select no obj ptr and use an empty mask for spatial memory
337
+ is_obj_appearing = torch.any(gt_masks.float().flatten(1) > 0, dim=1)
338
+ is_obj_appearing = is_obj_appearing[..., None]
339
+ else:
340
+ is_obj_appearing = object_score_logits > 0
341
+
342
+ # Mask used for spatial memories is always a *hard* choice between obj and no obj,
343
+ # consistent with the actual mask prediction
344
+ low_res_multimasks = torch.where(
345
+ is_obj_appearing[:, None, None],
346
+ low_res_multimasks,
347
+ NO_OBJ_SCORE,
348
+ )
349
+
350
+ # convert masks from possibly bfloat16 (or float16) to float32
351
+ # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
352
+ low_res_multimasks = low_res_multimasks.float()
353
+ high_res_multimasks = F.interpolate(
354
+ low_res_multimasks,
355
+ size=(self.image_size, self.image_size),
356
+ mode="bilinear",
357
+ align_corners=False,
358
+ )
359
+
360
+ sam_output_token = sam_output_tokens[:, 0]
361
+ if multimask_output:
362
+ # take the best mask prediction (with the highest IoU estimation)
363
+ best_iou_inds = torch.argmax(ious, dim=-1)
364
+ batch_inds = torch.arange(B, device=device)
365
+ low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
366
+ high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
367
+ if sam_output_tokens.size(1) > 1:
368
+ sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
369
+ else:
370
+ low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
371
+
372
+ # Extract object pointer from the SAM output token (with occlusion handling)
373
+ obj_ptr = self.obj_ptr_proj(sam_output_token)
374
+ lambda_is_obj_appearing = is_obj_appearing.float()
375
+
376
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
377
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
378
+
379
+ return (
380
+ low_res_multimasks,
381
+ high_res_multimasks,
382
+ ious,
383
+ low_res_masks,
384
+ high_res_masks,
385
+ obj_ptr,
386
+ object_score_logits,
387
+ )
388
+
389
+ def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
390
+ """
391
+ Directly turn binary `mask_inputs` into a output mask logits without using SAM.
392
+ (same input and output shapes as in _forward_sam_heads above).
393
+ """
394
+ # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
395
+ out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
396
+ mask_inputs_float = mask_inputs.float()
397
+ high_res_masks = mask_inputs_float * out_scale + out_bias
398
+ low_res_masks = F.interpolate(
399
+ high_res_masks,
400
+ size=(
401
+ high_res_masks.size(-2) // self.backbone_stride * 4,
402
+ high_res_masks.size(-1) // self.backbone_stride * 4,
403
+ ),
404
+ align_corners=False,
405
+ mode="bilinear",
406
+ antialias=True, # use antialias for downsampling
407
+ )
408
+ # a dummy IoU prediction of all 1's under mask input
409
+ ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
410
+ # produce an object pointer using the SAM decoder from the mask input
411
+ _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
412
+ backbone_features=backbone_features,
413
+ mask_inputs=self.mask_downsample(mask_inputs_float),
414
+ high_res_features=high_res_features,
415
+ gt_masks=mask_inputs,
416
+ )
417
+ # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
418
+ # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
419
+ # on the object_scores from the SAM decoder.
420
+ is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
421
+ is_obj_appearing = is_obj_appearing[..., None]
422
+ lambda_is_obj_appearing = is_obj_appearing.float()
423
+ object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
424
+ obj_ptr = lambda_is_obj_appearing * obj_ptr
425
+ obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
426
+
427
+ return (
428
+ low_res_masks,
429
+ high_res_masks,
430
+ ious,
431
+ low_res_masks,
432
+ high_res_masks,
433
+ obj_ptr,
434
+ object_score_logits,
435
+ )
436
+
437
+ def forward(self, input: BatchedDatapoint, is_inference=False):
438
+ raise NotImplementedError(
439
+ "Please use the corresponding methods in SAM3VideoPredictor for inference."
440
+ "See examples/sam3_dense_video_tracking.ipynb for an inference example."
441
+ )
442
+
443
+ def forward_image(self, img_batch):
444
+ """Get the image feature on the input batch."""
445
+ # This line is the only change from the parent class
446
+ # to use the SAM3 backbone instead of the SAM2 backbone.
447
+ backbone_out = self.backbone.forward_image(img_batch)["sam2_backbone_out"]
448
+ # precompute projected level 0 and level 1 features in SAM decoder
449
+ # to avoid running it again on every SAM click
450
+ backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
451
+ backbone_out["backbone_fpn"][0]
452
+ )
453
+ backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
454
+ backbone_out["backbone_fpn"][1]
455
+ )
456
+ # Clone to help torch.compile
457
+ for i in range(len(backbone_out["backbone_fpn"])):
458
+ backbone_out["backbone_fpn"][i] = self._maybe_clone(
459
+ backbone_out["backbone_fpn"][i]
460
+ )
461
+ backbone_out["vision_pos_enc"][i] = self._maybe_clone(
462
+ backbone_out["vision_pos_enc"][i]
463
+ )
464
+ return backbone_out
465
+
466
+ def _prepare_backbone_features(self, backbone_out):
467
+ """Prepare and flatten visual features (same as in MDETR_API model)."""
468
+ backbone_out = backbone_out.copy()
469
+ assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
470
+ assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
471
+
472
+ feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
473
+ vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
474
+
475
+ feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
476
+ # flatten NxCxHxW to HWxNxC
477
+ vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
478
+ vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
479
+
480
+ return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
481
+
482
+ def _prepare_backbone_features_per_frame(self, img_batch, img_ids):
483
+ """Compute the image backbone features on the fly for the given img_ids."""
484
+ # Only forward backbone on unique image ids to avoid repeatitive computation
485
+ # (if `img_ids` has only one element, it's already unique so we skip this step).
486
+ if img_ids.numel() > 1:
487
+ unique_img_ids, inv_ids = torch.unique(img_ids, return_inverse=True)
488
+ else:
489
+ unique_img_ids, inv_ids = img_ids, None
490
+
491
+ # Compute the image features on those unique image ids
492
+ image = img_batch[unique_img_ids]
493
+ backbone_out = self.forward_image(image)
494
+ (
495
+ _,
496
+ vision_feats,
497
+ vision_pos_embeds,
498
+ feat_sizes,
499
+ ) = self._prepare_backbone_features(backbone_out)
500
+ # Inverse-map image features for `unique_img_ids` to the final image features
501
+ # for the original input `img_ids`.
502
+ if inv_ids is not None:
503
+ image = image[inv_ids]
504
+ vision_feats = [x[:, inv_ids] for x in vision_feats]
505
+ vision_pos_embeds = [x[:, inv_ids] for x in vision_pos_embeds]
506
+
507
+ return image, vision_feats, vision_pos_embeds, feat_sizes
508
+
509
+ def cal_mem_score(self, object_score_logits, iou_score):
510
+ object_score_norm = torch.where(
511
+ object_score_logits > 0,
512
+ object_score_logits.sigmoid() * 2 - 1, ## rescale to [0, 1]
513
+ torch.zeros_like(object_score_logits),
514
+ )
515
+ score_per_frame = (object_score_norm * iou_score).mean()
516
+ return score_per_frame
517
+
518
+ def frame_filter(self, output_dict, track_in_reverse, frame_idx, num_frames, r):
519
+ if (frame_idx == 0 and not track_in_reverse) or (
520
+ frame_idx == num_frames - 1 and track_in_reverse
521
+ ):
522
+ return []
523
+
524
+ max_num = min(
525
+ num_frames, self.max_obj_ptrs_in_encoder
526
+ ) ## maximum number of pointer memory frames to consider
527
+
528
+ if not track_in_reverse:
529
+ start = frame_idx - 1
530
+ end = 0
531
+ step = -r
532
+ must_include = frame_idx - 1
533
+ else:
534
+ start = frame_idx + 1
535
+ end = num_frames
536
+ step = r
537
+ must_include = frame_idx + 1
538
+
539
+ valid_indices = []
540
+ for i in range(start, end, step):
541
+ if (
542
+ i not in output_dict["non_cond_frame_outputs"]
543
+ or "eff_iou_score" not in output_dict["non_cond_frame_outputs"][i]
544
+ ):
545
+ continue
546
+
547
+ score_per_frame = output_dict["non_cond_frame_outputs"][i]["eff_iou_score"]
548
+
549
+ if score_per_frame > self.mf_threshold: # threshold
550
+ valid_indices.insert(0, i)
551
+
552
+ if len(valid_indices) >= max_num - 1:
553
+ break
554
+
555
+ if must_include not in valid_indices:
556
+ valid_indices.append(must_include)
557
+
558
+ return valid_indices
559
+
560
+ def _prepare_memory_conditioned_features(
561
+ self,
562
+ frame_idx,
563
+ is_init_cond_frame,
564
+ current_vision_feats,
565
+ current_vision_pos_embeds,
566
+ feat_sizes,
567
+ output_dict,
568
+ num_frames,
569
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
570
+ use_prev_mem_frame=True,
571
+ ):
572
+ """Fuse the current frame's visual feature map with previous memory."""
573
+ B = current_vision_feats[-1].size(1) # batch size on this frame
574
+ C = self.hidden_dim
575
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
576
+ device = current_vision_feats[-1].device
577
+ # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
578
+ # In this case, we skip the fusion with any memory.
579
+ if self.num_maskmem == 0: # Disable memory and skip fusion
580
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
581
+ return pix_feat
582
+
583
+ num_obj_ptr_tokens = 0
584
+ tpos_sign_mul = -1 if track_in_reverse else 1
585
+ # Step 1: condition the visual features of the current frame on previous memories
586
+ if not is_init_cond_frame and use_prev_mem_frame:
587
+ # Retrieve the memories encoded with the maskmem backbone
588
+ to_cat_prompt, to_cat_prompt_mask, to_cat_prompt_pos_embed = [], [], []
589
+ # Add conditioning frames's output first (all cond frames have t_pos=0 for
590
+ # when getting temporal positional embedding below)
591
+ assert len(output_dict["cond_frame_outputs"]) > 0
592
+ # Select a maximum number of temporally closest cond frames for cross attention
593
+ cond_outputs = output_dict["cond_frame_outputs"]
594
+ selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
595
+ frame_idx,
596
+ cond_outputs,
597
+ self.max_cond_frames_in_attn,
598
+ keep_first_cond_frame=self.keep_first_cond_frame,
599
+ )
600
+ t_pos_and_prevs = [
601
+ ((frame_idx - t) * tpos_sign_mul, out, True)
602
+ for t, out in selected_cond_outputs.items()
603
+ ]
604
+ # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
605
+ # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
606
+ # We also allow taking the memory frame non-consecutively (with r>1), in which case
607
+ # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
608
+ r = 1 if self.training else self.memory_temporal_stride_for_eval
609
+
610
+ if self.use_memory_selection:
611
+ valid_indices = self.frame_filter(
612
+ output_dict, track_in_reverse, frame_idx, num_frames, r
613
+ )
614
+
615
+ for t_pos in range(1, self.num_maskmem):
616
+ t_rel = self.num_maskmem - t_pos # how many frames before current frame
617
+ if self.use_memory_selection:
618
+ if t_rel > len(valid_indices):
619
+ continue
620
+ prev_frame_idx = valid_indices[-t_rel]
621
+ else:
622
+ if t_rel == 1:
623
+ # for t_rel == 1, we take the last frame (regardless of r)
624
+ if not track_in_reverse:
625
+ # the frame immediately before this frame (i.e. frame_idx - 1)
626
+ prev_frame_idx = frame_idx - t_rel
627
+ else:
628
+ # the frame immediately after this frame (i.e. frame_idx + 1)
629
+ prev_frame_idx = frame_idx + t_rel
630
+ else:
631
+ # for t_rel >= 2, we take the memory frame from every r-th frames
632
+ if not track_in_reverse:
633
+ # first find the nearest frame among every r-th frames before this frame
634
+ # for r=1, this would be (frame_idx - 2)
635
+ prev_frame_idx = ((frame_idx - 2) // r) * r
636
+ # then seek further among every r-th frames
637
+ prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
638
+ else:
639
+ # first find the nearest frame among every r-th frames after this frame
640
+ # for r=1, this would be (frame_idx + 2)
641
+ prev_frame_idx = -(-(frame_idx + 2) // r) * r
642
+ # then seek further among every r-th frames
643
+ prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
644
+
645
+ out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
646
+ if out is None:
647
+ # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
648
+ # frames, we still attend to it as if it's a non-conditioning frame.
649
+ out = unselected_cond_outputs.get(prev_frame_idx, None)
650
+ t_pos_and_prevs.append((t_pos, out, False))
651
+
652
+ for t_pos, prev, is_selected_cond_frame in t_pos_and_prevs:
653
+ if prev is None:
654
+ continue # skip padding frames
655
+ # "maskmem_features" might have been offloaded to CPU in demo use cases,
656
+ # so we load it back to GPU (it's a no-op if it's already on GPU).
657
+ feats = prev["maskmem_features"].cuda(non_blocking=True)
658
+ seq_len = feats.shape[-2] * feats.shape[-1]
659
+ to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1))
660
+ to_cat_prompt_mask.append(
661
+ torch.zeros(B, seq_len, device=device, dtype=bool)
662
+ )
663
+ # Spatial positional encoding (it might have been offloaded to CPU in eval)
664
+ maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
665
+ maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
666
+
667
+ if (
668
+ is_selected_cond_frame
669
+ and getattr(self, "cond_frame_spatial_embedding", None) is not None
670
+ ):
671
+ # add a spatial embedding for the conditioning frame
672
+ maskmem_enc = maskmem_enc + self.cond_frame_spatial_embedding
673
+
674
+ # Temporal positional encoding
675
+ t = t_pos if not is_selected_cond_frame else 0
676
+ maskmem_enc = (
677
+ maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t - 1]
678
+ )
679
+ to_cat_prompt_pos_embed.append(maskmem_enc)
680
+
681
+ # Construct the list of past object pointers
682
+ # Optionally, select only a subset of spatial memory frames during trainining
683
+ if (
684
+ self.training
685
+ and self.prob_to_dropout_spatial_mem > 0
686
+ and self.rng.random() < self.prob_to_dropout_spatial_mem
687
+ ):
688
+ num_spatial_mem_keep = self.rng.integers(len(to_cat_prompt) + 1)
689
+ keep = self.rng.choice(
690
+ range(len(to_cat_prompt)), num_spatial_mem_keep, replace=False
691
+ ).tolist()
692
+ to_cat_prompt = [to_cat_prompt[i] for i in keep]
693
+ to_cat_prompt_mask = [to_cat_prompt_mask[i] for i in keep]
694
+ to_cat_prompt_pos_embed = [to_cat_prompt_pos_embed[i] for i in keep]
695
+
696
+ max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
697
+ # First add those object pointers from selected conditioning frames
698
+ # (optionally, only include object pointers in the past during evaluation)
699
+ if not self.training:
700
+ ptr_cond_outputs = {
701
+ t: out
702
+ for t, out in selected_cond_outputs.items()
703
+ if (t >= frame_idx if track_in_reverse else t <= frame_idx)
704
+ }
705
+ else:
706
+ ptr_cond_outputs = selected_cond_outputs
707
+ pos_and_ptrs = [
708
+ # Temporal pos encoding contains how far away each pointer is from current frame
709
+ (
710
+ (frame_idx - t) * tpos_sign_mul,
711
+ out["obj_ptr"],
712
+ True, # is_selected_cond_frame
713
+ )
714
+ for t, out in ptr_cond_outputs.items()
715
+ ]
716
+
717
+ # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
718
+ for t_diff in range(1, max_obj_ptrs_in_encoder):
719
+ if not self.use_memory_selection:
720
+ t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
721
+ if t < 0 or (num_frames is not None and t >= num_frames):
722
+ break
723
+ else:
724
+ if -t_diff <= -len(valid_indices):
725
+ break
726
+ t = valid_indices[-t_diff]
727
+
728
+ out = output_dict["non_cond_frame_outputs"].get(
729
+ t, unselected_cond_outputs.get(t, None)
730
+ )
731
+ if out is not None:
732
+ pos_and_ptrs.append((t_diff, out["obj_ptr"], False))
733
+
734
+ # If we have at least one object pointer, add them to the across attention
735
+ if len(pos_and_ptrs) > 0:
736
+ pos_list, ptrs_list, is_selected_cond_frame_list = zip(*pos_and_ptrs)
737
+ # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
738
+ obj_ptrs = torch.stack(ptrs_list, dim=0)
739
+ if getattr(self, "cond_frame_obj_ptr_embedding", None) is not None:
740
+ obj_ptrs = (
741
+ obj_ptrs
742
+ + self.cond_frame_obj_ptr_embedding
743
+ * torch.tensor(is_selected_cond_frame_list, device=device)[
744
+ ..., None, None
745
+ ].float()
746
+ )
747
+ # a temporal positional embedding based on how far each object pointer is from
748
+ # the current frame (sine embedding normalized by the max pointer num).
749
+ obj_pos = self._get_tpos_enc(
750
+ pos_list,
751
+ max_abs_pos=max_obj_ptrs_in_encoder,
752
+ device=device,
753
+ )
754
+ # expand to batch size
755
+ obj_pos = obj_pos.unsqueeze(1).expand(-1, B, -1)
756
+
757
+ if self.mem_dim < C:
758
+ # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
759
+ obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
760
+ obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
761
+ obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
762
+ to_cat_prompt.append(obj_ptrs)
763
+ to_cat_prompt_mask.append(None) # "to_cat_prompt_mask" is not used
764
+ to_cat_prompt_pos_embed.append(obj_pos)
765
+ num_obj_ptr_tokens = obj_ptrs.shape[0]
766
+ else:
767
+ num_obj_ptr_tokens = 0
768
+ else:
769
+ # directly add no-mem embedding (instead of using the transformer encoder)
770
+ pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
771
+ pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
772
+ return pix_feat_with_mem
773
+
774
+ # Use a dummy token on the first grame (to avoid emtpy memory input to tranformer encoder)
775
+ to_cat_prompt = [self.no_mem_embed.expand(1, B, self.mem_dim)]
776
+ to_cat_prompt_mask = [torch.zeros(B, 1, device=device, dtype=bool)]
777
+ to_cat_prompt_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
778
+
779
+ # Step 2: Concatenate the memories and forward through the transformer encoder
780
+ prompt = torch.cat(to_cat_prompt, dim=0)
781
+ prompt_mask = None # For now, we always masks are zeros anyways
782
+ prompt_pos_embed = torch.cat(to_cat_prompt_pos_embed, dim=0)
783
+ encoder_out = self.transformer.encoder(
784
+ src=current_vision_feats,
785
+ src_key_padding_mask=[None],
786
+ src_pos=current_vision_pos_embeds,
787
+ prompt=prompt,
788
+ prompt_pos=prompt_pos_embed,
789
+ prompt_key_padding_mask=prompt_mask,
790
+ feat_sizes=feat_sizes,
791
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
792
+ )
793
+ # reshape the output (HW)BC => BCHW
794
+ pix_feat_with_mem = encoder_out["memory"].permute(1, 2, 0).view(B, C, H, W)
795
+ return pix_feat_with_mem
796
+
797
+ def _encode_new_memory(
798
+ self,
799
+ image,
800
+ current_vision_feats,
801
+ feat_sizes,
802
+ pred_masks_high_res,
803
+ object_score_logits,
804
+ is_mask_from_pts,
805
+ output_dict=None,
806
+ is_init_cond_frame=False,
807
+ ):
808
+ """Encode the current image and its prediction into a memory feature."""
809
+ B = current_vision_feats[-1].size(1) # batch size on this frame
810
+ C = self.hidden_dim
811
+ H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
812
+ # top-level feature, (HW)BC => BCHW
813
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
814
+ if self.non_overlap_masks_for_mem_enc and not self.training:
815
+ # optionally, apply non-overlapping constraints to the masks (it's applied
816
+ # in the batch dimension and should only be used during eval, where all
817
+ # the objects come from the same video under batch size 1).
818
+ pred_masks_high_res = self._apply_non_overlapping_constraints(
819
+ pred_masks_high_res
820
+ )
821
+ # scale the raw mask logits with a temperature before applying sigmoid
822
+ if is_mask_from_pts and not self.training:
823
+ mask_for_mem = (pred_masks_high_res > 0).float()
824
+ else:
825
+ # apply sigmoid on the raw mask logits to turn them into range (0, 1)
826
+ mask_for_mem = torch.sigmoid(pred_masks_high_res)
827
+ # apply scale and bias terms to the sigmoid probabilities
828
+ if self.sigmoid_scale_for_mem_enc != 1.0:
829
+ mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
830
+ if self.sigmoid_bias_for_mem_enc != 0.0:
831
+ mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
832
+
833
+ if isinstance(self.maskmem_backbone, SimpleMaskEncoder):
834
+ pix_feat = pix_feat.view_as(pix_feat)
835
+ maskmem_out = self.maskmem_backbone(
836
+ pix_feat, mask_for_mem, skip_mask_sigmoid=True
837
+ )
838
+ else:
839
+ maskmem_out = self.maskmem_backbone(image, pix_feat, mask_for_mem)
840
+ # Clone the feats and pos_enc to enable compilation
841
+ maskmem_features = self._maybe_clone(maskmem_out["vision_features"])
842
+ maskmem_pos_enc = [self._maybe_clone(m) for m in maskmem_out["vision_pos_enc"]]
843
+ # add a no-object embedding to the spatial memory to indicate that the frame
844
+ # is predicted to be occluded (i.e. no object is appearing in the frame)
845
+ is_obj_appearing = (object_score_logits > 0).float()
846
+ maskmem_features += (
847
+ 1 - is_obj_appearing[..., None, None]
848
+ ) * self.no_obj_embed_spatial[..., None, None].expand(*maskmem_features.shape)
849
+
850
+ return maskmem_features, maskmem_pos_enc
851
+
852
+ def forward_tracking(self, backbone_out, input, return_dict=False):
853
+ """Forward video tracking on each frame (and sample correction clicks)."""
854
+ img_feats_already_computed = backbone_out["backbone_fpn"] is not None
855
+ if img_feats_already_computed:
856
+ # Prepare the backbone features
857
+ # - vision_feats and vision_pos_embeds are in (HW)BC format
858
+ (
859
+ _,
860
+ vision_feats,
861
+ vision_pos_embeds,
862
+ feat_sizes,
863
+ ) = self._prepare_backbone_features(backbone_out)
864
+
865
+ # Starting the stage loop
866
+ num_frames = backbone_out["num_frames"]
867
+ init_cond_frames = backbone_out["init_cond_frames"]
868
+ frames_to_add_correction_pt = backbone_out["frames_to_add_correction_pt"]
869
+ # first process all the initial conditioning frames to encode them as memory,
870
+ # and then conditioning on them to track the remaining frames
871
+ processing_order = init_cond_frames + backbone_out["frames_not_in_init_cond"]
872
+ output_dict = {
873
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
874
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
875
+ }
876
+ for stage_id in processing_order:
877
+ # Get the image features for the current frames
878
+ img_ids = input.find_inputs[stage_id].img_ids
879
+ if img_feats_already_computed:
880
+ # Retrieve image features according to img_ids (if they are already computed).
881
+ current_image = input.img_batch[img_ids]
882
+ current_vision_feats = [x[:, img_ids] for x in vision_feats]
883
+ current_vision_pos_embeds = [x[:, img_ids] for x in vision_pos_embeds]
884
+ else:
885
+ # Otherwise, compute the image features on the fly for the given img_ids
886
+ # (this might be used for evaluation on long videos to avoid backbone OOM).
887
+ (
888
+ current_image,
889
+ current_vision_feats,
890
+ current_vision_pos_embeds,
891
+ feat_sizes,
892
+ ) = self._prepare_backbone_features_per_frame(input.img_batch, img_ids)
893
+ # Get output masks based on this frame's prompts and previous memory
894
+ current_out = self.track_step(
895
+ frame_idx=stage_id,
896
+ is_init_cond_frame=stage_id in init_cond_frames,
897
+ current_vision_feats=current_vision_feats,
898
+ current_vision_pos_embeds=current_vision_pos_embeds,
899
+ feat_sizes=feat_sizes,
900
+ image=current_image,
901
+ point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None),
902
+ mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None),
903
+ gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None),
904
+ frames_to_add_correction_pt=frames_to_add_correction_pt,
905
+ output_dict=output_dict,
906
+ num_frames=num_frames,
907
+ )
908
+ # Append the output, depending on whether it's a conditioning frame
909
+ add_output_as_cond_frame = stage_id in init_cond_frames or (
910
+ self.add_all_frames_to_correct_as_cond
911
+ and stage_id in frames_to_add_correction_pt
912
+ )
913
+ if add_output_as_cond_frame:
914
+ output_dict["cond_frame_outputs"][stage_id] = current_out
915
+ else:
916
+ output_dict["non_cond_frame_outputs"][stage_id] = current_out
917
+
918
+ if return_dict:
919
+ return output_dict
920
+ # turn `output_dict` into a list for loss function
921
+ all_frame_outputs = {}
922
+ all_frame_outputs.update(output_dict["cond_frame_outputs"])
923
+ all_frame_outputs.update(output_dict["non_cond_frame_outputs"])
924
+ all_frame_outputs = [all_frame_outputs[t] for t in range(num_frames)]
925
+ # Make DDP happy with activation checkpointing by removing unused keys
926
+ all_frame_outputs = [
927
+ {k: v for k, v in d.items() if k != "obj_ptr"} for d in all_frame_outputs
928
+ ]
929
+
930
+ return all_frame_outputs
931
+
932
+ def track_step(
933
+ self,
934
+ frame_idx,
935
+ is_init_cond_frame,
936
+ current_vision_feats,
937
+ current_vision_pos_embeds,
938
+ feat_sizes,
939
+ image,
940
+ point_inputs,
941
+ mask_inputs,
942
+ output_dict,
943
+ num_frames,
944
+ track_in_reverse=False, # tracking in reverse time order (for demo usage)
945
+ # Whether to run the memory encoder on the predicted masks. Sometimes we might want
946
+ # to skip the memory encoder with `run_mem_encoder=False`. For example,
947
+ # in demo we might call `track_step` multiple times for each user click,
948
+ # and only encode the memory when the user finalizes their clicks. And in ablation
949
+ # settings like SAM training on static images, we don't need the memory encoder.
950
+ run_mem_encoder=True,
951
+ # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
952
+ prev_sam_mask_logits=None,
953
+ use_prev_mem_frame=True,
954
+ ):
955
+ current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
956
+ # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
957
+ if len(current_vision_feats) > 1:
958
+ high_res_features = [
959
+ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
960
+ for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
961
+ ]
962
+ else:
963
+ high_res_features = None
964
+ if mask_inputs is not None:
965
+ # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
966
+ pix_feat = current_vision_feats[-1].permute(1, 2, 0)
967
+ pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
968
+ sam_outputs = self._use_mask_as_output(
969
+ pix_feat, high_res_features, mask_inputs
970
+ )
971
+ else:
972
+ # fused the visual feature with previous memory features in the memory bank
973
+ pix_feat_with_mem = self._prepare_memory_conditioned_features(
974
+ frame_idx=frame_idx,
975
+ is_init_cond_frame=is_init_cond_frame,
976
+ current_vision_feats=current_vision_feats[-1:],
977
+ current_vision_pos_embeds=current_vision_pos_embeds[-1:],
978
+ feat_sizes=feat_sizes[-1:],
979
+ output_dict=output_dict,
980
+ num_frames=num_frames,
981
+ track_in_reverse=track_in_reverse,
982
+ use_prev_mem_frame=use_prev_mem_frame,
983
+ )
984
+ # apply SAM-style segmentation head
985
+ # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
986
+ # e.g. in demo where such logits come from earlier interaction instead of correction sampling
987
+ # (in this case, the SAM mask decoder should have `self.iter_use_prev_mask_pred=True`, and
988
+ # any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
989
+ if prev_sam_mask_logits is not None:
990
+ assert self.iter_use_prev_mask_pred
991
+ assert point_inputs is not None and mask_inputs is None
992
+ mask_inputs = prev_sam_mask_logits
993
+ multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
994
+ sam_outputs = self._forward_sam_heads(
995
+ backbone_features=pix_feat_with_mem,
996
+ point_inputs=point_inputs,
997
+ mask_inputs=mask_inputs,
998
+ high_res_features=high_res_features,
999
+ multimask_output=multimask_output,
1000
+ )
1001
+ (
1002
+ _,
1003
+ high_res_multimasks,
1004
+ ious,
1005
+ low_res_masks,
1006
+ high_res_masks,
1007
+ obj_ptr,
1008
+ object_score_logits,
1009
+ ) = sam_outputs
1010
+ # Use the final prediction (after all correction steps for output and eval)
1011
+ current_out["pred_masks"] = low_res_masks
1012
+ current_out["pred_masks_high_res"] = high_res_masks
1013
+ current_out["obj_ptr"] = obj_ptr
1014
+ if self.use_memory_selection:
1015
+ current_out["object_score_logits"] = object_score_logits
1016
+ iou_score = ious.max(-1)[0]
1017
+ current_out["iou_score"] = iou_score
1018
+ current_out["eff_iou_score"] = self.cal_mem_score(
1019
+ object_score_logits, iou_score
1020
+ )
1021
+ if not self.training:
1022
+ # Only add this in inference (to avoid unused param in activation checkpointing;
1023
+ # it's mainly used in the demo to encode spatial memories w/ consolidated masks)
1024
+ current_out["object_score_logits"] = object_score_logits
1025
+
1026
+ # Finally run the memory encoder on the predicted mask to encode
1027
+ # it into a new memory feature (that can be used in future frames)
1028
+ # (note that `self.num_maskmem == 0` is primarily used for reproducing SAM on
1029
+ # images, in which case we'll just skip memory encoder to save compute).
1030
+ if run_mem_encoder and self.num_maskmem > 0:
1031
+ high_res_masks_for_mem_enc = high_res_masks
1032
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
1033
+ image=image,
1034
+ current_vision_feats=current_vision_feats,
1035
+ feat_sizes=feat_sizes,
1036
+ pred_masks_high_res=high_res_masks_for_mem_enc,
1037
+ object_score_logits=object_score_logits,
1038
+ is_mask_from_pts=(point_inputs is not None),
1039
+ output_dict=output_dict,
1040
+ is_init_cond_frame=is_init_cond_frame,
1041
+ )
1042
+ current_out["maskmem_features"] = maskmem_features
1043
+ current_out["maskmem_pos_enc"] = maskmem_pos_enc
1044
+ else:
1045
+ current_out["maskmem_features"] = None
1046
+ current_out["maskmem_pos_enc"] = None
1047
+
1048
+ # Optionally, offload the outputs to CPU memory during evaluation to avoid
1049
+ # GPU OOM on very long videos or very large resolution or too many objects
1050
+ if self.offload_output_to_cpu_for_eval and not self.training:
1051
+ # Here we only keep those keys needed for evaluation to get a compact output
1052
+ trimmed_out = {
1053
+ "pred_masks": current_out["pred_masks"].cpu(),
1054
+ "pred_masks_high_res": current_out["pred_masks_high_res"].cpu(),
1055
+ # other items for evaluation (these are small tensors so we keep them on GPU)
1056
+ "obj_ptr": current_out["obj_ptr"],
1057
+ "object_score_logits": current_out["object_score_logits"],
1058
+ }
1059
+ if run_mem_encoder and self.num_maskmem > 0:
1060
+ trimmed_out["maskmem_features"] = maskmem_features.cpu()
1061
+ trimmed_out["maskmem_pos_enc"] = [x.cpu() for x in maskmem_pos_enc]
1062
+ if self.use_memory_selection:
1063
+ trimmed_out["iou_score"] = current_out["iou_score"].cpu()
1064
+ trimmed_out["eff_iou_score"] = current_out["eff_iou_score"].cpu()
1065
+ current_out = trimmed_out
1066
+
1067
+ # Optionally, trim the output of past non-conditioning frame (r * num_maskmem frames
1068
+ # before the current frame) during evaluation. This is intended to save GPU or CPU
1069
+ # memory for semi-supervised VOS eval, where only the first frame receives prompts.
1070
+ def _trim_past_out(past_out, current_out):
1071
+ if past_out is None:
1072
+ return None
1073
+ return {
1074
+ "pred_masks": past_out["pred_masks"],
1075
+ "obj_ptr": past_out["obj_ptr"],
1076
+ "object_score_logits": past_out["object_score_logits"],
1077
+ }
1078
+
1079
+ if self.trim_past_non_cond_mem_for_eval and not self.training:
1080
+ r = self.memory_temporal_stride_for_eval
1081
+ past_frame_idx = frame_idx - r * self.num_maskmem
1082
+ past_out = output_dict["non_cond_frame_outputs"].get(past_frame_idx, None)
1083
+
1084
+ if past_out is not None:
1085
+ print(past_out.get("eff_iou_score", 0))
1086
+ if (
1087
+ self.use_memory_selection
1088
+ and past_out.get("eff_iou_score", 0) < self.mf_threshold
1089
+ ) or not self.use_memory_selection:
1090
+ output_dict["non_cond_frame_outputs"][past_frame_idx] = (
1091
+ _trim_past_out(past_out, current_out)
1092
+ )
1093
+
1094
+ if (
1095
+ self.use_memory_selection and not self.offload_output_to_cpu_for_eval
1096
+ ): ## design for memory selection, trim too old frames to save memory
1097
+ far_old_frame_idx = frame_idx - 20 * self.max_obj_ptrs_in_encoder
1098
+ past_out = output_dict["non_cond_frame_outputs"].get(
1099
+ far_old_frame_idx, None
1100
+ )
1101
+ if past_out is not None:
1102
+ output_dict["non_cond_frame_outputs"][far_old_frame_idx] = (
1103
+ _trim_past_out(past_out, current_out)
1104
+ )
1105
+
1106
+ return current_out
1107
+
1108
+ def _use_multimask(self, is_init_cond_frame, point_inputs):
1109
+ """Whether to use multimask output in the SAM head."""
1110
+ num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
1111
+ multimask_output = (
1112
+ self.multimask_output_in_sam
1113
+ and (is_init_cond_frame or self.multimask_output_for_tracking)
1114
+ and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
1115
+ )
1116
+ return multimask_output
1117
+
1118
+ def _apply_non_overlapping_constraints(self, pred_masks):
1119
+ """
1120
+ Apply non-overlapping constraints to the object scores in pred_masks. Here we
1121
+ keep only the highest scoring object at each spatial location in pred_masks.
1122
+ """
1123
+ batch_size = pred_masks.size(0)
1124
+ if batch_size == 1:
1125
+ return pred_masks
1126
+
1127
+ device = pred_masks.device
1128
+ # "max_obj_inds": object index of the object with the highest score at each location
1129
+ max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
1130
+ # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
1131
+ batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
1132
+ keep = max_obj_inds == batch_obj_inds
1133
+ # suppress overlapping regions' scores below -10.0 so that the foreground regions
1134
+ # don't overlap (here sigmoid(-10.0)=4.5398e-05)
1135
+ pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
1136
+ return pred_masks
1137
+
1138
+ def _compile_all_components(self):
1139
+ """Compile all model components for faster inference."""
1140
+ # a larger cache size to hold varying number of shapes for torch.compile
1141
+ # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49
1142
+ torch._dynamo.config.cache_size_limit = 64
1143
+ torch._dynamo.config.accumulated_cache_size_limit = 2048
1144
+ from sam3.perflib.compile import compile_wrapper
1145
+
1146
+ logging.info("Compiling all components. First time may be very slow.")
1147
+
1148
+ self.maskmem_backbone.forward = compile_wrapper(
1149
+ self.maskmem_backbone.forward,
1150
+ mode="max-autotune",
1151
+ fullgraph=True,
1152
+ dynamic=False,
1153
+ )
1154
+ self.transformer.encoder.forward = compile_wrapper(
1155
+ self.transformer.encoder.forward,
1156
+ mode="max-autotune",
1157
+ fullgraph=True,
1158
+ dynamic=True, # Num. of memories varies
1159
+ )
1160
+ # We disable compilation of sam_prompt_encoder as it sometimes gives a large accuracy regression,
1161
+ # especially when sam_mask_prompt (previous mask logits) is not None
1162
+ # self.sam_prompt_encoder.forward = torch.compile(
1163
+ # self.sam_prompt_encoder.forward,
1164
+ # mode="max-autotune",
1165
+ # fullgraph=True,
1166
+ # dynamic=False, # Accuracy regression on True
1167
+ # )
1168
+ self.sam_mask_decoder.forward = compile_wrapper(
1169
+ self.sam_mask_decoder.forward,
1170
+ mode="max-autotune",
1171
+ fullgraph=True,
1172
+ dynamic=False, # Accuracy regression on True
1173
+ )
1174
+
1175
+ def _maybe_clone(self, x):
1176
+ """Clone a tensor if and only if `self.compile_all_components` is True."""
1177
+ return x.clone() if self.compile_all_components else x
1178
+
1179
+
1180
+ def concat_points(old_point_inputs, new_points, new_labels):
1181
+ """Add new points and labels to previous point inputs (add at the end)."""
1182
+ if old_point_inputs is None:
1183
+ points, labels = new_points, new_labels
1184
+ else:
1185
+ points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
1186
+ labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
1187
+
1188
+ return {"point_coords": points, "point_labels": labels}
detect_tools/sam3/sam3/model/sam3_tracker_utils.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from numpy.typing import NDArray
7
+
8
+ from sam3.model.edt import edt_triton
9
+
10
+
11
+ def sample_box_points(
12
+ masks: torch.Tensor,
13
+ noise: float = 0.1, # SAM default
14
+ noise_bound: int = 20, # SAM default
15
+ top_left_label: int = 2,
16
+ bottom_right_label: int = 3,
17
+ ) -> tuple[NDArray, NDArray]:
18
+ """
19
+ Sample a noised version of the top left and bottom right corners of a given `bbox`
20
+
21
+ Inputs:
22
+ - masks: [B, 1, H, W] tensor
23
+ - noise: noise as a fraction of box width and height, dtype=float
24
+ - noise_bound: maximum amount of noise (in pure pixels), dtype=int
25
+
26
+ Returns:
27
+ - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
28
+ - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
29
+ """
30
+ device = masks.device
31
+ box_coords = mask_to_box(masks)
32
+ B, _, H, W = masks.shape
33
+ box_labels = torch.tensor(
34
+ [top_left_label, bottom_right_label], dtype=torch.int, device=device
35
+ ).repeat(B)
36
+ if noise > 0.0:
37
+ if not isinstance(noise_bound, torch.Tensor):
38
+ noise_bound = torch.tensor(noise_bound, device=device)
39
+ bbox_w = box_coords[..., 2] - box_coords[..., 0]
40
+ bbox_h = box_coords[..., 3] - box_coords[..., 1]
41
+ max_dx = torch.min(bbox_w * noise, noise_bound)
42
+ max_dy = torch.min(bbox_h * noise, noise_bound)
43
+ box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
44
+ box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
45
+
46
+ box_coords = box_coords + box_noise
47
+ img_bounds = (
48
+ torch.tensor([W, H, W, H], device=device) - 1
49
+ ) # uncentered pixel coords
50
+ box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
51
+
52
+ box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
53
+ box_labels = box_labels.reshape(-1, 2)
54
+ return box_coords, box_labels
55
+
56
+
57
+ def mask_to_box(masks: torch.Tensor):
58
+ """
59
+ compute bounding box given an input mask
60
+
61
+ Inputs:
62
+ - masks: [B, 1, H, W] tensor
63
+
64
+ Returns:
65
+ - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
66
+ """
67
+ B, _, h, w = masks.shape
68
+ device = masks.device
69
+ mask_area = masks.sum(dim=(-1, -2))
70
+ xs = torch.arange(w, device=device, dtype=torch.int32)
71
+ ys = torch.arange(h, device=device, dtype=torch.int32)
72
+ grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
73
+ grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
74
+ grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
75
+ min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
76
+ max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
77
+ min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
78
+ max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
79
+ bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
80
+ bbox_coords = torch.where(
81
+ mask_area[..., None] > 0, bbox_coords, torch.zeros_like(bbox_coords)
82
+ )
83
+ return bbox_coords
84
+
85
+
86
+ def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
87
+ """
88
+ Sample `num_pt` random points (along with their labels) independently from the error regions.
89
+
90
+ Inputs:
91
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
92
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
93
+ - num_pt: int, number of points to sample independently for each of the B error maps
94
+
95
+ Outputs:
96
+ - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
97
+ - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
98
+ negative clicks
99
+ """
100
+ if pred_masks is None: # if pred_masks is not provided, treat it as empty
101
+ pred_masks = torch.zeros_like(gt_masks)
102
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
103
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
104
+ assert num_pt >= 0
105
+
106
+ B, _, H_im, W_im = gt_masks.shape
107
+ device = gt_masks.device
108
+
109
+ # false positive region, a new point sampled in this region should have
110
+ # negative label to correct the FP error
111
+ fp_masks = ~gt_masks & pred_masks
112
+ # false negative region, a new point sampled in this region should have
113
+ # positive label to correct the FN error
114
+ fn_masks = gt_masks & ~pred_masks
115
+ # whether the prediction completely match the ground-truth on each mask
116
+ all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
117
+ all_correct = all_correct[..., None, None]
118
+
119
+ # channel 0 is FP map, while channel 1 is FN map
120
+ pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
121
+ # sample a negative new click from FP region or a positive new click
122
+ # from FN region, depend on where the maximum falls,
123
+ # and in case the predictions are all correct (no FP or FN), we just
124
+ # sample a negative click from the background region
125
+ pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
126
+ pts_noise[..., 1] *= fn_masks
127
+ pts_idx = pts_noise.flatten(2).argmax(dim=2)
128
+ labels = (pts_idx % 2).to(torch.int32)
129
+ pts_idx = pts_idx // 2
130
+ pts_x = pts_idx % W_im
131
+ pts_y = pts_idx // W_im
132
+ points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
133
+ return points, labels
134
+
135
+
136
+ def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
137
+ """
138
+ Sample 1 random point (along with its label) from the center of each error region,
139
+ that is, the point with the largest distance to the boundary of each error region.
140
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
141
+
142
+ Inputs:
143
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
144
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
145
+ - padding: if True, pad with boundary of 1 px for distance transform
146
+
147
+ Outputs:
148
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
149
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
150
+ """
151
+ if pred_masks is None:
152
+ pred_masks = torch.zeros_like(gt_masks)
153
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
154
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
155
+
156
+ B, _, H, W = gt_masks.shape
157
+
158
+ # false positive region, a new point sampled in this region should have
159
+ # negative label to correct the FP error
160
+ fp_masks = (~gt_masks & pred_masks).squeeze(1)
161
+ # false negative region, a new point sampled in this region should have
162
+ # positive label to correct the FN error
163
+ fn_masks = (gt_masks & ~pred_masks).squeeze(1)
164
+
165
+ if padding:
166
+ padded_fp_masks = torch.zeros(
167
+ B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device
168
+ )
169
+ padded_fp_masks[:, 1 : H + 1, 1 : W + 1] = fp_masks
170
+ padded_fn_masks = torch.zeros(
171
+ B, H + 2, W + 2, dtype=fp_masks.dtype, device=fp_masks.device
172
+ )
173
+ padded_fn_masks[:, 1 : H + 1, 1 : W + 1] = fn_masks
174
+ else:
175
+ padded_fp_masks = fp_masks
176
+ padded_fn_masks = fn_masks
177
+
178
+ fn_mask_dt = edt_triton(padded_fn_masks)
179
+ fp_mask_dt = edt_triton(padded_fp_masks)
180
+ if padding:
181
+ fn_mask_dt = fn_mask_dt[:, 1:-1, 1:-1]
182
+ fp_mask_dt = fp_mask_dt[:, 1:-1, 1:-1]
183
+
184
+ fn_max, fn_argmax = fn_mask_dt.reshape(B, -1).max(dim=-1)
185
+ fp_max, fp_argmax = fp_mask_dt.reshape(B, -1).max(dim=-1)
186
+ is_positive = fn_max > fp_max
187
+ chosen = torch.where(is_positive, fn_argmax, fp_argmax)
188
+ points_x = chosen % W
189
+ points_y = chosen // W
190
+
191
+ labels = is_positive.long()
192
+ points = torch.stack([points_x, points_y], -1)
193
+ return points.unsqueeze(1), labels.unsqueeze(1)
194
+
195
+
196
+ def sample_one_point_from_error_center_slow(gt_masks, pred_masks, padding=True):
197
+ """
198
+ Sample 1 random point (along with its label) from the center of each error region,
199
+ that is, the point with the largest distance to the boundary of each error region.
200
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
201
+
202
+ Inputs:
203
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
204
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
205
+ - padding: if True, pad with boundary of 1 px for distance transform
206
+
207
+ Outputs:
208
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
209
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
210
+ """
211
+ import cv2 # delay OpenCV import to avoid unnecessary dependency
212
+
213
+ if pred_masks is None:
214
+ pred_masks = torch.zeros_like(gt_masks)
215
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
216
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
217
+
218
+ B, _, _, W_im = gt_masks.shape
219
+ device = gt_masks.device
220
+
221
+ # false positive region, a new point sampled in this region should have
222
+ # negative label to correct the FP error
223
+ fp_masks = ~gt_masks & pred_masks
224
+ # false negative region, a new point sampled in this region should have
225
+ # positive label to correct the FN error
226
+ fn_masks = gt_masks & ~pred_masks
227
+
228
+ fp_masks = fp_masks.cpu().numpy()
229
+ fn_masks = fn_masks.cpu().numpy()
230
+ points = torch.zeros(B, 1, 2, dtype=torch.float)
231
+ labels = torch.ones(B, 1, dtype=torch.int32)
232
+ for b in range(B):
233
+ fn_mask = fn_masks[b, 0]
234
+ fp_mask = fp_masks[b, 0]
235
+ if padding:
236
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
237
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
238
+ # compute the distance of each point in FN/FP region to its boundary
239
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
240
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
241
+ if padding:
242
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
243
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
244
+
245
+ # take the point in FN/FP region with the largest distance to its boundary
246
+ fn_mask_dt_flat = fn_mask_dt.reshape(-1)
247
+ fp_mask_dt_flat = fp_mask_dt.reshape(-1)
248
+ fn_argmax = np.argmax(fn_mask_dt_flat)
249
+ fp_argmax = np.argmax(fp_mask_dt_flat)
250
+ is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
251
+ pt_idx = fn_argmax if is_positive else fp_argmax
252
+ points[b, 0, 0] = pt_idx % W_im # x
253
+ points[b, 0, 1] = pt_idx // W_im # y
254
+ labels[b, 0] = int(is_positive)
255
+
256
+ points = points.to(device)
257
+ labels = labels.to(device)
258
+ return points, labels
259
+
260
+
261
+ def get_next_point(gt_masks, pred_masks, method):
262
+ if method == "uniform":
263
+ return sample_random_points_from_errors(gt_masks, pred_masks)
264
+ elif method == "center":
265
+ return sample_one_point_from_error_center(gt_masks, pred_masks)
266
+ else:
267
+ raise ValueError(f"unknown sampling method {method}")
268
+
269
+
270
+ def select_closest_cond_frames(
271
+ frame_idx, cond_frame_outputs, max_cond_frame_num, keep_first_cond_frame=False
272
+ ):
273
+ """
274
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
275
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
276
+ - a) the closest conditioning frame before `frame_idx` (if any);
277
+ - b) the closest conditioning frame after `frame_idx` (if any);
278
+ - c) any other temporally closest conditioning frames until reaching a total
279
+ of `max_cond_frame_num` conditioning frames.
280
+
281
+ Outputs:
282
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
283
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
284
+ """
285
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
286
+ selected_outputs = cond_frame_outputs
287
+ unselected_outputs = {}
288
+ else:
289
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
290
+ selected_outputs = {}
291
+ if keep_first_cond_frame:
292
+ idx_first = min(
293
+ (t for t in cond_frame_outputs if t < frame_idx), default=None
294
+ )
295
+ if idx_first is None:
296
+ # Maybe we are tracking in reverse
297
+ idx_first = max(
298
+ (t for t in cond_frame_outputs if t > frame_idx), default=None
299
+ )
300
+ if idx_first is not None:
301
+ selected_outputs[idx_first] = cond_frame_outputs[idx_first]
302
+ # the closest conditioning frame before `frame_idx` (if any)
303
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
304
+ if idx_before is not None:
305
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
306
+
307
+ # the closest conditioning frame after `frame_idx` (if any)
308
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
309
+ if idx_after is not None:
310
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
311
+
312
+ # add other temporally closest conditioning frames until reaching a total
313
+ # of `max_cond_frame_num` conditioning frames.
314
+ num_remain = max_cond_frame_num - len(selected_outputs)
315
+ inds_remain = sorted(
316
+ (t for t in cond_frame_outputs if t not in selected_outputs),
317
+ key=lambda x: abs(x - frame_idx),
318
+ )[:num_remain]
319
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
320
+ unselected_outputs = {
321
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
322
+ }
323
+
324
+ return selected_outputs, unselected_outputs
325
+
326
+
327
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
328
+ """
329
+ Get 1D sine positional embedding as in the original Transformer paper.
330
+ """
331
+ pe_dim = dim // 2
332
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
333
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
334
+
335
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
336
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
337
+ return pos_embed
338
+
339
+
340
+ def get_best_gt_match_from_multimasks(pred_multimasks, gt_masks, pred_scores=None):
341
+ """
342
+ Get the mask with the best match to GT masks (based on IoU) from pred_multimasks.
343
+ Optionally, use `pred_scores` to break ties in case all IoUs are zeros.
344
+ """
345
+ assert pred_multimasks.ndim == 4 and gt_masks.ndim == 4
346
+ if pred_multimasks.size(1) == 1:
347
+ return pred_multimasks # only a single mask channel, nothing to select
348
+
349
+ pred_multimasks_binary = pred_multimasks > 0
350
+ area_i = torch.sum(pred_multimasks_binary & gt_masks, dim=(2, 3)).float()
351
+ area_u = torch.sum(pred_multimasks_binary | gt_masks, dim=(2, 3)).float()
352
+ ious = area_i / torch.clamp(area_u, min=1.0)
353
+
354
+ # In case all IoUs are zeros (e.g. because the GT mask is empty), use pred_scores
355
+ # to break ties and select the best mask
356
+ if pred_scores is not None:
357
+ has_nonzero_ious = torch.any(ious > 0).expand_as(ious)
358
+ scores = torch.where(has_nonzero_ious, ious, pred_scores)
359
+ else:
360
+ scores = ious
361
+
362
+ # Finally, take the best mask prediction (with the highest score)
363
+ best_scores_inds = torch.argmax(scores, dim=-1)
364
+ batch_inds = torch.arange(scores.size(0), device=scores.device)
365
+ best_pred_mask = pred_multimasks[batch_inds, best_scores_inds].unsqueeze(1)
366
+ return best_pred_mask
367
+
368
+
369
+ def fill_holes_in_mask_scores(mask, max_area, fill_holes=True, remove_sprinkles=True):
370
+ """
371
+ A post processor to fill small holes in mask scores with area under `max_area`.
372
+ Holes are those small connected components in either background or foreground.
373
+
374
+ Note that it relies on the "cc_torch" package to find connected components fast. You can
375
+ install it via the following command (`TORCH_CUDA_ARCH_LIST=8.0` is for A100 GPUs):
376
+ ```
377
+ pip uninstall -y cc_torch; TORCH_CUDA_ARCH_LIST=8.0 9.0 pip install git+https://github.com/ronghanghu/cc_torch
378
+ ```
379
+ Otherwise, it will fallback to a slightly slower triton implementation, or skimage if the tensor is on cpu
380
+ """
381
+
382
+ if max_area <= 0:
383
+ return mask # nothing to fill in this case
384
+
385
+ if fill_holes:
386
+ # We remove small connected components in background by changing them to foreground
387
+ # with a small positive mask score (0.1).
388
+ mask_bg = mask <= 0
389
+ bg_area_thresh = max_area
390
+ _, areas_bg = _get_connected_components_with_padding(mask_bg)
391
+ small_components_bg = mask_bg & (areas_bg <= bg_area_thresh)
392
+ mask = torch.where(small_components_bg, 0.1, mask)
393
+
394
+ if remove_sprinkles:
395
+ # We remove small connected components in foreground by changing them to background
396
+ # with a small negative mask score (-0.1). Here we only remove connected components
397
+ # whose areas are under both `max_area` and half of the entire mask's area. This
398
+ # removes sprinkles while avoids filtering out tiny objects that we want to track.
399
+ mask_fg = mask > 0
400
+ fg_area_thresh = torch.sum(mask_fg, dim=(2, 3), keepdim=True, dtype=torch.int32)
401
+ fg_area_thresh.floor_divide_(2).clamp_(max=max_area)
402
+ _, areas_fg = _get_connected_components_with_padding(mask_fg)
403
+ small_components_fg = mask_fg & (areas_fg <= fg_area_thresh)
404
+ mask = torch.where(small_components_fg, -0.1, mask)
405
+ return mask
406
+
407
+
408
+ def _get_connected_components_with_padding(mask):
409
+ """Get connected components from masks (possibly padding them to an even size)."""
410
+ from sam3.perflib.connected_components import connected_components
411
+
412
+ mask = mask.to(torch.uint8)
413
+ _, _, H, W = mask.shape
414
+ # make sure both height and width are even (to be compatible with cc_torch)
415
+ pad_h = H % 2
416
+ pad_w = W % 2
417
+ if pad_h == 0 and pad_w == 0:
418
+ labels, counts = connected_components(mask)
419
+ else:
420
+ # pad the mask to make its height and width even
421
+ # padding format is (padding_left,padding_right,padding_top,padding_bottom)
422
+ mask_pad = F.pad(mask, (0, pad_w, 0, pad_h), mode="constant", value=0)
423
+ labels, counts = connected_components(mask_pad)
424
+ labels = labels[:, :, :H, :W]
425
+ counts = counts[:, :, :H, :W]
426
+
427
+ return labels, counts
detect_tools/sam3/sam3/model/sam3_tracking_predictor.py ADDED
@@ -0,0 +1,1370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import logging
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+
8
+ from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
9
+ from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
10
+ from sam3.model.utils.sam2_utils import load_video_frames
11
+ from tqdm.auto import tqdm
12
+
13
+
14
+ class Sam3TrackerPredictor(Sam3TrackerBase):
15
+ """
16
+ The demo class that extends the `Sam3TrackerBase` to handle user interactions
17
+ and manage inference states, with support for multi-object tracking.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
23
+ # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
24
+ clear_non_cond_mem_around_input=False,
25
+ # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
26
+ clear_non_cond_mem_for_multi_obj=False,
27
+ # if fill_hole_area > 0, we fill small holes in the final masks up to this area (after resizing them to the original video resolution)
28
+ fill_hole_area=0,
29
+ # if always_start_from_first_ann_frame is True, we always start tracking from the frame where we receive the first annotation (clicks or mask)
30
+ # and ignore the `start_frame_idx` passed to `propagate_in_video`
31
+ always_start_from_first_ann_frame=False,
32
+ # the maximum number of points to be used in the prompt encoder, which reduce the domain gap between training (that only has 8 points)
33
+ # - if it's set to a positive integer, we only take the `max_point_num_in_prompt_enc//2` points and
34
+ # the last `(max_point_num_in_prompt_enc - max_point_num_in_prompt_enc//2)` points in the prompt encoder
35
+ # - if it's set to 0 or negative, this option is turned off and we use all points in the prompt encoder
36
+ max_point_num_in_prompt_enc=16,
37
+ non_overlap_masks_for_output=True,
38
+ # checkpoint_file=None,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+ self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
43
+ self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
44
+ self.fill_hole_area = fill_hole_area
45
+ self.always_start_from_first_ann_frame = always_start_from_first_ann_frame
46
+ self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc
47
+ self.non_overlap_masks_for_output = non_overlap_masks_for_output
48
+
49
+ self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
50
+ self.bf16_context.__enter__() # keep using for the entire model process
51
+
52
+ self.iter_use_prev_mask_pred = True
53
+ self.add_all_frames_to_correct_as_cond = True
54
+
55
+ @torch.inference_mode()
56
+ def init_state(
57
+ self,
58
+ video_height=None,
59
+ video_width=None,
60
+ num_frames=None,
61
+ video_path=None,
62
+ cached_features=None,
63
+ offload_video_to_cpu=False,
64
+ offload_state_to_cpu=False,
65
+ async_loading_frames=False,
66
+ ):
67
+ """Initialize a inference state."""
68
+ inference_state = {}
69
+ # whether to offload the video frames to CPU memory
70
+ # turning on this option saves the GPU memory with only a very small overhead
71
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
72
+ # whether to offload the inference state to CPU memory
73
+ # turning on this option saves the GPU memory at the cost of a lower tracking fps
74
+ # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
75
+ # and from 24 to 21 when tracking two objects)
76
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
77
+ inference_state["device"] = self.device
78
+ if offload_state_to_cpu:
79
+ inference_state["storage_device"] = torch.device("cpu")
80
+ else:
81
+ inference_state["storage_device"] = torch.device("cuda")
82
+
83
+ if video_path is not None:
84
+ images, video_height, video_width = load_video_frames(
85
+ video_path=video_path,
86
+ image_size=self.image_size,
87
+ offload_video_to_cpu=offload_video_to_cpu,
88
+ async_loading_frames=async_loading_frames,
89
+ compute_device=inference_state["storage_device"],
90
+ )
91
+ inference_state["images"] = images
92
+ inference_state["num_frames"] = len(images)
93
+ inference_state["video_height"] = video_height
94
+ inference_state["video_width"] = video_width
95
+ else:
96
+ # the original video height and width, used for resizing final output scores
97
+ inference_state["video_height"] = video_height
98
+ inference_state["video_width"] = video_width
99
+ inference_state["num_frames"] = num_frames
100
+ # inputs on each frame
101
+ inference_state["point_inputs_per_obj"] = {}
102
+ inference_state["mask_inputs_per_obj"] = {}
103
+ # visual features on a small number of recently visited frames for quick interactions
104
+ inference_state["cached_features"] = (
105
+ {} if cached_features is None else cached_features
106
+ )
107
+ # values that don't change across frames (so we only need to hold one copy of them)
108
+ inference_state["constants"] = {}
109
+ # mapping between client-side object id and model-side object index
110
+ inference_state["obj_id_to_idx"] = OrderedDict()
111
+ inference_state["obj_idx_to_id"] = OrderedDict()
112
+ inference_state["obj_ids"] = []
113
+ # A storage to hold the model's tracking results and states on each frame
114
+ inference_state["output_dict"] = {
115
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
116
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
117
+ }
118
+ # The index of the frame that received the first annotation
119
+ inference_state["first_ann_frame_idx"] = None
120
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
121
+ inference_state["output_dict_per_obj"] = {}
122
+ # A temporary storage to hold new outputs when user interact with a frame
123
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
124
+ inference_state["temp_output_dict_per_obj"] = {}
125
+ # Frames that already holds consolidated outputs from click or mask inputs
126
+ # (we directly use their consolidated outputs during tracking)
127
+ inference_state["consolidated_frame_inds"] = {
128
+ "cond_frame_outputs": set(), # set containing frame indices
129
+ "non_cond_frame_outputs": set(), # set containing frame indices
130
+ }
131
+ # metadata for each tracking frame (e.g. which direction it's tracked)
132
+ inference_state["tracking_has_started"] = False
133
+ inference_state["frames_already_tracked"] = {}
134
+ self.clear_all_points_in_video(inference_state)
135
+ return inference_state
136
+
137
+ def _obj_id_to_idx(self, inference_state, obj_id):
138
+ """Map client-side object id to model-side object index."""
139
+ obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
140
+ if obj_idx is not None:
141
+ return obj_idx
142
+
143
+ # This is a new object id not sent to the server before. We only allow adding
144
+ # new objects *before* the tracking starts.
145
+ allow_new_object = not inference_state["tracking_has_started"]
146
+ if allow_new_object:
147
+ # get the next object slot
148
+ obj_idx = len(inference_state["obj_id_to_idx"])
149
+ inference_state["obj_id_to_idx"][obj_id] = obj_idx
150
+ inference_state["obj_idx_to_id"][obj_idx] = obj_id
151
+ inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
152
+ # set up input and output structures for this object
153
+ inference_state["point_inputs_per_obj"][obj_idx] = {}
154
+ inference_state["mask_inputs_per_obj"][obj_idx] = {}
155
+ inference_state["output_dict_per_obj"][obj_idx] = {
156
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
157
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
158
+ }
159
+ inference_state["temp_output_dict_per_obj"][obj_idx] = {
160
+ "cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
161
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
162
+ }
163
+ return obj_idx
164
+ else:
165
+ raise RuntimeError(
166
+ f"Cannot add new object id {obj_id} after tracking starts. "
167
+ f"All existing object ids: {inference_state['obj_ids']}."
168
+ )
169
+
170
+ def _obj_idx_to_id(self, inference_state, obj_idx):
171
+ """Map model-side object index to client-side object id."""
172
+ return inference_state["obj_idx_to_id"][obj_idx]
173
+
174
+ def _get_obj_num(self, inference_state):
175
+ """Get the total number of unique object ids received so far in this session."""
176
+ return len(inference_state["obj_idx_to_id"])
177
+
178
+ @torch.inference_mode()
179
+ def add_new_points_or_box(
180
+ self,
181
+ inference_state,
182
+ frame_idx,
183
+ obj_id,
184
+ points=None,
185
+ labels=None,
186
+ clear_old_points=True,
187
+ rel_coordinates=True,
188
+ use_prev_mem_frame=False,
189
+ normalize_coords=True,
190
+ box=None,
191
+ ):
192
+ """Add new points to a frame."""
193
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
194
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
195
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
196
+
197
+ if (points is not None) != (labels is not None):
198
+ raise ValueError("points and labels must be provided together")
199
+ if points is None and box is None:
200
+ raise ValueError("at least one of points or box must be provided as input")
201
+
202
+ if points is None:
203
+ points = torch.zeros(0, 2, dtype=torch.float32)
204
+ elif not isinstance(points, torch.Tensor):
205
+ points = torch.tensor(points, dtype=torch.float32)
206
+ if labels is None:
207
+ labels = torch.zeros(0, dtype=torch.int32)
208
+ elif not isinstance(labels, torch.Tensor):
209
+ labels = torch.tensor(labels, dtype=torch.int32)
210
+ if points.dim() == 2:
211
+ points = points.unsqueeze(0) # add batch dimension
212
+ if labels.dim() == 1:
213
+ labels = labels.unsqueeze(0) # add batch dimension
214
+
215
+ if rel_coordinates:
216
+ # convert the points from relative coordinates to absolute coordinates
217
+ if points is not None:
218
+ points = points * self.image_size
219
+ if box is not None:
220
+ box = box * self.image_size
221
+
222
+ # If `box` is provided, we add it as the first two points with labels 2 and 3
223
+ # along with the user-provided points (consistent with how SAM 2 is trained).
224
+ if box is not None:
225
+ if not clear_old_points:
226
+ raise ValueError(
227
+ "cannot add box without clearing old points, since "
228
+ "box prompt must be provided before any point prompt "
229
+ "(please use clear_old_points=True instead)"
230
+ )
231
+ if not isinstance(box, torch.Tensor):
232
+ box = torch.tensor(box, dtype=torch.float32, device=points.device)
233
+ box_coords = box.reshape(1, 2, 2)
234
+ box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
235
+ box_labels = box_labels.reshape(1, 2)
236
+ points = torch.cat([box_coords, points], dim=1)
237
+ labels = torch.cat([box_labels, labels], dim=1)
238
+
239
+ points = points.to(inference_state["device"])
240
+ labels = labels.to(inference_state["device"])
241
+
242
+ if not clear_old_points:
243
+ point_inputs = point_inputs_per_frame.get(frame_idx, None)
244
+ else:
245
+ point_inputs = None
246
+ point_inputs = concat_points(point_inputs, points, labels)
247
+
248
+ point_inputs_per_frame[frame_idx] = point_inputs
249
+ mask_inputs_per_frame.pop(frame_idx, None)
250
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
251
+ # frame, meaning that the inputs points are to generate segments on this frame without
252
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
253
+ # the input points will be used to correct the already tracked masks.
254
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
255
+ # whether to track in reverse time order
256
+ if is_init_cond_frame:
257
+ reverse = False
258
+ else:
259
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
260
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
261
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
262
+ # Add a frame to conditioning output if it's an initial conditioning frame or
263
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
264
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
265
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
266
+
267
+ # Limit to a maximum number of input points to the prompt encoder (to reduce domain gap)
268
+ num_points = point_inputs["point_coords"].size(1)
269
+ if num_points > self.max_point_num_in_prompt_enc > 0:
270
+ num_first = self.max_point_num_in_prompt_enc // 2
271
+ num_last = self.max_point_num_in_prompt_enc - num_first
272
+ point_inputs["point_coords"] = torch.cat(
273
+ [
274
+ point_inputs["point_coords"][:, :num_first],
275
+ point_inputs["point_coords"][:, -num_last:],
276
+ ],
277
+ dim=1,
278
+ )
279
+ point_inputs["point_labels"] = torch.cat(
280
+ [
281
+ point_inputs["point_labels"][:, :num_first],
282
+ point_inputs["point_labels"][:, -num_last:],
283
+ ],
284
+ dim=1,
285
+ )
286
+ logging.warning(
287
+ f"Too many points ({num_points}) are provided on frame {frame_idx}. Only "
288
+ f"the first {num_first} points and the last {num_last} points will be used."
289
+ )
290
+ # Get any previously predicted mask logits on this object and feed it along with
291
+ # the new clicks into the SAM mask decoder when `self.iter_use_prev_mask_pred=True`.
292
+ prev_sam_mask_logits = None
293
+ if self.iter_use_prev_mask_pred:
294
+ # lookup temporary output dict first, which contains the most recent output
295
+ # (if not found, then lookup conditioning and non-conditioning frame output)
296
+ prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
297
+ if prev_out is None:
298
+ prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
299
+ if prev_out is None:
300
+ prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
301
+
302
+ if prev_out is not None and prev_out["pred_masks"] is not None:
303
+ prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
304
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
305
+ prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
306
+ current_out, _ = self._run_single_frame_inference(
307
+ inference_state=inference_state,
308
+ output_dict=obj_output_dict, # run on the slice of a single object
309
+ frame_idx=frame_idx,
310
+ batch_size=1, # run on the slice of a single object
311
+ is_init_cond_frame=is_init_cond_frame,
312
+ point_inputs=point_inputs,
313
+ mask_inputs=None,
314
+ reverse=reverse,
315
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
316
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
317
+ # allows us to enforce non-overlapping constraints on all objects before encoding
318
+ # them into memory.
319
+ run_mem_encoder=False,
320
+ prev_sam_mask_logits=prev_sam_mask_logits,
321
+ use_prev_mem_frame=use_prev_mem_frame,
322
+ )
323
+ # Add the output to the output dict (to be used as future memory)
324
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
325
+
326
+ # Resize the output mask to the original video resolution
327
+ obj_ids = inference_state["obj_ids"]
328
+ consolidated_out = self._consolidate_temp_output_across_obj(
329
+ inference_state,
330
+ frame_idx,
331
+ is_cond=is_cond,
332
+ run_mem_encoder=False,
333
+ consolidate_at_video_res=True,
334
+ )
335
+ _, video_res_masks = self._get_orig_video_res_output(
336
+ inference_state, consolidated_out["pred_masks_video_res"]
337
+ )
338
+ low_res_masks = None # not needed by the demo
339
+ return frame_idx, obj_ids, low_res_masks, video_res_masks
340
+
341
+ @torch.inference_mode()
342
+ def add_new_mask(
343
+ self,
344
+ inference_state,
345
+ frame_idx,
346
+ obj_id,
347
+ mask,
348
+ add_mask_to_memory=False,
349
+ ):
350
+ """Add new mask to a frame."""
351
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
352
+ point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
353
+ mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
354
+
355
+ assert mask.dim() == 2
356
+ mask_H, mask_W = mask.shape
357
+ mask_inputs_orig = mask[None, None] # add batch and channel dimension
358
+ mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
359
+
360
+ # resize the mask if it doesn't match the model's input mask size
361
+ if mask_H != self.input_mask_size or mask_W != self.input_mask_size:
362
+ mask_inputs = torch.nn.functional.interpolate(
363
+ mask_inputs_orig,
364
+ size=(self.input_mask_size, self.input_mask_size),
365
+ align_corners=False,
366
+ mode="bilinear",
367
+ antialias=True, # use antialias for downsampling
368
+ )
369
+ else:
370
+ mask_inputs = mask_inputs_orig
371
+
372
+ # also get the mask at the original video resolution (for outputting)
373
+ video_H = inference_state["video_height"]
374
+ video_W = inference_state["video_width"]
375
+ if mask_H != video_H or mask_W != video_W:
376
+ mask_inputs_video_res = torch.nn.functional.interpolate(
377
+ mask_inputs_orig,
378
+ size=(video_H, video_W),
379
+ align_corners=False,
380
+ mode="bilinear",
381
+ antialias=True, # use antialias for potential downsampling
382
+ )
383
+ else:
384
+ mask_inputs_video_res = mask_inputs_orig
385
+ # convert mask_inputs_video_res to binary (threshold at 0.5 as it is in range 0~1)
386
+ mask_inputs_video_res = mask_inputs_video_res > 0.5
387
+
388
+ mask_inputs_per_frame[frame_idx] = mask_inputs_video_res
389
+ point_inputs_per_frame.pop(frame_idx, None)
390
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
391
+ # frame, meaning that the inputs points are to generate segments on this frame without
392
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
393
+ # the input points will be used to correct the already tracked masks.
394
+ is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
395
+ # whether to track in reverse time order
396
+ if is_init_cond_frame:
397
+ reverse = False
398
+ else:
399
+ reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
400
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
401
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
402
+ # Add a frame to conditioning output if it's an initial conditioning frame or
403
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
404
+ is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
405
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
406
+
407
+ current_out, _ = self._run_single_frame_inference(
408
+ inference_state=inference_state,
409
+ output_dict=obj_output_dict, # run on the slice of a single object
410
+ frame_idx=frame_idx,
411
+ batch_size=1, # run on the slice of a single object
412
+ is_init_cond_frame=is_init_cond_frame,
413
+ point_inputs=None,
414
+ mask_inputs=mask_inputs,
415
+ reverse=reverse,
416
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
417
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
418
+ # allows us to enforce non-overlapping constraints on all objects before encoding
419
+ # them into memory.
420
+ run_mem_encoder=False,
421
+ )
422
+ # We directly use the input mask at video resolution as the output mask for a better
423
+ # video editing experience (so that the masks don't change after each brushing).
424
+ # Here NO_OBJ_SCORE is a large negative value to represent the background and
425
+ # similarly -NO_OBJ_SCORE is a large positive value to represent the foreground.
426
+ current_out["pred_masks"] = None
427
+ current_out["pred_masks_video_res"] = torch.where(
428
+ mask_inputs_video_res, -NO_OBJ_SCORE, NO_OBJ_SCORE
429
+ )
430
+ # Add the output to the output dict (to be used as future memory)
431
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
432
+ # Remove the overlapping proportion of other objects' input masks on this frame
433
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
434
+ for obj_idx2, obj_temp_output_dict2 in temp_output_dict_per_obj.items():
435
+ if obj_idx2 == obj_idx:
436
+ continue
437
+ current_out2 = obj_temp_output_dict2[storage_key].get(frame_idx, None)
438
+ if current_out2 is not None and "pred_masks_video_res" in current_out2:
439
+ current_out2["pred_masks_video_res"] = torch.where(
440
+ mask_inputs_video_res,
441
+ NO_OBJ_SCORE,
442
+ current_out2["pred_masks_video_res"],
443
+ )
444
+
445
+ # Resize the output mask to the original video resolution
446
+ obj_ids = inference_state["obj_ids"]
447
+ consolidated_out = self._consolidate_temp_output_across_obj(
448
+ inference_state,
449
+ frame_idx,
450
+ is_cond=is_cond,
451
+ run_mem_encoder=False,
452
+ consolidate_at_video_res=True,
453
+ )
454
+ _, video_res_masks = self._get_orig_video_res_output(
455
+ inference_state, consolidated_out["pred_masks_video_res"]
456
+ )
457
+ low_res_masks = None # not needed by the demo
458
+ return frame_idx, obj_ids, low_res_masks, video_res_masks
459
+
460
+ def add_new_points(self, *args, **kwargs):
461
+ """Deprecated method. Please use `add_new_points_or_box` instead."""
462
+ return self.add_new_points_or_box(*args, **kwargs)
463
+
464
+ def _get_orig_video_res_output(self, inference_state, any_res_masks):
465
+ """
466
+ Resize the object scores to the original video resolution (video_res_masks)
467
+ and apply non-overlapping constraints for final output.
468
+ """
469
+ device = inference_state["device"]
470
+ video_H = inference_state["video_height"]
471
+ video_W = inference_state["video_width"]
472
+ any_res_masks = any_res_masks.to(device, non_blocking=True)
473
+ if any_res_masks.shape[-2:] == (video_H, video_W):
474
+ video_res_masks = any_res_masks
475
+ else:
476
+ video_res_masks = torch.nn.functional.interpolate(
477
+ any_res_masks,
478
+ size=(video_H, video_W),
479
+ mode="bilinear",
480
+ align_corners=False,
481
+ )
482
+ if self.non_overlap_masks_for_output:
483
+ video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
484
+ # potentially fill holes in the predicted masks
485
+ if self.fill_hole_area > 0:
486
+ video_res_masks = fill_holes_in_mask_scores(
487
+ video_res_masks, self.fill_hole_area
488
+ )
489
+ return any_res_masks, video_res_masks
490
+
491
+ def _consolidate_temp_output_across_obj(
492
+ self,
493
+ inference_state,
494
+ frame_idx,
495
+ is_cond,
496
+ run_mem_encoder,
497
+ consolidate_at_video_res=False,
498
+ ):
499
+ """
500
+ Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
501
+ a frame into a single output for all objects, including
502
+ 1) fill any missing objects either from `output_dict_per_obj` (if they exist in
503
+ `output_dict_per_obj` for this frame) or leave them as placeholder values
504
+ (if they don't exist in `output_dict_per_obj` for this frame);
505
+ 2) if specified, rerun memory encoder after apply non-overlapping constraints
506
+ on the object scores.
507
+ """
508
+ batch_size = self._get_obj_num(inference_state)
509
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
510
+ # Optionally, we allow consolidating the temporary outputs at the original
511
+ # video resolution (to provide a better editing experience for mask prompts).
512
+ if consolidate_at_video_res:
513
+ assert not run_mem_encoder, "memory encoder cannot run at video resolution"
514
+ consolidated_H = inference_state["video_height"]
515
+ consolidated_W = inference_state["video_width"]
516
+ consolidated_mask_key = "pred_masks_video_res"
517
+ else:
518
+ consolidated_H = consolidated_W = self.low_res_mask_size
519
+ consolidated_mask_key = "pred_masks"
520
+
521
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
522
+ # will be added when rerunning the memory encoder after applying non-overlapping
523
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
524
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
525
+ consolidated_out = {
526
+ "maskmem_features": None,
527
+ "maskmem_pos_enc": None,
528
+ consolidated_mask_key: torch.full(
529
+ size=(batch_size, 1, consolidated_H, consolidated_W),
530
+ fill_value=NO_OBJ_SCORE,
531
+ dtype=torch.float32,
532
+ device=inference_state["storage_device"],
533
+ ),
534
+ "obj_ptr": torch.full(
535
+ size=(batch_size, self.hidden_dim),
536
+ fill_value=NO_OBJ_SCORE,
537
+ dtype=torch.float32,
538
+ device=inference_state["device"],
539
+ ),
540
+ "object_score_logits": torch.full(
541
+ size=(batch_size, 1),
542
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
543
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
544
+ fill_value=10.0,
545
+ dtype=torch.float32,
546
+ device=inference_state["device"],
547
+ ),
548
+ }
549
+ if self.use_memory_selection:
550
+ consolidated_out["iou_score"] = torch.full(
551
+ size=(batch_size, 1),
552
+ fill_value=0.0,
553
+ dtype=torch.float32,
554
+ device=inference_state["device"],
555
+ )
556
+ empty_mask_ptr = None
557
+ for obj_idx in range(batch_size):
558
+ obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
559
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
560
+ out = obj_temp_output_dict[storage_key].get(frame_idx, None)
561
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
562
+ # we fall back and look up its previous output in "output_dict_per_obj".
563
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
564
+ # "output_dict_per_obj" to find a previous output for this object.
565
+ if out is None:
566
+ out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
567
+ if out is None:
568
+ out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
569
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
570
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
571
+ # placeholder above) and set its object pointer to be a dummy pointer.
572
+ if out is None:
573
+ # Fill in dummy object pointers for those objects without any inputs or
574
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
575
+ # i.e. when we need to build the memory for tracking).
576
+ if run_mem_encoder:
577
+ if empty_mask_ptr is None:
578
+ empty_mask_ptr = self._get_empty_mask_ptr(
579
+ inference_state, frame_idx
580
+ )
581
+ # fill object pointer with a dummy pointer (based on an empty mask)
582
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
583
+ continue
584
+ # Add the temporary object output mask to consolidated output mask
585
+ # (use "pred_masks_video_res" if it's available)
586
+ obj_mask = out.get("pred_masks_video_res", out["pred_masks"])
587
+ consolidated_pred_masks = consolidated_out[consolidated_mask_key]
588
+ if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
589
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
590
+ else:
591
+ # Resize first if temporary object mask has a different resolution
592
+ is_downsampling = "pred_masks_video_res" in out
593
+ resized_obj_mask = torch.nn.functional.interpolate(
594
+ obj_mask,
595
+ size=consolidated_pred_masks.shape[-2:],
596
+ mode="bilinear",
597
+ align_corners=False,
598
+ antialias=is_downsampling, # use antialias for downsampling
599
+ )
600
+ consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
601
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
602
+ consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
603
+ "object_score_logits"
604
+ ]
605
+ if self.use_memory_selection:
606
+ consolidated_out["iou_score"][obj_idx : obj_idx + 1] = out["iou_score"]
607
+ # Optionally, apply non-overlapping constraints on the consolidated scores
608
+ # and rerun the memory encoder
609
+ if run_mem_encoder:
610
+ device = inference_state["device"]
611
+ high_res_masks = torch.nn.functional.interpolate(
612
+ consolidated_out["pred_masks"].to(device, non_blocking=True),
613
+ size=(self.image_size, self.image_size),
614
+ mode="bilinear",
615
+ align_corners=False,
616
+ )
617
+ high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
618
+ maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
619
+ inference_state=inference_state,
620
+ frame_idx=frame_idx,
621
+ batch_size=batch_size,
622
+ high_res_masks=high_res_masks,
623
+ object_score_logits=consolidated_out["object_score_logits"],
624
+ is_mask_from_pts=True, # these frames are what the user interacted with
625
+ )
626
+ consolidated_out["maskmem_features"] = maskmem_features
627
+ consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
628
+
629
+ return consolidated_out
630
+
631
+ def _get_empty_mask_ptr(self, inference_state, frame_idx):
632
+ """Get a dummy object pointer based on an empty mask on the current frame."""
633
+ # A dummy (empty) mask with a single object
634
+ batch_size = 1
635
+ mask_inputs = torch.zeros(
636
+ (batch_size, 1, self.image_size, self.image_size),
637
+ dtype=torch.float32,
638
+ device=inference_state["device"],
639
+ )
640
+
641
+ # Retrieve correct image features
642
+ (
643
+ image,
644
+ _,
645
+ current_vision_feats,
646
+ current_vision_pos_embeds,
647
+ feat_sizes,
648
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
649
+
650
+ # Feed the empty mask and image feature above to get a dummy object pointer
651
+ current_out = self.track_step(
652
+ frame_idx=frame_idx,
653
+ is_init_cond_frame=True,
654
+ current_vision_feats=current_vision_feats,
655
+ current_vision_pos_embeds=current_vision_pos_embeds,
656
+ feat_sizes=feat_sizes,
657
+ image=image,
658
+ point_inputs=None,
659
+ mask_inputs=mask_inputs,
660
+ gt_masks=None,
661
+ frames_to_add_correction_pt=[],
662
+ output_dict={
663
+ "cond_frame_outputs": {},
664
+ "non_cond_frame_outputs": {},
665
+ },
666
+ num_frames=inference_state["num_frames"],
667
+ track_in_reverse=False,
668
+ run_mem_encoder=False,
669
+ prev_sam_mask_logits=None,
670
+ )
671
+ return current_out["obj_ptr"]
672
+
673
+ @torch.inference_mode()
674
+ def propagate_in_video_preflight(self, inference_state, run_mem_encoder=True):
675
+ """Prepare inference_state and consolidate temporary outputs before tracking."""
676
+ # Tracking has started and we don't allow adding new objects until session is reset.
677
+ inference_state["tracking_has_started"] = True
678
+ batch_size = self._get_obj_num(inference_state)
679
+
680
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
681
+ # add them into "output_dict".
682
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
683
+ output_dict = inference_state["output_dict"]
684
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
685
+ # temporary outputs have been added (either in this call or any previous calls
686
+ # to `propagate_in_video_preflight`).
687
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
688
+ for is_cond in [False, True]:
689
+ # Separately consolidate conditioning and non-conditioning temp outptus
690
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
691
+ # Find all the frames that contain temporary outputs for any objects
692
+ # (these should be the frames that have just received clicks for mask inputs
693
+ # via `add_new_points` or `add_new_mask`)
694
+ temp_frame_inds = set()
695
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
696
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
697
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
698
+ # consolidate the temprary output across all objects on this frame
699
+ for frame_idx in temp_frame_inds:
700
+ consolidated_out = self._consolidate_temp_output_across_obj(
701
+ inference_state,
702
+ frame_idx,
703
+ is_cond=is_cond,
704
+ run_mem_encoder=run_mem_encoder,
705
+ )
706
+ # merge them into "output_dict" and also create per-object slices
707
+ output_dict[storage_key][frame_idx] = consolidated_out
708
+ self._add_output_per_object(
709
+ inference_state, frame_idx, consolidated_out, storage_key
710
+ )
711
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
712
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
713
+ )
714
+ if clear_non_cond_mem:
715
+ # clear non-conditioning memory of the surrounding frames
716
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
717
+
718
+ # clear temporary outputs in `temp_output_dict_per_obj`
719
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
720
+ obj_temp_output_dict[storage_key].clear()
721
+
722
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
723
+ # output on the same frame in "non_cond_frame_outputs"
724
+ for frame_idx in output_dict["cond_frame_outputs"]:
725
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
726
+ for obj_output_dict in inference_state["output_dict_per_obj"].values():
727
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
728
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
729
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
730
+ assert frame_idx in output_dict["cond_frame_outputs"]
731
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
732
+
733
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
734
+ # with either points or mask inputs (which should be true under a correct demo workflow).
735
+ all_consolidated_frame_inds = (
736
+ consolidated_frame_inds["cond_frame_outputs"]
737
+ | consolidated_frame_inds["non_cond_frame_outputs"]
738
+ )
739
+ input_frames_inds = set()
740
+ for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
741
+ input_frames_inds.update(point_inputs_per_frame.keys())
742
+ for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
743
+ input_frames_inds.update(mask_inputs_per_frame.keys())
744
+ assert all_consolidated_frame_inds == input_frames_inds
745
+ # Record the first interacted frame index (for tracking start)
746
+ if inference_state["first_ann_frame_idx"] is None:
747
+ inference_state["first_ann_frame_idx"] = min(
748
+ input_frames_inds, default=None
749
+ )
750
+ # In case `first_ann_frame_idx` is not in the conditioning frames (e.g. because
751
+ # we cleared the input points on that frame), pick the first conditioning frame
752
+ if (
753
+ inference_state["first_ann_frame_idx"]
754
+ not in output_dict["cond_frame_outputs"]
755
+ ):
756
+ inference_state["first_ann_frame_idx"] = min(
757
+ output_dict["cond_frame_outputs"], default=None
758
+ )
759
+
760
+ def _get_processing_order(
761
+ self, inference_state, start_frame_idx, max_frame_num_to_track, reverse
762
+ ):
763
+ num_frames = inference_state["num_frames"]
764
+ # set start index, end index, and processing order
765
+ if self.always_start_from_first_ann_frame:
766
+ # in this case, we always start tracking from the frame where we receive
767
+ # the initial annotation and ignore the provided start_frame_idx
768
+ start_frame_idx = inference_state["first_ann_frame_idx"]
769
+ if start_frame_idx is None:
770
+ # default: start from the earliest frame with input points
771
+ start_frame_idx = min(inference_state["output_dict"]["cond_frame_outputs"])
772
+ if max_frame_num_to_track is None:
773
+ # default: track all the frames in the video
774
+ max_frame_num_to_track = num_frames
775
+ if reverse:
776
+ end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
777
+ if start_frame_idx > 0:
778
+ processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
779
+ else:
780
+ # this is the edge case where we start from frame 0 and track in reverse order;
781
+ # in this case, we track a single frame (frame 0)
782
+ processing_order = [0]
783
+ else:
784
+ end_frame_idx = min(
785
+ start_frame_idx + max_frame_num_to_track, num_frames - 1
786
+ )
787
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
788
+ return processing_order
789
+
790
+ @torch.inference_mode()
791
+ def propagate_in_video(
792
+ self,
793
+ inference_state,
794
+ start_frame_idx,
795
+ max_frame_num_to_track,
796
+ reverse,
797
+ tqdm_disable=False,
798
+ obj_ids=None,
799
+ run_mem_encoder=True,
800
+ propagate_preflight=False,
801
+ ):
802
+ """Propagate the input points across frames to track in the entire video."""
803
+ if propagate_preflight:
804
+ self.propagate_in_video_preflight(inference_state)
805
+ # NOTE: This is a copy from the parent class, except that we return object scores as well.
806
+ output_dict = inference_state["output_dict"]
807
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
808
+ if obj_ids is not None:
809
+ raise NotImplementedError(
810
+ "Per-object tracking yet for batched inference if not implemented."
811
+ )
812
+ obj_ids = inference_state["obj_ids"]
813
+ batch_size = self._get_obj_num(inference_state)
814
+ if len(output_dict["cond_frame_outputs"]) == 0:
815
+ raise RuntimeError("No points are provided; please add points first")
816
+ clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
817
+ self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
818
+ )
819
+
820
+ processing_order = self._get_processing_order(
821
+ inference_state,
822
+ start_frame_idx,
823
+ max_frame_num_to_track,
824
+ reverse,
825
+ )
826
+
827
+ for frame_idx in tqdm(
828
+ processing_order, desc="propagate in video", disable=tqdm_disable
829
+ ):
830
+ # We skip those frames already in consolidated outputs (these are frames
831
+ # that received input clicks or mask). Note that we cannot directly run
832
+ # batched forward on them via `_run_single_frame_inference` because the
833
+ # number of clicks on each object might be different.
834
+ if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
835
+ storage_key = "cond_frame_outputs"
836
+ current_out = output_dict[storage_key][frame_idx]
837
+ pred_masks = current_out["pred_masks"]
838
+ obj_scores = current_out["object_score_logits"]
839
+ if clear_non_cond_mem:
840
+ # clear non-conditioning memory of the surrounding frames
841
+ self._clear_non_cond_mem_around_input(inference_state, frame_idx)
842
+ elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
843
+ storage_key = "non_cond_frame_outputs"
844
+ current_out = output_dict[storage_key][frame_idx]
845
+ pred_masks = current_out["pred_masks"]
846
+ obj_scores = current_out["object_score_logits"]
847
+ else:
848
+ storage_key = "non_cond_frame_outputs"
849
+ current_out, pred_masks = self._run_single_frame_inference(
850
+ inference_state=inference_state,
851
+ output_dict=output_dict,
852
+ frame_idx=frame_idx,
853
+ batch_size=batch_size,
854
+ is_init_cond_frame=False,
855
+ point_inputs=None,
856
+ mask_inputs=None,
857
+ reverse=reverse,
858
+ run_mem_encoder=run_mem_encoder,
859
+ )
860
+ obj_scores = current_out["object_score_logits"]
861
+ output_dict[storage_key][frame_idx] = current_out
862
+ # Create slices of per-object outputs for subsequent interaction with each
863
+ # individual object after tracking.
864
+ self._add_output_per_object(
865
+ inference_state, frame_idx, current_out, storage_key
866
+ )
867
+ inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
868
+
869
+ # Resize the output mask to the original video resolution (we directly use
870
+ # the mask scores on GPU for output to avoid any CPU conversion in between)
871
+ low_res_masks, video_res_masks = self._get_orig_video_res_output(
872
+ inference_state, pred_masks
873
+ )
874
+ yield frame_idx, obj_ids, low_res_masks, video_res_masks, obj_scores
875
+
876
+ def _add_output_per_object(
877
+ self, inference_state, frame_idx, current_out, storage_key
878
+ ):
879
+ """
880
+ Split a multi-object output into per-object output slices and add them into
881
+ `output_dict_per_obj`. The resulting slices share the same tensor storage.
882
+ """
883
+ maskmem_features = current_out["maskmem_features"]
884
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
885
+
886
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
887
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
888
+
889
+ output_dict_per_obj = inference_state["output_dict_per_obj"]
890
+ for obj_idx, obj_output_dict in output_dict_per_obj.items():
891
+ obj_slice = slice(obj_idx, obj_idx + 1)
892
+ obj_out = {
893
+ "maskmem_features": None,
894
+ "maskmem_pos_enc": None,
895
+ "pred_masks": current_out["pred_masks"][obj_slice],
896
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
897
+ "object_score_logits": current_out["object_score_logits"][obj_slice],
898
+ }
899
+ if self.use_memory_selection:
900
+ obj_out["iou_score"] = current_out["iou_score"][obj_slice]
901
+ if maskmem_features is not None:
902
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
903
+ if maskmem_pos_enc is not None:
904
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
905
+ obj_output_dict[storage_key][frame_idx] = obj_out
906
+
907
+ @torch.inference_mode()
908
+ def clear_all_points_in_frame(
909
+ self, inference_state, frame_idx, obj_id, need_output=True
910
+ ):
911
+ """Remove all input points or mask in a specific frame for a given object."""
912
+ obj_idx = self._obj_id_to_idx(inference_state, obj_id)
913
+
914
+ # Clear the conditioning information on the given frame
915
+ inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
916
+ inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
917
+
918
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
919
+ temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
920
+ temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
921
+
922
+ # Check and see if there are still any inputs left on this frame
923
+ batch_size = self._get_obj_num(inference_state)
924
+ frame_has_input = False
925
+ for obj_idx2 in range(batch_size):
926
+ if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
927
+ frame_has_input = True
928
+ break
929
+ if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
930
+ frame_has_input = True
931
+ break
932
+
933
+ # If this frame has no remaining inputs for any objects, we further clear its
934
+ # conditioning frame status
935
+ if not frame_has_input:
936
+ output_dict = inference_state["output_dict"]
937
+ consolidated_frame_inds = inference_state["consolidated_frame_inds"]
938
+ consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
939
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
940
+ # Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
941
+ out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
942
+ if out is not None:
943
+ # The frame is not a conditioning frame anymore since it's not receiving inputs,
944
+ # so we "downgrade" its output (if exists) to a non-conditioning frame output.
945
+ output_dict["non_cond_frame_outputs"][frame_idx] = out
946
+ inference_state["frames_already_tracked"].pop(frame_idx, None)
947
+ # Similarly, do it for the sliced output on each object.
948
+ for obj_idx2 in range(batch_size):
949
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
950
+ obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
951
+ if obj_out is not None:
952
+ obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
953
+
954
+ # If all the conditioning frames have been removed, we also clear the tracking outputs
955
+ if len(output_dict["cond_frame_outputs"]) == 0:
956
+ self._reset_tracking_results(inference_state)
957
+
958
+ if not need_output:
959
+ return
960
+ # Finally, output updated masks per object (after removing the inputs above)
961
+ obj_ids = inference_state["obj_ids"]
962
+ is_cond = any(
963
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
964
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
965
+ )
966
+ consolidated_out = self._consolidate_temp_output_across_obj(
967
+ inference_state,
968
+ frame_idx,
969
+ is_cond=is_cond,
970
+ run_mem_encoder=False,
971
+ consolidate_at_video_res=True,
972
+ )
973
+ _, video_res_masks = self._get_orig_video_res_output(
974
+ inference_state, consolidated_out["pred_masks_video_res"]
975
+ )
976
+ low_res_masks = None # not needed by the demo
977
+ return frame_idx, obj_ids, low_res_masks, video_res_masks
978
+
979
+ @torch.inference_mode()
980
+ def clear_all_points_in_video(self, inference_state):
981
+ """Remove all input points or mask in all frames throughout the video."""
982
+ self._reset_tracking_results(inference_state)
983
+ # Remove all object ids
984
+ inference_state["obj_id_to_idx"].clear()
985
+ inference_state["obj_idx_to_id"].clear()
986
+ inference_state["obj_ids"].clear()
987
+ inference_state["point_inputs_per_obj"].clear()
988
+ inference_state["mask_inputs_per_obj"].clear()
989
+ inference_state["output_dict_per_obj"].clear()
990
+ inference_state["temp_output_dict_per_obj"].clear()
991
+
992
+ def _reset_tracking_results(self, inference_state):
993
+ """Reset all tracking inputs and results across the videos."""
994
+ for v in inference_state["point_inputs_per_obj"].values():
995
+ v.clear()
996
+ for v in inference_state["mask_inputs_per_obj"].values():
997
+ v.clear()
998
+ for v in inference_state["output_dict_per_obj"].values():
999
+ v["cond_frame_outputs"].clear()
1000
+ v["non_cond_frame_outputs"].clear()
1001
+ for v in inference_state["temp_output_dict_per_obj"].values():
1002
+ v["cond_frame_outputs"].clear()
1003
+ v["non_cond_frame_outputs"].clear()
1004
+ inference_state["output_dict"]["cond_frame_outputs"].clear()
1005
+ inference_state["output_dict"]["non_cond_frame_outputs"].clear()
1006
+ inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
1007
+ inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
1008
+ inference_state["tracking_has_started"] = False
1009
+ inference_state["frames_already_tracked"].clear()
1010
+ inference_state["first_ann_frame_idx"] = None
1011
+
1012
+ def _get_image_feature(self, inference_state, frame_idx, batch_size):
1013
+ """Compute the image features on a given frame."""
1014
+ # Look up in the cache
1015
+ image, backbone_out = inference_state["cached_features"].get(
1016
+ frame_idx, (None, None)
1017
+ )
1018
+ if backbone_out is None:
1019
+ if self.backbone is None:
1020
+ raise RuntimeError(
1021
+ f"Image features for frame {frame_idx} are not cached. "
1022
+ "Please run inference on this frame first."
1023
+ )
1024
+ else:
1025
+ # Cache miss -- we will run inference on a single image
1026
+ image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
1027
+ backbone_out = self.forward_image(image)
1028
+ # Cache the most recent frame's feature (for repeated interactions with
1029
+ # a frame; we can use an LRU cache for more frames in the future).
1030
+ inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
1031
+ if "tracker_backbone_out" in backbone_out:
1032
+ backbone_out = backbone_out["tracker_backbone_out"] # get backbone output
1033
+
1034
+ # expand the features to have the same dimension as the number of objects
1035
+ expanded_image = image.expand(batch_size, -1, -1, -1)
1036
+ expanded_backbone_out = {
1037
+ "backbone_fpn": backbone_out["backbone_fpn"].copy(),
1038
+ "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
1039
+ }
1040
+ for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
1041
+ feat = feat.expand(batch_size, -1, -1, -1)
1042
+ expanded_backbone_out["backbone_fpn"][i] = feat
1043
+ for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
1044
+ pos = pos.expand(batch_size, -1, -1, -1)
1045
+ expanded_backbone_out["vision_pos_enc"][i] = pos
1046
+
1047
+ features = self._prepare_backbone_features(expanded_backbone_out)
1048
+ features = (expanded_image,) + features
1049
+ return features
1050
+
1051
+ def _run_single_frame_inference(
1052
+ self,
1053
+ inference_state,
1054
+ output_dict,
1055
+ frame_idx,
1056
+ batch_size,
1057
+ is_init_cond_frame,
1058
+ point_inputs,
1059
+ mask_inputs,
1060
+ reverse,
1061
+ run_mem_encoder,
1062
+ prev_sam_mask_logits=None,
1063
+ use_prev_mem_frame=True,
1064
+ ):
1065
+ """Run tracking on a single frame based on current inputs and previous memory."""
1066
+ # Retrieve correct image features
1067
+ (
1068
+ image,
1069
+ _,
1070
+ current_vision_feats,
1071
+ current_vision_pos_embeds,
1072
+ feat_sizes,
1073
+ ) = self._get_image_feature(inference_state, frame_idx, batch_size)
1074
+
1075
+ # point and mask should not appear as input simultaneously on the same frame
1076
+ assert point_inputs is None or mask_inputs is None
1077
+ current_out = self.track_step(
1078
+ frame_idx=frame_idx,
1079
+ is_init_cond_frame=is_init_cond_frame,
1080
+ current_vision_feats=current_vision_feats,
1081
+ current_vision_pos_embeds=current_vision_pos_embeds,
1082
+ feat_sizes=feat_sizes,
1083
+ image=image,
1084
+ point_inputs=point_inputs,
1085
+ mask_inputs=mask_inputs,
1086
+ output_dict=output_dict,
1087
+ num_frames=inference_state["num_frames"],
1088
+ track_in_reverse=reverse,
1089
+ run_mem_encoder=run_mem_encoder,
1090
+ prev_sam_mask_logits=prev_sam_mask_logits,
1091
+ use_prev_mem_frame=use_prev_mem_frame,
1092
+ )
1093
+
1094
+ # optionally offload the output to CPU memory to save GPU space
1095
+ storage_device = inference_state["storage_device"]
1096
+ maskmem_features = current_out["maskmem_features"]
1097
+ if maskmem_features is not None:
1098
+ maskmem_features = maskmem_features.to(torch.bfloat16)
1099
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
1100
+ pred_masks_gpu = current_out["pred_masks"]
1101
+ pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
1102
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1103
+ maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
1104
+ # object pointer is a small tensor, so we always keep it on GPU memory for fast access
1105
+ obj_ptr = current_out["obj_ptr"]
1106
+ object_score_logits = current_out["object_score_logits"]
1107
+ # make a compact version of this frame's output to reduce the state size
1108
+ compact_current_out = {
1109
+ "maskmem_features": maskmem_features,
1110
+ "maskmem_pos_enc": maskmem_pos_enc,
1111
+ "pred_masks": pred_masks,
1112
+ "obj_ptr": obj_ptr,
1113
+ "object_score_logits": object_score_logits,
1114
+ }
1115
+ if self.use_memory_selection:
1116
+ compact_current_out["iou_score"] = current_out["iou_score"]
1117
+ compact_current_out["eff_iou_score"] = current_out["eff_iou_score"]
1118
+ return compact_current_out, pred_masks_gpu
1119
+
1120
+ def _run_memory_encoder(
1121
+ self,
1122
+ inference_state,
1123
+ frame_idx,
1124
+ batch_size,
1125
+ high_res_masks,
1126
+ object_score_logits,
1127
+ is_mask_from_pts,
1128
+ ):
1129
+ """
1130
+ Run the memory encoder on `high_res_masks`. This is usually after applying
1131
+ non-overlapping constraints to object scores. Since their scores changed, their
1132
+ memory also need to be computed again with the memory encoder.
1133
+ """
1134
+ # Retrieve correct image features
1135
+ image, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
1136
+ inference_state, frame_idx, batch_size
1137
+ )
1138
+ maskmem_features, maskmem_pos_enc = self._encode_new_memory(
1139
+ image=image,
1140
+ current_vision_feats=current_vision_feats,
1141
+ feat_sizes=feat_sizes,
1142
+ pred_masks_high_res=high_res_masks,
1143
+ object_score_logits=object_score_logits,
1144
+ is_mask_from_pts=is_mask_from_pts,
1145
+ )
1146
+
1147
+ # optionally offload the output to CPU memory to save GPU space
1148
+ storage_device = inference_state["storage_device"]
1149
+ maskmem_features = maskmem_features.to(torch.bfloat16)
1150
+ maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
1151
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1152
+ maskmem_pos_enc = self._get_maskmem_pos_enc(
1153
+ inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
1154
+ )
1155
+ return maskmem_features, maskmem_pos_enc
1156
+
1157
+ def _get_maskmem_pos_enc(self, inference_state, current_out):
1158
+ """
1159
+ `maskmem_pos_enc` is the same across frames and objects, so we cache it as
1160
+ a constant in the inference session to reduce session storage size.
1161
+ """
1162
+ model_constants = inference_state["constants"]
1163
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
1164
+ out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
1165
+ if out_maskmem_pos_enc is not None:
1166
+ if "maskmem_pos_enc" not in model_constants:
1167
+ assert isinstance(out_maskmem_pos_enc, list)
1168
+ # only take the slice for one object, since it's same across objects
1169
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1170
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
1171
+ else:
1172
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
1173
+ # expand the cached maskmem_pos_enc to the actual batch size
1174
+ batch_size = out_maskmem_pos_enc[0].size(0)
1175
+ expanded_maskmem_pos_enc = [
1176
+ x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
1177
+ ]
1178
+ else:
1179
+ expanded_maskmem_pos_enc = None
1180
+ return expanded_maskmem_pos_enc
1181
+
1182
+ @torch.inference_mode()
1183
+ def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
1184
+ """
1185
+ Remove an object id from the tracking state. If strict is True, we check whether
1186
+ the object id actually exists and raise an error if it doesn't exist.
1187
+ """
1188
+ old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
1189
+ updated_frames = []
1190
+ # Check whether this object_id to remove actually exists and possibly raise an error.
1191
+ if old_obj_idx_to_rm is None:
1192
+ if not strict:
1193
+ return inference_state["obj_ids"], updated_frames
1194
+ raise RuntimeError(
1195
+ f"Cannot remove object id {obj_id} as it doesn't exist. "
1196
+ f"All existing object ids: {inference_state['obj_ids']}."
1197
+ )
1198
+
1199
+ # If this is the only remaining object id, we simply reset the state.
1200
+ if len(inference_state["obj_id_to_idx"]) == 1:
1201
+ self.clear_all_points_in_video(inference_state)
1202
+ return inference_state["obj_ids"], updated_frames
1203
+
1204
+ # There are still remaining objects after removing this object id. In this case,
1205
+ # we need to delete the object storage from inference state tensors.
1206
+ # Step 0: clear the input on those frames where this object id has point or mask input
1207
+ # (note that this step is required as it might downgrade conditioning frames to
1208
+ # non-conditioning ones)
1209
+ obj_input_frames_inds = set()
1210
+ obj_input_frames_inds.update(
1211
+ inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
1212
+ )
1213
+ obj_input_frames_inds.update(
1214
+ inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
1215
+ )
1216
+ for frame_idx in obj_input_frames_inds:
1217
+ self.clear_all_points_in_frame(
1218
+ inference_state, frame_idx, obj_id, need_output=False
1219
+ )
1220
+
1221
+ # Step 1: Update the object id mapping (note that it must be done after Step 0,
1222
+ # since Step 0 still requires the old object id mappings in inference_state)
1223
+ old_obj_ids = inference_state["obj_ids"]
1224
+ old_obj_inds = list(range(len(old_obj_ids)))
1225
+ remain_old_obj_inds = old_obj_inds.copy()
1226
+ remain_old_obj_inds.remove(old_obj_idx_to_rm)
1227
+ new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
1228
+ new_obj_inds = list(range(len(new_obj_ids)))
1229
+ # build new mappings
1230
+ old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
1231
+ inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
1232
+ inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
1233
+ inference_state["obj_ids"] = new_obj_ids
1234
+
1235
+ # Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
1236
+ # (note that "consolidated_frame_inds" doesn't need to be updated in this step as
1237
+ # it's already handled in Step 0)
1238
+ def _map_keys(container):
1239
+ new_kvs = []
1240
+ for k in old_obj_inds:
1241
+ v = container.pop(k)
1242
+ if k in old_idx_to_new_idx:
1243
+ new_kvs.append((old_idx_to_new_idx[k], v))
1244
+ container.update(new_kvs)
1245
+
1246
+ _map_keys(inference_state["point_inputs_per_obj"])
1247
+ _map_keys(inference_state["mask_inputs_per_obj"])
1248
+ _map_keys(inference_state["output_dict_per_obj"])
1249
+ _map_keys(inference_state["temp_output_dict_per_obj"])
1250
+
1251
+ # Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
1252
+ def _slice_state(output_dict, storage_key):
1253
+ for frame_idx, out in output_dict[storage_key].items():
1254
+ out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
1255
+ out["maskmem_pos_enc"] = [
1256
+ x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
1257
+ ]
1258
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
1259
+ out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
1260
+ out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
1261
+ out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
1262
+ out["object_score_logits"] = out["object_score_logits"][
1263
+ remain_old_obj_inds
1264
+ ]
1265
+ if self.use_memory_selection:
1266
+ out["iou_score"] = out["iou_score"][remain_old_obj_inds]
1267
+ out["eff_iou_score"] = self.cal_mem_score(
1268
+ out["object_score_logits"], out["iou_score"]
1269
+ ) # recalculate the memory frame score
1270
+ # also update the per-object slices
1271
+ self._add_output_per_object(
1272
+ inference_state, frame_idx, out, storage_key
1273
+ )
1274
+
1275
+ _slice_state(inference_state["output_dict"], "cond_frame_outputs")
1276
+ _slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
1277
+
1278
+ # Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
1279
+ # could show an updated mask for objects previously occluded by the object being removed
1280
+ if need_output:
1281
+ temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
1282
+ for frame_idx in obj_input_frames_inds:
1283
+ is_cond = any(
1284
+ frame_idx in obj_temp_output_dict["cond_frame_outputs"]
1285
+ for obj_temp_output_dict in temp_output_dict_per_obj.values()
1286
+ )
1287
+ consolidated_out = self._consolidate_temp_output_across_obj(
1288
+ inference_state,
1289
+ frame_idx,
1290
+ is_cond=is_cond,
1291
+ run_mem_encoder=False,
1292
+ consolidate_at_video_res=True,
1293
+ )
1294
+ _, video_res_masks = self._get_orig_video_res_output(
1295
+ inference_state, consolidated_out["pred_masks_video_res"]
1296
+ )
1297
+ updated_frames.append((frame_idx, video_res_masks))
1298
+
1299
+ return inference_state["obj_ids"], updated_frames
1300
+
1301
+ def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
1302
+ """
1303
+ Remove the non-conditioning memory around the input frame. When users provide
1304
+ correction clicks, the surrounding frames' non-conditioning memories can still
1305
+ contain outdated object appearance information and could confuse the model.
1306
+
1307
+ This method clears those non-conditioning memories surrounding the interacted
1308
+ frame to avoid giving the model both old and new information about the object.
1309
+ """
1310
+ r = self.memory_temporal_stride_for_eval
1311
+ frame_idx_begin = frame_idx - r * self.num_maskmem
1312
+ frame_idx_end = frame_idx + r * self.num_maskmem
1313
+ batch_size = self._get_obj_num(inference_state)
1314
+ for obj_idx in range(batch_size):
1315
+ obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
1316
+ non_cond_frame_outputs = obj_output_dict["non_cond_frame_outputs"]
1317
+ for t in range(frame_idx_begin, frame_idx_end + 1):
1318
+ non_cond_frame_outputs.pop(t, None)
1319
+
1320
+ def _suppress_shrinked_masks(
1321
+ self, pred_masks, new_pred_masks, shrink_threshold=0.3
1322
+ ):
1323
+ area_before = (pred_masks > 0).sum(dim=(-1, -2))
1324
+ area_after = (new_pred_masks > 0).sum(dim=(-1, -2))
1325
+ area_before = torch.clamp(area_before, min=1.0)
1326
+ area_ratio = area_after / area_before
1327
+ keep = area_ratio >= shrink_threshold
1328
+ keep_mask = keep[..., None, None].expand_as(pred_masks)
1329
+ pred_masks_after = torch.where(
1330
+ keep_mask, pred_masks, torch.clamp(pred_masks, max=-10.0)
1331
+ )
1332
+ return pred_masks_after
1333
+
1334
+ def _suppress_object_pw_area_shrinkage(self, pred_masks):
1335
+ """
1336
+ This function suppresses masks that shrink in area after applying pixelwise non-overlapping constriants.
1337
+ Note that the final output can still be overlapping.
1338
+ """
1339
+ # Apply pixel-wise non-overlapping constraint based on mask scores
1340
+ pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints(
1341
+ pred_masks
1342
+ )
1343
+ # Fully suppress masks with high shrinkage (probably noisy) based on the pixel wise non-overlapping constraints
1344
+ # NOTE: The output of this function can be a no op if none of the masks shrinked by a large factor.
1345
+ pred_masks = self._suppress_shrinked_masks(
1346
+ pred_masks, pixel_level_non_overlapping_masks
1347
+ )
1348
+ return pred_masks
1349
+
1350
+ def _apply_object_wise_non_overlapping_constraints(
1351
+ self, pred_masks, obj_scores, background_value=-10.0
1352
+ ):
1353
+ """
1354
+ Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region)
1355
+ """
1356
+ # Replace pixel scores with object scores
1357
+ pred_masks_single_score = torch.where(
1358
+ pred_masks > 0, obj_scores[..., None, None], background_value
1359
+ )
1360
+ # Apply pixel-wise non-overlapping constraint based on mask scores
1361
+ pixel_level_non_overlapping_masks = super()._apply_non_overlapping_constraints(
1362
+ pred_masks_single_score
1363
+ )
1364
+ # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region
1365
+ pred_masks = torch.where(
1366
+ pixel_level_non_overlapping_masks > 0,
1367
+ pred_masks,
1368
+ torch.clamp(pred_masks, max=background_value),
1369
+ )
1370
+ return pred_masks
detect_tools/sam3/sam3/model/sam3_video_base.py ADDED
@@ -0,0 +1,1767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import datetime
4
+ import logging
5
+ import math
6
+ import os
7
+ from collections import defaultdict
8
+ from copy import deepcopy
9
+ from enum import Enum
10
+ from typing import Any, Dict, List, Set
11
+
12
+ import numpy as np
13
+ import numpy.typing as npt
14
+ import torch
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ from sam3 import perflib
19
+ from sam3.logger import get_logger
20
+ from sam3.model.box_ops import fast_diag_box_iou
21
+ from sam3.model.data_misc import BatchedDatapoint
22
+ from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores, mask_to_box
23
+ from sam3.perflib.masks_ops import mask_iou
24
+ from sam3.train.masks_ops import rle_encode
25
+ from torch import nn, Tensor
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class MaskletConfirmationStatus(Enum):
31
+ UNCONFIRMED = 1 # newly added masklet, not confirmed by any detection yet
32
+ CONFIRMED = 2 # confirmed by at least one detection
33
+
34
+
35
+ class Sam3VideoBase(nn.Module):
36
+ def __init__(
37
+ self,
38
+ detector: nn.Module,
39
+ tracker: nn.Module,
40
+ # prob threshold for detection outputs -- only keep detections above this threshold
41
+ # enters NMS and det-to-track matching
42
+ score_threshold_detection=0.5,
43
+ # IoU threshold for detection NMS
44
+ det_nms_thresh=0.0,
45
+ # IoU threshold for det-to-track matching -- a detection is considered "matched" to a tracklet it
46
+ # overlaps with a tracklet above this threshold -- it is often a loose threshold like 0.1
47
+ assoc_iou_thresh=0.5,
48
+ # IoU threshold for det-to-track matching, which is used to determine whether a masklet is "unmatched"
49
+ # by any detections -- it is often a stricter threshold like 0.5
50
+ trk_assoc_iou_thresh=0.5,
51
+ # prob threshold for a detection to be added as a new object
52
+ new_det_thresh=0.0,
53
+ # hotstart parameters: we hold off the outputs for `hotstart_delay` frames and
54
+ # 1) remove those tracklets unmatched by any detections based on `hotstart_unmatch_thresh`
55
+ # 2) remove those tracklets overlapping with one another based on `hotstart_dup_thresh`
56
+ hotstart_delay=0,
57
+ hotstart_unmatch_thresh=3,
58
+ hotstart_dup_thresh=3,
59
+ # Whether to suppress masks only within hotstart. If False, we can suppress masks even if they start before hotstart period.
60
+ suppress_unmatched_only_within_hotstart=True,
61
+ init_trk_keep_alive=0,
62
+ max_trk_keep_alive=8,
63
+ min_trk_keep_alive=-4,
64
+ # Threshold for suppressing overlapping objects based on recent occlusion
65
+ suppress_overlapping_based_on_recent_occlusion_threshold=0.0,
66
+ decrease_trk_keep_alive_for_empty_masklets=False,
67
+ o2o_matching_masklets_enable=False, # Enable hungarian matching to match existing masklets
68
+ suppress_det_close_to_boundary=False,
69
+ fill_hole_area=16,
70
+ # The maximum number of objects (masklets) to track across all GPUs (for no limit, set it to -1)
71
+ max_num_objects=-1,
72
+ recondition_every_nth_frame=-1,
73
+ # masket confirmation status (to suppress unconfirmed masklets)
74
+ masklet_confirmation_enable=False,
75
+ # a masklet is confirmed after being consecutively detected and matched for
76
+ # `masklet_confirmation_consecutive_det_thresh`
77
+ masklet_confirmation_consecutive_det_thresh=3,
78
+ # bbox heuristic parameters
79
+ reconstruction_bbox_iou_thresh=0.0,
80
+ reconstruction_bbox_det_score=0.0,
81
+ ):
82
+ super().__init__()
83
+ self.detector = detector
84
+ self.tracker = tracker
85
+ self.score_threshold_detection = score_threshold_detection
86
+ self.det_nms_thresh = det_nms_thresh
87
+ self.assoc_iou_thresh = assoc_iou_thresh
88
+ self.trk_assoc_iou_thresh = trk_assoc_iou_thresh
89
+ self.new_det_thresh = new_det_thresh
90
+
91
+ # hotstart parameters
92
+ if hotstart_delay > 0:
93
+ assert hotstart_unmatch_thresh <= hotstart_delay
94
+ assert hotstart_dup_thresh <= hotstart_delay
95
+ self.hotstart_delay = hotstart_delay
96
+ self.hotstart_unmatch_thresh = hotstart_unmatch_thresh
97
+ self.hotstart_dup_thresh = hotstart_dup_thresh
98
+ self.suppress_unmatched_only_within_hotstart = (
99
+ suppress_unmatched_only_within_hotstart
100
+ )
101
+ self.init_trk_keep_alive = init_trk_keep_alive
102
+ self.max_trk_keep_alive = max_trk_keep_alive
103
+ self.min_trk_keep_alive = min_trk_keep_alive
104
+ self.suppress_overlapping_based_on_recent_occlusion_threshold = (
105
+ suppress_overlapping_based_on_recent_occlusion_threshold
106
+ )
107
+ self.suppress_det_close_to_boundary = suppress_det_close_to_boundary
108
+ self.decrease_trk_keep_alive_for_empty_masklets = (
109
+ decrease_trk_keep_alive_for_empty_masklets
110
+ )
111
+ self.o2o_matching_masklets_enable = o2o_matching_masklets_enable
112
+ self.fill_hole_area = fill_hole_area
113
+ self.eval()
114
+ self.rank = int(os.getenv("RANK", "0"))
115
+ self.world_size = int(os.getenv("WORLD_SIZE", "1"))
116
+ self._dist_pg_cpu = None # CPU process group (lazy-initialized on first use)
117
+
118
+ # the maximum object number
119
+ if max_num_objects > 0:
120
+ num_obj_for_compile = math.ceil(max_num_objects / self.world_size)
121
+ else:
122
+ max_num_objects = 10000 # no limit
123
+ num_obj_for_compile = 16
124
+ logger.info(f"setting {max_num_objects=} and {num_obj_for_compile=}")
125
+ self.max_num_objects = max_num_objects
126
+ self.num_obj_for_compile = num_obj_for_compile
127
+ self.recondition_every_nth_frame = recondition_every_nth_frame
128
+ self.masklet_confirmation_enable = masklet_confirmation_enable
129
+ self.masklet_confirmation_consecutive_det_thresh = (
130
+ masklet_confirmation_consecutive_det_thresh
131
+ )
132
+ self.reconstruction_bbox_iou_thresh = reconstruction_bbox_iou_thresh
133
+ self.reconstruction_bbox_det_score = reconstruction_bbox_det_score
134
+
135
+ @property
136
+ def device(self):
137
+ self._device = getattr(self, "_device", None) or next(self.parameters()).device
138
+ return self._device
139
+
140
+ def _init_dist_pg_cpu(self):
141
+ # a short 3-min timeout to quickly detect any synchronization failures
142
+ timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180"))
143
+ timeout = datetime.timedelta(seconds=timeout_sec)
144
+ self._dist_pg_cpu = dist.new_group(backend="gloo", timeout=timeout)
145
+
146
+ def broadcast_python_obj_cpu(self, python_obj_list, src):
147
+ if self._dist_pg_cpu is None:
148
+ self._init_dist_pg_cpu()
149
+ dist.broadcast_object_list(python_obj_list, src=src, group=self._dist_pg_cpu)
150
+
151
+ def _det_track_one_frame(
152
+ self,
153
+ frame_idx: int,
154
+ num_frames: int,
155
+ reverse: bool,
156
+ input_batch: BatchedDatapoint,
157
+ geometric_prompt: Any,
158
+ tracker_states_local: List[Any],
159
+ tracker_metadata_prev: Dict[str, Any],
160
+ feature_cache: Dict,
161
+ orig_vid_height: int,
162
+ orig_vid_width: int,
163
+ is_image_only: bool = False,
164
+ allow_new_detections: bool = True,
165
+ ):
166
+ """
167
+ This function handles one-step inference for the DenseTracking model in an SPMD manner.
168
+ At a high-level, all GPUs execute the same function calls as if it's done on a single GPU,
169
+ while under the hood, some function calls involve distributed computation based on sharded
170
+ SAM2 states.
171
+
172
+ - `input_batch` contains image and other inputs on the entire video; it should be identical across GPUs
173
+ - `tracker_states_local` holds the local masklet information in this GPU shard
174
+ - `tracker_metadata_prev` manages the metadata for SAM2 objects, such as which masklet is hold on which GPUs
175
+ it contains both global and local masklet information
176
+ """
177
+
178
+ # Step 1: run backbone and detector in a distributed manner -- this is done via Sam3ImageOnVideoMultiGPU,
179
+ # a MultiGPU model (assigned to `self.detector`) that shards frames in a round-robin manner.
180
+ # It returns a "det_out" dict for `frame_idx` and fills SAM2 backbone features for `frame_idx`
181
+ # into `feature_cache`. Despite its distributed inference under the hood, the results would be
182
+ # the same as if it is running backbone and detector for every frame on a single GPU.
183
+ det_out = self.run_backbone_and_detection(
184
+ frame_idx=frame_idx,
185
+ num_frames=num_frames,
186
+ reverse=reverse,
187
+ input_batch=input_batch,
188
+ geometric_prompt=geometric_prompt,
189
+ feature_cache=feature_cache,
190
+ allow_new_detections=allow_new_detections,
191
+ )
192
+
193
+ # Step 2: each GPU propagates its local SAM2 states to get the SAM2 prediction masks.
194
+ # the returned `tracker_low_res_masks_global` contains the concatenated masklet predictions
195
+ # gathered from all GPUs (as if they are propagated on a single GPU). Note that this step only
196
+ # runs the SAM2 propagation step, but doesn't encode new memory for the predicted masks;
197
+ # we defer memory encoding to `run_tracker_update_execution_phase` after resolving all heuristics.
198
+ if tracker_metadata_prev == {}:
199
+ # initialize masklet metadata if it's uninitialized (empty dict)
200
+ tracker_metadata_prev.update(self._initialize_metadata())
201
+ tracker_low_res_masks_global, tracker_obj_scores_global = (
202
+ self.run_tracker_propagation(
203
+ frame_idx=frame_idx,
204
+ num_frames=num_frames,
205
+ reverse=reverse,
206
+ tracker_states_local=tracker_states_local,
207
+ tracker_metadata_prev=tracker_metadata_prev,
208
+ )
209
+ )
210
+
211
+ # Step 3: based on detection outputs and the propagated SAM2 prediction masks, we make plans
212
+ # for SAM2 masklet updates (i.e. which objects to add and remove, how to load-balance them, etc).
213
+ # We also run SAM2 memory encoder globally in this step to resolve non-overlapping constraints.
214
+ # **This step should involve all the heuristics needed for any updates.** Most of the update
215
+ # planning will be done on the master rank (GPU 0) and the resulting plan `tracker_update_plan` is
216
+ # broadcasted to other GPUs (to be executed in a distributed manner). This step also generates the
217
+ # new masklet metadata `tracker_metadata_new` (based on its previous version `tracker_metadata_prev`).
218
+ tracker_update_plan, tracker_metadata_new = (
219
+ self.run_tracker_update_planning_phase(
220
+ frame_idx=frame_idx,
221
+ num_frames=num_frames,
222
+ reverse=reverse,
223
+ det_out=det_out,
224
+ tracker_low_res_masks_global=tracker_low_res_masks_global,
225
+ tracker_obj_scores_global=tracker_obj_scores_global,
226
+ tracker_metadata_prev=tracker_metadata_prev,
227
+ tracker_states_local=tracker_states_local,
228
+ is_image_only=is_image_only,
229
+ )
230
+ )
231
+
232
+ # Get reconditioning info from the update plan
233
+ reconditioned_obj_ids = tracker_update_plan.get("reconditioned_obj_ids", set())
234
+ det_to_matched_trk_obj_ids = tracker_update_plan.get(
235
+ "det_to_matched_trk_obj_ids", {}
236
+ )
237
+
238
+ # Step 4: based on `tracker_update_plan`, each GPU executes the update w.r.t. its local SAM2 inference states
239
+ tracker_states_local_new = self.run_tracker_update_execution_phase(
240
+ frame_idx=frame_idx,
241
+ num_frames=num_frames,
242
+ reverse=reverse,
243
+ det_out=det_out,
244
+ tracker_states_local=tracker_states_local,
245
+ tracker_update_plan=tracker_update_plan,
246
+ orig_vid_height=orig_vid_height,
247
+ orig_vid_width=orig_vid_width,
248
+ feature_cache=feature_cache,
249
+ )
250
+
251
+ # Step 5: finally, build the outputs for this frame (it only needs to be done on GPU 0 since
252
+ # only GPU 0 will send outputs to the server).
253
+ if self.rank == 0:
254
+ obj_id_to_mask = self.build_outputs(
255
+ frame_idx=frame_idx,
256
+ num_frames=num_frames,
257
+ reverse=reverse,
258
+ det_out=det_out,
259
+ tracker_low_res_masks_global=tracker_low_res_masks_global,
260
+ tracker_obj_scores_global=tracker_obj_scores_global,
261
+ tracker_metadata_prev=tracker_metadata_prev,
262
+ tracker_update_plan=tracker_update_plan,
263
+ orig_vid_height=orig_vid_height,
264
+ orig_vid_width=orig_vid_width,
265
+ reconditioned_obj_ids=reconditioned_obj_ids,
266
+ det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
267
+ )
268
+ obj_id_to_score = tracker_metadata_new["obj_id_to_score"]
269
+ else:
270
+ obj_id_to_mask, obj_id_to_score = {}, {} # dummy outputs on other GPUs
271
+ # a few statistics for the current frame as a part of the output
272
+ frame_stats = {
273
+ "num_obj_tracked": np.sum(tracker_metadata_new["num_obj_per_gpu"]),
274
+ "num_obj_dropped": tracker_update_plan["num_obj_dropped_due_to_limit"],
275
+ }
276
+ # add tracker scores to metadata, it should be fired for frames except the first frame
277
+ if tracker_obj_scores_global.shape[0] > 0:
278
+ # Convert tracker_obj_scores_global to sigmoid scores before updating
279
+ tracker_obj_scores_global = tracker_obj_scores_global.sigmoid().tolist()
280
+ tracker_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
281
+ tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
282
+ frame_idx
283
+ ].update(dict(zip(tracker_obj_ids, tracker_obj_scores_global)))
284
+ return (
285
+ obj_id_to_mask, # a dict: obj_id --> output mask
286
+ obj_id_to_score, # a dict: obj_id --> output score (prob)
287
+ tracker_states_local_new,
288
+ tracker_metadata_new,
289
+ frame_stats,
290
+ tracker_obj_scores_global, # a dict: obj_id --> tracker frame-level scores
291
+ )
292
+
293
+ def _suppress_detections_close_to_boundary(self, boxes, margin=0.025):
294
+ """
295
+ Suppress detections too close to image edges (for normalized boxes).
296
+
297
+ boxes: (N, 4) in xyxy format, normalized [0,1]
298
+ margin: fraction of image
299
+ """
300
+ x_min, y_min, x_max, y_max = boxes.unbind(-1)
301
+ x_c = (x_min + x_max) / 2
302
+ y_c = (y_min + y_max) / 2
303
+ keep = (
304
+ (x_c > margin)
305
+ & (x_c < 1.0 - margin)
306
+ & (y_c > margin)
307
+ & (y_c < 1.0 - margin)
308
+ )
309
+
310
+ return keep
311
+
312
+ def run_backbone_and_detection(
313
+ self,
314
+ frame_idx: int,
315
+ num_frames: int,
316
+ input_batch: BatchedDatapoint,
317
+ geometric_prompt: Any,
318
+ feature_cache: Dict,
319
+ reverse: bool,
320
+ allow_new_detections: bool,
321
+ ):
322
+ # Step 1: if text feature is not cached in `feature_cache`, compute and cache it
323
+ text_batch_key = tuple(input_batch.find_text_batch)
324
+ if "text" not in feature_cache or text_batch_key not in feature_cache["text"]:
325
+ text_outputs = self.detector.backbone.forward_text(
326
+ input_batch.find_text_batch, device=self.device
327
+ )
328
+ # note: we only cache the text feature of the most recent prompt
329
+ feature_cache["text"] = {text_batch_key: text_outputs}
330
+ else:
331
+ text_outputs = feature_cache["text"][text_batch_key]
332
+
333
+ # Step 2: run backbone, detector, and post-processing with NMS
334
+ if "multigpu_buffer" not in feature_cache:
335
+ # "multigpu_buffer" is a buffer cache used by `self.detector` and it needs
336
+ # to be passed to `forward_video_grounding_multigpu` for every call
337
+ feature_cache["multigpu_buffer"] = {}
338
+
339
+ # Extract max_frame_num_to_track from feature_cache if available
340
+ tracking_bounds = feature_cache.get("tracking_bounds", {})
341
+ max_frame_num_to_track = tracking_bounds.get("max_frame_num_to_track")
342
+ start_frame_idx = tracking_bounds.get("propagate_in_video_start_frame_idx")
343
+
344
+ sam3_image_out, _ = self.detector.forward_video_grounding_multigpu(
345
+ backbone_out={
346
+ "img_batch_all_stages": input_batch.img_batch,
347
+ **text_outputs,
348
+ },
349
+ find_inputs=input_batch.find_inputs,
350
+ geometric_prompt=geometric_prompt,
351
+ frame_idx=frame_idx,
352
+ num_frames=num_frames,
353
+ multigpu_buffer=feature_cache["multigpu_buffer"],
354
+ track_in_reverse=reverse,
355
+ # also get the SAM2 backbone features
356
+ return_tracker_backbone_feats=True,
357
+ # run NMS as a part of distributed computation
358
+ run_nms=self.det_nms_thresh > 0.0,
359
+ nms_prob_thresh=self.score_threshold_detection,
360
+ nms_iou_thresh=self.det_nms_thresh,
361
+ # pass max_frame_num_to_track to respect tracking limits
362
+ max_frame_num_to_track=max_frame_num_to_track,
363
+ propagate_in_video_start_frame_idx=start_frame_idx,
364
+ )
365
+ # note: detections in `sam3_image_out` has already gone through NMS
366
+ pred_probs = sam3_image_out["pred_logits"].squeeze(-1).sigmoid()
367
+ if not allow_new_detections:
368
+ pred_probs = pred_probs - 1e8 # make sure no detections are kept
369
+ pred_boxes_xyxy = sam3_image_out["pred_boxes_xyxy"]
370
+ pred_masks = sam3_image_out["pred_masks"]
371
+ # get the positive detection outputs above threshold
372
+ pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection)
373
+ det_out = {
374
+ "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]],
375
+ "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]],
376
+ "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]],
377
+ }
378
+
379
+ # Step 3: build SAM2 backbone features and store them in `feature_cache`
380
+ backbone_cache = {}
381
+ sam_mask_decoder = self.tracker.sam_mask_decoder
382
+ tracker_backbone_fpn = [
383
+ sam_mask_decoder.conv_s0(sam3_image_out["tracker_backbone_fpn_0"]),
384
+ sam_mask_decoder.conv_s1(sam3_image_out["tracker_backbone_fpn_1"]),
385
+ sam3_image_out["tracker_backbone_fpn_2"], # fpn_2 doesn't need conv
386
+ ]
387
+ tracker_backbone_out = {
388
+ "vision_features": tracker_backbone_fpn[-1], # top-level feature
389
+ "vision_pos_enc": sam3_image_out["tracker_backbone_pos_enc"],
390
+ "backbone_fpn": tracker_backbone_fpn,
391
+ }
392
+ backbone_cache["tracker_backbone_out"] = tracker_backbone_out
393
+ feature_cache[frame_idx] = (
394
+ input_batch.img_batch[frame_idx],
395
+ backbone_cache,
396
+ )
397
+ # remove from `feature_cache` old features to save GPU memory
398
+ feature_cache.pop(frame_idx - 1 if not reverse else frame_idx + 1, None)
399
+ return det_out
400
+
401
+ def run_tracker_propagation(
402
+ self,
403
+ frame_idx: int,
404
+ num_frames: int,
405
+ reverse: bool,
406
+ tracker_states_local: List[Any],
407
+ tracker_metadata_prev: Dict[str, npt.NDArray],
408
+ ):
409
+ # Step 1: propagate the local SAM2 states to get the current frame's prediction
410
+ # `low_res_masks_local` of the existing masklets on this GPU
411
+ # - obj_ids_local: List[int] -- list of object IDs
412
+ # - low_res_masks_local: Tensor -- (num_local_obj, H_mask, W_mask)
413
+ obj_ids_local, low_res_masks_local, obj_scores_local = (
414
+ self._propogate_tracker_one_frame_local_gpu(
415
+ tracker_states_local, frame_idx=frame_idx, reverse=reverse
416
+ )
417
+ )
418
+
419
+ assert np.all(
420
+ obj_ids_local == tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
421
+ ), "{} != {}".format(
422
+ obj_ids_local, tracker_metadata_prev["obj_ids_per_gpu"][self.rank]
423
+ )
424
+
425
+ # Step 2: all-gather `low_res_masks_local` into `low_res_masks_global`
426
+ # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
427
+ _, H_mask, W_mask = low_res_masks_local.shape
428
+ if self.world_size > 1:
429
+ # `low_res_masks_local` and `obj_scores_local` need to be contiguous and float32
430
+ # (they could be non-contiguous due to slicing and/or bfloat16 due to autocast)
431
+ low_res_masks_local = low_res_masks_local.float().contiguous()
432
+ obj_scores_local = obj_scores_local.float().contiguous()
433
+ num_obj_this_gpu = tracker_metadata_prev["num_obj_per_gpu"][self.rank]
434
+ assert low_res_masks_local.size(0) == num_obj_this_gpu
435
+ assert obj_scores_local.size(0) == num_obj_this_gpu
436
+ low_res_masks_peers = [
437
+ low_res_masks_local.new_empty(num_obj, H_mask, W_mask)
438
+ for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
439
+ ]
440
+ obj_scores_peers = [
441
+ obj_scores_local.new_empty(num_obj)
442
+ for num_obj in tracker_metadata_prev["num_obj_per_gpu"]
443
+ ]
444
+ dist.all_gather(low_res_masks_peers, low_res_masks_local)
445
+ dist.all_gather(obj_scores_peers, obj_scores_local)
446
+ low_res_masks_global = torch.cat(low_res_masks_peers, dim=0)
447
+ obj_scores_global = torch.cat(obj_scores_peers, dim=0)
448
+ else:
449
+ low_res_masks_global = low_res_masks_local
450
+ obj_scores_global = obj_scores_local
451
+ return low_res_masks_global, obj_scores_global
452
+
453
+ def _recondition_masklets(
454
+ self,
455
+ frame_idx,
456
+ det_out: Dict[str, Tensor],
457
+ trk_id_to_max_iou_high_conf_det: List[int],
458
+ tracker_states_local: List[Any],
459
+ tracker_metadata: Dict[str, npt.NDArray],
460
+ tracker_obj_scores_global: Tensor,
461
+ ):
462
+ # Recondition the masklets based on the new detections
463
+ for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
464
+ new_mask = det_out["mask"][det_idx : det_idx + 1]
465
+ input_mask_res = self.tracker.input_mask_size
466
+ new_mask_binary = (
467
+ F.interpolate(
468
+ new_mask.unsqueeze(1),
469
+ size=(input_mask_res, input_mask_res),
470
+ mode="bilinear",
471
+ align_corners=False,
472
+ ).squeeze(1)[0]
473
+ > 0
474
+ )
475
+ HIGH_CONF_THRESH = 0.8
476
+ reconditioned_states_idx = set()
477
+ obj_idx = np.where(tracker_metadata["obj_ids_all_gpu"] == trk_obj_id)[
478
+ 0
479
+ ].item()
480
+ obj_score = tracker_obj_scores_global[obj_idx]
481
+ for state_idx, inference_state in enumerate(tracker_states_local):
482
+ if (
483
+ trk_obj_id in inference_state["obj_ids"]
484
+ # NOTE: Goal of this condition is to avoid reconditioning masks that are occluded/low qualiy.
485
+ # Unfortunately, these can get reconditioned anyway due to batching. We should consider removing these heuristics.
486
+ and obj_score > HIGH_CONF_THRESH
487
+ ):
488
+ logger.debug(
489
+ f"Adding new mask for track {trk_obj_id} at frame {frame_idx}. Objects {inference_state['obj_ids']} are all reconditioned."
490
+ )
491
+ self.tracker.add_new_mask(
492
+ inference_state=inference_state,
493
+ frame_idx=frame_idx,
494
+ obj_id=trk_obj_id,
495
+ mask=new_mask_binary,
496
+ )
497
+ reconditioned_states_idx.add(state_idx)
498
+
499
+ for idx in reconditioned_states_idx:
500
+ self.tracker.propagate_in_video_preflight(
501
+ tracker_states_local[idx], run_mem_encoder=True
502
+ )
503
+ return tracker_states_local
504
+
505
+ def run_tracker_update_planning_phase(
506
+ self,
507
+ frame_idx: int,
508
+ num_frames: int,
509
+ reverse: bool,
510
+ det_out: Dict[str, Tensor],
511
+ tracker_low_res_masks_global: Tensor,
512
+ tracker_obj_scores_global: Tensor,
513
+ tracker_metadata_prev: Dict[str, npt.NDArray],
514
+ tracker_states_local: List[Any],
515
+ is_image_only: bool = False,
516
+ ):
517
+ # initialize new metadata from previous metadata (its values will be updated later)
518
+ tracker_metadata_new = {
519
+ "obj_ids_per_gpu": deepcopy(tracker_metadata_prev["obj_ids_per_gpu"]),
520
+ "obj_ids_all_gpu": None, # will be filled later
521
+ "num_obj_per_gpu": deepcopy(tracker_metadata_prev["num_obj_per_gpu"]),
522
+ "obj_id_to_score": deepcopy(tracker_metadata_prev["obj_id_to_score"]),
523
+ "obj_id_to_tracker_score_frame_wise": deepcopy(
524
+ tracker_metadata_prev["obj_id_to_tracker_score_frame_wise"]
525
+ ),
526
+ "obj_id_to_last_occluded": {}, # will be filled later
527
+ "max_obj_id": deepcopy(tracker_metadata_prev["max_obj_id"]),
528
+ }
529
+
530
+ # Initialize reconditioned_obj_ids early to avoid UnboundLocalError
531
+ reconditioned_obj_ids = set()
532
+
533
+ # Step 1: make the update plan and resolve heuristics on GPU 0
534
+ det_mask_preds: Tensor = det_out["mask"] # low-res mask logits
535
+ det_scores_np: npt.NDArray = det_out["scores"].float().cpu().numpy()
536
+ det_bbox_xyxy: Tensor = det_out["bbox"]
537
+ if self.rank == 0:
538
+ # a) match detector and tracker masks and find new objects
539
+ (
540
+ new_det_fa_inds,
541
+ unmatched_trk_obj_ids,
542
+ det_to_matched_trk_obj_ids,
543
+ trk_id_to_max_iou_high_conf_det,
544
+ empty_trk_obj_ids,
545
+ ) = self._associate_det_trk(
546
+ det_masks=det_mask_preds,
547
+ det_scores_np=det_scores_np,
548
+ trk_masks=tracker_low_res_masks_global,
549
+ trk_obj_ids=tracker_metadata_prev["obj_ids_all_gpu"],
550
+ )
551
+ if self.suppress_det_close_to_boundary:
552
+ keep = self._suppress_detections_close_to_boundary(
553
+ det_bbox_xyxy[new_det_fa_inds]
554
+ )
555
+ new_det_fa_inds = new_det_fa_inds[keep.cpu().numpy()]
556
+
557
+ # check whether we've hit the maximum number of objects we can track (and if so, drop some detections)
558
+ prev_obj_num = np.sum(tracker_metadata_prev["num_obj_per_gpu"])
559
+ new_det_num = len(new_det_fa_inds)
560
+ num_obj_dropped_due_to_limit = 0
561
+ if not is_image_only and prev_obj_num + new_det_num > self.max_num_objects:
562
+ logger.warning(
563
+ f"hitting {self.max_num_objects=} with {new_det_num=} and {prev_obj_num=}"
564
+ )
565
+ new_det_num_to_keep = self.max_num_objects - prev_obj_num
566
+ num_obj_dropped_due_to_limit = new_det_num - new_det_num_to_keep
567
+ new_det_fa_inds = self._drop_new_det_with_obj_limit(
568
+ new_det_fa_inds, det_scores_np, new_det_num_to_keep
569
+ )
570
+ assert len(new_det_fa_inds) == new_det_num_to_keep
571
+ new_det_num = len(new_det_fa_inds)
572
+
573
+ # assign object IDs to new detections and decide which GPU to place them
574
+ new_det_start_obj_id = tracker_metadata_prev["max_obj_id"] + 1
575
+ new_det_obj_ids = new_det_start_obj_id + np.arange(new_det_num)
576
+ prev_workload_per_gpu = tracker_metadata_prev["num_obj_per_gpu"]
577
+ new_det_gpu_ids = self._assign_new_det_to_gpus(
578
+ new_det_num=new_det_num,
579
+ prev_workload_per_gpu=prev_workload_per_gpu,
580
+ )
581
+
582
+ # b) handle hotstart heuristics to remove objects
583
+ # here `rank0_metadata` contains metadata stored on (and only accessible to) GPU 0;
584
+ # we avoid broadcasting them to other GPUs to save communication cost, assuming
585
+ # that `rank0_metadata` is not needed by other GPUs
586
+ rank0_metadata_new = deepcopy(tracker_metadata_prev["rank0_metadata"])
587
+ if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
588
+ obj_ids_newly_removed, rank0_metadata_new = self._process_hotstart(
589
+ frame_idx=frame_idx,
590
+ num_frames=num_frames,
591
+ reverse=reverse,
592
+ det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
593
+ new_det_obj_ids=new_det_obj_ids,
594
+ empty_trk_obj_ids=empty_trk_obj_ids,
595
+ unmatched_trk_obj_ids=unmatched_trk_obj_ids,
596
+ rank0_metadata=rank0_metadata_new,
597
+ tracker_metadata=tracker_metadata_prev,
598
+ )
599
+ else:
600
+ # if warm-up is not complete, we don't remove any objects
601
+ obj_ids_newly_removed = set()
602
+ tracker_metadata_new["rank0_metadata"] = rank0_metadata_new
603
+
604
+ # Step 2: broadcast the update plan to other GPUs
605
+ NUM_BROADCAST_ITEMS = 9
606
+ if self.rank == 0 and self.world_size > 1:
607
+ # `num_obj_per_gpu_on_rank0` is used for metadata consistency check on other GPUs
608
+ # (it's a small array with length==self.world_size, so broadcasting it is cheap)
609
+ num_obj_per_gpu_on_rank0 = tracker_metadata_prev["num_obj_per_gpu"]
610
+ update_plan = [
611
+ new_det_fa_inds,
612
+ new_det_obj_ids,
613
+ new_det_gpu_ids,
614
+ num_obj_per_gpu_on_rank0,
615
+ unmatched_trk_obj_ids,
616
+ det_to_matched_trk_obj_ids,
617
+ obj_ids_newly_removed,
618
+ num_obj_dropped_due_to_limit,
619
+ trk_id_to_max_iou_high_conf_det,
620
+ ]
621
+ assert (
622
+ len(update_plan) == NUM_BROADCAST_ITEMS
623
+ ), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
624
+ self.broadcast_python_obj_cpu(update_plan, src=0)
625
+ elif self.rank > 0 and self.world_size > 1:
626
+ update_plan = [
627
+ None
628
+ ] * NUM_BROADCAST_ITEMS # other ranks receive the plan from rank 0
629
+ self.broadcast_python_obj_cpu(update_plan, src=0)
630
+ (
631
+ new_det_fa_inds,
632
+ new_det_obj_ids,
633
+ new_det_gpu_ids,
634
+ num_obj_per_gpu_on_rank0,
635
+ unmatched_trk_obj_ids,
636
+ det_to_matched_trk_obj_ids,
637
+ obj_ids_newly_removed,
638
+ num_obj_dropped_due_to_limit,
639
+ trk_id_to_max_iou_high_conf_det,
640
+ ) = update_plan
641
+ # metadata consistency check: verify that the received `num_obj_per_gpu_on_rank0` is consistent with the local metadata
642
+ # it's critical that all GPUs agree on the previous number of objects (otherwise the inference might hang or fail silently)
643
+ if not np.all(
644
+ num_obj_per_gpu_on_rank0 == tracker_metadata_prev["num_obj_per_gpu"]
645
+ ):
646
+ raise RuntimeError(
647
+ f"{self.rank=} received {num_obj_per_gpu_on_rank0=}, which is inconsistent with local record "
648
+ f"{tracker_metadata_prev['num_obj_per_gpu']=}. There's likely a bug in update planning or execution."
649
+ )
650
+
651
+ # `tracker_update_plan` should be identical on all GPUs after broadcasting
652
+ tracker_update_plan = {
653
+ "new_det_fa_inds": new_det_fa_inds, # npt.NDArray
654
+ "new_det_obj_ids": new_det_obj_ids, # npt.NDArray
655
+ "new_det_gpu_ids": new_det_gpu_ids, # npt.NDArray
656
+ "unmatched_trk_obj_ids": unmatched_trk_obj_ids, # npt.NDArray
657
+ "det_to_matched_trk_obj_ids": det_to_matched_trk_obj_ids, # dict
658
+ "obj_ids_newly_removed": obj_ids_newly_removed, # set
659
+ "num_obj_dropped_due_to_limit": num_obj_dropped_due_to_limit, # int
660
+ "trk_id_to_max_iou_high_conf_det": trk_id_to_max_iou_high_conf_det, # dict
661
+ "reconditioned_obj_ids": reconditioned_obj_ids, # set
662
+ }
663
+
664
+ # Step 3 (optional): recondition masklets based on high-confidence detections before memory encoding
665
+ # NOTE: Running this in execution phase (after memory encoding) can lead to suboptimal results
666
+ should_recondition_iou = False
667
+
668
+ # Evaluate tracklets for reconditioning based on bbox IoU mismatch with detections
669
+ if (
670
+ self.reconstruction_bbox_iou_thresh > 0
671
+ and len(trk_id_to_max_iou_high_conf_det) > 0
672
+ ):
673
+ for trk_obj_id, det_idx in trk_id_to_max_iou_high_conf_det.items():
674
+ det_box = det_out["bbox"][det_idx]
675
+ det_score = det_out["scores"][det_idx]
676
+
677
+ try:
678
+ trk_idx = list(tracker_metadata_prev["obj_ids_all_gpu"]).index(
679
+ trk_obj_id
680
+ )
681
+ except ValueError:
682
+ continue # Skip if tracklet not found
683
+
684
+ tracker_mask = tracker_low_res_masks_global[trk_idx]
685
+ mask_binary = tracker_mask > 0
686
+ mask_area = mask_binary.sum().item()
687
+
688
+ if mask_area == 0:
689
+ continue # Skip tracklets with zero mask area
690
+
691
+ # Get bounding box from SAM2 mask and convert to normalized coordinates
692
+ tracker_box_pixels = (
693
+ mask_to_box(mask_binary.unsqueeze(0).unsqueeze(0))
694
+ .squeeze(0)
695
+ .squeeze(0)
696
+ )
697
+ mask_height, mask_width = tracker_mask.shape[-2:]
698
+ tracker_box_normalized = torch.tensor(
699
+ [
700
+ tracker_box_pixels[0] / mask_width,
701
+ tracker_box_pixels[1] / mask_height,
702
+ tracker_box_pixels[2] / mask_width,
703
+ tracker_box_pixels[3] / mask_height,
704
+ ],
705
+ device=tracker_box_pixels.device,
706
+ )
707
+
708
+ # Compute IoU between detection and SAM2 tracklet bounding boxes
709
+ det_box_batch = det_box.unsqueeze(0)
710
+ tracker_box_batch = tracker_box_normalized.unsqueeze(0)
711
+ iou = fast_diag_box_iou(det_box_batch, tracker_box_batch)[0]
712
+
713
+ if (
714
+ iou < self.reconstruction_bbox_iou_thresh
715
+ and det_score >= self.reconstruction_bbox_det_score
716
+ ):
717
+ should_recondition_iou = True
718
+ reconditioned_obj_ids.add(trk_obj_id)
719
+
720
+ should_recondition_periodic = (
721
+ self.recondition_every_nth_frame > 0
722
+ and frame_idx % self.recondition_every_nth_frame == 0
723
+ and len(trk_id_to_max_iou_high_conf_det) > 0
724
+ )
725
+
726
+ # Recondition if periodic or IoU condition met
727
+ if should_recondition_periodic or should_recondition_iou:
728
+ self._recondition_masklets(
729
+ frame_idx,
730
+ det_out,
731
+ trk_id_to_max_iou_high_conf_det,
732
+ tracker_states_local,
733
+ tracker_metadata_prev,
734
+ tracker_obj_scores_global,
735
+ )
736
+
737
+ # Step 4: Run SAM2 memory encoder on the current frame's prediction masks
738
+ # This is done on all GPUs
739
+ batch_size = tracker_low_res_masks_global.size(0)
740
+ if batch_size > 0:
741
+ if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
742
+ if self.suppress_overlapping_based_on_recent_occlusion_threshold > 0.0:
743
+ # NOTE: tracker_low_res_masks_global is updated in-place then returned
744
+ tracker_low_res_masks_global = (
745
+ self._suppress_overlapping_based_on_recent_occlusion(
746
+ frame_idx,
747
+ tracker_low_res_masks_global,
748
+ tracker_metadata_prev,
749
+ tracker_metadata_new,
750
+ obj_ids_newly_removed,
751
+ reverse,
752
+ )
753
+ )
754
+
755
+ self._tracker_update_memories(
756
+ tracker_states_local,
757
+ frame_idx,
758
+ tracker_metadata=tracker_metadata_prev,
759
+ low_res_masks=tracker_low_res_masks_global,
760
+ )
761
+
762
+ # Step 4: update the SAM2 metadata based on the update plan
763
+ # note: except for "rank0_metadata" (that is only available on GPU 0),
764
+ # the updated `tracker_metadata_new` should be identical on all GPUs
765
+ for rank in range(self.world_size):
766
+ new_det_obj_ids_this_gpu = new_det_obj_ids[new_det_gpu_ids == rank]
767
+ updated_obj_ids_this_gpu = tracker_metadata_new["obj_ids_per_gpu"][rank]
768
+ if len(new_det_obj_ids_this_gpu) > 0:
769
+ updated_obj_ids_this_gpu = np.concatenate(
770
+ [updated_obj_ids_this_gpu, new_det_obj_ids_this_gpu]
771
+ )
772
+ if len(obj_ids_newly_removed) > 0:
773
+ is_removed = np.isin(
774
+ updated_obj_ids_this_gpu, list(obj_ids_newly_removed)
775
+ )
776
+ updated_obj_ids_this_gpu = updated_obj_ids_this_gpu[~is_removed]
777
+ tracker_metadata_new["obj_ids_per_gpu"][rank] = updated_obj_ids_this_gpu
778
+ tracker_metadata_new["num_obj_per_gpu"][rank] = len(
779
+ updated_obj_ids_this_gpu
780
+ )
781
+ tracker_metadata_new["obj_ids_all_gpu"] = np.concatenate(
782
+ tracker_metadata_new["obj_ids_per_gpu"]
783
+ )
784
+ # update object scores and the maximum object ID assigned so far
785
+ if len(new_det_obj_ids) > 0:
786
+ tracker_metadata_new["obj_id_to_score"].update(
787
+ zip(new_det_obj_ids, det_scores_np[new_det_fa_inds])
788
+ )
789
+ # tracker scores are not available for new objects, use det score instead.
790
+ tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][
791
+ frame_idx
792
+ ].update(zip(new_det_obj_ids, det_scores_np[new_det_fa_inds]))
793
+ tracker_metadata_new["max_obj_id"] = max(
794
+ tracker_metadata_new["max_obj_id"],
795
+ np.max(new_det_obj_ids),
796
+ )
797
+ # for removed objects, we set their scores to a very low value (-1e4) but still
798
+ # keep them in "obj_id_to_score" (it's easier to handle outputs this way)
799
+ for obj_id in obj_ids_newly_removed:
800
+ tracker_metadata_new["obj_id_to_score"][obj_id] = -1e4
801
+ tracker_metadata_new["obj_id_to_tracker_score_frame_wise"][frame_idx][
802
+ obj_id
803
+ ] = -1e4
804
+ tracker_metadata_new["obj_id_to_last_occluded"].pop(obj_id, None)
805
+ # check that "rank0_metadata" is in tracker_metadata_new if and only if it's GPU 0
806
+ assert ("rank0_metadata" in tracker_metadata_new) == (self.rank == 0)
807
+ if self.rank == 0 and self.masklet_confirmation_enable:
808
+ rank0_metadata = self.update_masklet_confirmation_status(
809
+ rank0_metadata=tracker_metadata_new["rank0_metadata"],
810
+ obj_ids_all_gpu_prev=tracker_metadata_prev["obj_ids_all_gpu"],
811
+ obj_ids_all_gpu_updated=tracker_metadata_new["obj_ids_all_gpu"],
812
+ det_to_matched_trk_obj_ids=det_to_matched_trk_obj_ids,
813
+ new_det_obj_ids=new_det_obj_ids,
814
+ )
815
+ tracker_metadata_new["rank0_metadata"] = rank0_metadata
816
+
817
+ return tracker_update_plan, tracker_metadata_new
818
+
819
+ def _suppress_overlapping_based_on_recent_occlusion(
820
+ self,
821
+ frame_idx: int,
822
+ tracker_low_res_masks_global: Tensor,
823
+ tracker_metadata_prev: Dict[str, Any],
824
+ tracker_metadata_new: Dict[str, Any],
825
+ obj_ids_newly_removed: Set[int],
826
+ reverse: bool = False,
827
+ ):
828
+ """
829
+ Suppress overlapping masks based on the most recent occlusion information. If an object is removed by hotstart, we always suppress it if it overlaps with any other object.
830
+ Args:
831
+ frame_idx (int): The current frame index.
832
+ tracker_low_res_masks_global (Tensor): The low-resolution masks for the current frame.
833
+ tracker_metadata_prev (Dict[str, Any]): The metadata from the previous frame.
834
+ tracker_metadata_new (Dict[str, Any]): The metadata for the current frame.
835
+ obj_ids_newly_removed (Set[int]): The object IDs that have been removed.
836
+ Return:
837
+ Tensor: The updated low-resolution masks with some objects suppressed.
838
+ """
839
+ obj_ids_global = tracker_metadata_prev["obj_ids_all_gpu"]
840
+ binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
841
+ batch_size = tracker_low_res_masks_global.size(0)
842
+ if batch_size > 0:
843
+ assert (
844
+ len(obj_ids_global) == batch_size
845
+ ), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
846
+ NEVER_OCCLUDED = -1
847
+ ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
848
+ last_occluded_prev = torch.cat(
849
+ [
850
+ tracker_metadata_prev["obj_id_to_last_occluded"].get(
851
+ obj_id,
852
+ torch.full(
853
+ (1,),
854
+ fill_value=(
855
+ NEVER_OCCLUDED
856
+ if obj_id not in obj_ids_newly_removed
857
+ else ALWAYS_OCCLUDED
858
+ ),
859
+ device=binary_tracker_low_res_masks_global.device,
860
+ dtype=torch.long,
861
+ ),
862
+ )
863
+ for obj_id in obj_ids_global
864
+ ],
865
+ dim=0,
866
+ )
867
+ to_suppress = self._get_objects_to_suppress_based_on_most_recently_occluded(
868
+ binary_tracker_low_res_masks_global,
869
+ last_occluded_prev,
870
+ obj_ids_global,
871
+ frame_idx,
872
+ reverse,
873
+ )
874
+
875
+ # Update metadata with occlusion information
876
+ is_obj_occluded = ~(binary_tracker_low_res_masks_global.any(dim=(-1, -2)))
877
+ is_obj_occluded_or_suppressed = is_obj_occluded | to_suppress
878
+ last_occluded_new = last_occluded_prev.clone()
879
+ last_occluded_new[is_obj_occluded_or_suppressed] = frame_idx
880
+ # Slice out the last occluded frame for each object
881
+ tracker_metadata_new["obj_id_to_last_occluded"] = {
882
+ obj_id: last_occluded_new[obj_idx : obj_idx + 1]
883
+ for obj_idx, obj_id in enumerate(obj_ids_global)
884
+ }
885
+
886
+ # Zero out suppressed masks before memory encoding
887
+ NO_OBJ_LOGIT = -10
888
+ tracker_low_res_masks_global[to_suppress] = NO_OBJ_LOGIT
889
+
890
+ return tracker_low_res_masks_global
891
+
892
+ def run_tracker_update_execution_phase(
893
+ self,
894
+ frame_idx: int,
895
+ num_frames: int,
896
+ reverse: bool,
897
+ det_out: Dict[str, Tensor],
898
+ tracker_states_local: List[Any],
899
+ tracker_update_plan: Dict[str, npt.NDArray],
900
+ orig_vid_height: int,
901
+ orig_vid_width: int,
902
+ feature_cache: Dict,
903
+ ):
904
+ # initialize tracking scores with detection scores
905
+ new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
906
+ new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
907
+ new_det_gpu_ids: npt.NDArray = tracker_update_plan["new_det_gpu_ids"]
908
+ is_on_this_gpu: npt.NDArray = new_det_gpu_ids == self.rank
909
+ new_det_obj_ids_local: npt.NDArray = new_det_obj_ids[is_on_this_gpu]
910
+ new_det_fa_inds_local: npt.NDArray = new_det_fa_inds[is_on_this_gpu]
911
+ obj_ids_newly_removed: Set[int] = tracker_update_plan["obj_ids_newly_removed"]
912
+
913
+ # Step 1: add new objects from the detector to SAM2 inference states
914
+ if len(new_det_fa_inds_local) > 0:
915
+ new_det_fa_inds_local_t = torch.from_numpy(new_det_fa_inds_local)
916
+ new_det_masks: Tensor = det_out["mask"][new_det_fa_inds_local_t]
917
+ # initialize SAM2 with new object masks
918
+ tracker_states_local = self._tracker_add_new_objects(
919
+ frame_idx=frame_idx,
920
+ num_frames=num_frames,
921
+ new_obj_ids=new_det_obj_ids_local,
922
+ new_obj_masks=new_det_masks,
923
+ tracker_states_local=tracker_states_local,
924
+ orig_vid_height=orig_vid_height,
925
+ orig_vid_width=orig_vid_width,
926
+ feature_cache=feature_cache,
927
+ )
928
+
929
+ # Step 2: remove from SAM2 inference states those objects removed by heuristics
930
+ if len(obj_ids_newly_removed) > 0:
931
+ self._tracker_remove_objects(tracker_states_local, obj_ids_newly_removed)
932
+
933
+ return tracker_states_local
934
+
935
+ def build_outputs(
936
+ self,
937
+ frame_idx: int,
938
+ num_frames: int,
939
+ reverse: bool,
940
+ det_out: Dict[str, Tensor],
941
+ tracker_low_res_masks_global: Tensor,
942
+ tracker_obj_scores_global: Tensor,
943
+ tracker_metadata_prev: Dict[str, npt.NDArray],
944
+ tracker_update_plan: Dict[str, npt.NDArray],
945
+ orig_vid_height: int,
946
+ orig_vid_width: int,
947
+ reconditioned_obj_ids: set = None,
948
+ det_to_matched_trk_obj_ids: dict = None,
949
+ ):
950
+ new_det_fa_inds: npt.NDArray = tracker_update_plan["new_det_fa_inds"]
951
+ new_det_obj_ids: npt.NDArray = tracker_update_plan["new_det_obj_ids"]
952
+ obj_id_to_mask = {} # obj_id --> output mask tensor
953
+
954
+ # Part 1: masks from previous SAM2 propagation
955
+ existing_masklet_obj_ids = tracker_metadata_prev["obj_ids_all_gpu"]
956
+ existing_masklet_video_res_masks = F.interpolate(
957
+ tracker_low_res_masks_global.unsqueeze(1),
958
+ size=(orig_vid_height, orig_vid_width),
959
+ mode="bilinear",
960
+ align_corners=False,
961
+ ) # (num_obj, 1, H_video, W_video)
962
+ existing_masklet_binary = existing_masklet_video_res_masks > 0
963
+ assert len(existing_masklet_obj_ids) == len(existing_masklet_binary)
964
+ for obj_id, mask in zip(existing_masklet_obj_ids, existing_masklet_binary):
965
+ obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
966
+
967
+ # Part 2: masks from new detections
968
+ new_det_fa_inds_t = torch.from_numpy(new_det_fa_inds)
969
+ new_det_low_res_masks = det_out["mask"][new_det_fa_inds_t].unsqueeze(1)
970
+ new_det_low_res_masks = fill_holes_in_mask_scores(
971
+ new_det_low_res_masks,
972
+ max_area=self.fill_hole_area,
973
+ fill_holes=True,
974
+ remove_sprinkles=True,
975
+ )
976
+ new_masklet_video_res_masks = F.interpolate(
977
+ new_det_low_res_masks,
978
+ size=(orig_vid_height, orig_vid_width),
979
+ mode="bilinear",
980
+ align_corners=False,
981
+ ) # (num_obj, 1, H_video, W_video)
982
+
983
+ new_masklet_binary = new_masklet_video_res_masks > 0
984
+ assert len(new_det_obj_ids) == len(new_masklet_video_res_masks)
985
+ for obj_id, mask in zip(new_det_obj_ids, new_masklet_binary):
986
+ obj_id_to_mask[obj_id] = mask # (1, H_video, W_video)
987
+
988
+ # Part 3: Override masks for reconditioned objects using detection masks
989
+ if reconditioned_obj_ids is not None and len(reconditioned_obj_ids) > 0:
990
+ trk_id_to_max_iou_high_conf_det = tracker_update_plan.get(
991
+ "trk_id_to_max_iou_high_conf_det", {}
992
+ )
993
+
994
+ for obj_id in reconditioned_obj_ids:
995
+ det_idx = trk_id_to_max_iou_high_conf_det.get(obj_id)
996
+
997
+ if det_idx is not None:
998
+ det_mask = det_out["mask"][det_idx]
999
+ det_mask = det_mask.unsqueeze(0).unsqueeze(0)
1000
+ det_mask_resized = (
1001
+ F.interpolate(
1002
+ det_mask.float(),
1003
+ size=(orig_vid_height, orig_vid_width),
1004
+ mode="bilinear",
1005
+ align_corners=False,
1006
+ )
1007
+ > 0
1008
+ )
1009
+
1010
+ det_mask_final = det_mask_resized.squeeze(0)
1011
+ obj_id_to_mask[obj_id] = det_mask_final
1012
+
1013
+ return obj_id_to_mask
1014
+
1015
+ def _get_objects_to_suppress_based_on_most_recently_occluded(
1016
+ self,
1017
+ binary_low_res_masks: Tensor,
1018
+ last_occluded: List[int],
1019
+ obj_ids: List[int],
1020
+ frame_idx: int = None,
1021
+ reverse: bool = False,
1022
+ ):
1023
+ # Suppress overlapping masks for objects that were most recently occluded
1024
+ assert (
1025
+ binary_low_res_masks.dtype == torch.bool
1026
+ ), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
1027
+ to_suppress = torch.zeros(
1028
+ binary_low_res_masks.size(0),
1029
+ device=binary_low_res_masks.device,
1030
+ dtype=torch.bool,
1031
+ )
1032
+ if len(obj_ids) <= 1:
1033
+ return to_suppress
1034
+
1035
+ iou = mask_iou(binary_low_res_masks, binary_low_res_masks) # [N,N]
1036
+
1037
+ # Create masks for upper triangular matrix (i < j) and IoU threshold
1038
+ mask_iou_thresh = (
1039
+ iou >= self.suppress_overlapping_based_on_recent_occlusion_threshold
1040
+ )
1041
+ overlapping_pairs = torch.triu(mask_iou_thresh, diagonal=1) # [N,N]
1042
+
1043
+ last_occ_expanded_i = last_occluded.unsqueeze(1) # (N, 1)
1044
+ last_occ_expanded_j = last_occluded.unsqueeze(0) # (1, N)
1045
+ # Suppress most recently occluded
1046
+ cmp_op = torch.gt if not reverse else torch.lt
1047
+ suppress_i_mask = (
1048
+ overlapping_pairs
1049
+ & cmp_op(
1050
+ last_occ_expanded_i, last_occ_expanded_j
1051
+ ) # (last_occ_expanded_i > last_occ_expanded_j)
1052
+ & (
1053
+ last_occ_expanded_j > -1
1054
+ ) # j can suppress i only if i was previously occluded
1055
+ )
1056
+ suppress_j_mask = (
1057
+ overlapping_pairs
1058
+ & cmp_op(last_occ_expanded_j, last_occ_expanded_i)
1059
+ & (
1060
+ last_occ_expanded_i > -1
1061
+ ) # i can suppress j only if j was previously occluded
1062
+ )
1063
+ # Apply suppression
1064
+ to_suppress = suppress_i_mask.any(dim=1) | suppress_j_mask.any(dim=0)
1065
+
1066
+ # Log for debugging
1067
+ if (
1068
+ self.rank == 0
1069
+ and logger.isEnabledFor(logging.DEBUG)
1070
+ and frame_idx is not None
1071
+ ):
1072
+ suppress_i_mask = suppress_i_mask.cpu().numpy()
1073
+ suppress_j_mask = suppress_j_mask.cpu().numpy()
1074
+ last_occluded = last_occluded.cpu().numpy()
1075
+
1076
+ # Find all suppression pairs without using torch.where
1077
+ batch_size = suppress_i_mask.shape[0]
1078
+
1079
+ # Log i-suppression cases (where i gets suppressed in favor of j)
1080
+ for i in range(batch_size):
1081
+ for j in range(batch_size):
1082
+ if suppress_i_mask[i, j]:
1083
+ logger.debug(
1084
+ f"{frame_idx=}: Suppressing obj {obj_ids[i]} last occluded {last_occluded[i]} in favor of {obj_ids[j]} last occluded {last_occluded[j]}"
1085
+ )
1086
+
1087
+ # Log j-suppression cases (where j gets suppressed in favor of i)
1088
+ for i in range(batch_size):
1089
+ for j in range(batch_size):
1090
+ if suppress_j_mask[i, j]:
1091
+ logger.debug(
1092
+ f"{frame_idx=}: Suppressing obj {obj_ids[j]} last occluded {last_occluded[j]} in favor of {obj_ids[i]} last occluded {last_occluded[i]}"
1093
+ )
1094
+
1095
+ return to_suppress
1096
+
1097
+ def _propogate_tracker_one_frame_local_gpu(
1098
+ self,
1099
+ inference_states: List[Any],
1100
+ frame_idx: int,
1101
+ reverse: bool,
1102
+ # by default, we disable memory encoding until we gather all outputs
1103
+ run_mem_encoder: bool = False,
1104
+ ):
1105
+ """
1106
+ inference_states: List of inference states, each state corresponds to a different set of objects.
1107
+ """
1108
+ obj_ids_local = []
1109
+ low_res_masks_list = []
1110
+ obj_scores_list = []
1111
+ for inference_state in inference_states:
1112
+ if len(inference_state["obj_ids"]) == 0:
1113
+ continue # skip propagation on empty inference states
1114
+
1115
+ # propagate one frame
1116
+ num_frames_propagated = 0
1117
+ for out in self.tracker.propagate_in_video(
1118
+ inference_state,
1119
+ start_frame_idx=frame_idx,
1120
+ # end_frame_idx = start_frame_idx + max_frame_num_to_track
1121
+ # (i.e. propagating 1 frame since end_frame_idx is inclusive)
1122
+ max_frame_num_to_track=0,
1123
+ reverse=reverse,
1124
+ tqdm_disable=True,
1125
+ run_mem_encoder=run_mem_encoder,
1126
+ ):
1127
+ out_frame_idx, out_obj_ids, out_low_res_masks, _, out_obj_scores = out
1128
+ num_frames_propagated += 1
1129
+
1130
+ # only 1 frames should be propagated
1131
+ assert (
1132
+ num_frames_propagated == 1 and out_frame_idx == frame_idx
1133
+ ), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
1134
+ assert isinstance(out_obj_ids, list)
1135
+ obj_ids_local.extend(out_obj_ids)
1136
+ low_res_masks_list.append(out_low_res_masks.squeeze(1))
1137
+ obj_scores_list.append(out_obj_scores.squeeze(1))
1138
+
1139
+ # concatenate the output masklets from all local inference states
1140
+ H_mask = W_mask = self.tracker.low_res_mask_size
1141
+ if len(low_res_masks_list) > 0:
1142
+ low_res_masks_local = torch.cat(low_res_masks_list, dim=0)
1143
+ obj_scores_local = torch.cat(obj_scores_list, dim=0)
1144
+ assert low_res_masks_local.shape[1:] == (H_mask, W_mask)
1145
+
1146
+ # Apply hole filling to the masks
1147
+ low_res_masks_local = fill_holes_in_mask_scores(
1148
+ low_res_masks_local.unsqueeze(1),
1149
+ max_area=self.fill_hole_area,
1150
+ fill_holes=True,
1151
+ remove_sprinkles=True,
1152
+ )
1153
+ low_res_masks_local = low_res_masks_local.squeeze(1)
1154
+ else:
1155
+ low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device)
1156
+ obj_scores_local = torch.zeros(0, device=self.device)
1157
+
1158
+ return obj_ids_local, low_res_masks_local, obj_scores_local
1159
+
1160
+ def _associate_det_trk(
1161
+ self,
1162
+ det_masks: Tensor,
1163
+ det_scores_np: npt.NDArray,
1164
+ trk_masks: Tensor,
1165
+ trk_obj_ids: npt.NDArray,
1166
+ ):
1167
+ """
1168
+ Match detections on the current frame with the existing masklets.
1169
+
1170
+ Args:
1171
+ - det_masks: (N, H, W) tensor of predicted masks
1172
+ - det_scores_np: (N,) array of detection scores
1173
+ - trk_masks: (M, H, W) tensor of track masks
1174
+ - trk_obj_ids: (M,) array of object IDs corresponding to trk_masks
1175
+
1176
+ Returns:
1177
+ - new_det_fa_inds: array of new object indices.
1178
+ - unmatched_trk_obj_ids: array of existing masklet object IDs that are not matched
1179
+ to any detections on this frame (for unmatched, we only count masklets with >0 area)
1180
+ - det_to_matched_trk_obj_ids: dict[int, npt.NDArray]: mapping from detector's detection indices
1181
+ to the list of matched tracklet object IDs
1182
+ - empty_trk_obj_ids: array of existing masklet object IDs with zero area in SAM2 prediction
1183
+ """
1184
+ iou_threshold = self.assoc_iou_thresh
1185
+ iou_threshold_trk = self.trk_assoc_iou_thresh
1186
+ new_det_thresh = self.new_det_thresh
1187
+
1188
+ assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
1189
+ assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
1190
+ assert (
1191
+ trk_masks.size(0) == len(trk_obj_ids)
1192
+ ), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
1193
+ if trk_masks.size(0) == 0:
1194
+ # all detections are new
1195
+ new_det_fa_inds = np.arange(det_masks.size(0))
1196
+ unmatched_trk_obj_ids = np.array([], np.int64)
1197
+ empty_trk_obj_ids = np.array([], np.int64)
1198
+ det_to_matched_trk_obj_ids = {}
1199
+ trk_id_to_max_iou_high_conf_det = {}
1200
+ return (
1201
+ new_det_fa_inds,
1202
+ unmatched_trk_obj_ids,
1203
+ det_to_matched_trk_obj_ids,
1204
+ trk_id_to_max_iou_high_conf_det,
1205
+ empty_trk_obj_ids,
1206
+ )
1207
+ elif det_masks.size(0) == 0:
1208
+ # all previous tracklets are unmatched if they have a non-zero area
1209
+ new_det_fa_inds = np.array([], np.int64)
1210
+ trk_is_nonempty = (trk_masks > 0).any(dim=(1, 2)).cpu().numpy()
1211
+ unmatched_trk_obj_ids = trk_obj_ids[trk_is_nonempty]
1212
+ empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
1213
+ det_to_matched_trk_obj_ids = {}
1214
+ trk_id_to_max_iou_high_conf_det = {}
1215
+ return (
1216
+ new_det_fa_inds,
1217
+ unmatched_trk_obj_ids,
1218
+ det_to_matched_trk_obj_ids,
1219
+ trk_id_to_max_iou_high_conf_det,
1220
+ empty_trk_obj_ids,
1221
+ )
1222
+
1223
+ if det_masks.shape[-2:] != trk_masks.shape[-2:]:
1224
+ # resize to the smaller size to save GPU memory
1225
+ if np.prod(det_masks.shape[-2:]) < np.prod(trk_masks.shape[-2:]):
1226
+ trk_masks = F.interpolate(
1227
+ trk_masks.unsqueeze(1),
1228
+ size=det_masks.shape[-2:],
1229
+ mode="bilinear",
1230
+ align_corners=False,
1231
+ ).squeeze(1)
1232
+ else:
1233
+ # resize detections to track size
1234
+ det_masks = F.interpolate(
1235
+ det_masks.unsqueeze(1),
1236
+ size=trk_masks.shape[-2:],
1237
+ mode="bilinear",
1238
+ align_corners=False,
1239
+ ).squeeze(1)
1240
+
1241
+ det_masks_binary = det_masks > 0
1242
+ trk_masks_binary = trk_masks > 0
1243
+ ious = mask_iou(det_masks_binary, trk_masks_binary) # (N, M)
1244
+
1245
+ ious_np = ious.cpu().numpy()
1246
+ if self.o2o_matching_masklets_enable:
1247
+ from scipy.optimize import linear_sum_assignment
1248
+
1249
+ # Hungarian matching for tracks (one-to-one: each track matches at most one detection)
1250
+ cost_matrix = 1 - ious_np # Hungarian solves for minimum cost
1251
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
1252
+ trk_is_matched = np.zeros(trk_masks.size(0), dtype=bool)
1253
+ for d, t in zip(row_ind, col_ind):
1254
+ if ious_np[d, t] >= iou_threshold_trk:
1255
+ trk_is_matched[t] = True
1256
+ else:
1257
+ trk_is_matched = (ious_np >= iou_threshold_trk).any(axis=0)
1258
+ # Non-empty tracks not matched by Hungarian assignment above threshold are unmatched
1259
+ trk_is_nonempty = trk_masks_binary.any(dim=(1, 2)).cpu().numpy()
1260
+ trk_is_unmatched = np.logical_and(trk_is_nonempty, ~trk_is_matched)
1261
+ unmatched_trk_obj_ids = trk_obj_ids[trk_is_unmatched]
1262
+ # also record masklets that have zero area in SAM 2 prediction
1263
+ empty_trk_obj_ids = trk_obj_ids[~trk_is_nonempty]
1264
+
1265
+ # For detections: allow many tracks to match to the same detection (many-to-one)
1266
+ # So, a detection is 'new' if it does not match any track above threshold
1267
+ is_new_det = np.logical_and(
1268
+ det_scores_np >= new_det_thresh,
1269
+ np.logical_not(np.any(ious_np >= iou_threshold, axis=1)),
1270
+ )
1271
+ new_det_fa_inds = np.nonzero(is_new_det)[0]
1272
+
1273
+ # for each detection, which tracks it matched to (above threshold)
1274
+ det_to_matched_trk_obj_ids = {}
1275
+ trk_id_to_max_iou_high_conf_det = {} # trk id --> exactly one detection idx
1276
+ HIGH_CONF_THRESH = 0.8
1277
+ HIGH_IOU_THRESH = 0.8
1278
+ det_to_max_iou_trk_idx = np.argmax(ious_np, axis=1)
1279
+ det_is_high_conf = (det_scores_np >= HIGH_CONF_THRESH) & ~is_new_det
1280
+ det_is_high_iou = np.max(ious_np, axis=1) >= HIGH_IOU_THRESH
1281
+ det_is_high_conf_and_iou = set(
1282
+ np.nonzero(det_is_high_conf & det_is_high_iou)[0]
1283
+ )
1284
+ for d in range(det_masks.size(0)):
1285
+ det_to_matched_trk_obj_ids[d] = trk_obj_ids[ious_np[d, :] >= iou_threshold]
1286
+ if d in det_is_high_conf_and_iou:
1287
+ trk_obj_id = trk_obj_ids[det_to_max_iou_trk_idx[d]].item()
1288
+ trk_id_to_max_iou_high_conf_det[trk_obj_id] = d
1289
+
1290
+ return (
1291
+ new_det_fa_inds,
1292
+ unmatched_trk_obj_ids,
1293
+ det_to_matched_trk_obj_ids,
1294
+ trk_id_to_max_iou_high_conf_det,
1295
+ empty_trk_obj_ids,
1296
+ )
1297
+
1298
+ def _assign_new_det_to_gpus(self, new_det_num, prev_workload_per_gpu):
1299
+ """Distribute the new objects to the GPUs with the least workload."""
1300
+ workload_per_gpu: npt.NDArray = prev_workload_per_gpu.copy()
1301
+ new_det_gpu_ids = np.zeros(new_det_num, np.int64)
1302
+
1303
+ # assign the objects one by one
1304
+ for i in range(len(new_det_gpu_ids)):
1305
+ # find the GPU with the least workload
1306
+ min_gpu = np.argmin(workload_per_gpu)
1307
+ new_det_gpu_ids[i] = min_gpu
1308
+ workload_per_gpu[min_gpu] += 1
1309
+ return new_det_gpu_ids
1310
+
1311
+ def _process_hotstart(
1312
+ self,
1313
+ frame_idx: int,
1314
+ num_frames: int,
1315
+ reverse: bool,
1316
+ det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
1317
+ new_det_obj_ids: npt.NDArray,
1318
+ empty_trk_obj_ids: npt.NDArray,
1319
+ unmatched_trk_obj_ids: npt.NDArray,
1320
+ rank0_metadata: Dict[str, Any],
1321
+ tracker_metadata: Dict[str, Any],
1322
+ ):
1323
+ """Handle hotstart heuristics to remove unmatched or duplicated objects."""
1324
+ # obj_id --> first frame index where the object was detected
1325
+ obj_first_frame_idx = rank0_metadata["obj_first_frame_idx"]
1326
+ # obj_id --> [mismatched frame indices]
1327
+ unmatched_frame_inds = rank0_metadata["unmatched_frame_inds"]
1328
+ trk_keep_alive = rank0_metadata["trk_keep_alive"]
1329
+ # (first_appear_obj_id, obj_id) --> [overlap frame indices]
1330
+ overlap_pair_to_frame_inds = rank0_metadata["overlap_pair_to_frame_inds"]
1331
+ # removed_obj_ids: object IDs that are suppressed via hot-start
1332
+ removed_obj_ids = rank0_metadata["removed_obj_ids"]
1333
+ suppressed_obj_ids = rank0_metadata["suppressed_obj_ids"][frame_idx]
1334
+
1335
+ obj_ids_newly_removed = set() # object IDs to be newly removed on this frame
1336
+ hotstart_diff = (
1337
+ frame_idx - self.hotstart_delay
1338
+ if not reverse
1339
+ else frame_idx + self.hotstart_delay
1340
+ )
1341
+
1342
+ # Step 1: log the frame index where each object ID first appears
1343
+ for obj_id in new_det_obj_ids:
1344
+ if obj_id not in obj_first_frame_idx:
1345
+ obj_first_frame_idx[obj_id] = frame_idx
1346
+ assert obj_id not in trk_keep_alive
1347
+ trk_keep_alive[obj_id] = self.init_trk_keep_alive
1348
+
1349
+ matched_trks = set()
1350
+ # We use the det-->tracks list to check for matched objects. Otherwise, we need to compute areas to decide whether they're occluded
1351
+ for matched_trks_per_det in det_to_matched_trk_obj_ids.values():
1352
+ matched_trks.update(matched_trks_per_det)
1353
+ for obj_id in matched_trks:
1354
+ # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the max value of trk_keep_alive
1355
+ trk_keep_alive[obj_id] = min(
1356
+ self.max_trk_keep_alive, trk_keep_alive[obj_id] + 1
1357
+ )
1358
+ for obj_id in unmatched_trk_obj_ids:
1359
+ unmatched_frame_inds[obj_id].append(frame_idx)
1360
+ # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
1361
+ # The max keep alive is 2x the min, means the model prefers to keep the prediction rather than suppress it if it was matched long enough.
1362
+ trk_keep_alive[obj_id] = max(
1363
+ self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
1364
+ )
1365
+ if self.decrease_trk_keep_alive_for_empty_masklets:
1366
+ for obj_id in empty_trk_obj_ids:
1367
+ # NOTE: To minimize number of configurable params, we use the hotstart_unmatch_thresh to set the min value of trk_keep_alive
1368
+ trk_keep_alive[obj_id] = max(
1369
+ self.min_trk_keep_alive, trk_keep_alive[obj_id] - 1
1370
+ )
1371
+
1372
+ # Step 2: removed tracks that has not matched with detections for `hotstart_unmatch_thresh` frames with hotstart period
1373
+ # a) add unmatched frame indices for each existing object ID
1374
+ # note that `unmatched_trk_obj_ids` contains those frames where the SAM2 output mask
1375
+ # doesn't match any detection; it excludes those frames where SAM2 gives an empty mask
1376
+ # b) remove a masklet if it first appears after `hotstart_diff` and is unmatched for more
1377
+ # than `self.hotstart_unmatch_thresh` frames
1378
+ for obj_id, frame_indices in unmatched_frame_inds.items():
1379
+ if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
1380
+ continue # skip if the object is already removed
1381
+ if len(frame_indices) >= self.hotstart_unmatch_thresh:
1382
+ is_within_hotstart = (
1383
+ obj_first_frame_idx[obj_id] > hotstart_diff and not reverse
1384
+ ) or (obj_first_frame_idx[obj_id] < hotstart_diff and reverse)
1385
+ if is_within_hotstart:
1386
+ obj_ids_newly_removed.add(obj_id)
1387
+ logger.debug(
1388
+ f"Removing object {obj_id} at frame {frame_idx} "
1389
+ f"since it is unmatched for frames: {frame_indices}"
1390
+ )
1391
+ if (
1392
+ trk_keep_alive[obj_id] <= 0 # Object has not been matched for too long
1393
+ and not self.suppress_unmatched_only_within_hotstart
1394
+ and obj_id not in removed_obj_ids
1395
+ and obj_id not in obj_ids_newly_removed
1396
+ ):
1397
+ logger.debug(
1398
+ f"Suppressing object {obj_id} at frame {frame_idx}, due to being unmatched"
1399
+ )
1400
+ suppressed_obj_ids.add(obj_id)
1401
+
1402
+ # Step 3: removed tracks that overlaps with another track for `hotstart_dup_thresh` frames
1403
+ # a) find overlaps tracks -- we consider overlap if they match to the same detection
1404
+ for _, matched_trk_obj_ids in det_to_matched_trk_obj_ids.items():
1405
+ if len(matched_trk_obj_ids) < 2:
1406
+ continue # only count detections that are matched to multiple (>=2) masklets
1407
+ # if there are multiple matched track ids, we need to find the one that appeared first;
1408
+ # these later appearing ids may be removed since they may be considered as duplicates
1409
+ first_appear_obj_id = (
1410
+ min(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
1411
+ if not reverse
1412
+ else max(matched_trk_obj_ids, key=lambda x: obj_first_frame_idx[x])
1413
+ )
1414
+ for obj_id in matched_trk_obj_ids:
1415
+ if obj_id != first_appear_obj_id:
1416
+ key = (first_appear_obj_id, obj_id)
1417
+ overlap_pair_to_frame_inds[key].append(frame_idx)
1418
+
1419
+ # b) remove a masklet if it first appears after `hotstart_diff` and it overlaps with another
1420
+ # masklet (that appears earlier) for more than `self.hotstart_dup_thresh` frames
1421
+ for (first_obj_id, obj_id), frame_indices in overlap_pair_to_frame_inds.items():
1422
+ if obj_id in removed_obj_ids or obj_id in obj_ids_newly_removed:
1423
+ continue # skip if the object is already removed
1424
+ if (obj_first_frame_idx[obj_id] > hotstart_diff and not reverse) or (
1425
+ obj_first_frame_idx[obj_id] < hotstart_diff and reverse
1426
+ ):
1427
+ if len(frame_indices) >= self.hotstart_dup_thresh:
1428
+ obj_ids_newly_removed.add(obj_id)
1429
+ logger.debug(
1430
+ f"Removing object {obj_id} at frame {frame_idx} "
1431
+ f"since it overlaps with another track {first_obj_id} at frames: {frame_indices}"
1432
+ )
1433
+
1434
+ removed_obj_ids.update(obj_ids_newly_removed)
1435
+ return obj_ids_newly_removed, rank0_metadata
1436
+
1437
+ def _tracker_update_memories(
1438
+ self,
1439
+ tracker_inference_states: List[Any],
1440
+ frame_idx: int,
1441
+ tracker_metadata: Dict[str, Any],
1442
+ low_res_masks: Tensor,
1443
+ ):
1444
+ """
1445
+ Run Sam2 memory encoder, enforcing non-overlapping constraints globally.
1446
+ """
1447
+ if len(tracker_inference_states) == 0:
1448
+ return
1449
+ # Avoid an extra interpolation step by directly interpolating to `interpol_size`
1450
+ high_res_H, high_res_W = (
1451
+ self.tracker.maskmem_backbone.mask_downsampler.interpol_size
1452
+ )
1453
+ # NOTE: inspect this part if we observe OOMs in the demo
1454
+ high_res_masks = F.interpolate(
1455
+ low_res_masks.unsqueeze(1),
1456
+ size=(high_res_H, high_res_W),
1457
+ mode="bilinear",
1458
+ align_corners=False,
1459
+ )
1460
+ # We first apply non-overlapping constraints before memory encoding. This may include some suppression heuristics.
1461
+ if not hasattr(self, "_warm_up_complete") or self._warm_up_complete:
1462
+ high_res_masks = self.tracker._suppress_object_pw_area_shrinkage(
1463
+ high_res_masks
1464
+ )
1465
+ # Instead of gathering the predicted object scores, we use mask areas as a proxy.
1466
+ object_score_logits = torch.where(
1467
+ (high_res_masks > 0).any(dim=(-1, -2)), 10.0, -10.0
1468
+ )
1469
+
1470
+ # Run the memory encoder on local slices for each GPU
1471
+ start_idx_gpu = sum(tracker_metadata["num_obj_per_gpu"][: self.rank])
1472
+ start_idx_state = start_idx_gpu
1473
+ for tracker_state in tracker_inference_states:
1474
+ num_obj_per_state = len(tracker_state["obj_ids"])
1475
+ if num_obj_per_state == 0:
1476
+ continue
1477
+ # Get the local high-res masks and object score logits for this inference state
1478
+ end_idx_state = start_idx_state + num_obj_per_state
1479
+ local_high_res_masks = high_res_masks[start_idx_state:end_idx_state]
1480
+ local_object_score_logits = object_score_logits[
1481
+ start_idx_state:end_idx_state
1482
+ ]
1483
+ local_batch_size = local_high_res_masks.size(0)
1484
+ # Run Sam2 memory encoder. Note that we do not re-enforce the non-overlapping constraint as it is turned off by default
1485
+
1486
+ encoded_mem = self.tracker._run_memory_encoder(
1487
+ tracker_state,
1488
+ frame_idx,
1489
+ local_batch_size,
1490
+ local_high_res_masks,
1491
+ local_object_score_logits,
1492
+ is_mask_from_pts=False,
1493
+ )
1494
+ local_maskmem_features, local_maskmem_pos_enc = encoded_mem
1495
+ # Store encoded memories in the local inference state
1496
+ output_dict = tracker_state["output_dict"]
1497
+ for storage_key in ["cond_frame_outputs", "non_cond_frame_outputs"]:
1498
+ if frame_idx not in output_dict[storage_key]:
1499
+ continue
1500
+ output_dict[storage_key][frame_idx]["maskmem_features"] = (
1501
+ local_maskmem_features
1502
+ )
1503
+ output_dict[storage_key][frame_idx]["maskmem_pos_enc"] = [
1504
+ pos for pos in local_maskmem_pos_enc
1505
+ ]
1506
+ # for batched inference state, we also need to add per-object
1507
+ # memory slides to support instance interactivity
1508
+ self.tracker._add_output_per_object(
1509
+ inference_state=tracker_state,
1510
+ frame_idx=frame_idx,
1511
+ current_out=output_dict[storage_key][frame_idx],
1512
+ storage_key=storage_key,
1513
+ )
1514
+ start_idx_state += num_obj_per_state
1515
+
1516
+ def _tracker_add_new_objects(
1517
+ self,
1518
+ frame_idx: int,
1519
+ num_frames: int,
1520
+ new_obj_ids: List[int],
1521
+ new_obj_masks: Tensor,
1522
+ tracker_states_local: List[Any],
1523
+ orig_vid_height: int,
1524
+ orig_vid_width: int,
1525
+ feature_cache: Dict,
1526
+ ):
1527
+ """Add a new object to SAM2 inference states."""
1528
+ prev_tracker_state = (
1529
+ tracker_states_local[0] if len(tracker_states_local) > 0 else None
1530
+ )
1531
+
1532
+ # prepare inference_state
1533
+ # batch objects that first appear on the same frame together
1534
+ # Clear inference state. Keep the cached image features if available.
1535
+ new_tracker_state = self.tracker.init_state(
1536
+ cached_features=feature_cache,
1537
+ video_height=orig_vid_height,
1538
+ video_width=orig_vid_width,
1539
+ num_frames=num_frames,
1540
+ )
1541
+ new_tracker_state["backbone_out"] = (
1542
+ prev_tracker_state.get("backbone_out", None)
1543
+ if prev_tracker_state is not None
1544
+ else None
1545
+ )
1546
+
1547
+ assert len(new_obj_ids) == new_obj_masks.size(0)
1548
+ assert new_obj_masks.is_floating_point()
1549
+ input_mask_res = self.tracker.input_mask_size
1550
+ new_obj_masks = F.interpolate(
1551
+ new_obj_masks.unsqueeze(1),
1552
+ size=(input_mask_res, input_mask_res),
1553
+ mode="bilinear",
1554
+ align_corners=False,
1555
+ ).squeeze(1)
1556
+ new_obj_masks = new_obj_masks > 0
1557
+
1558
+ # add object one by one
1559
+ for new_obj_id, new_mask in zip(new_obj_ids, new_obj_masks):
1560
+ self.tracker.add_new_mask(
1561
+ inference_state=new_tracker_state,
1562
+ frame_idx=frame_idx,
1563
+ obj_id=new_obj_id,
1564
+ mask=new_mask,
1565
+ add_mask_to_memory=True,
1566
+ )
1567
+ # NOTE: we skip enforcing the non-overlapping constraint **globally** when adding new objects.
1568
+ self.tracker.propagate_in_video_preflight(
1569
+ new_tracker_state, run_mem_encoder=True
1570
+ )
1571
+ tracker_states_local.append(new_tracker_state)
1572
+ return tracker_states_local
1573
+
1574
+ def _tracker_remove_object(self, tracker_states_local: List[Any], obj_id: int):
1575
+ """
1576
+ Remove an object from SAM2 inference states. This would remove the object from
1577
+ all frames in the video.
1578
+ """
1579
+ tracker_states_local_before_removal = tracker_states_local.copy()
1580
+ tracker_states_local.clear()
1581
+ for tracker_inference_state in tracker_states_local_before_removal:
1582
+ # we try to remove `obj_id` on every inference state with `strict=False`
1583
+ # it will not do anything if an inference state doesn't contain `obj_id`
1584
+ new_obj_ids, _ = self.tracker.remove_object(
1585
+ tracker_inference_state, obj_id, strict=False, need_output=False
1586
+ )
1587
+ # only keep an inference state if it's non-empty after object removal
1588
+ if len(new_obj_ids) > 0:
1589
+ tracker_states_local.append(tracker_inference_state)
1590
+
1591
+ def _tracker_remove_objects(
1592
+ self, tracker_states_local: List[Any], obj_ids: list[int]
1593
+ ):
1594
+ """
1595
+ Remove an object from SAM2 inference states. This would remove the object from
1596
+ all frames in the video.
1597
+ """
1598
+ for obj_id in obj_ids:
1599
+ self._tracker_remove_object(tracker_states_local, obj_id)
1600
+
1601
+ def _initialize_metadata(self):
1602
+ """Initialize metadata for the masklets."""
1603
+ tracker_metadata = {
1604
+ "obj_ids_per_gpu": [np.array([], np.int64) for _ in range(self.world_size)],
1605
+ "obj_ids_all_gpu": np.array([], np.int64),
1606
+ "num_obj_per_gpu": np.zeros(self.world_size, np.int64),
1607
+ "max_obj_id": -1,
1608
+ "obj_id_to_score": {},
1609
+ "obj_id_to_tracker_score_frame_wise": defaultdict(dict),
1610
+ "obj_id_to_last_occluded": {},
1611
+ }
1612
+ if self.rank == 0:
1613
+ # "rank0_metadata" contains metadata that is only stored on (and accessible to) GPU 0
1614
+ # - obj_first_frame_idx: obj_id --> first frame index where the object was detected
1615
+ # - unmatched_frame_inds: obj_id --> [mismatched frame indices]
1616
+ # - overlap_pair_to_frame_inds: (first_appear_obj_id, obj_id) --> [overlap frame indices]
1617
+ # - removed_obj_ids: object IDs that are suppressed via hot-start
1618
+ rank0_metadata = {
1619
+ "obj_first_frame_idx": {},
1620
+ "unmatched_frame_inds": defaultdict(list),
1621
+ "trk_keep_alive": defaultdict(
1622
+ int
1623
+ ), # This is used only for object suppression not for removal
1624
+ "overlap_pair_to_frame_inds": defaultdict(list),
1625
+ "removed_obj_ids": set(),
1626
+ "suppressed_obj_ids": defaultdict(
1627
+ set
1628
+ ), # frame_idx --> set of objects with suppressed outputs, but still continue to be tracked
1629
+ }
1630
+ if self.masklet_confirmation_enable:
1631
+ # all the following are npt.NDArray with the same shape as `obj_ids_all_gpu`
1632
+ rank0_metadata["masklet_confirmation"] = {
1633
+ # "status" is the confirmation status of each masklet (in `MaskletConfirmationStatus`)
1634
+ "status": np.array([], np.int64),
1635
+ # "consecutive_det_num" is the number of consecutive frames where the masklet is
1636
+ # detected by the detector (with a matched detection)
1637
+ "consecutive_det_num": np.array([], np.int64),
1638
+ }
1639
+ tracker_metadata["rank0_metadata"] = rank0_metadata
1640
+
1641
+ return tracker_metadata
1642
+
1643
+ def update_masklet_confirmation_status(
1644
+ self,
1645
+ rank0_metadata: Dict[str, Any],
1646
+ obj_ids_all_gpu_prev: npt.NDArray,
1647
+ obj_ids_all_gpu_updated: npt.NDArray,
1648
+ det_to_matched_trk_obj_ids: Dict[int, npt.NDArray],
1649
+ new_det_obj_ids: npt.NDArray,
1650
+ ):
1651
+ confirmation_data = rank0_metadata["masklet_confirmation"]
1652
+
1653
+ # a) first, expand "confirmation_data" to include new masklets added in this frame
1654
+ status_prev = confirmation_data["status"]
1655
+ consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
1656
+ assert (
1657
+ status_prev.shape == obj_ids_all_gpu_prev.shape
1658
+ ), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
1659
+
1660
+ obj_id_to_updated_idx = {
1661
+ obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
1662
+ }
1663
+ prev_elem_is_in_updated = np.isin(obj_ids_all_gpu_prev, obj_ids_all_gpu_updated)
1664
+ prev_elem_obj_ids_in_updated = obj_ids_all_gpu_prev[prev_elem_is_in_updated]
1665
+ prev_elem_inds_in_updated = np.array(
1666
+ [obj_id_to_updated_idx[obj_id] for obj_id in prev_elem_obj_ids_in_updated],
1667
+ dtype=np.int64,
1668
+ )
1669
+ # newly added masklets are initialized to "UNCONFIRMED" status
1670
+ unconfirmed_val = MaskletConfirmationStatus.UNCONFIRMED.value
1671
+ status = np.full_like(obj_ids_all_gpu_updated, fill_value=unconfirmed_val)
1672
+ status[prev_elem_inds_in_updated] = status_prev[prev_elem_is_in_updated]
1673
+ consecutive_det_num = np.zeros_like(obj_ids_all_gpu_updated)
1674
+ consecutive_det_num[prev_elem_inds_in_updated] = consecutive_det_num_prev[
1675
+ prev_elem_is_in_updated
1676
+ ]
1677
+
1678
+ # b) update the confirmation status of all masklets based on the current frame
1679
+ # b.1) update "consecutive_det_num"
1680
+ # "is_matched": whether a masklet is matched to a detection on this frame
1681
+ is_matched = np.isin(obj_ids_all_gpu_updated, new_det_obj_ids)
1682
+ for matched_trk_obj_ids in det_to_matched_trk_obj_ids.values():
1683
+ is_matched |= np.isin(obj_ids_all_gpu_updated, matched_trk_obj_ids)
1684
+ consecutive_det_num = np.where(is_matched, consecutive_det_num + 1, 0)
1685
+
1686
+ # b.2) update "status"
1687
+ change_to_confirmed = (
1688
+ consecutive_det_num >= self.masklet_confirmation_consecutive_det_thresh
1689
+ )
1690
+ status[change_to_confirmed] = MaskletConfirmationStatus.CONFIRMED.value
1691
+
1692
+ confirmation_data["status"] = status
1693
+ confirmation_data["consecutive_det_num"] = consecutive_det_num
1694
+ return rank0_metadata
1695
+
1696
+ def forward(self, input: BatchedDatapoint, is_inference: bool = False):
1697
+ raise NotImplementedError("Evaluation outside demo is not implemented yet")
1698
+
1699
+ def _load_checkpoint(self, ckpt_path: str, strict: bool = True):
1700
+ sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
1701
+ missing_keys, unexpected_keys = self.load_state_dict(sd, strict=strict)
1702
+ if len(missing_keys) > 0 or len(unexpected_keys) > 0:
1703
+ logger.warning(f"Loaded ckpt with {missing_keys=}, {unexpected_keys=}")
1704
+ else:
1705
+ logger.info("Loaded ckpt successfully without missing or unexpected keys")
1706
+
1707
+ def prep_for_evaluator(self, video_frames, tracking_res, scores_labels):
1708
+ """This method is only used for benchmark eval (not used in the demo)."""
1709
+ num_frames = len(video_frames)
1710
+ w, h = video_frames[0].size
1711
+ zero_mask = torch.zeros((1, h, w), dtype=torch.bool)
1712
+ object_ids = list(scores_labels.keys())
1713
+ preds = {"scores": [], "labels": [], "boxes": [], "masks_rle": []}
1714
+ for oid in object_ids:
1715
+ o_masks = []
1716
+ o_score = scores_labels[oid][0].item()
1717
+ o_label = scores_labels[oid][1]
1718
+ for frame_idx in range(num_frames):
1719
+ if frame_idx not in tracking_res:
1720
+ o_masks.append(zero_mask)
1721
+ else:
1722
+ o_masks.append(tracking_res[frame_idx].get(oid, zero_mask))
1723
+
1724
+ o_masks = torch.cat(o_masks, dim=0) # (n_frames, H, W)
1725
+ preds["scores"].append(o_score)
1726
+ preds["labels"].append(o_label)
1727
+ preds["boxes"].append(mask_to_box(o_masks.unsqueeze(1)).squeeze())
1728
+ preds["masks_rle"].append(rle_encode(o_masks, return_areas=True))
1729
+
1730
+ preds["boxes"] = (
1731
+ torch.stack(preds["boxes"], dim=0)
1732
+ if len(preds["boxes"]) > 0
1733
+ else torch.empty(
1734
+ (0, num_frames, 4), dtype=torch.float32, device=self.device
1735
+ )
1736
+ )
1737
+ preds["scores"] = (
1738
+ torch.tensor(preds["scores"], device=self.device)
1739
+ if len(preds["scores"]) > 0
1740
+ else torch.empty((0,), device=self.device)
1741
+ )
1742
+ preds["per_frame_scores"] = preds["scores"]
1743
+ preds["labels"] = (
1744
+ torch.tensor(preds["labels"], device=self.device)
1745
+ if len(preds["labels"]) > 0
1746
+ else torch.empty((0,), device=self.device)
1747
+ )
1748
+ return preds
1749
+
1750
+ def _encode_prompt(self, **kwargs):
1751
+ return self.detector._encode_prompt(**kwargs)
1752
+
1753
+ def _drop_new_det_with_obj_limit(self, new_det_fa_inds, det_scores_np, num_to_keep):
1754
+ """
1755
+ Drop a few new detections based on the maximum number of objects. We drop new objects based
1756
+ on their detection scores, keeping the high-scoring ones and dropping the low-scoring ones.
1757
+ """
1758
+ assert 0 <= num_to_keep <= len(new_det_fa_inds)
1759
+ if num_to_keep == 0:
1760
+ return np.array([], np.int64) # keep none
1761
+ if num_to_keep == len(new_det_fa_inds):
1762
+ return new_det_fa_inds # keep all
1763
+
1764
+ # keep the top-scoring detections
1765
+ score_order = np.argsort(det_scores_np[new_det_fa_inds])[::-1]
1766
+ new_det_fa_inds = new_det_fa_inds[score_order[:num_to_keep]]
1767
+ return new_det_fa_inds
detect_tools/sam3/sam3/model/sam3_video_inference.py ADDED
@@ -0,0 +1,1709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import logging
4
+ from collections import defaultdict
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ import torch.nn.functional as F
10
+
11
+ from sam3 import perflib
12
+ from sam3.logger import get_logger
13
+ from sam3.model.act_ckpt_utils import clone_output_wrapper
14
+ from sam3.model.box_ops import box_xywh_to_cxcywh, box_xyxy_to_xywh
15
+ from sam3.model.data_misc import BatchedDatapoint, convert_my_tensors, FindStage
16
+ from sam3.model.geometry_encoders import Prompt
17
+ from sam3.model.io_utils import IMAGE_EXTS, load_resource_as_video_frames
18
+ from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
19
+ from sam3.model.sam3_video_base import MaskletConfirmationStatus, Sam3VideoBase
20
+ from sam3.model.utils.misc import copy_data_to_device
21
+ from sam3.perflib.compile import compile_wrapper, shape_logging_wrapper
22
+ from sam3.perflib.masks_ops import masks_to_boxes as perf_masks_to_boxes
23
+ from torchvision.ops import masks_to_boxes
24
+ from tqdm.auto import tqdm
25
+
26
+ logger = get_logger(__name__)
27
+
28
+
29
+ class Sam3VideoInference(Sam3VideoBase):
30
+ TEXT_ID_FOR_TEXT = 0
31
+ TEXT_ID_FOR_VISUAL = 1
32
+
33
+ def __init__(
34
+ self,
35
+ image_size=1008,
36
+ image_mean=(0.5, 0.5, 0.5),
37
+ image_std=(0.5, 0.5, 0.5),
38
+ compile_model=False,
39
+ **kwargs,
40
+ ):
41
+ """
42
+ hotstart_delay: int, the delay (in #frames) before the model starts to yield output, 0 to disable hotstart delay.
43
+ hotstart_unmatch_thresh: int, remove the object if it has this many unmatched frames within its hotstart_delay period.
44
+ If `hotstart_delay` is set to 0, this parameter is ignored.
45
+ hotstart_dup_thresh: int, remove the object if it has overlapped with another object this many frames within its hotstart_delay period.
46
+ """
47
+ super().__init__(**kwargs)
48
+ self.image_size = image_size
49
+ self.image_mean = image_mean
50
+ self.image_std = image_std
51
+ self.compile_model = compile_model
52
+
53
+ @torch.inference_mode()
54
+ def init_state(
55
+ self,
56
+ resource_path,
57
+ offload_video_to_cpu=False,
58
+ async_loading_frames=False,
59
+ video_loader_type="cv2",
60
+ ):
61
+ """Initialize an inference state from `resource_path` (an image or a video)."""
62
+ images, orig_height, orig_width = load_resource_as_video_frames(
63
+ resource_path=resource_path,
64
+ image_size=self.image_size,
65
+ offload_video_to_cpu=offload_video_to_cpu,
66
+ img_mean=self.image_mean,
67
+ img_std=self.image_std,
68
+ async_loading_frames=async_loading_frames,
69
+ video_loader_type=video_loader_type,
70
+ )
71
+ inference_state = {}
72
+ inference_state["image_size"] = self.image_size
73
+ inference_state["num_frames"] = len(images)
74
+ # the original video height and width, used for resizing final output scores
75
+ inference_state["orig_height"] = orig_height
76
+ inference_state["orig_width"] = orig_width
77
+ # values that don't change across frames (so we only need to hold one copy of them)
78
+ inference_state["constants"] = {}
79
+ # inputs on each frame
80
+ self._construct_initial_input_batch(inference_state, images)
81
+ # initialize extra states
82
+ inference_state["tracker_inference_states"] = []
83
+ inference_state["tracker_metadata"] = {}
84
+ inference_state["feature_cache"] = {}
85
+ inference_state["cached_frame_outputs"] = {}
86
+ inference_state["action_history"] = [] # for logging user actions
87
+ inference_state["is_image_only"] = is_image_type(resource_path)
88
+ return inference_state
89
+
90
+ @torch.inference_mode()
91
+ def reset_state(self, inference_state):
92
+ """Revert `inference_state` to what it was right after initialization."""
93
+ inference_state["input_batch"].find_text_batch[0] = "<text placeholder>"
94
+ inference_state["text_prompt"] = None
95
+ for t in range(inference_state["num_frames"]):
96
+ inference_state["input_batch"].find_inputs[t].text_ids[...] = 0
97
+ # constructing an output list in inference state (we start with an empty list)
98
+ inference_state["previous_stages_out"][t] = None
99
+ inference_state["per_frame_raw_point_input"][t] = None
100
+ inference_state["per_frame_raw_box_input"][t] = None
101
+ inference_state["per_frame_visual_prompt"][t] = None
102
+ inference_state["per_frame_geometric_prompt"][t] = None
103
+ inference_state["per_frame_cur_step"][t] = 0
104
+
105
+ inference_state["visual_prompt_embed"] = None
106
+ inference_state["visual_prompt_mask"] = None
107
+ inference_state["tracker_inference_states"].clear()
108
+ inference_state["tracker_metadata"].clear()
109
+ inference_state["feature_cache"].clear()
110
+ inference_state["cached_frame_outputs"].clear()
111
+ inference_state["action_history"].clear() # for logging user actions
112
+
113
+ def _construct_initial_input_batch(self, inference_state, images):
114
+ """Construct an initial `BatchedDatapoint` instance as input."""
115
+ # 1) img_batch
116
+ num_frames = len(images)
117
+ device = self.device
118
+
119
+ # 2) find_text_batch
120
+ # "<text placeholder>" will be replaced by the actual text prompt when adding prompts
121
+ find_text_batch = ["<text placeholder>", "visual"]
122
+
123
+ # 3) find_inputs
124
+ input_box_embedding_dim = 258 # historical default
125
+ input_points_embedding_dim = 257 # historical default
126
+ stages = [
127
+ FindStage(
128
+ img_ids=[stage_id],
129
+ text_ids=[0],
130
+ input_boxes=[torch.zeros(input_box_embedding_dim)],
131
+ input_boxes_mask=[torch.empty(0, dtype=torch.bool)],
132
+ input_boxes_label=[torch.empty(0, dtype=torch.long)],
133
+ input_points=[torch.empty(0, input_points_embedding_dim)],
134
+ input_points_mask=[torch.empty(0)],
135
+ object_ids=[],
136
+ )
137
+ for stage_id in range(num_frames)
138
+ ]
139
+ for i in range(len(stages)):
140
+ stages[i] = convert_my_tensors(stages[i])
141
+
142
+ # construct the final `BatchedDatapoint` and cast to GPU
143
+ input_batch = BatchedDatapoint(
144
+ img_batch=images,
145
+ find_text_batch=find_text_batch,
146
+ find_inputs=stages,
147
+ find_targets=[None] * num_frames,
148
+ find_metadatas=[None] * num_frames,
149
+ )
150
+ input_batch = copy_data_to_device(input_batch, device, non_blocking=True)
151
+ inference_state["input_batch"] = input_batch
152
+
153
+ # construct the placeholder interactive prompts and tracking queries
154
+ bs = 1
155
+ inference_state["constants"]["empty_geometric_prompt"] = Prompt(
156
+ box_embeddings=torch.zeros(0, bs, 4, device=device),
157
+ box_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool),
158
+ box_labels=torch.zeros(0, bs, device=device, dtype=torch.long),
159
+ point_embeddings=torch.zeros(0, bs, 2, device=device),
160
+ point_mask=torch.zeros(bs, 0, device=device, dtype=torch.bool),
161
+ point_labels=torch.zeros(0, bs, device=device, dtype=torch.long),
162
+ )
163
+
164
+ # constructing an output list in inference state (we start with an empty list)
165
+ inference_state["previous_stages_out"] = [None] * num_frames
166
+ inference_state["text_prompt"] = None
167
+ inference_state["per_frame_raw_point_input"] = [None] * num_frames
168
+ inference_state["per_frame_raw_box_input"] = [None] * num_frames
169
+ inference_state["per_frame_visual_prompt"] = [None] * num_frames
170
+ inference_state["per_frame_geometric_prompt"] = [None] * num_frames
171
+ inference_state["per_frame_cur_step"] = [0] * num_frames
172
+
173
+ # placeholders for cached outputs
174
+ # (note: currently, a single visual prompt embedding is shared for all frames)
175
+ inference_state["visual_prompt_embed"] = None
176
+ inference_state["visual_prompt_mask"] = None
177
+
178
+ def _get_visual_prompt(self, inference_state, frame_idx, boxes_cxcywh, box_labels):
179
+ """
180
+ Handle the case of visual prompt. Currently, in the inference API we do not
181
+ explicitly distinguish between initial box as visual prompt vs subsequent boxes
182
+ or boxes after inference for refinement.
183
+ """
184
+ # If the frame hasn't had any inference results before (prompting or propagation),
185
+ # we treat the first added box prompt as a visual prompt; otherwise, we treat
186
+ # the first box just as a refinement prompt.
187
+ is_new_visual_prompt = (
188
+ inference_state["per_frame_visual_prompt"][frame_idx] is None
189
+ and inference_state["previous_stages_out"][frame_idx] is None
190
+ )
191
+ if is_new_visual_prompt:
192
+ if boxes_cxcywh.size(0) != 1:
193
+ raise RuntimeError(
194
+ "visual prompts (box as an initial prompt) should only have one box, "
195
+ f"but got {boxes_cxcywh.shape=}"
196
+ )
197
+ if not box_labels.item():
198
+ logging.warning("A negative box is added as a visual prompt.")
199
+ # take the first box prompt as a visual prompt
200
+ device = self.device
201
+ new_visual_prompt = Prompt(
202
+ box_embeddings=boxes_cxcywh[None, 0:1, :].to(device), # (seq, bs, 4)
203
+ box_mask=None,
204
+ box_labels=box_labels[None, 0:1].to(device), # (seq, bs)
205
+ point_embeddings=None,
206
+ point_mask=None,
207
+ point_labels=None,
208
+ )
209
+ inference_state["per_frame_visual_prompt"][frame_idx] = new_visual_prompt
210
+ else:
211
+ new_visual_prompt = None
212
+
213
+ # `boxes_cxcywh` and `box_labels` contains all the raw box inputs added so far
214
+ # strip any visual prompt from the input boxes (for geometric prompt encoding)
215
+ if inference_state["per_frame_visual_prompt"][frame_idx] is not None:
216
+ boxes_cxcywh = boxes_cxcywh[1:]
217
+ box_labels = box_labels[1:]
218
+
219
+ return boxes_cxcywh, box_labels, new_visual_prompt
220
+
221
+ def _get_processing_order(
222
+ self, inference_state, start_frame_idx, max_frame_num_to_track, reverse
223
+ ):
224
+ num_frames = inference_state["num_frames"]
225
+ previous_stages_out = inference_state["previous_stages_out"]
226
+ if all(out is None for out in previous_stages_out) and start_frame_idx is None:
227
+ raise RuntimeError(
228
+ "No prompts are received on any frames. Please add prompt on at least one frame before propagation."
229
+ )
230
+ # set start index, end index, and processing order
231
+ if start_frame_idx is None:
232
+ # default: start from the earliest frame with input points
233
+ start_frame_idx = min(
234
+ t for t, out in enumerate(previous_stages_out) if out is not None
235
+ )
236
+ if max_frame_num_to_track is None:
237
+ # default: track all the frames in the video
238
+ max_frame_num_to_track = num_frames
239
+ if reverse:
240
+ end_frame_idx = start_frame_idx - max_frame_num_to_track
241
+ end_frame_idx = max(end_frame_idx, 0)
242
+ processing_order = range(start_frame_idx - 1, end_frame_idx - 1, -1)
243
+ else:
244
+ end_frame_idx = start_frame_idx + max_frame_num_to_track
245
+ end_frame_idx = min(end_frame_idx, num_frames - 1)
246
+ processing_order = range(start_frame_idx, end_frame_idx + 1)
247
+ return processing_order, end_frame_idx
248
+
249
+ @torch.inference_mode()
250
+ def propagate_in_video(
251
+ self,
252
+ inference_state,
253
+ start_frame_idx=None,
254
+ max_frame_num_to_track=None,
255
+ reverse=False,
256
+ ):
257
+ """
258
+ Propagate the prompts to get grounding results for the entire video. This method
259
+ is a generator and yields inference outputs for all frames in the range specified
260
+ by `start_frame_idx`, `max_frame_num_to_track`, and `reverse`.
261
+ """
262
+ # compile the model (it's a no-op if the model is already compiled)
263
+ # note that it's intentionally added to `self.propagate_in_video`, so that the first
264
+ # `self.add_prompt` call will be done in eager mode to fill in the decoder buffers
265
+ # such as positional encoding cache)
266
+ self._compile_model()
267
+
268
+ processing_order, end_frame_idx = self._get_processing_order(
269
+ inference_state,
270
+ start_frame_idx,
271
+ max_frame_num_to_track,
272
+ reverse=reverse,
273
+ )
274
+
275
+ # Store max_frame_num_to_track in feature_cache for downstream methods
276
+ inference_state["feature_cache"]["tracking_bounds"] = {
277
+ "max_frame_num_to_track": max_frame_num_to_track,
278
+ "propagate_in_video_start_frame_idx": start_frame_idx,
279
+ }
280
+
281
+ hotstart_buffer = []
282
+ hotstart_removed_obj_ids = set()
283
+ # when deciding whether to output a masklet on `yield_frame_idx`, we check whether the object is confirmed
284
+ # in a future frame (`unconfirmed_frame_delay` frames after the current frame). For example, if we require
285
+ # an object to be detected in 3 consecutive frames to be confirmed, then we look 2 frames in the future --
286
+ # e.g., we output an object on frame 4 only if it becomes confirmed on frame 6.
287
+ unconfirmed_status_delay = self.masklet_confirmation_consecutive_det_thresh - 1
288
+ unconfirmed_obj_ids_per_frame = {} # frame_idx -> hidden_obj_ids
289
+ for frame_idx in tqdm(
290
+ processing_order, desc="propagate_in_video", disable=self.rank > 0
291
+ ):
292
+ out = self._run_single_frame_inference(inference_state, frame_idx, reverse)
293
+
294
+ if self.hotstart_delay > 0:
295
+ # accumulate the outputs for the first `hotstart_delay` frames
296
+ hotstart_buffer.append([frame_idx, out])
297
+ # update the object IDs removed by hotstart so that we don't output them
298
+ if self.rank == 0:
299
+ hotstart_removed_obj_ids.update(out["removed_obj_ids"])
300
+ unconfirmed_obj_ids = out.get("unconfirmed_obj_ids", None)
301
+ if unconfirmed_obj_ids is not None:
302
+ unconfirmed_obj_ids_per_frame[frame_idx] = unconfirmed_obj_ids
303
+
304
+ if frame_idx == end_frame_idx:
305
+ # we reached the end of propagation -- yield all frames in the buffer
306
+ yield_list = hotstart_buffer
307
+ hotstart_buffer = []
308
+ elif len(hotstart_buffer) >= self.hotstart_delay:
309
+ # we have enough frames -- yield and remove the first (oldest) frame from the buffer
310
+ yield_list = hotstart_buffer[:1]
311
+ hotstart_buffer = hotstart_buffer[1:]
312
+ else:
313
+ # not enough frames yet -- skip yielding
314
+ yield_list = []
315
+ else:
316
+ yield_list = [(frame_idx, out)] # output the current frame
317
+
318
+ for yield_frame_idx, yield_out in yield_list:
319
+ # post-process the output and yield it
320
+ if self.rank == 0:
321
+ suppressed_obj_ids = yield_out["suppressed_obj_ids"]
322
+ unconfirmed_status_frame_idx = (
323
+ yield_frame_idx + unconfirmed_status_delay
324
+ if not reverse
325
+ else yield_frame_idx - unconfirmed_status_delay
326
+ )
327
+
328
+ # Clamp the frame index to stay within video bounds
329
+ num_frames = inference_state["num_frames"]
330
+ unconfirmed_status_frame_idx = max(
331
+ 0, min(unconfirmed_status_frame_idx, num_frames - 1)
332
+ )
333
+
334
+ unconfirmed_obj_ids = unconfirmed_obj_ids_per_frame.get(
335
+ unconfirmed_status_frame_idx, None
336
+ )
337
+ postprocessed_out = self._postprocess_output(
338
+ inference_state,
339
+ yield_out,
340
+ hotstart_removed_obj_ids,
341
+ suppressed_obj_ids,
342
+ unconfirmed_obj_ids,
343
+ )
344
+
345
+ self._cache_frame_outputs(
346
+ inference_state,
347
+ yield_frame_idx,
348
+ yield_out["obj_id_to_mask"],
349
+ suppressed_obj_ids=suppressed_obj_ids,
350
+ removed_obj_ids=hotstart_removed_obj_ids,
351
+ unconfirmed_obj_ids=unconfirmed_obj_ids,
352
+ )
353
+ else:
354
+ postprocessed_out = None # no output on other GPUs
355
+ yield yield_frame_idx, postprocessed_out
356
+
357
+ def _run_single_frame_inference(self, inference_state, frame_idx, reverse):
358
+ """
359
+ Perform inference on a single frame and get its inference results. This would
360
+ also update `inference_state`.
361
+ """
362
+ # prepare inputs
363
+ input_batch = inference_state["input_batch"]
364
+ tracker_states_local = inference_state["tracker_inference_states"]
365
+ has_text_prompt = inference_state["text_prompt"] is not None
366
+ has_geometric_prompt = (
367
+ inference_state["per_frame_geometric_prompt"][frame_idx] is not None
368
+ )
369
+ # run inference for the current frame
370
+ (
371
+ obj_id_to_mask,
372
+ obj_id_to_score,
373
+ tracker_states_local_new,
374
+ tracker_metadata_new,
375
+ frame_stats,
376
+ _,
377
+ ) = self._det_track_one_frame(
378
+ frame_idx=frame_idx,
379
+ num_frames=inference_state["num_frames"],
380
+ reverse=reverse,
381
+ input_batch=input_batch,
382
+ geometric_prompt=(
383
+ inference_state["constants"]["empty_geometric_prompt"]
384
+ if not has_geometric_prompt
385
+ else inference_state["per_frame_geometric_prompt"][frame_idx]
386
+ ),
387
+ tracker_states_local=tracker_states_local,
388
+ tracker_metadata_prev=inference_state["tracker_metadata"],
389
+ feature_cache=inference_state["feature_cache"],
390
+ orig_vid_height=inference_state["orig_height"],
391
+ orig_vid_width=inference_state["orig_width"],
392
+ is_image_only=inference_state["is_image_only"],
393
+ allow_new_detections=has_text_prompt or has_geometric_prompt,
394
+ )
395
+ # update inference state
396
+ inference_state["tracker_inference_states"] = tracker_states_local_new
397
+ inference_state["tracker_metadata"] = tracker_metadata_new
398
+ # use a dummy string in "previous_stages_out" to indicate this frame has outputs
399
+ inference_state["previous_stages_out"][frame_idx] = "_THIS_FRAME_HAS_OUTPUTS_"
400
+
401
+ if self.rank == 0:
402
+ self._cache_frame_outputs(inference_state, frame_idx, obj_id_to_mask)
403
+
404
+ out = {
405
+ "obj_id_to_mask": obj_id_to_mask,
406
+ "obj_id_to_score": obj_id_to_score, # first frame detection score
407
+ "obj_id_to_tracker_score": tracker_metadata_new[
408
+ "obj_id_to_tracker_score_frame_wise"
409
+ ][frame_idx],
410
+ }
411
+ # removed_obj_ids is only needed on rank 0 to handle hotstart delay buffer
412
+ if self.rank == 0:
413
+ rank0_metadata = tracker_metadata_new["rank0_metadata"]
414
+ removed_obj_ids = rank0_metadata["removed_obj_ids"]
415
+ out["removed_obj_ids"] = removed_obj_ids
416
+ out["suppressed_obj_ids"] = rank0_metadata["suppressed_obj_ids"][frame_idx]
417
+ out["frame_stats"] = frame_stats
418
+ if self.masklet_confirmation_enable:
419
+ status = rank0_metadata["masklet_confirmation"]["status"]
420
+ is_unconfirmed = status == MaskletConfirmationStatus.UNCONFIRMED.value
421
+ out["unconfirmed_obj_ids"] = tracker_metadata_new["obj_ids_all_gpu"][
422
+ is_unconfirmed
423
+ ].tolist()
424
+ else:
425
+ out["unconfirmed_obj_ids"] = []
426
+
427
+ return out
428
+
429
+ def _postprocess_output(
430
+ self,
431
+ inference_state,
432
+ out,
433
+ removed_obj_ids=None,
434
+ suppressed_obj_ids=None,
435
+ unconfirmed_obj_ids=None,
436
+ ):
437
+ obj_id_to_mask = out["obj_id_to_mask"] # low res masks
438
+ curr_obj_ids = sorted(obj_id_to_mask.keys())
439
+ H_video, W_video = inference_state["orig_height"], inference_state["orig_width"]
440
+ if len(curr_obj_ids) == 0:
441
+ out_obj_ids = torch.zeros(0, dtype=torch.int64)
442
+ out_probs = torch.zeros(0, dtype=torch.float32)
443
+ out_binary_masks = torch.zeros(0, H_video, W_video, dtype=torch.bool)
444
+ out_boxes_xywh = torch.zeros(0, 4, dtype=torch.float32)
445
+ else:
446
+ out_obj_ids = torch.tensor(curr_obj_ids, dtype=torch.int64)
447
+ out_probs = torch.tensor(
448
+ [out["obj_id_to_score"][obj_id] for obj_id in curr_obj_ids]
449
+ )
450
+ out_tracker_probs = torch.tensor(
451
+ [
452
+ (
453
+ out["obj_id_to_tracker_score"][obj_id]
454
+ if obj_id in out["obj_id_to_tracker_score"]
455
+ else 0.0
456
+ )
457
+ for obj_id in curr_obj_ids
458
+ ]
459
+ )
460
+ out_binary_masks = torch.cat(
461
+ [obj_id_to_mask[obj_id] for obj_id in curr_obj_ids], dim=0
462
+ )
463
+
464
+ assert out_binary_masks.dtype == torch.bool
465
+ keep = out_binary_masks.any(dim=(1, 2)).cpu() # remove masks with 0 areas
466
+ # hide outputs for those object IDs in `obj_ids_to_hide`
467
+ obj_ids_to_hide = []
468
+ if suppressed_obj_ids is not None:
469
+ obj_ids_to_hide.extend(suppressed_obj_ids)
470
+ if removed_obj_ids is not None:
471
+ obj_ids_to_hide.extend(removed_obj_ids)
472
+ if unconfirmed_obj_ids is not None:
473
+ obj_ids_to_hide.extend(unconfirmed_obj_ids)
474
+ if len(obj_ids_to_hide) > 0:
475
+ obj_ids_to_hide_t = torch.tensor(obj_ids_to_hide, dtype=torch.int64)
476
+ keep &= ~torch.isin(out_obj_ids, obj_ids_to_hide_t)
477
+
478
+ # slice those valid entries from the original outputs
479
+ keep_idx = torch.nonzero(keep, as_tuple=True)[0]
480
+ keep_idx_gpu = keep_idx.pin_memory().to(
481
+ device=out_binary_masks.device, non_blocking=True
482
+ )
483
+
484
+ out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx)
485
+ out_probs = torch.index_select(out_probs, 0, keep_idx)
486
+ out_tracker_probs = torch.index_select(out_tracker_probs, 0, keep_idx)
487
+ out_binary_masks = torch.index_select(out_binary_masks, 0, keep_idx_gpu)
488
+
489
+ if perflib.is_enabled:
490
+ out_boxes_xyxy = perf_masks_to_boxes(
491
+ out_binary_masks, out_obj_ids.tolist()
492
+ )
493
+ else:
494
+ out_boxes_xyxy = masks_to_boxes(out_binary_masks)
495
+
496
+ out_boxes_xywh = box_xyxy_to_xywh(out_boxes_xyxy) # convert to xywh format
497
+ # normalize boxes
498
+ out_boxes_xywh[..., 0] /= W_video
499
+ out_boxes_xywh[..., 1] /= H_video
500
+ out_boxes_xywh[..., 2] /= W_video
501
+ out_boxes_xywh[..., 3] /= H_video
502
+
503
+ # apply non-overlapping constraints on the existing masklets
504
+ if out_binary_masks.shape[0] > 1:
505
+ assert len(out_binary_masks) == len(out_tracker_probs)
506
+ out_binary_masks = (
507
+ self.tracker._apply_object_wise_non_overlapping_constraints(
508
+ out_binary_masks.unsqueeze(1),
509
+ out_tracker_probs.unsqueeze(1).to(out_binary_masks.device),
510
+ background_value=0,
511
+ ).squeeze(1)
512
+ ) > 0
513
+
514
+ outputs = {
515
+ "out_obj_ids": out_obj_ids.cpu().numpy(),
516
+ "out_probs": out_probs.cpu().numpy(),
517
+ "out_boxes_xywh": out_boxes_xywh.cpu().numpy(),
518
+ "out_binary_masks": out_binary_masks.cpu().numpy(),
519
+ "frame_stats": out.get("frame_stats", None),
520
+ }
521
+ return outputs
522
+
523
+ def _cache_frame_outputs(
524
+ self,
525
+ inference_state,
526
+ frame_idx,
527
+ obj_id_to_mask,
528
+ suppressed_obj_ids=None,
529
+ removed_obj_ids=None,
530
+ unconfirmed_obj_ids=None,
531
+ ):
532
+ # Filter out suppressed, removed, and unconfirmed objects from the cache
533
+ filtered_obj_id_to_mask = obj_id_to_mask.copy()
534
+
535
+ objects_to_exclude = set()
536
+ if suppressed_obj_ids is not None:
537
+ objects_to_exclude.update(suppressed_obj_ids)
538
+ if removed_obj_ids is not None:
539
+ objects_to_exclude.update(removed_obj_ids)
540
+ if unconfirmed_obj_ids is not None:
541
+ objects_to_exclude.update(unconfirmed_obj_ids)
542
+
543
+ if objects_to_exclude:
544
+ for obj_id in objects_to_exclude:
545
+ if obj_id in filtered_obj_id_to_mask:
546
+ del filtered_obj_id_to_mask[obj_id]
547
+
548
+ inference_state["cached_frame_outputs"][frame_idx] = filtered_obj_id_to_mask
549
+
550
+ def _build_tracker_output(
551
+ self, inference_state, frame_idx, refined_obj_id_to_mask=None
552
+ ):
553
+ assert (
554
+ "cached_frame_outputs" in inference_state
555
+ and frame_idx in inference_state["cached_frame_outputs"]
556
+ ), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
557
+ cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
558
+
559
+ obj_id_to_mask = cached_outputs.copy()
560
+
561
+ # Update with refined masks if provided
562
+ if refined_obj_id_to_mask is not None:
563
+ for obj_id, refined_mask in refined_obj_id_to_mask.items():
564
+ assert (
565
+ refined_mask is not None
566
+ ), f"Refined mask data must be provided for obj_id {obj_id}"
567
+ obj_id_to_mask[obj_id] = refined_mask
568
+
569
+ return obj_id_to_mask
570
+
571
+ def _compile_model(self):
572
+ """Compile the SAM model with torch.compile for speedup."""
573
+ is_compiled = getattr(self, "_model_is_compiled", False)
574
+ if is_compiled or not self.compile_model:
575
+ return
576
+
577
+ import torch._dynamo
578
+
579
+ # a larger cache size to hold varying number of shapes for torch.compile
580
+ # see https://github.com/pytorch/pytorch/blob/v2.5.1/torch/_dynamo/config.py#L42-L49
581
+ torch._dynamo.config.cache_size_limit = 128
582
+ torch._dynamo.config.accumulated_cache_size_limit = 2048
583
+ torch._dynamo.config.capture_scalar_outputs = True
584
+ torch._dynamo.config.suppress_errors = True
585
+
586
+ # Compile module components
587
+ # skip compilation of `_encode_prompt` since it sometimes tiggger SymInt errors
588
+ # self._encode_prompt = clone_output_wrapper(
589
+ # torch.compile(self._encode_prompt, fullgraph=True, mode="max-autotune")
590
+ # )
591
+
592
+ ## Compile SAM3 model components
593
+ self.detector.backbone.vision_backbone.forward = clone_output_wrapper(
594
+ torch.compile(
595
+ self.detector.backbone.vision_backbone.forward,
596
+ fullgraph=True,
597
+ mode="max-autotune",
598
+ )
599
+ )
600
+ self.detector.transformer.encoder.forward = clone_output_wrapper(
601
+ torch.compile(
602
+ self.detector.transformer.encoder.forward,
603
+ fullgraph=True,
604
+ mode="max-autotune",
605
+ )
606
+ )
607
+ self.detector.transformer.decoder.forward = clone_output_wrapper(
608
+ torch.compile(
609
+ self.detector.transformer.decoder.forward,
610
+ fullgraph=True,
611
+ mode="max-autotune",
612
+ dynamic=False,
613
+ )
614
+ )
615
+
616
+ self.detector.segmentation_head.forward = clone_output_wrapper(
617
+ torch.compile(
618
+ self.detector.segmentation_head.forward,
619
+ fullgraph=True,
620
+ mode="max-autotune",
621
+ )
622
+ )
623
+
624
+ ## Compile Tracker model components
625
+ self.tracker.maskmem_backbone.forward = compile_wrapper(
626
+ self.tracker.maskmem_backbone.forward,
627
+ mode="max-autotune",
628
+ fullgraph=True,
629
+ dynamic=False,
630
+ )
631
+
632
+ self.tracker.transformer.encoder.forward = shape_logging_wrapper(
633
+ compile_wrapper(
634
+ self.tracker.transformer.encoder.forward,
635
+ mode="max-autotune-no-cudagraphs",
636
+ fullgraph=True,
637
+ dynamic=True,
638
+ ),
639
+ keep_kwargs=["src", "src_pos", "prompt", "prompt_pos"],
640
+ )
641
+
642
+ self.tracker.sam_mask_decoder.forward = compile_wrapper(
643
+ self.tracker.sam_mask_decoder.forward,
644
+ mode="max-autotune",
645
+ fullgraph=True,
646
+ dynamic=False, # Accuracy regression on True
647
+ )
648
+
649
+ self._model_is_compiled = True
650
+
651
+ def _warm_up_vg_propagation(self, inference_state, start_frame_idx=0):
652
+ # use different tracking score thresholds for each round to simulate different number of output objects
653
+ num_objects_list = range(self.num_obj_for_compile + 1)
654
+ new_det_score_thresh_list = [0.3, 0.5, 0.7]
655
+ num_rounds = len(new_det_score_thresh_list)
656
+ orig_new_det_thresh = self.new_det_thresh
657
+
658
+ for i, thresh in enumerate(new_det_score_thresh_list):
659
+ self.new_det_thresh = thresh
660
+ for num_objects in num_objects_list:
661
+ logger.info(f"{i+1}/{num_rounds} warming up model compilation")
662
+ self.add_prompt(
663
+ inference_state, frame_idx=start_frame_idx, text_str="cat"
664
+ )
665
+ logger.info(
666
+ f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
667
+ )
668
+ inference_state = self.add_fake_objects_to_inference_state(
669
+ inference_state, num_objects, frame_idx=start_frame_idx
670
+ )
671
+ inference_state["tracker_metadata"]["rank0_metadata"].update(
672
+ {
673
+ "masklet_confirmation": {
674
+ "status": np.zeros(num_objects, dtype=np.int64),
675
+ "consecutive_det_num": np.zeros(
676
+ num_objects, dtype=np.int64
677
+ ),
678
+ }
679
+ }
680
+ )
681
+ for _ in self.propagate_in_video(
682
+ inference_state, start_frame_idx, reverse=False
683
+ ):
684
+ pass
685
+ for _ in self.propagate_in_video(
686
+ inference_state, start_frame_idx, reverse=True
687
+ ):
688
+ pass
689
+ self.reset_state(inference_state)
690
+ logger.info(
691
+ f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
692
+ )
693
+
694
+ # Warm up Tracker memory encoder with varying input shapes
695
+ num_iters = 3
696
+ feat_size = self.tracker.sam_image_embedding_size**2 # 72 * 72 = 5184
697
+ hidden_dim = self.tracker.hidden_dim # 256
698
+ mem_dim = self.tracker.mem_dim # 64
699
+ for _ in tqdm(range(num_iters)):
700
+ for b in range(1, self.num_obj_for_compile + 1):
701
+ for i in range(
702
+ 1,
703
+ self.tracker.max_cond_frames_in_attn + self.tracker.num_maskmem,
704
+ ):
705
+ for j in range(
706
+ self.tracker.max_cond_frames_in_attn
707
+ + self.tracker.max_obj_ptrs_in_encoder
708
+ ):
709
+ num_obj_ptr_tokens = (hidden_dim // mem_dim) * j
710
+ src = torch.randn(feat_size, b, hidden_dim, device=self.device)
711
+ src_pos = torch.randn(
712
+ feat_size, b, hidden_dim, device=self.device
713
+ )
714
+ prompt = torch.randn(
715
+ feat_size * i + num_obj_ptr_tokens,
716
+ b,
717
+ mem_dim,
718
+ device=self.device,
719
+ )
720
+ prompt_pos = torch.randn(
721
+ feat_size * i + num_obj_ptr_tokens,
722
+ b,
723
+ mem_dim,
724
+ device=self.device,
725
+ )
726
+
727
+ self.tracker.transformer.encoder.forward(
728
+ src=src,
729
+ src_pos=src_pos,
730
+ prompt=prompt,
731
+ prompt_pos=prompt_pos,
732
+ num_obj_ptr_tokens=num_obj_ptr_tokens,
733
+ )
734
+
735
+ self.new_det_thresh = orig_new_det_thresh
736
+ return inference_state
737
+
738
+ def add_fake_objects_to_inference_state(
739
+ self, inference_state, num_objects, frame_idx
740
+ ):
741
+ new_det_obj_ids_local = np.arange(num_objects)
742
+ high_res_H, high_res_W = (
743
+ self.tracker.maskmem_backbone.mask_downsampler.interpol_size
744
+ )
745
+ new_det_masks = torch.ones(
746
+ len(new_det_obj_ids_local), high_res_H, high_res_W
747
+ ).to(self.device)
748
+
749
+ inference_state["tracker_inference_states"] = self._tracker_add_new_objects(
750
+ frame_idx=frame_idx,
751
+ num_frames=inference_state["num_frames"],
752
+ new_obj_ids=new_det_obj_ids_local,
753
+ new_obj_masks=new_det_masks,
754
+ tracker_states_local=inference_state["tracker_inference_states"],
755
+ orig_vid_height=inference_state["orig_height"],
756
+ orig_vid_width=inference_state["orig_width"],
757
+ feature_cache=inference_state["feature_cache"],
758
+ )
759
+
760
+ # Synthesize obj_id_to_mask data for cached_frame_outputs to support _build_tracker_output during warmup
761
+ obj_id_to_mask = {}
762
+ if num_objects > 0:
763
+ H_video = inference_state["orig_height"]
764
+ W_video = inference_state["orig_width"]
765
+
766
+ video_res_masks = F.interpolate(
767
+ new_det_masks.unsqueeze(1), # Add channel dimension for interpolation
768
+ size=(H_video, W_video),
769
+ mode="bilinear",
770
+ align_corners=False,
771
+ ) # (num_objects, 1, H_video, W_video)
772
+ for i, obj_id in enumerate(new_det_obj_ids_local):
773
+ obj_id_to_mask[obj_id] = (video_res_masks[i] > 0.0).to(torch.bool)
774
+ if self.rank == 0:
775
+ for fidx in range(inference_state["num_frames"]):
776
+ self._cache_frame_outputs(inference_state, fidx, obj_id_to_mask)
777
+
778
+ inference_state["tracker_metadata"].update(
779
+ {
780
+ "obj_ids_per_gpu": [np.arange(num_objects)],
781
+ "obj_ids_all_gpu": np.arange(num_objects), # Same as 1 GPU
782
+ "num_obj_per_gpu": [num_objects],
783
+ "obj_id_to_score": {i: 1.0 for i in range(num_objects)},
784
+ "max_obj_id": num_objects,
785
+ "rank0_metadata": {
786
+ "masklet_confirmation": {
787
+ "status": np.zeros(num_objects, dtype=np.int64),
788
+ "consecutive_det_num": np.zeros(num_objects, dtype=np.int64),
789
+ },
790
+ "removed_obj_ids": set(),
791
+ "suppressed_obj_ids": defaultdict(set),
792
+ },
793
+ }
794
+ )
795
+ return inference_state
796
+
797
+ @torch.inference_mode()
798
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
799
+ def warm_up_compilation(self):
800
+ """
801
+ Warm up the model by running a dummy inference to compile the model. This is
802
+ useful to avoid the compilation overhead in the first inference call.
803
+ """
804
+ if not self.compile_model:
805
+ return
806
+ self._warm_up_complete = False
807
+ if self.device.type != "cuda":
808
+ raise RuntimeError(
809
+ f"The model must be on CUDA for warm-up compilation, got {self.device=}."
810
+ )
811
+
812
+ # temporally set to single GPU temporarily for warm-up compilation
813
+ orig_rank = self.rank
814
+ orig_world_size = self.world_size
815
+ self.rank = self.detector.rank = 0
816
+ self.world_size = self.detector.world_size = 1
817
+ orig_recondition_every_nth_frame = self.recondition_every_nth_frame
818
+ # self.recondition_every_nth_frame = 2
819
+
820
+ # Get a random video
821
+ inference_state = self.init_state(resource_path="<load-dummy-video-30>")
822
+ start_frame_idx = 0
823
+
824
+ # Run basic propagation warm-up
825
+ inference_state = self._warm_up_vg_propagation(inference_state, start_frame_idx)
826
+
827
+ logger.info("Warm-up compilation completed.")
828
+
829
+ # revert to the original GPU and rank
830
+ self.rank = self.detector.rank = orig_rank
831
+ self.world_size = self.detector.world_size = orig_world_size
832
+ self.recondition_every_nth_frame = orig_recondition_every_nth_frame
833
+ self._warm_up_complete = True
834
+ self.tracker.transformer.encoder.forward.set_logging(True)
835
+
836
+ @torch.inference_mode()
837
+ def add_prompt(
838
+ self,
839
+ inference_state,
840
+ frame_idx,
841
+ text_str=None,
842
+ boxes_xywh=None,
843
+ box_labels=None,
844
+ ):
845
+ """
846
+ Add text, point or box prompts on a single frame. This method returns the inference
847
+ outputs only on the prompted frame.
848
+
849
+ Note that text prompts are NOT associated with a particular frame (i.e. they apply
850
+ to all frames). However, we only run inference on the frame specified in `frame_idx`.
851
+ """
852
+ logger.debug("Running add_prompt on frame %d", frame_idx)
853
+
854
+ num_frames = inference_state["num_frames"]
855
+ assert (
856
+ text_str is not None or boxes_xywh is not None
857
+ ), "at least one type of prompt (text, boxes) must be provided"
858
+ assert (
859
+ 0 <= frame_idx < num_frames
860
+ ), f"{frame_idx=} is out of range for a total of {num_frames} frames"
861
+
862
+ # since it's a semantic prompt, we start over
863
+ self.reset_state(inference_state)
864
+
865
+ # 1) add text prompt
866
+ if text_str is not None and text_str != "visual":
867
+ inference_state["text_prompt"] = text_str
868
+ inference_state["input_batch"].find_text_batch[0] = text_str
869
+ text_id = self.TEXT_ID_FOR_TEXT
870
+ else:
871
+ inference_state["text_prompt"] = None
872
+ inference_state["input_batch"].find_text_batch[0] = "<text placeholder>"
873
+ text_id = self.TEXT_ID_FOR_VISUAL
874
+ for t in range(inference_state["num_frames"]):
875
+ inference_state["input_batch"].find_inputs[t].text_ids[...] = text_id
876
+
877
+ # 2) handle box prompt
878
+ assert (boxes_xywh is not None) == (box_labels is not None)
879
+ if boxes_xywh is not None:
880
+ boxes_xywh = torch.as_tensor(boxes_xywh, dtype=torch.float32)
881
+ box_labels = torch.as_tensor(box_labels, dtype=torch.long)
882
+ # input boxes are expected to be [xmin, ymin, width, height] format
883
+ # in normalized coordinates of range 0~1, similar to FA
884
+ assert boxes_xywh.dim() == 2
885
+ assert boxes_xywh.size(0) > 0 and boxes_xywh.size(-1) == 4
886
+ assert box_labels.dim() == 1 and box_labels.size(0) == boxes_xywh.size(0)
887
+ boxes_cxcywh = box_xywh_to_cxcywh(boxes_xywh)
888
+ assert (boxes_xywh >= 0).all().item() and (boxes_xywh <= 1).all().item()
889
+ assert (boxes_cxcywh >= 0).all().item() and (boxes_cxcywh <= 1).all().item()
890
+
891
+ new_box_input = boxes_cxcywh, box_labels
892
+ inference_state["per_frame_raw_box_input"][frame_idx] = new_box_input
893
+
894
+ # handle the case of visual prompt (also added as an input box from the UI)
895
+ boxes_cxcywh, box_labels, geometric_prompt = self._get_visual_prompt(
896
+ inference_state, frame_idx, boxes_cxcywh, box_labels
897
+ )
898
+
899
+ inference_state["per_frame_geometric_prompt"][frame_idx] = geometric_prompt
900
+
901
+ out = self._run_single_frame_inference(
902
+ inference_state, frame_idx, reverse=False
903
+ )
904
+ return frame_idx, self._postprocess_output(inference_state, out)
905
+
906
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
907
+ def forward(self, input: BatchedDatapoint, is_inference: bool = False):
908
+ """This method is only used for benchmark eval (not used in the demo)."""
909
+ # set the model to single GPU for benchmark evaluation (to be compatible with trainer)
910
+ orig_rank = self.rank
911
+ orig_world_size = self.world_size
912
+ self.rank = self.detector.rank = 0
913
+ self.world_size = self.detector.world_size = 1
914
+
915
+ # get data
916
+ text_prompt_ids = input.find_metadatas[0].original_category_id
917
+ text_prompt_list = input.find_text_batch
918
+
919
+ # loop over txt prompts
920
+ tracking_res = defaultdict(dict) # frame_idx --> {obj_id: mask}
921
+ scores_labels = defaultdict(tuple) # obj_id --> (score, text_prompt_id)
922
+ inference_state = self.init_state(resource_path=input.raw_images)
923
+ for prompt_id, prompt in zip(text_prompt_ids, text_prompt_list):
924
+ self.add_prompt(inference_state, frame_idx=0, text_str=prompt)
925
+ start_obj_id = max(scores_labels.keys(), default=-1) + 1 # prev max + 1
926
+
927
+ # propagate the prompts
928
+ obj_ids_this_prompt = set()
929
+ for frame_idx, out in self.propagate_in_video(
930
+ inference_state,
931
+ start_frame_idx=0,
932
+ max_frame_num_to_track=inference_state["num_frames"],
933
+ reverse=False,
934
+ ):
935
+ current_frame_res = tracking_res[frame_idx]
936
+ for obj_id, mask in zip(out["out_obj_ids"], out["out_binary_masks"]):
937
+ mask_tensor = torch.tensor(mask[None], dtype=torch.bool)
938
+ current_frame_res[obj_id + start_obj_id] = mask_tensor
939
+ obj_ids_this_prompt.update(current_frame_res.keys())
940
+
941
+ obj_id_to_score = inference_state["tracker_metadata"]["obj_id_to_score"]
942
+ for obj_id, score in obj_id_to_score.items():
943
+ if obj_id + start_obj_id in obj_ids_this_prompt:
944
+ score_tensor = torch.tensor(score, dtype=torch.float32)
945
+ scores_labels[obj_id + start_obj_id] = (score_tensor, prompt_id)
946
+
947
+ self.reset_state(inference_state)
948
+
949
+ video_id = input.find_metadatas[0].original_image_id[0].cpu().item()
950
+ preds = self.prep_for_evaluator(input.raw_images, tracking_res, scores_labels)
951
+
952
+ # revert the model to the original GPU and rank
953
+ self.rank = self.detector.rank = orig_rank
954
+ self.world_size = self.detector.world_size = orig_world_size
955
+ return {video_id: preds}
956
+
957
+ def back_convert(self, targets):
958
+ # Needed for retraining compatibility with trainer
959
+ return targets
960
+
961
+
962
+ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
963
+ def __init__(
964
+ self,
965
+ use_prev_mem_frame=False,
966
+ use_stateless_refinement=False,
967
+ refinement_detector_cond_frame_removal_window=16,
968
+ **kwargs,
969
+ ):
970
+ """
971
+ use_prev_mem_frame: bool, whether to condition on previous memory frames for adding points
972
+ use_stateless_refinement: bool, whether to enable stateless refinement behavior
973
+ refinement_detector_cond_frame_removal_window: int, we remove a detector conditioning frame if it
974
+ is within this many frames of a user refined frame. Set to a large value (e.g. 10000) to
975
+ always remove detector conditioning frames if there is any user refinement in the video.
976
+ """
977
+ super().__init__(**kwargs)
978
+ self.use_prev_mem_frame = use_prev_mem_frame
979
+ self.use_stateless_refinement = use_stateless_refinement
980
+ self.refinement_detector_cond_frame_removal_window = (
981
+ refinement_detector_cond_frame_removal_window
982
+ )
983
+
984
+ def _init_new_tracker_state(self, inference_state):
985
+ return self.tracker.init_state(
986
+ cached_features=inference_state["feature_cache"],
987
+ video_height=inference_state["orig_height"],
988
+ video_width=inference_state["orig_width"],
989
+ num_frames=inference_state["num_frames"],
990
+ )
991
+
992
+ @torch.inference_mode()
993
+ def propagate_in_video(
994
+ self,
995
+ inference_state,
996
+ start_frame_idx=None,
997
+ max_frame_num_to_track=None,
998
+ reverse=False,
999
+ ):
1000
+ # step 1: check which type of propagation to run, should be the same for all GPUs.
1001
+ propagation_type, obj_ids = self.parse_action_history_for_propagation(
1002
+ inference_state
1003
+ )
1004
+ self.add_action_history(
1005
+ inference_state,
1006
+ action_type=propagation_type,
1007
+ obj_ids=obj_ids,
1008
+ frame_idx=start_frame_idx,
1009
+ )
1010
+
1011
+ # step 2: run full VG propagation
1012
+ if propagation_type == "propagation_full":
1013
+ logger.debug(f"Running full VG propagation (reverse={reverse}).")
1014
+ yield from super().propagate_in_video(
1015
+ inference_state,
1016
+ start_frame_idx=start_frame_idx,
1017
+ max_frame_num_to_track=max_frame_num_to_track,
1018
+ reverse=reverse,
1019
+ )
1020
+ return
1021
+
1022
+ # step 3: run Tracker partial propagation or direct fetch existing predictions
1023
+ assert propagation_type in ["propagation_partial", "propagation_fetch"]
1024
+ logger.debug(
1025
+ f"Running Tracker propagation for objects {obj_ids} and merging it with existing VG predictions (reverse={reverse})."
1026
+ if propagation_type == "propagation_partial"
1027
+ else f"Fetching existing VG predictions without running any propagation (reverse={reverse})."
1028
+ )
1029
+ processing_order, _ = self._get_processing_order(
1030
+ inference_state,
1031
+ start_frame_idx=start_frame_idx,
1032
+ max_frame_num_to_track=max_frame_num_to_track,
1033
+ reverse=reverse,
1034
+ )
1035
+
1036
+ tracker_metadata = inference_state["tracker_metadata"]
1037
+
1038
+ # if fetch just return from output
1039
+ if propagation_type == "propagation_fetch":
1040
+ for frame_idx in tqdm(processing_order):
1041
+ if self.rank == 0:
1042
+ obj_id_to_mask = inference_state["cached_frame_outputs"].get(
1043
+ frame_idx, {}
1044
+ )
1045
+ # post processing - remove suppressed obj_ids
1046
+ obj_id_to_score = tracker_metadata["obj_id_to_score"]
1047
+ suppressed_obj_ids = tracker_metadata["rank0_metadata"][
1048
+ "suppressed_obj_ids"
1049
+ ][frame_idx]
1050
+ obj_id_to_tracker_score = tracker_metadata[
1051
+ "obj_id_to_tracker_score_frame_wise"
1052
+ ][frame_idx]
1053
+
1054
+ out = {
1055
+ "obj_id_to_mask": obj_id_to_mask,
1056
+ "obj_id_to_score": obj_id_to_score,
1057
+ "obj_id_to_tracker_score": obj_id_to_tracker_score,
1058
+ }
1059
+ yield (
1060
+ frame_idx,
1061
+ self._postprocess_output(
1062
+ inference_state, out, suppressed_obj_ids=suppressed_obj_ids
1063
+ ),
1064
+ )
1065
+ else:
1066
+ yield frame_idx, None
1067
+
1068
+ return
1069
+
1070
+ # get Tracker inference states containing selected obj_ids
1071
+ if propagation_type == "propagation_partial":
1072
+ # can be empty for GPUs where objects are not in their inference states
1073
+ tracker_states_local = self._get_tracker_inference_states_by_obj_ids(
1074
+ inference_state, obj_ids
1075
+ )
1076
+ for tracker_state in tracker_states_local:
1077
+ self.tracker.propagate_in_video_preflight(
1078
+ tracker_state, run_mem_encoder=True
1079
+ )
1080
+
1081
+ for frame_idx in tqdm(processing_order):
1082
+ # run Tracker propagation
1083
+ if propagation_type == "propagation_partial":
1084
+ self._prepare_backbone_feats(inference_state, frame_idx, reverse)
1085
+ obj_ids_local, low_res_masks_local, tracker_scores_local = (
1086
+ self._propogate_tracker_one_frame_local_gpu(
1087
+ tracker_states_local,
1088
+ frame_idx=frame_idx,
1089
+ reverse=reverse,
1090
+ run_mem_encoder=True,
1091
+ )
1092
+ )
1093
+
1094
+ # broadcast refined object tracker scores and masks to all GPUs
1095
+ # handle multiple objects that can be located on different GPUs
1096
+ refined_obj_data = {} # obj_id -> (score, mask_video_res)
1097
+
1098
+ # Collect data for objects on this GPU
1099
+ local_obj_data = {}
1100
+ for obj_id in obj_ids:
1101
+ obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
1102
+ if self.rank == obj_rank and obj_id in obj_ids_local:
1103
+ refined_obj_idx = obj_ids_local.index(obj_id)
1104
+ refined_mask_low_res = low_res_masks_local[
1105
+ refined_obj_idx
1106
+ ] # (H_low_res, W_low_res)
1107
+ refined_score = tracker_scores_local[refined_obj_idx]
1108
+
1109
+ # Keep low resolution for broadcasting to reduce communication cost
1110
+ local_obj_data[obj_id] = (refined_score, refined_mask_low_res)
1111
+
1112
+ # Broadcast data from each GPU that has refined objects
1113
+ if self.world_size > 1:
1114
+ for obj_id in obj_ids:
1115
+ obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
1116
+ if self.rank == obj_rank:
1117
+ # This GPU has the object, broadcast its data
1118
+ data_to_broadcast = local_obj_data.get(obj_id, None)
1119
+ data_list = [
1120
+ (data_to_broadcast[0].cpu(), data_to_broadcast[1].cpu())
1121
+ ]
1122
+ self.broadcast_python_obj_cpu(data_list, src=obj_rank)
1123
+ if data_to_broadcast is not None:
1124
+ refined_obj_data[obj_id] = data_to_broadcast
1125
+ elif self.rank != obj_rank:
1126
+ # This GPU doesn't have the object, receive data
1127
+ data_list = [None]
1128
+ self.broadcast_python_obj_cpu(data_list, src=obj_rank)
1129
+ refined_obj_data[obj_id] = (
1130
+ data_list[0][0].to(self.device),
1131
+ data_list[0][1].to(self.device),
1132
+ )
1133
+ else:
1134
+ # Single GPU case
1135
+ refined_obj_data = local_obj_data
1136
+
1137
+ # Update Tracker scores for all refined objects
1138
+ for obj_id, (refined_score, _) in refined_obj_data.items():
1139
+ tracker_metadata["obj_id_to_tracker_score_frame_wise"][
1140
+ frame_idx
1141
+ ].update({obj_id: refined_score.item()})
1142
+
1143
+ if self.rank == 0:
1144
+ # get predictions from Tracker inference states, it includes the original
1145
+ # VG predictions and the refined predictions from interactivity.
1146
+
1147
+ # Prepare refined masks dictionary - upscale to video resolution after broadcast
1148
+ refined_obj_id_to_mask = {}
1149
+ for obj_id, (_, refined_mask_low_res) in refined_obj_data.items():
1150
+ refined_mask_video_res = (
1151
+ self._convert_low_res_mask_to_video_res(
1152
+ refined_mask_low_res, inference_state
1153
+ )
1154
+ ) # (1, H_video, W_video) bool
1155
+ refined_obj_id_to_mask[obj_id] = refined_mask_video_res
1156
+
1157
+ obj_id_to_mask = self._build_tracker_output(
1158
+ inference_state, frame_idx, refined_obj_id_to_mask
1159
+ )
1160
+ out = {
1161
+ "obj_id_to_mask": obj_id_to_mask,
1162
+ "obj_id_to_score": tracker_metadata["obj_id_to_score"],
1163
+ "obj_id_to_tracker_score": tracker_metadata[
1164
+ "obj_id_to_tracker_score_frame_wise"
1165
+ ][frame_idx],
1166
+ }
1167
+ suppressed_obj_ids = tracker_metadata["rank0_metadata"][
1168
+ "suppressed_obj_ids"
1169
+ ][frame_idx]
1170
+ self._cache_frame_outputs(
1171
+ inference_state,
1172
+ frame_idx,
1173
+ obj_id_to_mask,
1174
+ suppressed_obj_ids=suppressed_obj_ids,
1175
+ )
1176
+ suppressed_obj_ids = tracker_metadata["rank0_metadata"][
1177
+ "suppressed_obj_ids"
1178
+ ][frame_idx]
1179
+ yield (
1180
+ frame_idx,
1181
+ self._postprocess_output(
1182
+ inference_state, out, suppressed_obj_ids=suppressed_obj_ids
1183
+ ),
1184
+ )
1185
+ else:
1186
+ yield frame_idx, None
1187
+
1188
+ def add_action_history(
1189
+ self, inference_state, action_type, frame_idx=None, obj_ids=None
1190
+ ):
1191
+ """
1192
+ action_history is used to automatically decide what to do during propagation.
1193
+ action_type: one of ["add", "remove", "refine"] + ["propagation_full", "propagation_partial", "propagation_fetch"]
1194
+ """
1195
+ instance_actions = ["add", "remove", "refine"]
1196
+ propagation_actions = [
1197
+ "propagation_full",
1198
+ "propagation_partial",
1199
+ "propagation_fetch",
1200
+ ]
1201
+ assert (
1202
+ action_type in instance_actions + propagation_actions
1203
+ ), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
1204
+ action = {
1205
+ "type": action_type,
1206
+ "frame_idx": frame_idx,
1207
+ "obj_ids": obj_ids,
1208
+ }
1209
+ inference_state["action_history"].append(action)
1210
+
1211
+ def _has_object_been_refined(self, inference_state, obj_id):
1212
+ action_history = inference_state["action_history"]
1213
+ for action in action_history:
1214
+ if action["type"] in ["add", "refine"] and action.get("obj_ids"):
1215
+ if obj_id in action["obj_ids"]:
1216
+ return True
1217
+ return False
1218
+
1219
+ def parse_action_history_for_propagation(self, inference_state):
1220
+ """
1221
+ Parse the actions in history before the last propagation and prepare for the next propagation.
1222
+ We support multiple actions (add/remove/refine) between two propagations. If we had an action
1223
+ history similar to this ["propagate", "add", "refine", "remove", "add"], the next propagation
1224
+ would remove the removed object, and also propagate the two added/refined objects.
1225
+
1226
+ Returns:
1227
+ propagation_type: one of ["propagation_full", "propagation_partial", "propagation_fetch"]
1228
+ - "propagation_full": run VG propagation for all objects
1229
+ - "propagation_partial": run Tracker propagation for selected objects, useful for add/refine actions
1230
+ - "propagation_fetch": fetch existing VG predictions without running any propagation
1231
+ obj_ids: list of object ids to run Tracker propagation on if propagation_type is "propagation_partial".
1232
+ """
1233
+ action_history = inference_state["action_history"]
1234
+ if len(action_history) == 0:
1235
+ # we run propagation for the first time
1236
+ return "propagation_full", None
1237
+
1238
+ if "propagation" in action_history[-1]["type"]:
1239
+ if action_history[-1]["type"] in ["propagation_fetch"]:
1240
+ # last propagation is direct fetch, we fetch existing predictions
1241
+ return "propagation_fetch", None
1242
+ elif action_history[-1]["type"] in [
1243
+ "propagation_partial",
1244
+ "propagation_full",
1245
+ ]:
1246
+ # we do fetch prediction if we have already run propagation twice or we have run
1247
+ # propagation once and it is from the first frame or last frame.
1248
+ if (
1249
+ len(action_history) > 1
1250
+ and action_history[-2]["type"]
1251
+ in ["propagation_partial", "propagation_full"]
1252
+ ) or action_history[-1]["frame_idx"] in [
1253
+ 0,
1254
+ inference_state["num_frames"] - 1,
1255
+ ]:
1256
+ # we have run both forward and backward partial/full propagation
1257
+ return "propagation_fetch", None
1258
+ else:
1259
+ # we have run partial/full forward or backward propagation once, need run it for the rest of the frames
1260
+ return action_history[-1]["type"], action_history[-1]["obj_ids"]
1261
+
1262
+ # parse actions since last propagation
1263
+ obj_ids = []
1264
+ for action in action_history[::-1]:
1265
+ if "propagation" in action["type"]:
1266
+ # we reached the last propagation action, stop parsing
1267
+ break
1268
+ if action["type"] in ["add", "refine"]:
1269
+ obj_ids.extend(action["obj_ids"])
1270
+ # else action["type"] == "remove": noop
1271
+ obj_ids = list(set(obj_ids)) if len(obj_ids) > 0 else None
1272
+ propagation_type = (
1273
+ "propagation_partial" if obj_ids is not None else "propagation_fetch"
1274
+ )
1275
+ return propagation_type, obj_ids
1276
+
1277
+ def remove_object(self, inference_state, obj_id, is_user_action=False):
1278
+ """
1279
+ We try to remove object from tracker states on every GPU, it will do nothing
1280
+ for states without this object.
1281
+ """
1282
+ obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
1283
+ assert obj_rank is not None, f"Object {obj_id} not found in any GPU."
1284
+
1285
+ tracker_states_local = inference_state["tracker_inference_states"]
1286
+ if self.rank == obj_rank:
1287
+ self._tracker_remove_object(tracker_states_local, obj_id)
1288
+
1289
+ if is_user_action:
1290
+ self.add_action_history(
1291
+ inference_state, action_type="remove", obj_ids=[obj_id]
1292
+ )
1293
+
1294
+ # update metadata
1295
+ tracker_metadata = inference_state["tracker_metadata"]
1296
+ _obj_ids = tracker_metadata["obj_ids_per_gpu"][obj_rank]
1297
+ tracker_metadata["obj_ids_per_gpu"][obj_rank] = _obj_ids[_obj_ids != obj_id]
1298
+ tracker_metadata["num_obj_per_gpu"][obj_rank] = len(
1299
+ tracker_metadata["obj_ids_per_gpu"][obj_rank]
1300
+ )
1301
+ tracker_metadata["obj_ids_all_gpu"] = np.concatenate(
1302
+ tracker_metadata["obj_ids_per_gpu"]
1303
+ )
1304
+ tracker_metadata["obj_id_to_score"].pop(obj_id, None)
1305
+ # tracker_metadata["max_obj_id"] # we do not reuse the object id, so we do not update it here
1306
+
1307
+ # Clean up cached frame outputs to remove references to the deleted object
1308
+ if "cached_frame_outputs" in inference_state:
1309
+ for frame_idx in inference_state["cached_frame_outputs"]:
1310
+ frame_cache = inference_state["cached_frame_outputs"][frame_idx]
1311
+ if obj_id in frame_cache:
1312
+ del frame_cache[obj_id]
1313
+
1314
+ def _get_gpu_id_by_obj_id(self, inference_state, obj_id):
1315
+ """
1316
+ Locate GPU ID for a given object.
1317
+ """
1318
+ obj_ids_per_gpu = inference_state["tracker_metadata"]["obj_ids_per_gpu"]
1319
+ for rank, obj_ids in enumerate(obj_ids_per_gpu):
1320
+ if obj_id in obj_ids:
1321
+ return rank
1322
+ return None # object not found in any GPU
1323
+
1324
+ def _get_tracker_inference_states_by_obj_ids(self, inference_state, obj_ids):
1325
+ """
1326
+ Get the Tracker inference states that contain the given object ids.
1327
+ This is used to run partial Tracker propagation on a single object/bucket.
1328
+ Possibly multiple or zero states can be returned.
1329
+ """
1330
+ states = [
1331
+ state
1332
+ for state in inference_state["tracker_inference_states"]
1333
+ if set(obj_ids) & set(state["obj_ids"])
1334
+ ]
1335
+ return states
1336
+
1337
+ def _prepare_backbone_feats(self, inference_state, frame_idx, reverse):
1338
+ input_batch = inference_state["input_batch"]
1339
+ feature_cache = inference_state["feature_cache"]
1340
+ num_frames = inference_state["num_frames"]
1341
+ geometric_prompt = (
1342
+ inference_state["constants"]["empty_geometric_prompt"]
1343
+ if inference_state["per_frame_geometric_prompt"][frame_idx] is None
1344
+ else inference_state["per_frame_geometric_prompt"][frame_idx]
1345
+ )
1346
+ _ = self.run_backbone_and_detection(
1347
+ frame_idx=frame_idx,
1348
+ num_frames=num_frames,
1349
+ input_batch=input_batch,
1350
+ geometric_prompt=geometric_prompt,
1351
+ feature_cache=feature_cache,
1352
+ reverse=reverse,
1353
+ allow_new_detections=True,
1354
+ )
1355
+
1356
+ @torch.inference_mode()
1357
+ def add_prompt(
1358
+ self,
1359
+ inference_state,
1360
+ frame_idx,
1361
+ text_str=None,
1362
+ boxes_xywh=None,
1363
+ box_labels=None,
1364
+ points=None,
1365
+ point_labels=None,
1366
+ obj_id=None,
1367
+ rel_coordinates=True,
1368
+ ):
1369
+ if points is not None:
1370
+ # Tracker instance prompts
1371
+ assert (
1372
+ text_str is None and boxes_xywh is None
1373
+ ), "When points are provided, text_str and boxes_xywh must be None."
1374
+ assert (
1375
+ obj_id is not None
1376
+ ), "When points are provided, obj_id must be provided."
1377
+ return self.add_tracker_new_points(
1378
+ inference_state,
1379
+ frame_idx,
1380
+ obj_id=obj_id,
1381
+ points=points,
1382
+ labels=point_labels,
1383
+ rel_coordinates=rel_coordinates,
1384
+ use_prev_mem_frame=self.use_prev_mem_frame,
1385
+ )
1386
+ else:
1387
+ # SAM3 prompts
1388
+ return super().add_prompt(
1389
+ inference_state,
1390
+ frame_idx,
1391
+ text_str=text_str,
1392
+ boxes_xywh=boxes_xywh,
1393
+ box_labels=box_labels,
1394
+ )
1395
+
1396
+ @torch.inference_mode()
1397
+ def add_tracker_new_points(
1398
+ self,
1399
+ inference_state,
1400
+ frame_idx,
1401
+ obj_id,
1402
+ points,
1403
+ labels,
1404
+ rel_coordinates=True,
1405
+ use_prev_mem_frame=False,
1406
+ ):
1407
+ """Add a new point prompt to Tracker. Suppporting instance refinement to existing
1408
+ objects by passing existing obj_id or adding a new object by passing a new obj_id.
1409
+ use_prev_mem_frame=False to disable cross attention to previous memory frames.
1410
+ Every GPU returns the same results, and results should contain all masks including
1411
+ these masks not refined or not added by the current user points.
1412
+ """
1413
+ assert obj_id is not None, "obj_id must be provided to add new points"
1414
+ tracker_metadata = inference_state["tracker_metadata"]
1415
+ if tracker_metadata == {}:
1416
+ # initialize masklet metadata if it's uninitialized (empty dict)
1417
+ tracker_metadata.update(self._initialize_metadata())
1418
+
1419
+ obj_rank = self._get_gpu_id_by_obj_id(inference_state, obj_id)
1420
+
1421
+ # prepare feature
1422
+ self._prepare_backbone_feats(inference_state, frame_idx, reverse=False)
1423
+
1424
+ object_has_been_refined = self._has_object_been_refined(inference_state, obj_id)
1425
+ if (
1426
+ obj_rank is not None
1427
+ and self.use_stateless_refinement
1428
+ and not object_has_been_refined
1429
+ ):
1430
+ # The first time we start refinement on the object, we remove it.
1431
+ logger.debug(
1432
+ f"[rank={self.rank}] Removing object {obj_id} before refinement."
1433
+ )
1434
+ self.remove_object(inference_state, obj_id, is_user_action=False)
1435
+ obj_rank = None
1436
+
1437
+ if obj_rank is None:
1438
+ # new object, we assign it a GPU and create a new inference state if limit allows
1439
+ num_prev_obj = np.sum(tracker_metadata["num_obj_per_gpu"])
1440
+ if num_prev_obj >= self.max_num_objects:
1441
+ logger.warning(
1442
+ f"add_tracker_new_points: cannot add a new object as we are already tracking {num_prev_obj=} "
1443
+ f"masklets (under {self.max_num_objects=})"
1444
+ )
1445
+ obj_ids = []
1446
+ H_low_res = W_low_res = self.tracker.low_res_mask_size
1447
+ H_video_res = inference_state["orig_height"]
1448
+ W_video_res = inference_state["orig_width"]
1449
+ low_res_masks = torch.zeros(0, 1, H_low_res, W_low_res)
1450
+ video_res_masks = torch.zeros(0, 1, H_video_res, W_video_res)
1451
+ return frame_idx, obj_ids, low_res_masks, video_res_masks
1452
+
1453
+ new_det_gpu_ids = self._assign_new_det_to_gpus(
1454
+ new_det_num=1,
1455
+ prev_workload_per_gpu=tracker_metadata["num_obj_per_gpu"],
1456
+ )
1457
+ obj_rank = new_det_gpu_ids[0]
1458
+
1459
+ # get tracker inference state for the new object
1460
+ if self.rank == obj_rank:
1461
+ # for batched inference, we create a new inference state
1462
+ tracker_state = self._init_new_tracker_state(inference_state)
1463
+ inference_state["tracker_inference_states"].append(tracker_state)
1464
+
1465
+ # update metadata
1466
+ tracker_metadata["obj_ids_per_gpu"][obj_rank] = np.concatenate(
1467
+ [
1468
+ tracker_metadata["obj_ids_per_gpu"][obj_rank],
1469
+ np.array([obj_id], dtype=np.int64),
1470
+ ]
1471
+ )
1472
+ tracker_metadata["num_obj_per_gpu"][obj_rank] = len(
1473
+ tracker_metadata["obj_ids_per_gpu"][obj_rank]
1474
+ )
1475
+ tracker_metadata["obj_ids_all_gpu"] = np.concatenate(
1476
+ tracker_metadata["obj_ids_per_gpu"]
1477
+ )
1478
+ tracker_metadata["max_obj_id"] = max(tracker_metadata["max_obj_id"], obj_id)
1479
+
1480
+ logger.debug(
1481
+ f"[rank={self.rank}] Adding new object with id {obj_id} at frame {frame_idx}."
1482
+ )
1483
+ self.add_action_history(
1484
+ inference_state, "add", frame_idx=frame_idx, obj_ids=[obj_id]
1485
+ )
1486
+ else:
1487
+ # existing object, for refinement
1488
+ if self.rank == obj_rank:
1489
+ tracker_states = self._get_tracker_inference_states_by_obj_ids(
1490
+ inference_state, [obj_id]
1491
+ )
1492
+ assert (
1493
+ len(tracker_states) == 1
1494
+ ), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
1495
+ tracker_state = tracker_states[0]
1496
+
1497
+ # log
1498
+ logger.debug(
1499
+ f"[rank={self.rank}] Refining existing object with id {obj_id} at frame {frame_idx}."
1500
+ )
1501
+ self.add_action_history(
1502
+ inference_state, "refine", frame_idx=frame_idx, obj_ids=[obj_id]
1503
+ )
1504
+
1505
+ # assign higher score to added/refined object
1506
+ tracker_metadata["obj_id_to_score"][obj_id] = 1.0
1507
+ tracker_metadata["obj_id_to_tracker_score_frame_wise"][frame_idx][obj_id] = 1.0
1508
+
1509
+ if self.rank == 0:
1510
+ rank0_metadata = tracker_metadata.get("rank0_metadata", {})
1511
+
1512
+ if "removed_obj_ids" in rank0_metadata:
1513
+ rank0_metadata["removed_obj_ids"].discard(obj_id)
1514
+
1515
+ if "suppressed_obj_ids" in rank0_metadata:
1516
+ for frame_id in rank0_metadata["suppressed_obj_ids"]:
1517
+ rank0_metadata["suppressed_obj_ids"][frame_id].discard(obj_id)
1518
+
1519
+ if "masklet_confirmation" in rank0_metadata:
1520
+ obj_ids_all_gpu = tracker_metadata["obj_ids_all_gpu"]
1521
+ obj_indices = np.where(obj_ids_all_gpu == obj_id)[0]
1522
+ if len(obj_indices) > 0:
1523
+ obj_idx = obj_indices[0]
1524
+ if obj_idx < len(rank0_metadata["masklet_confirmation"]["status"]):
1525
+ rank0_metadata["masklet_confirmation"]["status"][obj_idx] = 1
1526
+ rank0_metadata["masklet_confirmation"]["consecutive_det_num"][
1527
+ obj_idx
1528
+ ] = self.masklet_confirmation_consecutive_det_thresh
1529
+
1530
+ if self.rank == obj_rank:
1531
+ frame_idx, obj_ids, low_res_masks, video_res_masks = (
1532
+ self.tracker.add_new_points(
1533
+ inference_state=tracker_state,
1534
+ frame_idx=frame_idx,
1535
+ obj_id=obj_id,
1536
+ points=points,
1537
+ labels=labels,
1538
+ clear_old_points=True,
1539
+ rel_coordinates=rel_coordinates,
1540
+ use_prev_mem_frame=use_prev_mem_frame,
1541
+ )
1542
+ )
1543
+
1544
+ if video_res_masks is not None and len(video_res_masks) > 0:
1545
+ video_res_masks = fill_holes_in_mask_scores(
1546
+ video_res_masks, # shape (N, 1, H_video, W_video)
1547
+ max_area=self.fill_hole_area,
1548
+ fill_holes=True,
1549
+ remove_sprinkles=True,
1550
+ )
1551
+
1552
+ # Since the mem encoder has already run for the current input points?
1553
+ self.tracker.propagate_in_video_preflight(
1554
+ tracker_state, run_mem_encoder=True
1555
+ )
1556
+ # Clear detector conditioning frames when user clicks are received to allow
1557
+ # model updating masks on these frames. It is a noop if user is refining on the
1558
+ # detector conditioning frames or adding new objects.
1559
+ self.clear_detector_added_cond_frame_in_tracker(
1560
+ tracker_state, obj_id, frame_idx
1561
+ )
1562
+
1563
+ # fetch results from states and gather across GPUs
1564
+ # Use optimized caching approach to avoid reprocessing unmodified objects
1565
+ if self.rank == obj_rank and len(obj_ids) > 0:
1566
+ new_mask_data = (video_res_masks[obj_ids.index(obj_id)] > 0.0).to(
1567
+ torch.bool
1568
+ )
1569
+ else:
1570
+ new_mask_data = None
1571
+ # Broadcast the new mask data across all ranks for consistency
1572
+ if self.world_size > 1:
1573
+ data_list = [new_mask_data.cpu() if new_mask_data is not None else None]
1574
+ self.broadcast_python_obj_cpu(data_list, src=obj_rank)
1575
+ new_mask_data = data_list[0].to(self.device)
1576
+
1577
+ if self.rank == 0:
1578
+ obj_id_to_mask = self._build_tracker_output(
1579
+ inference_state,
1580
+ frame_idx,
1581
+ {obj_id: new_mask_data} if new_mask_data is not None else None,
1582
+ )
1583
+ # post processing - remove suppressed obj_ids
1584
+ obj_id_to_score = tracker_metadata["obj_id_to_score"]
1585
+ suppressed_obj_ids = tracker_metadata["rank0_metadata"][
1586
+ "suppressed_obj_ids"
1587
+ ][frame_idx]
1588
+ obj_id_to_tracker_score = tracker_metadata[
1589
+ "obj_id_to_tracker_score_frame_wise"
1590
+ ][frame_idx]
1591
+
1592
+ out = {
1593
+ "obj_id_to_mask": obj_id_to_mask,
1594
+ "obj_id_to_score": obj_id_to_score,
1595
+ "obj_id_to_tracker_score": obj_id_to_tracker_score,
1596
+ }
1597
+ self._cache_frame_outputs(
1598
+ inference_state,
1599
+ frame_idx,
1600
+ obj_id_to_mask,
1601
+ suppressed_obj_ids=suppressed_obj_ids,
1602
+ )
1603
+ return frame_idx, self._postprocess_output(
1604
+ inference_state, out, suppressed_obj_ids=suppressed_obj_ids
1605
+ )
1606
+ else:
1607
+ return frame_idx, None # no output on other GPUs
1608
+
1609
+ def _gather_obj_id_to_mask_across_gpus(self, inference_state, obj_id_to_mask_local):
1610
+ """Gather obj_id_to_mask from all GPUs. Optionally resize the masks to the video resolution."""
1611
+ tracker_metadata = inference_state["tracker_metadata"]
1612
+
1613
+ # concatenate the output masklets from all local inference states
1614
+ H_mask = W_mask = self.tracker.low_res_mask_size
1615
+ obj_ids_local = tracker_metadata["obj_ids_per_gpu"][self.rank]
1616
+ low_res_masks_local = []
1617
+ for obj_id in obj_ids_local:
1618
+ if obj_id in obj_id_to_mask_local:
1619
+ low_res_masks_local.append(obj_id_to_mask_local[obj_id])
1620
+ else:
1621
+ low_res_masks_local.append(
1622
+ torch.full((H_mask, W_mask), -1024.0, device=self.device)
1623
+ )
1624
+ if len(low_res_masks_local) > 0:
1625
+ low_res_masks_local = torch.stack(low_res_masks_local, dim=0) # (N, H, W)
1626
+ assert low_res_masks_local.shape[1:] == (H_mask, W_mask)
1627
+ else:
1628
+ low_res_masks_local = torch.zeros(0, H_mask, W_mask, device=self.device)
1629
+
1630
+ # all-gather `low_res_masks_local` into `low_res_masks_global`
1631
+ # - low_res_masks_global: Tensor -- (num_global_obj, H_mask, W_mask)
1632
+ if self.world_size > 1:
1633
+ low_res_masks_local = low_res_masks_local.float().contiguous()
1634
+ low_res_masks_peers = [
1635
+ low_res_masks_local.new_empty(num_obj, H_mask, W_mask)
1636
+ for num_obj in tracker_metadata["num_obj_per_gpu"]
1637
+ ]
1638
+ dist.all_gather(low_res_masks_peers, low_res_masks_local)
1639
+ low_res_masks_global = torch.cat(low_res_masks_peers, dim=0)
1640
+ else:
1641
+ low_res_masks_global = low_res_masks_local
1642
+ return low_res_masks_global
1643
+
1644
+ def _convert_low_res_mask_to_video_res(self, low_res_mask, inference_state):
1645
+ """
1646
+ Convert a low-res mask to video resolution, matching the format expected by _build_tracker_output.
1647
+
1648
+ Args:
1649
+ low_res_mask: Tensor of shape (H_low_res, W_low_res)
1650
+ inference_state: Contains video dimensions
1651
+
1652
+ Returns:
1653
+ video_res_mask: Tensor of shape (1, H_video, W_video) bool
1654
+ """
1655
+ if low_res_mask is None:
1656
+ return None
1657
+
1658
+ # Convert to 3D for interpolation: (H_low_res, W_low_res) -> (1, H_low_res, W_low_res)
1659
+ low_res_mask_3d = low_res_mask.unsqueeze(0).unsqueeze(0)
1660
+
1661
+ # Get video dimensions
1662
+ H_video = inference_state["orig_height"]
1663
+ W_video = inference_state["orig_width"]
1664
+
1665
+ video_res_mask = F.interpolate(
1666
+ low_res_mask_3d.float(),
1667
+ size=(H_video, W_video),
1668
+ mode="bilinear",
1669
+ align_corners=False,
1670
+ ) # (1, H_video, W_video)
1671
+
1672
+ # Convert to boolean - already in the right shape!
1673
+ return (video_res_mask.squeeze(0) > 0.0).to(torch.bool)
1674
+
1675
+ def clear_detector_added_cond_frame_in_tracker(
1676
+ self, tracker_state, obj_id, refined_frame_idx
1677
+ ):
1678
+ """Clear detector added conditioning frame if it is within a predefined window
1679
+ of the refined frame. This allow model to update masks on these frames."""
1680
+ obj_idx = self.tracker._obj_id_to_idx(tracker_state, obj_id)
1681
+
1682
+ mask_only_cond_frame_indices = []
1683
+ window = self.refinement_detector_cond_frame_removal_window
1684
+ for frame_idx in tracker_state["mask_inputs_per_obj"][obj_idx]:
1685
+ if frame_idx not in tracker_state["point_inputs_per_obj"][obj_idx]:
1686
+ # clear conditioning frames within a window of the refined frame
1687
+ if abs(frame_idx - refined_frame_idx) <= window:
1688
+ mask_only_cond_frame_indices.append(frame_idx)
1689
+
1690
+ # clear
1691
+ if len(mask_only_cond_frame_indices) > 0:
1692
+ for frame_idx in mask_only_cond_frame_indices:
1693
+ # obj_ids_on_this_frame is essentially all obj_ids in the state
1694
+ # since they are bucket batched
1695
+ obj_ids_on_this_frame = tracker_state["obj_id_to_idx"].keys()
1696
+ for obj_id2 in obj_ids_on_this_frame:
1697
+ self.tracker.clear_all_points_in_frame(
1698
+ tracker_state, frame_idx, obj_id2, need_output=False
1699
+ )
1700
+ logger.debug(
1701
+ f"Cleared detector mask only conditioning frames ({mask_only_cond_frame_indices}) in Tracker."
1702
+ )
1703
+ return
1704
+
1705
+
1706
+ def is_image_type(resource_path: str) -> bool:
1707
+ if isinstance(resource_path, list):
1708
+ return len(resource_path) == 1
1709
+ return resource_path.lower().endswith(tuple(IMAGE_EXTS))
detect_tools/sam3/sam3/model/sam3_video_predictor.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import datetime
4
+ import gc
5
+ import multiprocessing as mp
6
+ import os
7
+ import queue
8
+ import socket
9
+ import sys
10
+ import time
11
+ import uuid
12
+ from contextlib import closing
13
+ from typing import List, Optional
14
+
15
+ import psutil
16
+ import torch
17
+
18
+ from sam3.logger import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class Sam3VideoPredictor:
24
+ # a global dictionary that holds all inference states for this model (key is session_id)
25
+ _ALL_INFERENCE_STATES = {}
26
+
27
+ def __init__(
28
+ self,
29
+ checkpoint_path=None,
30
+ bpe_path=None,
31
+ has_presence_token=True,
32
+ geo_encoder_use_img_cross_attn=True,
33
+ strict_state_dict_loading=True,
34
+ async_loading_frames=False,
35
+ video_loader_type="cv2",
36
+ apply_temporal_disambiguation: bool = True,
37
+ ):
38
+ self.async_loading_frames = async_loading_frames
39
+ self.video_loader_type = video_loader_type
40
+ from sam3.model_builder import build_sam3_video_model
41
+
42
+ self.model = (
43
+ build_sam3_video_model(
44
+ checkpoint_path=checkpoint_path,
45
+ bpe_path=bpe_path,
46
+ has_presence_token=has_presence_token,
47
+ geo_encoder_use_img_cross_attn=geo_encoder_use_img_cross_attn,
48
+ strict_state_dict_loading=strict_state_dict_loading,
49
+ apply_temporal_disambiguation=apply_temporal_disambiguation,
50
+ )
51
+ .cuda()
52
+ .eval()
53
+ )
54
+
55
+ @torch.inference_mode()
56
+ def handle_request(self, request):
57
+ """Dispatch a request based on its type."""
58
+ request_type = request["type"]
59
+ if request_type == "start_session":
60
+ return self.start_session(
61
+ resource_path=request["resource_path"],
62
+ session_id=request.get("session_id", None),
63
+ )
64
+ elif request_type == "add_prompt":
65
+ return self.add_prompt(
66
+ session_id=request["session_id"],
67
+ frame_idx=request["frame_index"],
68
+ text=request.get("text", None),
69
+ points=request.get("points", None),
70
+ point_labels=request.get("point_labels", None),
71
+ bounding_boxes=request.get("bounding_boxes", None),
72
+ bounding_box_labels=request.get("bounding_box_labels", None),
73
+ obj_id=request.get("obj_id", None),
74
+ )
75
+ elif request_type == "remove_object":
76
+ return self.remove_object(
77
+ session_id=request["session_id"],
78
+ obj_id=request["obj_id"],
79
+ is_user_action=request.get("is_user_action", True),
80
+ )
81
+ elif request_type == "reset_session":
82
+ return self.reset_session(session_id=request["session_id"])
83
+ elif request_type == "close_session":
84
+ return self.close_session(session_id=request["session_id"])
85
+ else:
86
+ raise RuntimeError(f"invalid request type: {request_type}")
87
+
88
+ @torch.inference_mode()
89
+ def handle_stream_request(self, request):
90
+ """Dispatch a stream request based on its type."""
91
+ request_type = request["type"]
92
+ if request_type == "propagate_in_video":
93
+ yield from self.propagate_in_video(
94
+ session_id=request["session_id"],
95
+ propagation_direction=request.get("propagation_direction", "both"),
96
+ start_frame_idx=request.get("start_frame_index", None),
97
+ max_frame_num_to_track=request.get("max_frame_num_to_track", None),
98
+ )
99
+ else:
100
+ raise RuntimeError(f"invalid request type: {request_type}")
101
+
102
+ def start_session(self, resource_path, session_id=None):
103
+ """
104
+ Start a new inference session on an image or a video. Here `resource_path`
105
+ can be either a path to an image file (for image inference) or an MP4 file
106
+ or directory with JPEG video frames (for video inference).
107
+
108
+ If `session_id` is defined, it will be used as identifier for the
109
+ session. If it is not defined, the start_session function will create
110
+ a session id and return it.
111
+ """
112
+ # get an initial inference_state from the model
113
+ inference_state = self.model.init_state(
114
+ resource_path=resource_path,
115
+ async_loading_frames=self.async_loading_frames,
116
+ video_loader_type=self.video_loader_type,
117
+ )
118
+ if not session_id:
119
+ session_id = str(uuid.uuid4())
120
+ self._ALL_INFERENCE_STATES[session_id] = {
121
+ "state": inference_state,
122
+ "session_id": session_id,
123
+ "start_time": time.time(),
124
+ }
125
+ logger.debug(
126
+ f"started new session {session_id}; {self._get_session_stats()}; "
127
+ f"{self._get_torch_and_gpu_properties()}"
128
+ )
129
+ return {"session_id": session_id}
130
+
131
+ def add_prompt(
132
+ self,
133
+ session_id: str,
134
+ frame_idx: int,
135
+ text: Optional[str] = None,
136
+ points: Optional[List[List[float]]] = None,
137
+ point_labels: Optional[List[int]] = None,
138
+ bounding_boxes: Optional[List[List[float]]] = None,
139
+ bounding_box_labels: Optional[List[int]] = None,
140
+ obj_id: Optional[int] = None,
141
+ ):
142
+ """Add text, box and/or point prompt on a specific video frame."""
143
+ logger.debug(
144
+ f"add prompt on frame {frame_idx} in session {session_id}: "
145
+ f"{text=}, {points=}, {point_labels=}, "
146
+ f"{bounding_boxes=}, {bounding_box_labels=}"
147
+ )
148
+ session = self._get_session(session_id)
149
+ inference_state = session["state"]
150
+
151
+ frame_idx, outputs = self.model.add_prompt(
152
+ inference_state=inference_state,
153
+ frame_idx=frame_idx,
154
+ text_str=text,
155
+ points=points,
156
+ point_labels=point_labels,
157
+ boxes_xywh=bounding_boxes,
158
+ box_labels=bounding_box_labels,
159
+ obj_id=obj_id,
160
+ )
161
+ return {"frame_index": frame_idx, "outputs": outputs}
162
+
163
+ def remove_object(
164
+ self,
165
+ session_id: str,
166
+ obj_id: int,
167
+ is_user_action: bool = True,
168
+ ):
169
+ """Remove an object from tracking."""
170
+ logger.debug(
171
+ f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
172
+ )
173
+ session = self._get_session(session_id)
174
+ inference_state = session["state"]
175
+
176
+ self.model.remove_object(
177
+ inference_state=inference_state,
178
+ obj_id=obj_id,
179
+ is_user_action=is_user_action,
180
+ )
181
+ return {"is_success": True}
182
+
183
+ def propagate_in_video(
184
+ self,
185
+ session_id,
186
+ propagation_direction,
187
+ start_frame_idx,
188
+ max_frame_num_to_track,
189
+ ):
190
+ """Propagate the added prompts to get grounding results on all video frames."""
191
+ logger.debug(
192
+ f"propagate in video in session {session_id}: "
193
+ f"{propagation_direction=}, {start_frame_idx=}, {max_frame_num_to_track=}"
194
+ )
195
+ try:
196
+ session = self._get_session(session_id)
197
+ inference_state = session["state"]
198
+ if propagation_direction not in ["both", "forward", "backward"]:
199
+ raise ValueError(
200
+ f"invalid propagation direction: {propagation_direction}"
201
+ )
202
+
203
+ # First doing the forward propagation
204
+ if propagation_direction in ["both", "forward"]:
205
+ for frame_idx, outputs in self.model.propagate_in_video(
206
+ inference_state=inference_state,
207
+ start_frame_idx=start_frame_idx,
208
+ max_frame_num_to_track=max_frame_num_to_track,
209
+ reverse=False,
210
+ ):
211
+ yield {"frame_index": frame_idx, "outputs": outputs}
212
+ # Then doing the backward propagation (reverse in time)
213
+ if propagation_direction in ["both", "backward"]:
214
+ for frame_idx, outputs in self.model.propagate_in_video(
215
+ inference_state=inference_state,
216
+ start_frame_idx=start_frame_idx,
217
+ max_frame_num_to_track=max_frame_num_to_track,
218
+ reverse=True,
219
+ ):
220
+ yield {"frame_index": frame_idx, "outputs": outputs}
221
+ finally:
222
+ # Log upon completion (so that e.g. we can see if two propagations happen in parallel).
223
+ # Using `finally` here to log even when the tracking is aborted with GeneratorExit.
224
+ logger.debug(
225
+ f"propagation ended in session {session_id}; {self._get_session_stats()}"
226
+ )
227
+
228
+ def reset_session(self, session_id):
229
+ """Reset the session to its initial state (as when it's initial opened)."""
230
+ logger.debug(f"reset session {session_id}")
231
+ session = self._get_session(session_id)
232
+ inference_state = session["state"]
233
+ self.model.reset_state(inference_state)
234
+ return {"is_success": True}
235
+
236
+ def close_session(self, session_id):
237
+ """
238
+ Close a session. This method is idempotent and can be called multiple
239
+ times on the same "session_id".
240
+ """
241
+ session = self._ALL_INFERENCE_STATES.pop(session_id, None)
242
+ if session is None:
243
+ logger.warning(
244
+ f"cannot close session {session_id} as it does not exist (it might have expired); "
245
+ f"{self._get_session_stats()}"
246
+ )
247
+ else:
248
+ del session
249
+ gc.collect()
250
+ logger.info(f"removed session {session_id}; {self._get_session_stats()}")
251
+ return {"is_success": True}
252
+
253
+ def _get_session(self, session_id):
254
+ session = self._ALL_INFERENCE_STATES.get(session_id, None)
255
+ if session is None:
256
+ raise RuntimeError(
257
+ f"Cannot find session {session_id}; it might have expired"
258
+ )
259
+ return session
260
+
261
+ def _get_session_stats(self):
262
+ """Get a statistics string for live sessions and their GPU usage."""
263
+ # print both the session ids and their video frame numbers
264
+ live_session_strs = [
265
+ f"'{session_id}' ({session['state']['num_frames']} frames)"
266
+ for session_id, session in self._ALL_INFERENCE_STATES.items()
267
+ ]
268
+ session_stats_str = (
269
+ f"live sessions: [{', '.join(live_session_strs)}], GPU memory: "
270
+ f"{torch.cuda.memory_allocated() // 1024**2} MiB used and "
271
+ f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved"
272
+ f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used "
273
+ f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)"
274
+ )
275
+ return session_stats_str
276
+
277
+ def _get_torch_and_gpu_properties(self):
278
+ """Get a string for PyTorch and GPU properties (for logging and debugging)."""
279
+ torch_and_gpu_str = (
280
+ f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, "
281
+ f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}"
282
+ )
283
+ return torch_and_gpu_str
284
+
285
+ def shutdown(self):
286
+ """Shutdown the predictor and clear all sessions."""
287
+ self._ALL_INFERENCE_STATES.clear()
288
+
289
+
290
+ class Sam3VideoPredictorMultiGPU(Sam3VideoPredictor):
291
+ def __init__(self, *model_args, gpus_to_use=None, **model_kwargs):
292
+ if gpus_to_use is None:
293
+ # if not specified, use only the current GPU by default
294
+ gpus_to_use = [torch.cuda.current_device()]
295
+
296
+ IS_MAIN_PROCESS = os.getenv("IS_MAIN_PROCESS", "1") == "1"
297
+ if IS_MAIN_PROCESS:
298
+ gpus_to_use = sorted(set(gpus_to_use))
299
+ logger.info(f"using the following GPU IDs: {gpus_to_use}")
300
+ assert len(gpus_to_use) > 0 and all(isinstance(i, int) for i in gpus_to_use)
301
+ assert all(0 <= i < torch.cuda.device_count() for i in gpus_to_use)
302
+ os.environ["MASTER_ADDR"] = "localhost"
303
+ os.environ["MASTER_PORT"] = f"{self._find_free_port()}"
304
+ os.environ["RANK"] = "0"
305
+ os.environ["WORLD_SIZE"] = f"{len(gpus_to_use)}"
306
+
307
+ self.gpus_to_use = gpus_to_use
308
+ self.rank = int(os.environ["RANK"])
309
+ self.world_size = int(os.environ["WORLD_SIZE"])
310
+ self.rank_str = f"rank={self.rank} with world_size={self.world_size}"
311
+ self.device = torch.device(f"cuda:{self.gpus_to_use[self.rank]}")
312
+ torch.cuda.set_device(self.device)
313
+ self.has_shutdown = False
314
+ if self.rank == 0:
315
+ logger.info("\n\n\n\t*** START loading model on all ranks ***\n\n")
316
+
317
+ logger.info(f"loading model on {self.rank_str} -- this could take a while ...")
318
+ super().__init__(*model_args, **model_kwargs)
319
+ logger.info(f"loading model on {self.rank_str} -- DONE locally")
320
+
321
+ if self.world_size > 1 and self.rank == 0:
322
+ # start the worker processes *after* the model is loaded in the main process
323
+ # so that the main process can run torch.compile and fill the cache first
324
+ self._start_worker_processes(*model_args, **model_kwargs)
325
+ for rank in range(1, self.world_size):
326
+ self.command_queues[rank].put(("start_nccl_process_group", None))
327
+ self._start_nccl_process_group()
328
+
329
+ if self.rank == 0:
330
+ logger.info("\n\n\n\t*** DONE loading model on all ranks ***\n\n")
331
+
332
+ @torch.inference_mode()
333
+ def handle_request(self, request):
334
+ """Dispatch a request based on its type."""
335
+ if self.has_shutdown:
336
+ raise RuntimeError(
337
+ "cannot handle request after the predictor has shutdown; please create a new predictor"
338
+ )
339
+
340
+ # when starting a session, we need to create a session id before dispatching
341
+ # the request to the workers
342
+ if request["type"] == "start_session" and request.get("session_id") is None:
343
+ request["session_id"] = str(uuid.uuid4())
344
+ # dispatch the request to all worker processes
345
+ if self.world_size > 1 and self.rank == 0:
346
+ for rank in range(1, self.world_size):
347
+ self.command_queues[rank].put((request, False))
348
+
349
+ response = super().handle_request(request)
350
+
351
+ if self.world_size > 1:
352
+ torch.distributed.barrier() # wait for all ranks to finish
353
+ return response
354
+
355
+ @torch.inference_mode()
356
+ def handle_stream_request(self, request):
357
+ """Dispatch a stream request based on its type."""
358
+ if self.has_shutdown:
359
+ raise RuntimeError(
360
+ "cannot handle request after the predictor has shutdown; please create a new predictor"
361
+ )
362
+
363
+ # dispatch the request to all worker processes
364
+ if self.world_size > 1 and self.rank == 0:
365
+ for rank in range(1, self.world_size):
366
+ self.command_queues[rank].put((request, True))
367
+
368
+ yield from super().handle_stream_request(request)
369
+
370
+ if self.world_size > 1:
371
+ torch.distributed.barrier() # wait for all ranks to finish
372
+
373
+ def _start_worker_processes(self, *model_args, **model_kwargs):
374
+ """Start worker processes for handling model inference."""
375
+ world_size = self.world_size
376
+ logger.info(f"spawning {world_size - 1} worker processes")
377
+ # Use "spawn" (instead of "fork") for different PyTorch or CUDA context
378
+ mp_ctx = mp.get_context("spawn")
379
+ self.command_queues = {rank: mp_ctx.Queue() for rank in range(1, world_size)}
380
+ self.result_queues = {rank: mp_ctx.Queue() for rank in range(1, world_size)}
381
+ parent_pid = os.getpid()
382
+ for rank in range(1, world_size):
383
+ # set the environment variables for each worker process
384
+ os.environ["IS_MAIN_PROCESS"] = "0" # mark this as a worker process
385
+ os.environ["RANK"] = f"{rank}"
386
+ worker_process = mp_ctx.Process(
387
+ target=Sam3VideoPredictorMultiGPU._worker_process_command_loop,
388
+ args=(
389
+ rank,
390
+ world_size,
391
+ self.command_queues[rank],
392
+ self.result_queues[rank],
393
+ model_args,
394
+ model_kwargs,
395
+ self.gpus_to_use,
396
+ parent_pid,
397
+ ),
398
+ daemon=True,
399
+ )
400
+ worker_process.start()
401
+ # revert the environment variables for the main process
402
+ os.environ["IS_MAIN_PROCESS"] = "1"
403
+ os.environ["RANK"] = "0"
404
+ # wait for all the worker processes to load the model and collect their PIDs
405
+ self.worker_pids = {}
406
+ for rank in range(1, self.world_size):
407
+ # a large timeout to cover potentially long model loading time due to compilation
408
+ _, worker_pid = self.result_queues[rank].get(timeout=7200)
409
+ self.worker_pids[rank] = worker_pid
410
+ logger.info(f"spawned {world_size - 1} worker processes")
411
+
412
+ def _start_nccl_process_group(self):
413
+ rank = int(os.environ["RANK"])
414
+ world_size = int(os.environ["WORLD_SIZE"])
415
+ if world_size == 1:
416
+ return
417
+
418
+ logger.debug(f"starting NCCL process group on {rank=} with {world_size=}")
419
+ assert not torch.distributed.is_initialized()
420
+ # use the "env://" init method with environment variables set in start_worker_processes
421
+ # a short 3-min timeout to quickly detect any synchronization failures
422
+ timeout_sec = int(os.getenv("SAM3_COLLECTIVE_OP_TIMEOUT_SEC", "180"))
423
+ timeout = datetime.timedelta(seconds=timeout_sec)
424
+ torch.distributed.init_process_group(
425
+ backend="nccl",
426
+ init_method="env://",
427
+ timeout=timeout,
428
+ device_id=self.device,
429
+ )
430
+ # warm-up the NCCL process group by running a dummy all-reduce
431
+ tensor = torch.ones(1024, 1024).cuda()
432
+ torch.distributed.all_reduce(tensor)
433
+ logger.debug(f"started NCCL process group on {rank=} with {world_size=}")
434
+
435
+ def _find_free_port(self) -> int:
436
+ """
437
+ Find a free port (a random free port from 1024 to 65535 will be selected)
438
+ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number)
439
+ """
440
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
441
+ s.bind(("", 0))
442
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
443
+ return s.getsockname()[1]
444
+
445
+ @staticmethod
446
+ def _worker_process_command_loop(
447
+ rank,
448
+ world_size,
449
+ command_queue,
450
+ result_queue,
451
+ model_args,
452
+ model_kwargs,
453
+ gpus_to_use,
454
+ parent_pid,
455
+ ):
456
+ """
457
+ The command loop for each worker process. It listens to commands from the main process
458
+ and executes them using the model.
459
+ """
460
+ logger.info(f"starting worker process {rank=} with {world_size=}")
461
+ # verify that the environment variables are set correctly
462
+ assert int(os.environ["IS_MAIN_PROCESS"]) == 0
463
+ assert int(os.environ["RANK"]) == rank
464
+ assert int(os.environ["WORLD_SIZE"]) == world_size
465
+ # load the model in this worker process
466
+ predictor = Sam3VideoPredictorMultiGPU(
467
+ *model_args, gpus_to_use=gpus_to_use, **model_kwargs
468
+ )
469
+ logger.info(f"started worker {rank=} with {world_size=}")
470
+ # return the worker process id to the main process for bookkeeping
471
+ worker_pid = os.getpid()
472
+ result_queue.put(("load_model", worker_pid))
473
+
474
+ # wait for the command to start the NCCL process group
475
+ request_type, _ = command_queue.get(timeout=7200)
476
+ assert request_type == "start_nccl_process_group"
477
+ predictor._start_nccl_process_group()
478
+
479
+ # keep listening to commands from the main process
480
+ while True:
481
+ try:
482
+ request, is_stream_request = command_queue.get(timeout=5.0)
483
+ if request == "shutdown":
484
+ logger.info(f"worker {rank=} shutting down")
485
+ torch.distributed.destroy_process_group()
486
+ result_queue.put(("shutdown", True)) # acknowledge the shutdown
487
+ sys.exit(0)
488
+
489
+ logger.debug(f"worker {rank=} received request {request['type']=}")
490
+ if is_stream_request:
491
+ for _ in predictor.handle_stream_request(request):
492
+ pass # handle stream requests in a generator fashion
493
+ else:
494
+ predictor.handle_request(request)
495
+ except queue.Empty:
496
+ # Usually Python's multiprocessing module will shutdown all the daemon worker
497
+ # processes when the main process exits gracefully. However, the user may kill
498
+ # the main process using SIGKILL and thereby leaving no chance for the main process
499
+ # to clean up its daemon child processes. So here we manually check whether the
500
+ # parent process still exists (every 5 sec as in `command_queue.get` timeout).
501
+ if not psutil.pid_exists(parent_pid):
502
+ logger.info(
503
+ f"stopping worker {rank=} as its parent process has exited"
504
+ )
505
+ sys.exit(1)
506
+ except Exception as e:
507
+ logger.error(f"worker {rank=} exception: {e}", exc_info=True)
508
+
509
+ def shutdown(self):
510
+ """Shutdown all worker processes."""
511
+ if self.rank == 0 and self.world_size > 1:
512
+ logger.info(f"shutting down {self.world_size - 1} worker processes")
513
+ for rank in range(1, self.world_size):
514
+ self.command_queues[rank].put(("shutdown", False))
515
+ torch.distributed.destroy_process_group()
516
+ for rank in range(1, self.world_size):
517
+ self.result_queues[rank].get() # wait for the worker to acknowledge
518
+ logger.info(f"shut down {self.world_size - 1} worker processes")
519
+ self.has_shutdown = True
520
+
521
+ super().shutdown()
detect_tools/sam3/sam3/model/text_encoder_ve.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from collections import OrderedDict
4
+ from typing import Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .model_misc import LayerScale
11
+
12
+
13
+ class ResidualAttentionBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ d_model: int,
17
+ n_head: int,
18
+ mlp_ratio: float = 4.0,
19
+ ls_init_value: Optional[float] = None,
20
+ act_layer: Callable[[], nn.Module] = nn.GELU,
21
+ norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
22
+ ):
23
+ super().__init__()
24
+ # Attention
25
+ self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
26
+
27
+ # LayerNorm, LayerScale
28
+ self.ln_1 = norm_layer(d_model)
29
+ self.ln_2 = norm_layer(d_model)
30
+
31
+ self.ls_1 = (
32
+ LayerScale(d_model, ls_init_value)
33
+ if ls_init_value is not None
34
+ else nn.Identity()
35
+ )
36
+ self.ls_2 = (
37
+ LayerScale(d_model, ls_init_value)
38
+ if ls_init_value is not None
39
+ else nn.Identity()
40
+ )
41
+
42
+ # MLP
43
+ mlp_width = int(d_model * mlp_ratio)
44
+ self.mlp = nn.Sequential(
45
+ OrderedDict(
46
+ [
47
+ ("c_fc", nn.Linear(d_model, mlp_width)),
48
+ ("gelu", act_layer()),
49
+ ("c_proj", nn.Linear(mlp_width, d_model)),
50
+ ]
51
+ )
52
+ )
53
+
54
+ def attention(
55
+ self,
56
+ q_x: torch.Tensor,
57
+ k_x: Optional[torch.Tensor] = None,
58
+ v_x: Optional[torch.Tensor] = None,
59
+ attn_mask: Optional[torch.Tensor] = None,
60
+ ) -> torch.Tensor:
61
+ k_x = k_x if k_x is not None else q_x
62
+ v_x = v_x if v_x is not None else q_x
63
+ if attn_mask is not None:
64
+ # Leave boolean masks as is
65
+ if not attn_mask.dtype == torch.bool:
66
+ attn_mask = attn_mask.to(q_x.dtype)
67
+
68
+ return self.attn(q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask)[0]
69
+
70
+ def forward(
71
+ self,
72
+ q_x: torch.Tensor,
73
+ k_x: Optional[torch.Tensor] = None,
74
+ v_x: Optional[torch.Tensor] = None,
75
+ attn_mask: Optional[torch.Tensor] = None,
76
+ ) -> torch.Tensor:
77
+ k_x = (
78
+ self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
79
+ )
80
+ v_x = (
81
+ self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
82
+ )
83
+ x = q_x + self.ls_1(
84
+ self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
85
+ )
86
+ x = x + self.ls_2(self.mlp(self.ln_2(x)))
87
+ return x
88
+
89
+
90
+ class Transformer(nn.Module):
91
+ def __init__(
92
+ self,
93
+ width: int,
94
+ layers: int,
95
+ heads: int,
96
+ mlp_ratio: float = 4.0,
97
+ ls_init_value: Optional[float] = None,
98
+ act_layer: Callable[[], nn.Module] = nn.GELU,
99
+ norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
100
+ compile_mode: Optional[str] = None,
101
+ use_act_checkpoint: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.width = width
105
+ self.layers = layers
106
+ self.grad_checkpointing = use_act_checkpoint
107
+ self.resblocks = nn.ModuleList(
108
+ [
109
+ ResidualAttentionBlock(
110
+ width,
111
+ heads,
112
+ mlp_ratio,
113
+ ls_init_value=ls_init_value,
114
+ act_layer=act_layer,
115
+ norm_layer=norm_layer,
116
+ )
117
+ for _ in range(layers)
118
+ ]
119
+ )
120
+
121
+ if compile_mode is not None:
122
+ self.forward = torch.compile(
123
+ self.forward, mode=compile_mode, fullgraph=True
124
+ )
125
+ if self.grad_checkpointing:
126
+ torch._dynamo.config.optimize_ddp = False
127
+
128
+ def forward(
129
+ self,
130
+ x: torch.Tensor,
131
+ attn_mask: Optional[torch.Tensor] = None,
132
+ ) -> torch.Tensor:
133
+ for _, r in enumerate(self.resblocks):
134
+ if (
135
+ self.grad_checkpointing
136
+ and not torch.jit.is_scripting()
137
+ and self.training
138
+ ):
139
+ x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)
140
+ else:
141
+ x = r(
142
+ x,
143
+ attn_mask=attn_mask,
144
+ )
145
+ return x
146
+
147
+
148
+ def text_global_pool(
149
+ x: torch.Tensor, text: Optional[torch.Tensor] = None, pool_type: str = "argmax"
150
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
151
+ if pool_type == "first":
152
+ pooled, tokens = x[:, 0], x[:, 1:]
153
+ elif pool_type == "last":
154
+ pooled, tokens = x[:, -1], x[:, :-1]
155
+ elif pool_type == "argmax":
156
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
157
+ assert text is not None
158
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
159
+ else:
160
+ pooled = tokens = x
161
+ return pooled, tokens
162
+
163
+
164
+ class TextTransformer(nn.Module):
165
+ def __init__(
166
+ self,
167
+ context_length: int = 77,
168
+ vocab_size: int = 49408,
169
+ width: int = 512,
170
+ heads: int = 8,
171
+ layers: int = 12,
172
+ mlp_ratio: float = 4.0,
173
+ ls_init_value: Optional[float] = None,
174
+ output_dim: int = 512,
175
+ no_causal_mask: bool = False,
176
+ pool_type: str = "none", # no pooling
177
+ proj_bias: bool = False,
178
+ act_layer: Callable = nn.GELU,
179
+ norm_layer: Callable = nn.LayerNorm,
180
+ output_tokens: bool = False,
181
+ use_ln_post: bool = True,
182
+ compile_mode: Optional[str] = None,
183
+ use_act_checkpoint: bool = False,
184
+ ):
185
+ super().__init__()
186
+ assert pool_type in ("first", "last", "argmax", "none")
187
+ self.output_tokens = output_tokens
188
+ self.num_pos = self.context_length = context_length
189
+ self.vocab_size = vocab_size
190
+ self.width = width
191
+ self.output_dim = output_dim
192
+ self.heads = heads
193
+ self.pool_type = pool_type
194
+
195
+ self.token_embedding = nn.Embedding(self.vocab_size, width)
196
+ self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
197
+ self.transformer = Transformer(
198
+ width=width,
199
+ layers=layers,
200
+ heads=heads,
201
+ mlp_ratio=mlp_ratio,
202
+ ls_init_value=ls_init_value,
203
+ act_layer=act_layer,
204
+ norm_layer=norm_layer,
205
+ compile_mode=compile_mode,
206
+ use_act_checkpoint=use_act_checkpoint,
207
+ )
208
+ self.ln_final = norm_layer(width) if use_ln_post else nn.Identity()
209
+ if no_causal_mask:
210
+ self.attn_mask = None
211
+ else:
212
+ self.register_buffer(
213
+ "attn_mask", self.build_causal_mask(), persistent=False
214
+ )
215
+ if proj_bias:
216
+ self.text_projection = nn.Linear(width, output_dim)
217
+ else:
218
+ self.text_projection = nn.Parameter(torch.empty(width, output_dim))
219
+
220
+ def build_causal_mask(self) -> torch.Tensor:
221
+ # lazily create causal attention mask, with full attention between the tokens
222
+ # pytorch uses additive attention mask; fill with -inf
223
+ mask = torch.empty(self.num_pos, self.num_pos)
224
+ mask.fill_(float("-inf"))
225
+ mask.triu_(1) # zero out the lower diagonal
226
+ return mask
227
+
228
+ def forward(
229
+ self, text: torch.Tensor
230
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
231
+ seq_len = text.shape[1]
232
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
233
+
234
+ attn_mask = self.attn_mask
235
+ if attn_mask is not None:
236
+ attn_mask = attn_mask[:seq_len, :seq_len]
237
+
238
+ x = x + self.positional_embedding[:seq_len]
239
+ x = self.transformer(x, attn_mask=attn_mask)
240
+
241
+ x = self.ln_final(x)
242
+ pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
243
+ if self.text_projection is not None:
244
+ if isinstance(self.text_projection, nn.Linear):
245
+ pooled = self.text_projection(pooled)
246
+ else:
247
+ pooled = pooled @ self.text_projection
248
+ if self.output_tokens:
249
+ return pooled, tokens
250
+ return pooled
251
+
252
+
253
+ class VETextEncoder(nn.Module):
254
+ def __init__(
255
+ self,
256
+ d_model: int,
257
+ tokenizer: Callable,
258
+ width: int = 1024,
259
+ heads: int = 16,
260
+ layers: int = 24,
261
+ context_length: int = 32,
262
+ vocab_size: int = 49408,
263
+ use_ln_post: bool = True,
264
+ compile_mode: Optional[str] = None,
265
+ use_act_checkpoint: bool = True,
266
+ ):
267
+ super().__init__()
268
+ self.context_length = context_length
269
+ self.use_ln_post = use_ln_post
270
+ self.tokenizer = tokenizer
271
+
272
+ self.encoder = TextTransformer(
273
+ context_length=self.context_length,
274
+ vocab_size=vocab_size,
275
+ width=width,
276
+ heads=heads,
277
+ layers=layers,
278
+ # we want the tokens, not just the pooled output
279
+ output_tokens=True,
280
+ use_ln_post=use_ln_post,
281
+ compile_mode=compile_mode,
282
+ use_act_checkpoint=use_act_checkpoint,
283
+ )
284
+ self.resizer = nn.Linear(self.encoder.width, d_model)
285
+
286
+ def forward(
287
+ self,
288
+ text: Union[List[str], Tuple[torch.Tensor, torch.Tensor, dict]],
289
+ input_boxes: Optional[List] = None,
290
+ device: torch.device = None,
291
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
292
+ if isinstance(text[0], str):
293
+ # no use case for this
294
+ assert input_boxes is None or len(input_boxes) == 0, "not supported"
295
+
296
+ # Encode the text
297
+ tokenized = self.tokenizer(text, context_length=self.context_length).to(
298
+ device
299
+ ) # [b, seq_len]
300
+ text_attention_mask = (tokenized != 0).bool()
301
+
302
+ # manually embed the tokens
303
+ inputs_embeds = self.encoder.token_embedding(
304
+ tokenized
305
+ ) # [b, seq_len, d=1024]
306
+ _, text_memory = self.encoder(tokenized) # [b, seq_len, d=1024]
307
+
308
+ assert text_memory.shape[1] == inputs_embeds.shape[1]
309
+ # Invert attention mask because its the opposite in pytorch transformer
310
+ text_attention_mask = text_attention_mask.ne(1)
311
+ # Transpose memory because pytorch's attention expects sequence first
312
+ text_memory = text_memory.transpose(0, 1)
313
+ # Resize the encoder hidden states to be of the same d_model as the decoder
314
+ text_memory_resized = self.resizer(text_memory)
315
+ else:
316
+ # The text is already encoded, use as is.
317
+ text_attention_mask, text_memory_resized, tokenized = text
318
+ inputs_embeds = tokenized["inputs_embeds"]
319
+ assert (
320
+ input_boxes is None or len(input_boxes) == 0
321
+ ), "Can't replace boxes in text if it's already encoded"
322
+
323
+ # Note that the input_embeds are returned in pytorch's convention (sequence first)
324
+ return (
325
+ text_attention_mask,
326
+ text_memory_resized,
327
+ inputs_embeds.transpose(0, 1),
328
+ )
detect_tools/sam3/sam3/model/tokenizer_ve.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ Text Tokenizer.
5
+
6
+ Copied and lightly adapted from VE repo, which in turn copied
7
+ from open_clip and openAI CLIP.
8
+ """
9
+
10
+ import gzip
11
+ import html
12
+ import io
13
+ import os
14
+ import string
15
+ from functools import lru_cache
16
+ from typing import List, Optional, Union
17
+
18
+ import ftfy
19
+ import regex as re
20
+ import torch
21
+ from iopath.common.file_io import g_pathmgr
22
+
23
+
24
+ # https://stackoverflow.com/q/62691279
25
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
+ DEFAULT_CONTEXT_LENGTH = 77
27
+
28
+
29
+ @lru_cache()
30
+ def bytes_to_unicode():
31
+ """
32
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
33
+ The reversible bpe codes work on unicode strings.
34
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
35
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
36
+ This is a significant percentage of your normal, say, 32K bpe vocab.
37
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
38
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
39
+ """
40
+ bs = (
41
+ list(range(ord("!"), ord("~") + 1))
42
+ + list(range(ord("¡"), ord("¬") + 1))
43
+ + list(range(ord("®"), ord("ÿ") + 1))
44
+ )
45
+ cs = bs[:]
46
+ n = 0
47
+ for b in range(2**8):
48
+ if b not in bs:
49
+ bs.append(b)
50
+ cs.append(2**8 + n)
51
+ n += 1
52
+ cs = [chr(n) for n in cs]
53
+ return dict(zip(bs, cs))
54
+
55
+
56
+ def get_pairs(word):
57
+ """Return set of symbol pairs in a word.
58
+ Word is represented as tuple of symbols (symbols being variable-length strings).
59
+ """
60
+ pairs = set()
61
+ prev_char = word[0]
62
+ for char in word[1:]:
63
+ pairs.add((prev_char, char))
64
+ prev_char = char
65
+ return pairs
66
+
67
+
68
+ def basic_clean(text):
69
+ text = ftfy.fix_text(text)
70
+ text = html.unescape(html.unescape(text))
71
+ return text.strip()
72
+
73
+
74
+ def whitespace_clean(text):
75
+ text = re.sub(r"\s+", " ", text)
76
+ text = text.strip()
77
+ return text
78
+
79
+
80
+ def _clean_canonicalize(x):
81
+ # basic, remove whitespace, remove punctuation, lower case
82
+ return canonicalize_text(basic_clean(x))
83
+
84
+
85
+ def _clean_lower(x):
86
+ # basic, remove whitespace, lower case
87
+ return whitespace_clean(basic_clean(x)).lower()
88
+
89
+
90
+ def _clean_whitespace(x):
91
+ # basic, remove whitespace
92
+ return whitespace_clean(basic_clean(x))
93
+
94
+
95
+ def get_clean_fn(type: str):
96
+ if type == "canonicalize":
97
+ return _clean_canonicalize
98
+ elif type == "lower":
99
+ return _clean_lower
100
+ elif type == "whitespace":
101
+ return _clean_whitespace
102
+ else:
103
+ assert False, f"Invalid clean function ({type})."
104
+
105
+
106
+ def canonicalize_text(text, *, keep_punctuation_exact_string=None):
107
+ """Returns canonicalized `text` (lowercase and punctuation removed).
108
+ From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
109
+ Args:
110
+ text: string to be canonicalized.
111
+ keep_punctuation_exact_string: If provided, then this exact string kept.
112
+ For example providing '{}' will keep any occurrences of '{}' (but will
113
+ still remove '{' and '}' that appear separately).
114
+ """
115
+ text = text.replace("_", " ")
116
+ if keep_punctuation_exact_string:
117
+ text = keep_punctuation_exact_string.join(
118
+ part.translate(str.maketrans("", "", string.punctuation))
119
+ for part in text.split(keep_punctuation_exact_string)
120
+ )
121
+ else:
122
+ text = text.translate(str.maketrans("", "", string.punctuation))
123
+ text = text.lower()
124
+ text = re.sub(r"\s+", " ", text)
125
+ return text.strip()
126
+
127
+
128
+ class SimpleTokenizer(object):
129
+ def __init__(
130
+ self,
131
+ bpe_path: Union[str, os.PathLike],
132
+ additional_special_tokens: Optional[List[str]] = None,
133
+ context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
134
+ clean: str = "lower",
135
+ ):
136
+ self.byte_encoder = bytes_to_unicode()
137
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
138
+ with g_pathmgr.open(bpe_path, "rb") as fh:
139
+ bpe_bytes = io.BytesIO(fh.read())
140
+ merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
141
+ # merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
142
+ merges = merges[1 : 49152 - 256 - 2 + 1]
143
+ merges = [tuple(merge.split()) for merge in merges]
144
+ vocab = list(bytes_to_unicode().values())
145
+ vocab = vocab + [v + "</w>" for v in vocab]
146
+ for merge in merges:
147
+ vocab.append("".join(merge))
148
+ special_tokens = ["<start_of_text>", "<end_of_text>"]
149
+ if additional_special_tokens:
150
+ special_tokens += additional_special_tokens
151
+ vocab.extend(special_tokens)
152
+ self.encoder = dict(zip(vocab, range(len(vocab))))
153
+ self.decoder = {v: k for k, v in self.encoder.items()}
154
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
155
+ self.cache = {t: t for t in special_tokens}
156
+ special = "|".join(special_tokens)
157
+ self.pat = re.compile(
158
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
159
+ re.IGNORECASE,
160
+ )
161
+ self.vocab_size = len(self.encoder)
162
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
163
+ self.sot_token_id = self.all_special_ids[0]
164
+ self.eot_token_id = self.all_special_ids[1]
165
+ self.context_length = context_length
166
+ self.clean_fn = get_clean_fn(clean)
167
+
168
+ def bpe(self, token):
169
+ if token in self.cache:
170
+ return self.cache[token]
171
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
172
+ pairs = get_pairs(word)
173
+ if not pairs:
174
+ return token + "</w>"
175
+ while True:
176
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
177
+ if bigram not in self.bpe_ranks:
178
+ break
179
+ first, second = bigram
180
+ new_word = []
181
+ i = 0
182
+ while i < len(word):
183
+ try:
184
+ j = word.index(first, i)
185
+ new_word.extend(word[i:j])
186
+ i = j
187
+ except:
188
+ new_word.extend(word[i:])
189
+ break
190
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
191
+ new_word.append(first + second)
192
+ i += 2
193
+ else:
194
+ new_word.append(word[i])
195
+ i += 1
196
+ new_word = tuple(new_word)
197
+ word = new_word
198
+ if len(word) == 1:
199
+ break
200
+ else:
201
+ pairs = get_pairs(word)
202
+ word = " ".join(word)
203
+ self.cache[token] = word
204
+ return word
205
+
206
+ def encode(self, text):
207
+ bpe_tokens = []
208
+ text = self.clean_fn(text)
209
+ for token in re.findall(self.pat, text):
210
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
211
+ bpe_tokens.extend(
212
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
213
+ )
214
+ return bpe_tokens
215
+
216
+ def decode(self, tokens):
217
+ text = "".join([self.decoder[token] for token in tokens])
218
+ text = (
219
+ bytearray([self.byte_decoder[c] for c in text])
220
+ .decode("utf-8", errors="replace")
221
+ .replace("</w>", " ")
222
+ )
223
+ return text
224
+
225
+ def __call__(
226
+ self, texts: Union[str, List[str]], context_length: Optional[int] = None
227
+ ) -> torch.LongTensor:
228
+ """Returns the tokenized representation of given input string(s)
229
+ Parameters
230
+ ----------
231
+ texts : Union[str, List[str]]
232
+ An input string or a list of input strings to tokenize
233
+ context_length : int
234
+ The context length to use; all CLIP models use 77 as the context length
235
+ Returns
236
+ -------
237
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
238
+ """
239
+ if isinstance(texts, str):
240
+ texts = [texts]
241
+ context_length = context_length or self.context_length
242
+ assert context_length, "Please set a valid context length"
243
+ all_tokens = [
244
+ [self.sot_token_id] + self.encode(text) + [self.eot_token_id]
245
+ for text in texts
246
+ ]
247
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
248
+ for i, tokens in enumerate(all_tokens):
249
+ if len(tokens) > context_length:
250
+ tokens = tokens[:context_length] # Truncate
251
+ tokens[-1] = self.eot_token_id
252
+ result[i, : len(tokens)] = torch.tensor(tokens)
253
+ return result
detect_tools/sam3/sam3/model/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
detect_tools/sam3/sam3/model/utils/misc.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from collections import defaultdict
4
+ from dataclasses import fields, is_dataclass
5
+ from typing import Any, Mapping, Protocol, runtime_checkable
6
+
7
+ import torch
8
+
9
+
10
+ def _is_named_tuple(x) -> bool:
11
+ return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
12
+
13
+
14
+ @runtime_checkable
15
+ class _CopyableData(Protocol):
16
+ def to(self, device: torch.device, *args: Any, **kwargs: Any):
17
+ """Copy data to the specified device"""
18
+ ...
19
+
20
+
21
+ def copy_data_to_device(data, device: torch.device, *args: Any, **kwargs: Any):
22
+ """Function that recursively copies data to a torch.device.
23
+
24
+ Args:
25
+ data: The data to copy to device
26
+ device: The device to which the data should be copied
27
+ args: positional arguments that will be passed to the `to` call
28
+ kwargs: keyword arguments that will be passed to the `to` call
29
+
30
+ Returns:
31
+ The data on the correct device
32
+ """
33
+
34
+ if _is_named_tuple(data):
35
+ return type(data)(
36
+ **copy_data_to_device(data._asdict(), device, *args, **kwargs)
37
+ )
38
+ elif isinstance(data, (list, tuple)):
39
+ return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
40
+ elif isinstance(data, defaultdict):
41
+ return type(data)(
42
+ data.default_factory,
43
+ {
44
+ k: copy_data_to_device(v, device, *args, **kwargs)
45
+ for k, v in data.items()
46
+ },
47
+ )
48
+ elif isinstance(data, Mapping):
49
+ return type(data)(
50
+ {
51
+ k: copy_data_to_device(v, device, *args, **kwargs)
52
+ for k, v in data.items()
53
+ }
54
+ )
55
+ elif is_dataclass(data) and not isinstance(data, type):
56
+ new_data_class = type(data)(
57
+ **{
58
+ field.name: copy_data_to_device(
59
+ getattr(data, field.name), device, *args, **kwargs
60
+ )
61
+ for field in fields(data)
62
+ if field.init
63
+ }
64
+ )
65
+ for field in fields(data):
66
+ if not field.init:
67
+ setattr(
68
+ new_data_class,
69
+ field.name,
70
+ copy_data_to_device(
71
+ getattr(data, field.name), device, *args, **kwargs
72
+ ),
73
+ )
74
+ return new_data_class
75
+ elif isinstance(data, _CopyableData):
76
+ return data.to(device, *args, **kwargs)
77
+ return data
detect_tools/sam3/sam3/model/utils/sam1_utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import warnings
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torchvision.transforms import Normalize, Resize, ToTensor
13
+
14
+
15
+ # Adapted from https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py
16
+ class SAM2Transforms(nn.Module):
17
+ def __init__(
18
+ self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
19
+ ):
20
+ """
21
+ Transforms for SAM2.
22
+ """
23
+ super().__init__()
24
+ self.resolution = resolution
25
+ self.mask_threshold = mask_threshold
26
+ self.max_hole_area = max_hole_area
27
+ self.max_sprinkle_area = max_sprinkle_area
28
+ self.mean = [0.5, 0.5, 0.5]
29
+ self.std = [0.5, 0.5, 0.5]
30
+ self.to_tensor = ToTensor()
31
+ self.transforms = torch.jit.script(
32
+ nn.Sequential(
33
+ Resize((self.resolution, self.resolution)),
34
+ Normalize(self.mean, self.std),
35
+ )
36
+ )
37
+
38
+ def __call__(self, x):
39
+ x = self.to_tensor(x)
40
+ return self.transforms(x)
41
+
42
+ def forward_batch(self, img_list):
43
+ img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
44
+ img_batch = torch.stack(img_batch, dim=0)
45
+ return img_batch
46
+
47
+ def transform_coords(
48
+ self, coords: torch.Tensor, normalize=False, orig_hw=None
49
+ ) -> torch.Tensor:
50
+ """
51
+ Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
52
+ If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
53
+
54
+ Returns
55
+ Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
56
+ """
57
+ if normalize:
58
+ assert orig_hw is not None
59
+ h, w = orig_hw
60
+ coords = coords.clone()
61
+ coords[..., 0] = coords[..., 0] / w
62
+ coords[..., 1] = coords[..., 1] / h
63
+
64
+ coords = coords * self.resolution # unnormalize coords
65
+ return coords
66
+
67
+ def transform_boxes(
68
+ self, boxes: torch.Tensor, normalize=False, orig_hw=None
69
+ ) -> torch.Tensor:
70
+ """
71
+ Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
72
+ if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
73
+ """
74
+ boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
75
+ return boxes
76
+
77
+ def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
78
+ """
79
+ Perform PostProcessing on output masks.
80
+ """
81
+ masks = masks.float()
82
+ input_masks = masks
83
+ mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
84
+ try:
85
+ from sam3.perflib.connected_components import connected_components
86
+
87
+ if self.max_hole_area > 0:
88
+ # Holes are those connected components in background with area <= self.fill_hole_area
89
+ # (background regions are those with mask scores <= self.mask_threshold)
90
+ labels, areas = connected_components(
91
+ (mask_flat <= self.mask_threshold).to(torch.uint8)
92
+ )
93
+ is_hole = (labels > 0) & (areas <= self.max_hole_area)
94
+ is_hole = is_hole.reshape_as(masks)
95
+ # We fill holes with a small positive mask score (10.0) to change them to foreground.
96
+ masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
97
+
98
+ if self.max_sprinkle_area > 0:
99
+ labels, areas = connected_components(
100
+ (mask_flat > self.mask_threshold).to(torch.uint8)
101
+ )
102
+ is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
103
+ is_hole = is_hole.reshape_as(masks)
104
+ # We fill holes with negative mask score (-10.0) to change them to background.
105
+ masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
106
+ except Exception as e:
107
+ # Skip the post-processing step if the CUDA kernel fails
108
+ warnings.warn(
109
+ f"{e}\n\nSkipping the post-processing step due to the error above. You can "
110
+ "still use SAM 3 and it's OK to ignore the error above, although some post-processing "
111
+ "functionality may be limited (which doesn't affect the results in most cases; see "
112
+ "https://github.com/facebookresearch/sam3/blob/main/INSTALL.md).",
113
+ category=UserWarning,
114
+ stacklevel=2,
115
+ )
116
+ masks = input_masks
117
+
118
+ masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
119
+ return masks
detect_tools/sam3/sam3/model/utils/sam2_utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ from threading import Thread
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+
15
+
16
+ def _load_img_as_tensor(img_path, image_size):
17
+ img_pil = Image.open(img_path)
18
+ img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
19
+ if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
20
+ img_np = img_np / 255.0
21
+ else:
22
+ raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
23
+ img = torch.from_numpy(img_np).permute(2, 0, 1)
24
+ video_width, video_height = img_pil.size # the original video size
25
+ return img, video_height, video_width
26
+
27
+
28
+ class AsyncVideoFrameLoader:
29
+ """
30
+ A list of video frames to be load asynchronously without blocking session start.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ img_paths,
36
+ image_size,
37
+ offload_video_to_cpu,
38
+ img_mean,
39
+ img_std,
40
+ compute_device,
41
+ ):
42
+ self.img_paths = img_paths
43
+ self.image_size = image_size
44
+ self.offload_video_to_cpu = offload_video_to_cpu
45
+ self.img_mean = img_mean
46
+ self.img_std = img_std
47
+ # items in `self.images` will be loaded asynchronously
48
+ self.images = [None] * len(img_paths)
49
+ # catch and raise any exceptions in the async loading thread
50
+ self.exception = None
51
+ # video_height and video_width be filled when loading the first image
52
+ self.video_height = None
53
+ self.video_width = None
54
+ self.compute_device = compute_device
55
+
56
+ # load the first frame to fill video_height and video_width and also
57
+ # to cache it (since it's most likely where the user will click)
58
+ self.__getitem__(0)
59
+
60
+ # load the rest of frames asynchronously without blocking the session start
61
+ def _load_frames():
62
+ try:
63
+ for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
64
+ self.__getitem__(n)
65
+ except Exception as e:
66
+ self.exception = e
67
+
68
+ self.thread = Thread(target=_load_frames, daemon=True)
69
+ self.thread.start()
70
+
71
+ def __getitem__(self, index):
72
+ if self.exception is not None:
73
+ raise RuntimeError("Failure in frame loading thread") from self.exception
74
+
75
+ img = self.images[index]
76
+ if img is not None:
77
+ return img
78
+
79
+ img, video_height, video_width = _load_img_as_tensor(
80
+ self.img_paths[index], self.image_size
81
+ )
82
+ self.video_height = video_height
83
+ self.video_width = video_width
84
+ # normalize by mean and std
85
+ img -= self.img_mean
86
+ img /= self.img_std
87
+ if not self.offload_video_to_cpu:
88
+ img = img.to(self.compute_device, non_blocking=True)
89
+ self.images[index] = img
90
+ return img
91
+
92
+ def __len__(self):
93
+ return len(self.images)
94
+
95
+
96
+ def load_video_frames(
97
+ video_path,
98
+ image_size,
99
+ offload_video_to_cpu,
100
+ img_mean=(0.485, 0.456, 0.406),
101
+ img_std=(0.229, 0.224, 0.225),
102
+ async_loading_frames=False,
103
+ compute_device=torch.device("cuda"),
104
+ ):
105
+ """
106
+ Load the video frames from video_path. The frames are resized to image_size as in
107
+ the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
108
+ """
109
+ is_bytes = isinstance(video_path, bytes)
110
+ is_str = isinstance(video_path, str)
111
+ is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
112
+ if is_bytes or is_mp4_path:
113
+ return load_video_frames_from_video_file(
114
+ video_path=video_path,
115
+ image_size=image_size,
116
+ offload_video_to_cpu=offload_video_to_cpu,
117
+ img_mean=img_mean,
118
+ img_std=img_std,
119
+ compute_device=compute_device,
120
+ )
121
+ elif is_str and os.path.isdir(video_path):
122
+ return load_video_frames_from_jpg_images(
123
+ video_path=video_path,
124
+ image_size=image_size,
125
+ offload_video_to_cpu=offload_video_to_cpu,
126
+ img_mean=img_mean,
127
+ img_std=img_std,
128
+ async_loading_frames=async_loading_frames,
129
+ compute_device=compute_device,
130
+ )
131
+ else:
132
+ raise NotImplementedError(
133
+ "Only MP4 video and JPEG folder are supported at this moment"
134
+ )
135
+
136
+
137
+ def load_video_frames_from_jpg_images(
138
+ video_path,
139
+ image_size,
140
+ offload_video_to_cpu,
141
+ img_mean=(0.485, 0.456, 0.406),
142
+ img_std=(0.229, 0.224, 0.225),
143
+ async_loading_frames=False,
144
+ compute_device=torch.device("cuda"),
145
+ ):
146
+ """
147
+ Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
148
+
149
+ The frames are resized to image_size x image_size and are loaded to GPU if
150
+ `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
151
+
152
+ You can load a frame asynchronously by setting `async_loading_frames` to `True`.
153
+ """
154
+ if isinstance(video_path, str) and os.path.isdir(video_path):
155
+ jpg_folder = video_path
156
+ else:
157
+ raise NotImplementedError(
158
+ "Only JPEG frames are supported at this moment. For video files, you may use "
159
+ "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n"
160
+ "```\n"
161
+ "ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'\n"
162
+ "```\n"
163
+ "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks "
164
+ "ffmpeg to start the JPEG file from 00000.jpg."
165
+ )
166
+
167
+ frame_names = [
168
+ p
169
+ for p in os.listdir(jpg_folder)
170
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
171
+ ]
172
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
173
+ num_frames = len(frame_names)
174
+ if num_frames == 0:
175
+ raise RuntimeError(f"no images found in {jpg_folder}")
176
+ img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
177
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
178
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
179
+
180
+ if async_loading_frames:
181
+ lazy_images = AsyncVideoFrameLoader(
182
+ img_paths,
183
+ image_size,
184
+ offload_video_to_cpu,
185
+ img_mean,
186
+ img_std,
187
+ compute_device,
188
+ )
189
+ return lazy_images, lazy_images.video_height, lazy_images.video_width
190
+
191
+ images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
192
+ for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
193
+ images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
194
+ if not offload_video_to_cpu:
195
+ images = images.to(compute_device)
196
+ img_mean = img_mean.to(compute_device)
197
+ img_std = img_std.to(compute_device)
198
+ # normalize by mean and std
199
+ images -= img_mean
200
+ images /= img_std
201
+ return images, video_height, video_width
202
+
203
+
204
+ def load_video_frames_from_video_file(
205
+ video_path,
206
+ image_size,
207
+ offload_video_to_cpu,
208
+ img_mean=(0.485, 0.456, 0.406),
209
+ img_std=(0.229, 0.224, 0.225),
210
+ compute_device=torch.device("cuda"),
211
+ ):
212
+ """Load the video frames from a video file."""
213
+ import decord
214
+
215
+ img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
216
+ img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
217
+ # Get the original video height and width
218
+ decord.bridge.set_bridge("torch")
219
+ video_height, video_width, _ = decord.VideoReader(video_path).next().shape
220
+ # Iterate over all frames in the video
221
+ images = []
222
+ for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
223
+ images.append(frame.permute(2, 0, 1))
224
+
225
+ images = torch.stack(images, dim=0).float() / 255.0
226
+ if not offload_video_to_cpu:
227
+ images = images.to(compute_device)
228
+ img_mean = img_mean.to(compute_device)
229
+ img_std = img_std.to(compute_device)
230
+ # normalize by mean and std
231
+ images -= img_mean
232
+ images /= img_std
233
+ return images, video_height, video_width
detect_tools/sam3/sam3/model/vitdet.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """
4
+ ViTDet backbone adapted from Detectron2.
5
+ This module implements Vision Transformer (ViT) backbone for object detection.
6
+
7
+ Rope embedding code adopted from:
8
+ 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
9
+ 2. https://github.com/naver-ai/rope-vit
10
+ 3. https://github.com/lucidrains/rotary-embedding-torch
11
+ """
12
+
13
+ import math
14
+ from functools import partial
15
+ from typing import Callable, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint as checkpoint
21
+
22
+ try:
23
+ from timm.layers import DropPath, Mlp, trunc_normal_
24
+ except ModuleNotFoundError:
25
+ # compatibility for older timm versions
26
+ from timm.models.layers import DropPath, Mlp, trunc_normal_
27
+ from torch import Tensor
28
+
29
+ from .model_misc import LayerScale
30
+
31
+
32
+ def init_t_xy(
33
+ end_x: int, end_y: int, scale: float = 1.0, offset: int = 0
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ t = torch.arange(end_x * end_y, dtype=torch.float32)
36
+ t_x = (t % end_x).float()
37
+ t_y = torch.div(t, end_x, rounding_mode="floor").float()
38
+ return t_x * scale + offset, t_y * scale + offset
39
+
40
+
41
+ def compute_axial_cis(
42
+ dim: int,
43
+ end_x: int,
44
+ end_y: int,
45
+ theta: float = 10000.0,
46
+ scale_pos: float = 1.0,
47
+ offset: int = 0,
48
+ ) -> torch.Tensor:
49
+ freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
50
+ freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
51
+
52
+ t_x, t_y = init_t_xy(end_x, end_y, scale_pos, offset)
53
+ freqs_x = torch.outer(t_x, freqs_x)
54
+ freqs_y = torch.outer(t_y, freqs_y)
55
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
56
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
57
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
58
+
59
+
60
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
61
+ ndim = x.ndim
62
+ assert 0 <= 1 < ndim
63
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
64
+ shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
65
+ return freqs_cis.view(*shape)
66
+
67
+
68
+ def apply_rotary_enc(
69
+ xq: torch.Tensor,
70
+ xk: torch.Tensor,
71
+ freqs_cis: torch.Tensor,
72
+ repeat_freqs_k: bool = False,
73
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
74
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
75
+ xk_ = (
76
+ torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
77
+ if xk.shape[-2] != 0
78
+ else None
79
+ )
80
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
81
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
82
+ if xk_ is None:
83
+ # no keys to rotate, due to dropout
84
+ return xq_out.type_as(xq).to(xq.device), xk
85
+ # repeat freqs along seq_len dim to match k seq_len
86
+ if repeat_freqs_k:
87
+ r = xk_.shape[-2] // xq_.shape[-2]
88
+ freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
89
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
90
+ return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
91
+
92
+
93
+ def window_partition(x: Tensor, window_size: int) -> Tuple[Tensor, Tuple[int, int]]:
94
+ """
95
+ Partition into non-overlapping windows with padding if needed.
96
+ Args:
97
+ x (tensor): input tokens with [B, H, W, C].
98
+ window_size (int): window size.
99
+ Returns:
100
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
101
+ (Hp, Wp): padded height and width before partition
102
+ """
103
+ B, H, W, C = x.shape
104
+
105
+ pad_h = (window_size - H % window_size) % window_size
106
+ pad_w = (window_size - W % window_size) % window_size
107
+ if pad_h > 0 or pad_w > 0:
108
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
109
+ Hp, Wp = H + pad_h, W + pad_w
110
+
111
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
112
+ windows = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
113
+ return windows, (Hp, Wp)
114
+
115
+
116
+ def window_unpartition(
117
+ windows: Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
118
+ ) -> Tensor:
119
+ """
120
+ Window unpartition into original sequences and removing padding.
121
+ Args:
122
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
123
+ window_size (int): window size.
124
+ pad_hw (Tuple): padded height and width (Hp, Wp).
125
+ hw (Tuple): original height and width (H, W) before padding.
126
+ Returns:
127
+ x: unpartitioned sequences with [B, H, W, C].
128
+ """
129
+ Hp, Wp = pad_hw
130
+ H, W = hw
131
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
132
+ x = windows.reshape(
133
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
134
+ )
135
+ x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
136
+
137
+ if Hp > H or Wp > W:
138
+ x = x[:, :H, :W, :]
139
+ return x
140
+
141
+
142
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: Tensor) -> Tensor:
143
+ """
144
+ Get relative positional embeddings according to the relative positions of
145
+ query and key sizes.
146
+ Args:
147
+ q_size (int): size of query q.
148
+ k_size (int): size of key k.
149
+ rel_pos (Tensor): relative position embeddings (L, C).
150
+ Returns:
151
+ Extracted positional embeddings according to relative positions.
152
+ """
153
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
154
+ # Interpolate rel pos if needed.
155
+ if rel_pos.shape[0] != max_rel_dist:
156
+ # Interpolate rel pos.
157
+ rel_pos_resized = F.interpolate(
158
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
159
+ size=max_rel_dist,
160
+ mode="linear",
161
+ align_corners=False,
162
+ )
163
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
164
+ else:
165
+ rel_pos_resized = rel_pos
166
+
167
+ # Scale the coords with short length if shapes for q and k are different.
168
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
169
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
170
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
171
+
172
+ return rel_pos_resized[relative_coords.long()]
173
+
174
+
175
+ def get_abs_pos(
176
+ abs_pos: Tensor,
177
+ has_cls_token: bool,
178
+ hw: Tuple[int, int],
179
+ retain_cls_token: bool = False,
180
+ tiling: bool = False,
181
+ ) -> Tensor:
182
+ """
183
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
184
+ dimension for the original embeddings.
185
+ Args:
186
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
187
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
188
+ hw (Tuple): size of input image tokens.
189
+ retain_cls_token: whether to retain the cls_token
190
+ tiling: whether to tile the embeddings, *instead* of interpolation (a la abs_win)
191
+ Returns:
192
+ Absolute positional embeddings after processing with shape (1, H, W, C),
193
+ if retain_cls_token is False, otherwise (1, 1+H*W, C)
194
+ """
195
+ if retain_cls_token:
196
+ assert has_cls_token
197
+
198
+ h, w = hw
199
+ if has_cls_token:
200
+ cls_pos = abs_pos[:, :1]
201
+ abs_pos = abs_pos[:, 1:]
202
+
203
+ xy_num = abs_pos.shape[1]
204
+ size = int(math.sqrt(xy_num))
205
+ assert size * size == xy_num
206
+
207
+ if size != h or size != w:
208
+ new_abs_pos = abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2)
209
+ if tiling:
210
+ new_abs_pos = new_abs_pos.tile(
211
+ [1, 1] + [x // y + 1 for x, y in zip((h, w), new_abs_pos.shape[2:])]
212
+ )[:, :, :h, :w]
213
+ else:
214
+ new_abs_pos = F.interpolate(
215
+ new_abs_pos,
216
+ size=(h, w),
217
+ mode="bicubic",
218
+ align_corners=False,
219
+ )
220
+
221
+ if not retain_cls_token:
222
+ return new_abs_pos.permute(0, 2, 3, 1)
223
+ else:
224
+ # add cls_token back, flatten spatial dims
225
+ assert has_cls_token
226
+ return torch.cat(
227
+ [cls_pos, new_abs_pos.permute(0, 2, 3, 1).reshape(1, h * w, -1)],
228
+ dim=1,
229
+ )
230
+
231
+ else:
232
+ if not retain_cls_token:
233
+ return abs_pos.reshape(1, h, w, -1)
234
+ else:
235
+ assert has_cls_token
236
+ return torch.cat([cls_pos, abs_pos], dim=1)
237
+
238
+
239
+ def concat_rel_pos(
240
+ q: Tensor,
241
+ k: Tensor,
242
+ q_hw: Tuple[int, int],
243
+ k_hw: Tuple[int, int],
244
+ rel_pos_h: Tensor,
245
+ rel_pos_w: Tensor,
246
+ rescale: bool = False,
247
+ relative_coords: Optional[Tensor] = None,
248
+ ) -> Tuple[Tensor, Tensor]:
249
+ """
250
+ Concatenate rel pos coeffs to the q & k tensors, so that qk^T is now
251
+ effectively including rel pos biases.
252
+ Args:
253
+ q (Tensor): q tensor with shape (B, L_q, C).
254
+ k (Tensor): k tensor with shape (B, L_k, C).
255
+ q_hw, k_hw: These are spatial size of q & k tensors.
256
+ rel_pos_h, rel_pos_w: These are relative pos embeddings/params of height, width.
257
+ rescale (bool): whether to rescale. e.g. for use when using sdpa, pytorch will
258
+ scale by the wrong factor due to the concat.
259
+ Returns:
260
+ q, k: But, padded so that qk^T accounts for rel pos biases
261
+ """
262
+ q_h, q_w = q_hw
263
+ k_h, k_w = k_hw
264
+
265
+ assert (q_h == q_w) and (k_h == k_w), "only square inputs supported"
266
+
267
+ if relative_coords is not None:
268
+ Rh = rel_pos_h[relative_coords]
269
+ Rw = rel_pos_w[relative_coords]
270
+ else:
271
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
272
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
273
+
274
+ B, _, dim = q.shape
275
+ r_q = q.reshape(B, q_h, q_w, dim)
276
+
277
+ old_scale = dim**0.5
278
+ new_scale = (dim + k_h + k_w) ** 0.5 if rescale else old_scale # for sdpa
279
+ # attn will be divided by new_scale, but we want to divide q by old_scale
280
+ scale_ratio = new_scale / old_scale
281
+
282
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) * new_scale # (B, q_h, q_w, k_h)
283
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) * new_scale # (B, q_h, q_w, k_w)
284
+
285
+ eye_h = torch.eye(k_h, dtype=q.dtype, device=q.device)
286
+ eye_w = torch.eye(k_w, dtype=q.dtype, device=q.device)
287
+
288
+ eye_h = eye_h.view(1, k_h, 1, k_h).expand([B, k_h, k_w, k_h])
289
+ eye_w = eye_w.view(1, 1, k_w, k_w).expand([B, k_h, k_w, k_w])
290
+
291
+ q = torch.cat([r_q * scale_ratio, rel_h, rel_w], dim=-1).view(B, q_h * q_w, -1)
292
+ k = torch.cat([k.view(B, k_h, k_w, -1), eye_h, eye_w], dim=-1).view(
293
+ B, k_h * k_w, -1
294
+ )
295
+
296
+ return q, k
297
+
298
+
299
+ class PatchEmbed(nn.Module):
300
+ """
301
+ Image to Patch Embedding.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ kernel_size: Tuple[int, int] = (16, 16),
307
+ stride: Tuple[int, int] = (16, 16),
308
+ padding: Tuple[int, int] = (0, 0),
309
+ in_chans: int = 3,
310
+ embed_dim: int = 768,
311
+ bias: bool = True,
312
+ ):
313
+ """
314
+ Args:
315
+ kernel_size (Tuple): kernel size of the projection layer.
316
+ stride (Tuple): stride of the projection layer.
317
+ padding (Tuple): padding size of the projection layer.
318
+ in_chans (int): Number of input image channels.
319
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
320
+ """
321
+ super().__init__()
322
+
323
+ self.proj = nn.Conv2d(
324
+ in_chans,
325
+ embed_dim,
326
+ kernel_size=kernel_size,
327
+ stride=stride,
328
+ padding=padding,
329
+ bias=bias,
330
+ )
331
+
332
+ def forward(self, x: Tensor) -> Tensor:
333
+ x = self.proj(x)
334
+ # B C H W -> B H W C
335
+ x = x.permute(0, 2, 3, 1)
336
+ return x
337
+
338
+
339
+ class Attention(nn.Module):
340
+ """Multi-head Attention block with relative position embeddings and 2d-rope."""
341
+
342
+ def __init__(
343
+ self,
344
+ dim: int,
345
+ num_heads: int = 8,
346
+ qkv_bias: bool = True,
347
+ use_rel_pos: bool = False,
348
+ rel_pos_zero_init: bool = True,
349
+ input_size: Optional[Tuple[int, int]] = None,
350
+ cls_token: bool = False,
351
+ use_rope: bool = False,
352
+ rope_theta: float = 10000.0,
353
+ rope_pt_size: Optional[Tuple[int, int]] = None,
354
+ rope_interp: bool = False,
355
+ ):
356
+ """
357
+ Args:
358
+ dim (int): Number of input channels.
359
+ num_heads (int): Number of attention heads.
360
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
361
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
362
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
363
+ input_size (int or None): Input resolution for calculating the relative positional
364
+ parameter size or rope size.
365
+ attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
366
+ cls_token: whether a cls_token is present.
367
+ use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
368
+ rope_theta: control frequencies of rope
369
+ rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
370
+ rope_interp: whether to interpolate (or extrapolate) rope to match input size
371
+ """
372
+ super().__init__()
373
+ self.num_heads = num_heads
374
+ self.head_dim = dim // num_heads
375
+ self.scale = self.head_dim**-0.5
376
+ self.cls_token = cls_token
377
+
378
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
379
+ self.proj = nn.Linear(dim, dim)
380
+
381
+ # rel_pos embeddings and rope
382
+ self.use_rel_pos = use_rel_pos
383
+ self.input_size = input_size
384
+
385
+ self.use_rope = use_rope
386
+ self.rope_theta = rope_theta
387
+ self.rope_pt_size = rope_pt_size
388
+ self.rope_interp = rope_interp
389
+
390
+ # init rel_pos embeddings and rope
391
+ self._setup_rel_pos(rel_pos_zero_init)
392
+ self._setup_rope_freqs()
393
+
394
+ def _setup_rel_pos(self, rel_pos_zero_init: bool = True) -> None:
395
+ if not self.use_rel_pos:
396
+ self.rel_pos_h = None
397
+ self.rel_pos_w = None
398
+ return
399
+
400
+ assert self.input_size is not None
401
+ assert self.cls_token is False, "not supported"
402
+ # initialize relative positional embeddings
403
+ self.rel_pos_h = nn.Parameter(
404
+ torch.zeros(2 * self.input_size[0] - 1, self.head_dim)
405
+ )
406
+ self.rel_pos_w = nn.Parameter(
407
+ torch.zeros(2 * self.input_size[1] - 1, self.head_dim)
408
+ )
409
+
410
+ if not rel_pos_zero_init:
411
+ trunc_normal_(self.rel_pos_h, std=0.02)
412
+ trunc_normal_(self.rel_pos_w, std=0.02)
413
+
414
+ # Precompute the relative coords
415
+ H, W = self.input_size
416
+ q_coords = torch.arange(H)[:, None]
417
+ k_coords = torch.arange(W)[None, :]
418
+ relative_coords = (q_coords - k_coords) + (H - 1)
419
+ self.register_buffer("relative_coords", relative_coords.long())
420
+
421
+ def _setup_rope_freqs(self) -> None:
422
+ if not self.use_rope:
423
+ self.freqs_cis = None
424
+ return
425
+
426
+ assert self.input_size is not None
427
+ # determine rope input size
428
+ if self.rope_pt_size is None:
429
+ self.rope_pt_size = self.input_size
430
+
431
+ # initialize 2d rope freqs
432
+ self.compute_cis = partial(
433
+ compute_axial_cis,
434
+ dim=self.head_dim,
435
+ theta=self.rope_theta,
436
+ )
437
+
438
+ # interpolate rope
439
+ scale_pos = 1.0
440
+ if self.rope_interp:
441
+ scale_pos = self.rope_pt_size[0] / self.input_size[0]
442
+ # get scaled freqs_cis
443
+ freqs_cis = self.compute_cis(
444
+ end_x=self.input_size[0],
445
+ end_y=self.input_size[1],
446
+ scale_pos=scale_pos,
447
+ )
448
+ if self.cls_token:
449
+ t = torch.zeros(
450
+ self.head_dim // 2,
451
+ dtype=torch.float32,
452
+ device=freqs_cis.device,
453
+ )
454
+ cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
455
+ freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
456
+
457
+ self.register_buffer("freqs_cis", freqs_cis)
458
+
459
+ def _apply_rope(self, q, k) -> Tuple[Tensor, Tensor]:
460
+ if not self.use_rope:
461
+ return q, k
462
+
463
+ assert self.freqs_cis is not None
464
+ return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis)
465
+
466
+ def forward(self, x: Tensor) -> Tensor:
467
+ s = 1 if self.cls_token else 0 # used to exclude cls_token
468
+ if x.ndim == 4:
469
+ B, H, W, _ = x.shape
470
+ assert s == 0 # no cls_token
471
+ L = H * W
472
+ ndim = 4
473
+ else:
474
+ assert x.ndim == 3
475
+ B, L, _ = x.shape
476
+ ndim = 3
477
+ H = W = math.sqrt(L - s)
478
+
479
+ # qkv with shape (3, B, nHead, L, C)
480
+ qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
481
+ # q, k, v with shape (B, nHead, L, C)
482
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
483
+
484
+ # handle rope and rel pos embeddings
485
+ q, k = self._apply_rope(q, k)
486
+ if self.use_rel_pos:
487
+ q, k = concat_rel_pos(
488
+ q.flatten(0, 1),
489
+ k.flatten(0, 1),
490
+ (H, W),
491
+ x.shape[1:3],
492
+ self.rel_pos_h,
493
+ self.rel_pos_w,
494
+ rescale=True,
495
+ relative_coords=self.relative_coords,
496
+ )
497
+
498
+ # sdpa expects [B, nheads, H*W, C] so we transpose back
499
+ q = q.reshape(B, self.num_heads, H * W, -1)
500
+ k = k.reshape(B, self.num_heads, H * W, -1)
501
+
502
+ x = F.scaled_dot_product_attention(q, k, v)
503
+
504
+ if ndim == 4:
505
+ x = (
506
+ x.view(B, self.num_heads, H, W, -1)
507
+ .permute(0, 2, 3, 1, 4)
508
+ .reshape(B, H, W, -1)
509
+ )
510
+ else:
511
+ x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
512
+
513
+ x = self.proj(x)
514
+
515
+ return x
516
+
517
+
518
+ class Block(nn.Module):
519
+ """Transformer blocks with support of window attention"""
520
+
521
+ def __init__(
522
+ self,
523
+ dim: int,
524
+ num_heads: int,
525
+ mlp_ratio: float = 4.0,
526
+ qkv_bias: bool = True,
527
+ drop_path: float = 0.0,
528
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
529
+ act_layer: Callable[..., nn.Module] = nn.GELU,
530
+ use_rel_pos: bool = False,
531
+ rel_pos_zero_init: bool = True,
532
+ window_size: int = 0,
533
+ input_size: Optional[Tuple[int, int]] = None,
534
+ use_rope: bool = False,
535
+ rope_pt_size: Optional[Tuple[int, int]] = None,
536
+ rope_tiled: bool = False,
537
+ rope_interp: bool = False,
538
+ use_ve_rope: bool = False,
539
+ cls_token: bool = False,
540
+ dropout: float = 0.0,
541
+ init_values: Optional[float] = None,
542
+ ):
543
+ """
544
+ Args:
545
+ dim (int): Number of input channels.
546
+ num_heads (int): Number of attention heads in each ViT block.
547
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
548
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
549
+ drop_path (float): Stochastic depth rate.
550
+ norm_layer (nn.Module): Normalization layer.
551
+ act_layer (nn.Module): Activation layer.
552
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
553
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
554
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
555
+ use window attention.
556
+ input_size (int or None): Input resolution for calculating the relative positional
557
+ parameter size.
558
+ dropout (float): Dropout rate.
559
+ cls_token: whether a cls_token is present.
560
+ use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
561
+ rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
562
+ rope_interp: whether to interpolate (or extrapolate) rope to match target input size,
563
+ expected to specify source size as rope_pt_size.
564
+ """
565
+ super().__init__()
566
+ self.norm1 = norm_layer(dim)
567
+ self.attn = Attention(
568
+ dim,
569
+ num_heads=num_heads,
570
+ qkv_bias=qkv_bias,
571
+ use_rel_pos=use_rel_pos,
572
+ rel_pos_zero_init=rel_pos_zero_init,
573
+ input_size=input_size if window_size == 0 else (window_size, window_size),
574
+ use_rope=use_rope,
575
+ rope_pt_size=rope_pt_size,
576
+ rope_interp=rope_interp,
577
+ cls_token=cls_token,
578
+ )
579
+ self.ls1 = (
580
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
581
+ )
582
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
583
+
584
+ self.norm2 = norm_layer(dim)
585
+ self.mlp = Mlp(
586
+ in_features=dim,
587
+ hidden_features=int(dim * mlp_ratio),
588
+ act_layer=act_layer,
589
+ drop=(dropout, 0.0),
590
+ )
591
+ self.ls2 = (
592
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
593
+ )
594
+ self.dropout = nn.Dropout(dropout)
595
+ self.window_size = window_size
596
+
597
+ def forward(self, x: Tensor) -> Tensor:
598
+ shortcut = x
599
+ x = self.norm1(x)
600
+ # Window partition
601
+ if self.window_size > 0:
602
+ H, W = x.shape[1], x.shape[2]
603
+ x, pad_hw = window_partition(x, self.window_size)
604
+
605
+ x = self.ls1(self.attn(x))
606
+ # Reverse window partition
607
+ if self.window_size > 0:
608
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
609
+
610
+ x = shortcut + self.dropout(self.drop_path(x))
611
+ x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
612
+
613
+ return x
614
+
615
+
616
+ class ViT(nn.Module):
617
+ """
618
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
619
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
620
+ https://arxiv.org/abs/2203.16527
621
+ """
622
+
623
+ def __init__(
624
+ self,
625
+ img_size: int = 1024,
626
+ patch_size: int = 16,
627
+ in_chans: int = 3,
628
+ embed_dim: int = 768,
629
+ depth: int = 12,
630
+ num_heads: int = 12,
631
+ mlp_ratio: float = 4.0,
632
+ qkv_bias: bool = True,
633
+ drop_path_rate: float = 0.0,
634
+ norm_layer: Union[Callable[..., nn.Module], str] = "LayerNorm",
635
+ act_layer: Callable[..., nn.Module] = nn.GELU,
636
+ use_abs_pos: bool = True,
637
+ tile_abs_pos: bool = True,
638
+ rel_pos_blocks: Union[Tuple[int, ...], bool] = (2, 5, 8, 11),
639
+ rel_pos_zero_init: bool = True,
640
+ window_size: int = 14,
641
+ global_att_blocks: Tuple[int, ...] = (2, 5, 8, 11),
642
+ use_rope: bool = False,
643
+ rope_pt_size: Optional[int] = None,
644
+ use_interp_rope: bool = False,
645
+ pretrain_img_size: int = 224,
646
+ pretrain_use_cls_token: bool = True,
647
+ retain_cls_token: bool = True,
648
+ dropout: float = 0.0,
649
+ return_interm_layers: bool = False,
650
+ init_values: Optional[float] = None, # for layerscale
651
+ ln_pre: bool = False,
652
+ ln_post: bool = False,
653
+ bias_patch_embed: bool = True,
654
+ compile_mode: Optional[str] = None,
655
+ use_act_checkpoint: bool = True,
656
+ ):
657
+ """
658
+ Args:
659
+ img_size (int): Input image size. Only relevant for rel pos or rope.
660
+ patch_size (int): Patch size.
661
+ in_chans (int): Number of input image channels.
662
+ embed_dim (int): Patch embedding dimension.
663
+ depth (int): Depth of ViT.
664
+ num_heads (int): Number of attention heads in each ViT block.
665
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
666
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
667
+ drop_path_rate (float): Stochastic depth rate.
668
+ norm_layer (nn.Module): Normalization layer.
669
+ act_layer (nn.Module): Activation layer.
670
+ use_abs_pos (bool): If True, use absolute positional embeddings.
671
+ tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
672
+ rel_pos_blocks (list): Blocks which have rel pos embeddings.
673
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
674
+ window_size (int): Window size for window attention blocks.
675
+ global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
676
+ use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
677
+ rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
678
+ use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size,
679
+ expected to specify source size as rope_pt_size.
680
+ use_act_checkpoint (bool): If True, use activation checkpointing.
681
+ pretrain_img_size (int): input image size for pretraining models.
682
+ pretrain_use_cls_token (bool): If True, pretraining models use class token.
683
+ retain_cls_token: whether cls_token should be retained.
684
+ dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
685
+
686
+ return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
687
+ init_values: layer scale init, None for no layer scale.
688
+
689
+ ln_pre (bool): If True, apply layer norm before transformer blocks.
690
+ ln_post (bool): If True, apply layer norm after transformer blocks.
691
+ bias_patch_embed (bool): bias in conv for patch embed?
692
+ compile_mode (str): mode to compile the forward
693
+ """
694
+ super().__init__()
695
+ self.pretrain_use_cls_token = pretrain_use_cls_token
696
+
697
+ window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
698
+ self.full_attn_ids = list(global_att_blocks)
699
+ self.rel_pos_blocks = [False] * depth
700
+ if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
701
+ self.rel_pos_blocks = [True] * depth
702
+ else:
703
+ for i in rel_pos_blocks:
704
+ self.rel_pos_blocks[i] = True
705
+
706
+ self.retain_cls_token = retain_cls_token
707
+ if self.retain_cls_token:
708
+ assert pretrain_use_cls_token
709
+ assert (
710
+ len(window_block_indexes) == 0
711
+ ), "windowing not supported with cls token"
712
+
713
+ assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
714
+
715
+ scale = embed_dim**-0.5
716
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
717
+
718
+ if isinstance(norm_layer, str):
719
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
720
+
721
+ self.patch_embed = PatchEmbed(
722
+ kernel_size=(patch_size, patch_size),
723
+ stride=(patch_size, patch_size),
724
+ in_chans=in_chans,
725
+ embed_dim=embed_dim,
726
+ bias=bias_patch_embed,
727
+ )
728
+
729
+ # Handle absolute positional embedding
730
+ self.tile_abs_pos = tile_abs_pos
731
+ self.use_abs_pos = use_abs_pos
732
+ if self.tile_abs_pos:
733
+ assert self.use_abs_pos
734
+
735
+ if self.use_abs_pos:
736
+ # Initialize absolute positional embedding with pretrain image size.
737
+ num_patches = (pretrain_img_size // patch_size) * (
738
+ pretrain_img_size // patch_size
739
+ )
740
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
741
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
742
+ else:
743
+ self.pos_embed = None
744
+
745
+ # stochastic depth decay rule
746
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
747
+
748
+ self.blocks = nn.ModuleList()
749
+ cur_stage = 1
750
+ for i in range(depth):
751
+ block = Block(
752
+ dim=embed_dim,
753
+ num_heads=num_heads,
754
+ mlp_ratio=mlp_ratio,
755
+ qkv_bias=qkv_bias,
756
+ drop_path=dpr[i],
757
+ norm_layer=norm_layer,
758
+ act_layer=act_layer,
759
+ use_rel_pos=self.rel_pos_blocks[i],
760
+ rel_pos_zero_init=rel_pos_zero_init,
761
+ window_size=window_size if i in window_block_indexes else 0,
762
+ input_size=(img_size // patch_size, img_size // patch_size),
763
+ use_rope=use_rope,
764
+ rope_pt_size=(
765
+ (window_size, window_size)
766
+ if rope_pt_size is None
767
+ else (rope_pt_size, rope_pt_size)
768
+ ),
769
+ rope_interp=use_interp_rope,
770
+ cls_token=self.retain_cls_token,
771
+ dropout=dropout,
772
+ init_values=init_values,
773
+ )
774
+
775
+ if i not in window_block_indexes:
776
+ cur_stage += 1
777
+
778
+ self.use_act_checkpoint = use_act_checkpoint
779
+
780
+ self.blocks.append(block)
781
+
782
+ self.return_interm_layers = return_interm_layers
783
+ self.channel_list = (
784
+ [embed_dim] * len(self.full_attn_ids)
785
+ if return_interm_layers
786
+ else [embed_dim]
787
+ )
788
+
789
+ if self.pos_embed is not None:
790
+ trunc_normal_(self.pos_embed, std=0.02)
791
+
792
+ self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
793
+ self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
794
+
795
+ self.apply(self._init_weights)
796
+
797
+ if compile_mode is not None:
798
+ self.forward = torch.compile(
799
+ self.forward, mode=compile_mode, fullgraph=True
800
+ )
801
+ if self.use_act_checkpoint and self.training:
802
+ torch._dynamo.config.optimize_ddp = False
803
+
804
+ def _init_weights(self, m: nn.Module) -> None:
805
+ if isinstance(m, nn.Linear):
806
+ trunc_normal_(m.weight, std=0.02)
807
+ if isinstance(m, nn.Linear) and m.bias is not None:
808
+ nn.init.constant_(m.bias, 0)
809
+ elif isinstance(m, nn.LayerNorm):
810
+ nn.init.constant_(m.bias, 0)
811
+ nn.init.constant_(m.weight, 1.0)
812
+
813
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
814
+ x = self.patch_embed(x)
815
+ h, w = x.shape[1], x.shape[2]
816
+
817
+ s = 0
818
+ if self.retain_cls_token:
819
+ # If cls_token is retained, we don't
820
+ # maintain spatial shape
821
+ x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
822
+ s = 1
823
+
824
+ if self.pos_embed is not None:
825
+ x = x + get_abs_pos(
826
+ self.pos_embed,
827
+ self.pretrain_use_cls_token,
828
+ (h, w),
829
+ self.retain_cls_token,
830
+ tiling=self.tile_abs_pos,
831
+ )
832
+
833
+ x = self.ln_pre(x)
834
+
835
+ outputs = []
836
+ for i, blk in enumerate(self.blocks):
837
+ if self.use_act_checkpoint and self.training:
838
+ x = checkpoint.checkpoint(blk, x, use_reentrant=False)
839
+ else:
840
+ x = blk(x)
841
+ if (i == self.full_attn_ids[-1]) or (
842
+ self.return_interm_layers and i in self.full_attn_ids
843
+ ):
844
+ if i == self.full_attn_ids[-1]:
845
+ x = self.ln_post(x)
846
+
847
+ feats = x[:, s:]
848
+ if feats.ndim == 4:
849
+ feats = feats.permute(0, 3, 1, 2)
850
+ else:
851
+ assert feats.ndim == 3
852
+ h = w = math.sqrt(feats.shape[1])
853
+ feats = feats.reshape(
854
+ feats.shape[0], h, w, feats.shape[-1]
855
+ ).permute(0, 3, 1, 2)
856
+
857
+ outputs.append(feats)
858
+
859
+ return outputs
860
+
861
+ def get_layer_id(self, layer_name: str) -> int:
862
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
863
+ num_layers = self.get_num_layers()
864
+
865
+ if layer_name.find("rel_pos") != -1:
866
+ return num_layers + 1
867
+ elif layer_name.find("ln_pre") != -1:
868
+ return 0
869
+ elif layer_name.find("pos_embed") != -1 or layer_name.find("cls_token") != -1:
870
+ return 0
871
+ elif layer_name.find("patch_embed") != -1:
872
+ return 0
873
+ elif layer_name.find("blocks") != -1:
874
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
875
+ else:
876
+ return num_layers + 1
877
+
878
+ def get_num_layers(self) -> int:
879
+ return len(self.blocks)
detect_tools/sam3/sam3/model/vl_combiner.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ """Provides utility to combine a vision backbone with a language backbone."""
4
+
5
+ from copy import copy
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from torch.nn.attention import sdpa_kernel, SDPBackend
12
+
13
+ from .act_ckpt_utils import activation_ckpt_wrapper
14
+ from .necks import Sam3DualViTDetNeck
15
+
16
+
17
+ class SAM3VLBackbone(nn.Module):
18
+ """This backbone combines a vision backbone and a language backbone without fusion.
19
+ As such it is more of a convenience wrapper to handle the two backbones together.
20
+
21
+ It adds support for activation checkpointing and compilation.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ visual: Sam3DualViTDetNeck,
27
+ text,
28
+ compile_visual: bool = False,
29
+ act_ckpt_whole_vision_backbone: bool = False,
30
+ act_ckpt_whole_language_backbone: bool = False,
31
+ scalp=0,
32
+ ):
33
+ """Initialize the backbone combiner.
34
+
35
+ :param visual: The vision backbone to use
36
+ :param text: The text encoder to use
37
+ """
38
+ super().__init__()
39
+ self.vision_backbone: Sam3DualViTDetNeck = (
40
+ torch.compile(visual) if compile_visual else visual
41
+ )
42
+ self.language_backbone = text
43
+ self.scalp = scalp
44
+ # allow running activation checkpointing on the entire vision and language backbones
45
+ self.act_ckpt_whole_vision_backbone = act_ckpt_whole_vision_backbone
46
+ self.act_ckpt_whole_language_backbone = act_ckpt_whole_language_backbone
47
+
48
+ def forward(
49
+ self,
50
+ samples: torch.Tensor,
51
+ captions: List[str],
52
+ input_boxes: Optional[torch.Tensor] = None,
53
+ additional_text: Optional[List[str]] = None,
54
+ ):
55
+ """Forward pass of the backbone combiner.
56
+
57
+ :param samples: The input images
58
+ :param captions: The input captions
59
+ :param input_boxes: If the text contains place-holders for boxes, this
60
+ parameter contains the tensor containing their spatial features
61
+ :param additional_text: This can be used to encode some additional text
62
+ (different from the captions) in the same forward of the backbone
63
+ :return: Output dictionary with the following keys:
64
+ - vision_features: The output of the vision backbone
65
+ - language_features: The output of the language backbone
66
+ - language_mask: The attention mask of the language backbone
67
+ - vision_pos_enc: The positional encoding of the vision backbone
68
+ - (optional) additional_text_features: The output of the language
69
+ backbone for the additional text
70
+ - (optional) additional_text_mask: The attention mask of the
71
+ language backbone for the additional text
72
+ """
73
+ output = self.forward_image(samples)
74
+ device = output["vision_features"].device
75
+ output.update(self.forward_text(captions, input_boxes, additional_text, device))
76
+ return output
77
+
78
+ def forward_image(self, samples: torch.Tensor):
79
+ return activation_ckpt_wrapper(self._forward_image_no_act_ckpt)(
80
+ samples=samples,
81
+ act_ckpt_enable=self.act_ckpt_whole_vision_backbone and self.training,
82
+ )
83
+
84
+ def _forward_image_no_act_ckpt(self, samples):
85
+ # Forward through backbone
86
+ sam3_features, sam3_pos, sam2_features, sam2_pos = self.vision_backbone.forward(
87
+ samples
88
+ )
89
+ if self.scalp > 0:
90
+ # Discard the lowest resolution features
91
+ sam3_features, sam3_pos = (
92
+ sam3_features[: -self.scalp],
93
+ sam3_pos[: -self.scalp],
94
+ )
95
+ if sam2_features is not None and sam2_pos is not None:
96
+ sam2_features, sam2_pos = (
97
+ sam2_features[: -self.scalp],
98
+ sam2_pos[: -self.scalp],
99
+ )
100
+
101
+ sam2_output = None
102
+
103
+ if sam2_features is not None and sam2_pos is not None:
104
+ sam2_src = sam2_features[-1]
105
+ sam2_output = {
106
+ "vision_features": sam2_src,
107
+ "vision_pos_enc": sam2_pos,
108
+ "backbone_fpn": sam2_features,
109
+ }
110
+
111
+ sam3_src = sam3_features[-1]
112
+ output = {
113
+ "vision_features": sam3_src,
114
+ "vision_pos_enc": sam3_pos,
115
+ "backbone_fpn": sam3_features,
116
+ "sam2_backbone_out": sam2_output,
117
+ }
118
+
119
+ return output
120
+
121
+ def forward_text(
122
+ self, captions, input_boxes=None, additional_text=None, device="cuda"
123
+ ):
124
+ return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
125
+ captions=captions,
126
+ input_boxes=input_boxes,
127
+ additional_text=additional_text,
128
+ device=device,
129
+ act_ckpt_enable=self.act_ckpt_whole_language_backbone and self.training,
130
+ )
131
+
132
+ def _forward_text_no_ack_ckpt(
133
+ self,
134
+ captions,
135
+ input_boxes=None,
136
+ additional_text=None,
137
+ device="cuda",
138
+ ):
139
+ output = {}
140
+
141
+ # Forward through text_encoder
142
+ text_to_encode = copy(captions)
143
+ if additional_text is not None:
144
+ # if there are additional_text, we piggy-back them into this forward.
145
+ # They'll be used later for output alignment
146
+ text_to_encode += additional_text
147
+
148
+ sdpa_context = sdpa_kernel(
149
+ [
150
+ SDPBackend.MATH,
151
+ SDPBackend.EFFICIENT_ATTENTION,
152
+ SDPBackend.FLASH_ATTENTION,
153
+ ]
154
+ )
155
+
156
+ with sdpa_context:
157
+ text_attention_mask, text_memory, text_embeds = self.language_backbone(
158
+ text_to_encode, input_boxes, device=device
159
+ )
160
+
161
+ if additional_text is not None:
162
+ output["additional_text_features"] = text_memory[:, -len(additional_text) :]
163
+ output["additional_text_mask"] = text_attention_mask[
164
+ -len(additional_text) :
165
+ ]
166
+
167
+ text_memory = text_memory[:, : len(captions)]
168
+ text_attention_mask = text_attention_mask[: len(captions)]
169
+ text_embeds = text_embeds[:, : len(captions)]
170
+ output["language_features"] = text_memory
171
+ output["language_mask"] = text_attention_mask
172
+ output["language_embeds"] = (
173
+ text_embeds # Text embeddings before forward to the encoder
174
+ )
175
+
176
+ return output
detect_tools/sam3/sam3/model_builder.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import os
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from huggingface_hub import hf_hub_download
9
+ from iopath.common.file_io import g_pathmgr
10
+ from sam3.model.decoder import (
11
+ TransformerDecoder,
12
+ TransformerDecoderLayer,
13
+ TransformerDecoderLayerv2,
14
+ TransformerEncoderCrossAttention,
15
+ )
16
+ from sam3.model.encoder import TransformerEncoderFusion, TransformerEncoderLayer
17
+ from sam3.model.geometry_encoders import SequenceGeometryEncoder
18
+ from sam3.model.maskformer_segmentation import PixelDecoder, UniversalSegmentationHead
19
+ from sam3.model.memory import (
20
+ CXBlock,
21
+ SimpleFuser,
22
+ SimpleMaskDownSampler,
23
+ SimpleMaskEncoder,
24
+ )
25
+ from sam3.model.model_misc import (
26
+ DotProductScoring,
27
+ MLP,
28
+ MultiheadAttentionWrapper as MultiheadAttention,
29
+ TransformerWrapper,
30
+ )
31
+ from sam3.model.necks import Sam3DualViTDetNeck
32
+ from sam3.model.position_encoding import PositionEmbeddingSine
33
+ from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
34
+ from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU
35
+ from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor
36
+ from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity
37
+ from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU
38
+ from sam3.model.text_encoder_ve import VETextEncoder
39
+ from sam3.model.tokenizer_ve import SimpleTokenizer
40
+ from sam3.model.vitdet import ViT
41
+ from sam3.model.vl_combiner import SAM3VLBackbone
42
+ from sam3.sam.transformer import RoPEAttention
43
+
44
+
45
+ # Setup TensorFloat-32 for Ampere GPUs if available
46
+ def _setup_tf32() -> None:
47
+ """Enable TensorFloat-32 for Ampere GPUs if available."""
48
+ if torch.cuda.is_available():
49
+ device_props = torch.cuda.get_device_properties(0)
50
+ if device_props.major >= 8:
51
+ torch.backends.cuda.matmul.allow_tf32 = True
52
+ torch.backends.cudnn.allow_tf32 = True
53
+
54
+
55
+ _setup_tf32()
56
+
57
+
58
+ def _create_position_encoding(precompute_resolution=None):
59
+ """Create position encoding for visual backbone."""
60
+ return PositionEmbeddingSine(
61
+ num_pos_feats=256,
62
+ normalize=True,
63
+ scale=None,
64
+ temperature=10000,
65
+ precompute_resolution=precompute_resolution,
66
+ )
67
+
68
+
69
+ def _create_vit_backbone(compile_mode=None):
70
+ """Create ViT backbone for visual feature extraction."""
71
+ return ViT(
72
+ img_size=1008,
73
+ pretrain_img_size=336,
74
+ patch_size=14,
75
+ embed_dim=1024,
76
+ depth=32,
77
+ num_heads=16,
78
+ mlp_ratio=4.625,
79
+ norm_layer="LayerNorm",
80
+ drop_path_rate=0.1,
81
+ qkv_bias=True,
82
+ use_abs_pos=True,
83
+ tile_abs_pos=True,
84
+ global_att_blocks=(7, 15, 23, 31),
85
+ rel_pos_blocks=(),
86
+ use_rope=True,
87
+ use_interp_rope=True,
88
+ window_size=24,
89
+ pretrain_use_cls_token=True,
90
+ retain_cls_token=False,
91
+ ln_pre=True,
92
+ ln_post=False,
93
+ return_interm_layers=False,
94
+ bias_patch_embed=False,
95
+ compile_mode=compile_mode,
96
+ )
97
+
98
+
99
+ def _create_vit_neck(position_encoding, vit_backbone, enable_inst_interactivity=False):
100
+ """Create ViT neck for feature pyramid."""
101
+ return Sam3DualViTDetNeck(
102
+ position_encoding=position_encoding,
103
+ d_model=256,
104
+ scale_factors=[4.0, 2.0, 1.0, 0.5],
105
+ trunk=vit_backbone,
106
+ add_sam2_neck=enable_inst_interactivity,
107
+ )
108
+
109
+
110
+ def _create_vl_backbone(vit_neck, text_encoder):
111
+ """Create visual-language backbone."""
112
+ return SAM3VLBackbone(visual=vit_neck, text=text_encoder, scalp=1)
113
+
114
+
115
+ def _create_transformer_encoder() -> TransformerEncoderFusion:
116
+ """Create transformer encoder with its layer."""
117
+ encoder_layer = TransformerEncoderLayer(
118
+ activation="relu",
119
+ d_model=256,
120
+ dim_feedforward=2048,
121
+ dropout=0.1,
122
+ pos_enc_at_attn=True,
123
+ pos_enc_at_cross_attn_keys=False,
124
+ pos_enc_at_cross_attn_queries=False,
125
+ pre_norm=True,
126
+ self_attention=MultiheadAttention(
127
+ num_heads=8,
128
+ dropout=0.1,
129
+ embed_dim=256,
130
+ batch_first=True,
131
+ ),
132
+ cross_attention=MultiheadAttention(
133
+ num_heads=8,
134
+ dropout=0.1,
135
+ embed_dim=256,
136
+ batch_first=True,
137
+ ),
138
+ )
139
+
140
+ encoder = TransformerEncoderFusion(
141
+ layer=encoder_layer,
142
+ num_layers=6,
143
+ d_model=256,
144
+ num_feature_levels=1,
145
+ frozen=False,
146
+ use_act_checkpoint=True,
147
+ add_pooled_text_to_img_feat=False,
148
+ pool_text_with_mask=True,
149
+ )
150
+ return encoder
151
+
152
+
153
+ def _create_transformer_decoder() -> TransformerDecoder:
154
+ """Create transformer decoder with its layer."""
155
+ decoder_layer = TransformerDecoderLayer(
156
+ activation="relu",
157
+ d_model=256,
158
+ dim_feedforward=2048,
159
+ dropout=0.1,
160
+ cross_attention=MultiheadAttention(
161
+ num_heads=8,
162
+ dropout=0.1,
163
+ embed_dim=256,
164
+ ),
165
+ n_heads=8,
166
+ use_text_cross_attention=True,
167
+ )
168
+
169
+ decoder = TransformerDecoder(
170
+ layer=decoder_layer,
171
+ num_layers=6,
172
+ num_queries=200,
173
+ return_intermediate=True,
174
+ box_refine=True,
175
+ num_o2m_queries=0,
176
+ dac=True,
177
+ boxRPB="log",
178
+ d_model=256,
179
+ frozen=False,
180
+ interaction_layer=None,
181
+ dac_use_selfatt_ln=True,
182
+ resolution=1008,
183
+ stride=14,
184
+ use_act_checkpoint=True,
185
+ presence_token=True,
186
+ )
187
+ return decoder
188
+
189
+
190
+ def _create_dot_product_scoring():
191
+ """Create dot product scoring module."""
192
+ prompt_mlp = MLP(
193
+ input_dim=256,
194
+ hidden_dim=2048,
195
+ output_dim=256,
196
+ num_layers=2,
197
+ dropout=0.1,
198
+ residual=True,
199
+ out_norm=nn.LayerNorm(256),
200
+ )
201
+ return DotProductScoring(d_model=256, d_proj=256, prompt_mlp=prompt_mlp)
202
+
203
+
204
+ def _create_segmentation_head(compile_mode=None):
205
+ """Create segmentation head with pixel decoder."""
206
+ pixel_decoder = PixelDecoder(
207
+ num_upsampling_stages=3,
208
+ interpolation_mode="nearest",
209
+ hidden_dim=256,
210
+ compile_mode=compile_mode,
211
+ )
212
+
213
+ cross_attend_prompt = MultiheadAttention(
214
+ num_heads=8,
215
+ dropout=0,
216
+ embed_dim=256,
217
+ )
218
+
219
+ segmentation_head = UniversalSegmentationHead(
220
+ hidden_dim=256,
221
+ upsampling_stages=3,
222
+ aux_masks=False,
223
+ presence_head=False,
224
+ dot_product_scorer=None,
225
+ act_ckpt=True,
226
+ cross_attend_prompt=cross_attend_prompt,
227
+ pixel_decoder=pixel_decoder,
228
+ )
229
+ return segmentation_head
230
+
231
+
232
+ def _create_geometry_encoder():
233
+ """Create geometry encoder with all its components."""
234
+ # Create position encoding for geometry encoder
235
+ geo_pos_enc = _create_position_encoding()
236
+ # Create CX block for fuser
237
+ cx_block = CXBlock(
238
+ dim=256,
239
+ kernel_size=7,
240
+ padding=3,
241
+ layer_scale_init_value=1.0e-06,
242
+ use_dwconv=True,
243
+ )
244
+ # Create geometry encoder layer
245
+ geo_layer = TransformerEncoderLayer(
246
+ activation="relu",
247
+ d_model=256,
248
+ dim_feedforward=2048,
249
+ dropout=0.1,
250
+ pos_enc_at_attn=False,
251
+ pre_norm=True,
252
+ self_attention=MultiheadAttention(
253
+ num_heads=8,
254
+ dropout=0.1,
255
+ embed_dim=256,
256
+ batch_first=False,
257
+ ),
258
+ pos_enc_at_cross_attn_queries=False,
259
+ pos_enc_at_cross_attn_keys=True,
260
+ cross_attention=MultiheadAttention(
261
+ num_heads=8,
262
+ dropout=0.1,
263
+ embed_dim=256,
264
+ batch_first=False,
265
+ ),
266
+ )
267
+
268
+ # Create geometry encoder
269
+ input_geometry_encoder = SequenceGeometryEncoder(
270
+ pos_enc=geo_pos_enc,
271
+ encode_boxes_as_points=False,
272
+ points_direct_project=True,
273
+ points_pool=True,
274
+ points_pos_enc=True,
275
+ boxes_direct_project=True,
276
+ boxes_pool=True,
277
+ boxes_pos_enc=True,
278
+ d_model=256,
279
+ num_layers=3,
280
+ layer=geo_layer,
281
+ use_act_ckpt=True,
282
+ add_cls=True,
283
+ add_post_encode_proj=True,
284
+ )
285
+ return input_geometry_encoder
286
+
287
+
288
+ def _create_sam3_model(
289
+ backbone,
290
+ transformer,
291
+ input_geometry_encoder,
292
+ segmentation_head,
293
+ dot_prod_scoring,
294
+ inst_interactive_predictor,
295
+ eval_mode,
296
+ ):
297
+ """Create the SAM3 image model."""
298
+ common_params = {
299
+ "backbone": backbone,
300
+ "transformer": transformer,
301
+ "input_geometry_encoder": input_geometry_encoder,
302
+ "segmentation_head": segmentation_head,
303
+ "num_feature_levels": 1,
304
+ "o2m_mask_predict": True,
305
+ "dot_prod_scoring": dot_prod_scoring,
306
+ "use_instance_query": False,
307
+ "multimask_output": True,
308
+ "inst_interactive_predictor": inst_interactive_predictor,
309
+ }
310
+
311
+ matcher = None
312
+ if not eval_mode:
313
+ from sam3.train.matcher import BinaryHungarianMatcherV2
314
+
315
+ matcher = BinaryHungarianMatcherV2(
316
+ focal=True,
317
+ cost_class=2.0,
318
+ cost_bbox=5.0,
319
+ cost_giou=2.0,
320
+ alpha=0.25,
321
+ gamma=2,
322
+ stable=False,
323
+ )
324
+ common_params["matcher"] = matcher
325
+ model = Sam3Image(**common_params)
326
+
327
+ return model
328
+
329
+
330
+ def _create_tracker_maskmem_backbone():
331
+ """Create the SAM3 Tracker memory encoder."""
332
+ # Position encoding for mask memory backbone
333
+ position_encoding = PositionEmbeddingSine(
334
+ num_pos_feats=64,
335
+ normalize=True,
336
+ scale=None,
337
+ temperature=10000,
338
+ precompute_resolution=1008,
339
+ )
340
+
341
+ # Mask processing components
342
+ mask_downsampler = SimpleMaskDownSampler(
343
+ kernel_size=3, stride=2, padding=1, interpol_size=[1152, 1152]
344
+ )
345
+
346
+ cx_block_layer = CXBlock(
347
+ dim=256,
348
+ kernel_size=7,
349
+ padding=3,
350
+ layer_scale_init_value=1.0e-06,
351
+ use_dwconv=True,
352
+ )
353
+
354
+ fuser = SimpleFuser(layer=cx_block_layer, num_layers=2)
355
+
356
+ maskmem_backbone = SimpleMaskEncoder(
357
+ out_dim=64,
358
+ position_encoding=position_encoding,
359
+ mask_downsampler=mask_downsampler,
360
+ fuser=fuser,
361
+ )
362
+
363
+ return maskmem_backbone
364
+
365
+
366
+ def _create_tracker_transformer():
367
+ """Create the SAM3 Tracker transformer components."""
368
+ # Self attention
369
+ self_attention = RoPEAttention(
370
+ embedding_dim=256,
371
+ num_heads=1,
372
+ downsample_rate=1,
373
+ dropout=0.1,
374
+ rope_theta=10000.0,
375
+ feat_sizes=[72, 72],
376
+ use_fa3=False,
377
+ use_rope_real=False,
378
+ )
379
+
380
+ # Cross attention
381
+ cross_attention = RoPEAttention(
382
+ embedding_dim=256,
383
+ num_heads=1,
384
+ downsample_rate=1,
385
+ dropout=0.1,
386
+ kv_in_dim=64,
387
+ rope_theta=10000.0,
388
+ feat_sizes=[72, 72],
389
+ rope_k_repeat=True,
390
+ use_fa3=False,
391
+ use_rope_real=False,
392
+ )
393
+
394
+ # Encoder layer
395
+ encoder_layer = TransformerDecoderLayerv2(
396
+ cross_attention_first=False,
397
+ activation="relu",
398
+ dim_feedforward=2048,
399
+ dropout=0.1,
400
+ pos_enc_at_attn=False,
401
+ pre_norm=True,
402
+ self_attention=self_attention,
403
+ d_model=256,
404
+ pos_enc_at_cross_attn_keys=True,
405
+ pos_enc_at_cross_attn_queries=False,
406
+ cross_attention=cross_attention,
407
+ )
408
+
409
+ # Encoder
410
+ encoder = TransformerEncoderCrossAttention(
411
+ remove_cross_attention_layers=[],
412
+ batch_first=True,
413
+ d_model=256,
414
+ frozen=False,
415
+ pos_enc_at_input=True,
416
+ layer=encoder_layer,
417
+ num_layers=4,
418
+ use_act_checkpoint=False,
419
+ )
420
+
421
+ # Transformer wrapper
422
+ transformer = TransformerWrapper(
423
+ encoder=encoder,
424
+ decoder=None,
425
+ d_model=256,
426
+ )
427
+
428
+ return transformer
429
+
430
+
431
+ def build_tracker(
432
+ apply_temporal_disambiguation: bool, with_backbone: bool = False, compile_mode=None
433
+ ) -> Sam3TrackerPredictor:
434
+ """
435
+ Build the SAM3 Tracker module for video tracking.
436
+
437
+ Returns:
438
+ Sam3TrackerPredictor: Wrapped SAM3 Tracker module
439
+ """
440
+
441
+ # Create model components
442
+ maskmem_backbone = _create_tracker_maskmem_backbone()
443
+ transformer = _create_tracker_transformer()
444
+ backbone = None
445
+ if with_backbone:
446
+ vision_backbone = _create_vision_backbone(compile_mode=compile_mode)
447
+ backbone = SAM3VLBackbone(scalp=1, visual=vision_backbone, text=None)
448
+ # Create the Tracker module
449
+ model = Sam3TrackerPredictor(
450
+ image_size=1008,
451
+ num_maskmem=7,
452
+ backbone=backbone,
453
+ backbone_stride=14,
454
+ transformer=transformer,
455
+ maskmem_backbone=maskmem_backbone,
456
+ # SAM parameters
457
+ multimask_output_in_sam=True,
458
+ # Evaluation
459
+ forward_backbone_per_frame_for_eval=True,
460
+ trim_past_non_cond_mem_for_eval=False,
461
+ # Multimask
462
+ multimask_output_for_tracking=True,
463
+ multimask_min_pt_num=0,
464
+ multimask_max_pt_num=1,
465
+ # Additional settings
466
+ always_start_from_first_ann_frame=False,
467
+ # Mask overlap
468
+ non_overlap_masks_for_mem_enc=False,
469
+ non_overlap_masks_for_output=False,
470
+ max_cond_frames_in_attn=4,
471
+ offload_output_to_cpu_for_eval=False,
472
+ # SAM decoder settings
473
+ sam_mask_decoder_extra_args={
474
+ "dynamic_multimask_via_stability": True,
475
+ "dynamic_multimask_stability_delta": 0.05,
476
+ "dynamic_multimask_stability_thresh": 0.98,
477
+ },
478
+ clear_non_cond_mem_around_input=True,
479
+ fill_hole_area=0,
480
+ use_memory_selection=apply_temporal_disambiguation,
481
+ )
482
+
483
+ return model
484
+
485
+
486
+ def _create_text_encoder(bpe_path: str) -> VETextEncoder:
487
+ """Create SAM3 text encoder."""
488
+ tokenizer = SimpleTokenizer(bpe_path=bpe_path)
489
+ return VETextEncoder(
490
+ tokenizer=tokenizer,
491
+ d_model=256,
492
+ width=1024,
493
+ heads=16,
494
+ layers=24,
495
+ )
496
+
497
+
498
+ def _create_vision_backbone(
499
+ compile_mode=None, enable_inst_interactivity=True
500
+ ) -> Sam3DualViTDetNeck:
501
+ """Create SAM3 visual backbone with ViT and neck."""
502
+ # Position encoding
503
+ position_encoding = _create_position_encoding(precompute_resolution=1008)
504
+ # ViT backbone
505
+ vit_backbone: ViT = _create_vit_backbone(compile_mode=compile_mode)
506
+ vit_neck: Sam3DualViTDetNeck = _create_vit_neck(
507
+ position_encoding,
508
+ vit_backbone,
509
+ enable_inst_interactivity=enable_inst_interactivity,
510
+ )
511
+ # Visual neck
512
+ return vit_neck
513
+
514
+
515
+ def _create_sam3_transformer(has_presence_token: bool = True) -> TransformerWrapper:
516
+ """Create SAM3 transformer encoder and decoder."""
517
+ encoder: TransformerEncoderFusion = _create_transformer_encoder()
518
+ decoder: TransformerDecoder = _create_transformer_decoder()
519
+
520
+ return TransformerWrapper(encoder=encoder, decoder=decoder, d_model=256)
521
+
522
+
523
+ def _load_checkpoint(model, checkpoint_path):
524
+ """Load model checkpoint from file."""
525
+ with g_pathmgr.open(checkpoint_path, "rb") as f:
526
+ ckpt = torch.load(f, map_location="cpu", weights_only=True)
527
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
528
+ ckpt = ckpt["model"]
529
+ sam3_image_ckpt = {
530
+ k.replace("detector.", ""): v for k, v in ckpt.items() if "detector" in k
531
+ }
532
+ if model.inst_interactive_predictor is not None:
533
+ sam3_image_ckpt.update(
534
+ {
535
+ k.replace("tracker.", "inst_interactive_predictor.model."): v
536
+ for k, v in ckpt.items()
537
+ if "tracker" in k
538
+ }
539
+ )
540
+ missing_keys, _ = model.load_state_dict(sam3_image_ckpt, strict=False)
541
+ if len(missing_keys) > 0:
542
+ print(
543
+ f"loaded {checkpoint_path} and found "
544
+ f"missing and/or unexpected keys:\n{missing_keys=}"
545
+ )
546
+
547
+
548
+ def _setup_device_and_mode(model, device, eval_mode):
549
+ """Setup model device and evaluation mode."""
550
+ if device == "cuda":
551
+ model = model.cuda()
552
+ if eval_mode:
553
+ model.eval()
554
+ return model
555
+
556
+
557
+ def build_sam3_image_model(
558
+ bpe_path=None,
559
+ device="cuda" if torch.cuda.is_available() else "cpu",
560
+ eval_mode=True,
561
+ checkpoint_path=None,
562
+ load_from_HF=True,
563
+ enable_segmentation=True,
564
+ enable_inst_interactivity=False,
565
+ compile=False,
566
+ ):
567
+ """
568
+ Build SAM3 image model
569
+
570
+ Args:
571
+ bpe_path: Path to the BPE tokenizer vocabulary
572
+ device: Device to load the model on ('cuda' or 'cpu')
573
+ eval_mode: Whether to set the model to evaluation mode
574
+ checkpoint_path: Optional path to model checkpoint
575
+ enable_segmentation: Whether to enable segmentation head
576
+ enable_inst_interactivity: Whether to enable instance interactivity (SAM 1 task)
577
+ compile_mode: To enable compilation, set to "default"
578
+
579
+ Returns:
580
+ A SAM3 image model
581
+ """
582
+ if bpe_path is None:
583
+ bpe_path = os.path.join(
584
+ os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz"
585
+ )
586
+ # Create visual components
587
+ compile_mode = "default" if compile else None
588
+ vision_encoder = _create_vision_backbone(
589
+ compile_mode=compile_mode, enable_inst_interactivity=enable_inst_interactivity
590
+ )
591
+
592
+ # Create text components
593
+ text_encoder = _create_text_encoder(bpe_path)
594
+
595
+ # Create visual-language backbone
596
+ backbone = _create_vl_backbone(vision_encoder, text_encoder)
597
+
598
+ # Create transformer components
599
+ transformer = _create_sam3_transformer()
600
+
601
+ # Create dot product scoring
602
+ dot_prod_scoring = _create_dot_product_scoring()
603
+
604
+ # Create segmentation head if enabled
605
+ segmentation_head = (
606
+ _create_segmentation_head(compile_mode=compile_mode)
607
+ if enable_segmentation
608
+ else None
609
+ )
610
+
611
+ # Create geometry encoder
612
+ input_geometry_encoder = _create_geometry_encoder()
613
+ if enable_inst_interactivity:
614
+ sam3_pvs_base = build_tracker(apply_temporal_disambiguation=False)
615
+ inst_predictor = SAM3InteractiveImagePredictor(sam3_pvs_base)
616
+ else:
617
+ inst_predictor = None
618
+ # Create the SAM3 model
619
+ model = _create_sam3_model(
620
+ backbone,
621
+ transformer,
622
+ input_geometry_encoder,
623
+ segmentation_head,
624
+ dot_prod_scoring,
625
+ inst_predictor,
626
+ eval_mode,
627
+ )
628
+ if load_from_HF and checkpoint_path is None:
629
+ checkpoint_path = download_ckpt_from_hf()
630
+ # Load checkpoint if provided
631
+ if checkpoint_path is not None:
632
+ _load_checkpoint(model, checkpoint_path)
633
+
634
+ # Setup device and mode
635
+ model = _setup_device_and_mode(model, device, eval_mode)
636
+
637
+ return model
638
+
639
+
640
+ def download_ckpt_from_hf():
641
+ SAM3_MODEL_ID = "facebook/sam3"
642
+ SAM3_CKPT_NAME = "sam3.pt"
643
+ SAM3_CFG_NAME = "config.json"
644
+ _ = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CFG_NAME)
645
+ checkpoint_path = hf_hub_download(repo_id=SAM3_MODEL_ID, filename=SAM3_CKPT_NAME)
646
+ return checkpoint_path
647
+
648
+
649
+ def build_sam3_video_model(
650
+ checkpoint_path: Optional[str] = None,
651
+ load_from_HF=True,
652
+ bpe_path: Optional[str] = None,
653
+ has_presence_token: bool = True,
654
+ geo_encoder_use_img_cross_attn: bool = True,
655
+ strict_state_dict_loading: bool = True,
656
+ apply_temporal_disambiguation: bool = True,
657
+ device="cuda" if torch.cuda.is_available() else "cpu",
658
+ compile=False,
659
+ ) -> Sam3VideoInferenceWithInstanceInteractivity:
660
+ """
661
+ Build SAM3 dense tracking model.
662
+
663
+ Args:
664
+ checkpoint_path: Optional path to checkpoint file
665
+ bpe_path: Path to the BPE tokenizer file
666
+
667
+ Returns:
668
+ Sam3VideoInferenceWithInstanceInteractivity: The instantiated dense tracking model
669
+ """
670
+ if bpe_path is None:
671
+ bpe_path = os.path.join(
672
+ os.path.dirname(__file__), "..", "assets", "bpe_simple_vocab_16e6.txt.gz"
673
+ )
674
+
675
+ # Build Tracker module
676
+ tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation)
677
+
678
+ # Build Detector components
679
+ visual_neck = _create_vision_backbone()
680
+ text_encoder = _create_text_encoder(bpe_path)
681
+ backbone = SAM3VLBackbone(scalp=1, visual=visual_neck, text=text_encoder)
682
+ transformer = _create_sam3_transformer(has_presence_token=has_presence_token)
683
+ segmentation_head: UniversalSegmentationHead = _create_segmentation_head()
684
+ input_geometry_encoder = _create_geometry_encoder()
685
+
686
+ # Create main dot product scoring
687
+ main_dot_prod_mlp = MLP(
688
+ input_dim=256,
689
+ hidden_dim=2048,
690
+ output_dim=256,
691
+ num_layers=2,
692
+ dropout=0.1,
693
+ residual=True,
694
+ out_norm=nn.LayerNorm(256),
695
+ )
696
+ main_dot_prod_scoring = DotProductScoring(
697
+ d_model=256, d_proj=256, prompt_mlp=main_dot_prod_mlp
698
+ )
699
+
700
+ # Build Detector module
701
+ detector = Sam3ImageOnVideoMultiGPU(
702
+ num_feature_levels=1,
703
+ backbone=backbone,
704
+ transformer=transformer,
705
+ segmentation_head=segmentation_head,
706
+ semantic_segmentation_head=None,
707
+ input_geometry_encoder=input_geometry_encoder,
708
+ use_early_fusion=True,
709
+ use_dot_prod_scoring=True,
710
+ dot_prod_scoring=main_dot_prod_scoring,
711
+ supervise_joint_box_scores=has_presence_token,
712
+ )
713
+
714
+ # Build the main SAM3 video model
715
+ if apply_temporal_disambiguation:
716
+ model = Sam3VideoInferenceWithInstanceInteractivity(
717
+ detector=detector,
718
+ tracker=tracker,
719
+ score_threshold_detection=0.5,
720
+ assoc_iou_thresh=0.1,
721
+ det_nms_thresh=0.1,
722
+ new_det_thresh=0.7,
723
+ hotstart_delay=15,
724
+ hotstart_unmatch_thresh=8,
725
+ hotstart_dup_thresh=8,
726
+ suppress_unmatched_only_within_hotstart=True,
727
+ min_trk_keep_alive=-1,
728
+ max_trk_keep_alive=30,
729
+ init_trk_keep_alive=30,
730
+ suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
731
+ suppress_det_close_to_boundary=False,
732
+ fill_hole_area=16,
733
+ recondition_every_nth_frame=16,
734
+ masklet_confirmation_enable=False,
735
+ decrease_trk_keep_alive_for_empty_masklets=False,
736
+ image_size=1008,
737
+ image_mean=(0.5, 0.5, 0.5),
738
+ image_std=(0.5, 0.5, 0.5),
739
+ compile_model=compile,
740
+ )
741
+ else:
742
+ # a version without any heuristics for ablation studies
743
+ model = Sam3VideoInferenceWithInstanceInteractivity(
744
+ detector=detector,
745
+ tracker=tracker,
746
+ score_threshold_detection=0.5,
747
+ assoc_iou_thresh=0.1,
748
+ det_nms_thresh=0.1,
749
+ new_det_thresh=0.7,
750
+ hotstart_delay=0,
751
+ hotstart_unmatch_thresh=0,
752
+ hotstart_dup_thresh=0,
753
+ suppress_unmatched_only_within_hotstart=True,
754
+ min_trk_keep_alive=-1,
755
+ max_trk_keep_alive=30,
756
+ init_trk_keep_alive=30,
757
+ suppress_overlapping_based_on_recent_occlusion_threshold=0.7,
758
+ suppress_det_close_to_boundary=False,
759
+ fill_hole_area=16,
760
+ recondition_every_nth_frame=0,
761
+ masklet_confirmation_enable=False,
762
+ decrease_trk_keep_alive_for_empty_masklets=False,
763
+ image_size=1008,
764
+ image_mean=(0.5, 0.5, 0.5),
765
+ image_std=(0.5, 0.5, 0.5),
766
+ compile_model=compile,
767
+ )
768
+
769
+ # Load checkpoint if provided
770
+ if load_from_HF and checkpoint_path is None:
771
+ checkpoint_path = download_ckpt_from_hf()
772
+ if checkpoint_path is not None:
773
+ with g_pathmgr.open(checkpoint_path, "rb") as f:
774
+ ckpt = torch.load(f, map_location="cpu", weights_only=True)
775
+ if "model" in ckpt and isinstance(ckpt["model"], dict):
776
+ ckpt = ckpt["model"]
777
+
778
+ missing_keys, unexpected_keys = model.load_state_dict(
779
+ ckpt, strict=strict_state_dict_loading
780
+ )
781
+ if missing_keys:
782
+ print(f"Missing keys: {missing_keys}")
783
+ if unexpected_keys:
784
+ print(f"Unexpected keys: {unexpected_keys}")
785
+
786
+ model.to(device=device)
787
+ return model
788
+
789
+
790
+ def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs):
791
+ return Sam3VideoPredictorMultiGPU(
792
+ *model_args, gpus_to_use=gpus_to_use, **model_kwargs
793
+ )
detect_tools/sam3/sam3/perflib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ import os
4
+
5
+ is_enabled = False
6
+ if os.getenv("USE_PERFLIB", "1") == "1":
7
+ # print("Enabled the use of perflib.\n", end="")
8
+ is_enabled = True
detect_tools/sam3/sam3/perflib/associate_det_trk.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2
+
3
+ from collections import defaultdict
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from sam3.perflib.masks_ops import mask_iou
8
+ from scipy.optimize import linear_sum_assignment
9
+
10
+
11
+ def associate_det_trk(
12
+ det_masks,
13
+ track_masks,
14
+ iou_threshold=0.5,
15
+ iou_threshold_trk=0.5,
16
+ det_scores=None,
17
+ new_det_thresh=0.0,
18
+ ):
19
+ """
20
+ Optimized implementation of detection <-> track association that minimizes DtoH syncs.
21
+
22
+ Args:
23
+ det_masks: (N, H, W) tensor of predicted masks
24
+ track_masks: (M, H, W) tensor of track masks
25
+
26
+ Returns:
27
+ new_det_indices: list of indices in det_masks considered 'new'
28
+ unmatched_trk_indices: list of indices in track_masks considered 'unmatched'
29
+ """
30
+ with torch.autograd.profiler.record_function("perflib: associate_det_trk"):
31
+ assert isinstance(det_masks, torch.Tensor), "det_masks should be a tensor"
32
+ assert isinstance(track_masks, torch.Tensor), "track_masks should be a tensor"
33
+ if det_masks.size(0) == 0 or track_masks.size(0) == 0:
34
+ return list(range(det_masks.size(0))), [], {}, {} # all detections are new
35
+
36
+ if list(det_masks.shape[-2:]) != list(track_masks.shape[-2:]):
37
+ # resize to the smaller size to save GPU memory
38
+ if torch.numel(det_masks[-2:]) < torch.numel(track_masks[-2:]):
39
+ track_masks = (
40
+ F.interpolate(
41
+ track_masks.unsqueeze(1).float(),
42
+ size=det_masks.shape[-2:],
43
+ mode="bilinear",
44
+ align_corners=False,
45
+ ).squeeze(1)
46
+ > 0
47
+ )
48
+ else:
49
+ # resize detections to track size
50
+ det_masks = (
51
+ F.interpolate(
52
+ det_masks.unsqueeze(1).float(),
53
+ size=track_masks.shape[-2:],
54
+ mode="bilinear",
55
+ align_corners=False,
56
+ ).squeeze(1)
57
+ > 0
58
+ )
59
+
60
+ det_masks = det_masks > 0
61
+ track_masks = track_masks > 0
62
+
63
+ iou = mask_iou(det_masks, track_masks) # (N, M)
64
+ igeit = iou >= iou_threshold
65
+ igeit_any_dim_1 = igeit.any(dim=1)
66
+ igeit_trk = iou >= iou_threshold_trk
67
+
68
+ iou_list = iou.cpu().numpy().tolist()
69
+ igeit_list = igeit.cpu().numpy().tolist()
70
+ igeit_any_dim_1_list = igeit_any_dim_1.cpu().numpy().tolist()
71
+ igeit_trk_list = igeit_trk.cpu().numpy().tolist()
72
+
73
+ det_scores_list = (
74
+ det_scores
75
+ if det_scores is None
76
+ else det_scores.cpu().float().numpy().tolist()
77
+ )
78
+
79
+ # Hungarian matching for tracks (one-to-one: each track matches at most one detection)
80
+ # For detections: allow many tracks to match to the same detection (many-to-one)
81
+
82
+ # If either is empty, return all detections as new
83
+ if det_masks.size(0) == 0 or track_masks.size(0) == 0:
84
+ return list(range(det_masks.size(0))), [], {}
85
+
86
+ # Hungarian matching: maximize IoU for tracks
87
+ cost_matrix = 1 - iou.cpu().numpy() # Hungarian solves for minimum cost
88
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
89
+
90
+ def branchy_hungarian_better_uses_the_cpu(
91
+ cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
92
+ ):
93
+ matched_trk = set()
94
+ matched_det = set()
95
+ matched_det_scores = {} # track index -> [det_score, det_score * iou] det score of matched detection mask
96
+ for d, t in zip(row_ind, col_ind):
97
+ matched_det_scores[t] = [
98
+ det_scores_list[d],
99
+ det_scores_list[d] * iou_list[d][t],
100
+ ]
101
+ if igeit_trk_list[d][t]:
102
+ matched_trk.add(t)
103
+ matched_det.add(d)
104
+
105
+ # Tracks not matched by Hungarian assignment above threshold are unmatched
106
+ unmatched_trk_indices = [
107
+ t for t in range(track_masks.size(0)) if t not in matched_trk
108
+ ]
109
+
110
+ # For detections: allow many tracks to match to the same detection (many-to-one)
111
+ # So, a detection is 'new' if it does not match any track above threshold
112
+ assert track_masks.size(0) == igeit.size(
113
+ 1
114
+ ) # Needed for loop optimizaiton below
115
+ new_det_indices = []
116
+ for d in range(det_masks.size(0)):
117
+ if not igeit_any_dim_1_list[d]:
118
+ if det_scores is not None and det_scores[d] >= new_det_thresh:
119
+ new_det_indices.append(d)
120
+
121
+ # for each detection, which tracks it matched to (above threshold)
122
+ det_to_matched_trk = defaultdict(list)
123
+ for d in range(det_masks.size(0)):
124
+ for t in range(track_masks.size(0)):
125
+ if igeit_list[d][t]:
126
+ det_to_matched_trk[d].append(t)
127
+
128
+ return (
129
+ new_det_indices,
130
+ unmatched_trk_indices,
131
+ det_to_matched_trk,
132
+ matched_det_scores,
133
+ )
134
+
135
+ return (branchy_hungarian_better_uses_the_cpu)(
136
+ cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
137
+ )