apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -296,9 +296,9 @@ def agent_inference(
|
|||||||
assert LATEST_SAM3_TEXT_PROMPT != ""
|
assert LATEST_SAM3_TEXT_PROMPT != ""
|
||||||
|
|
||||||
# Make sure that the last message is a image
|
# Make sure that the last message is a image
|
||||||
assert (
|
assert messages[-1]["content"][1]["type"] == "image", (
|
||||||
messages[-1]["content"][1]["type"] == "image"
|
"Second content element should be an image"
|
||||||
), "Second content element should be an image"
|
)
|
||||||
messages.pop() # Remove the last user message
|
messages.pop() # Remove the last user message
|
||||||
# Add simplified replacement message
|
# Add simplified replacement message
|
||||||
simplified_message = {
|
simplified_message = {
|
||||||
@@ -318,7 +318,7 @@ def agent_inference(
|
|||||||
|
|
||||||
# MLLM check the mask one by one
|
# MLLM check the mask one by one
|
||||||
for i in range(num_masks):
|
for i in range(num_masks):
|
||||||
print(f"🔍 Checking mask {i+1}/{num_masks}...")
|
print(f"🔍 Checking mask {i + 1}/{num_masks}...")
|
||||||
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
|
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
|
||||||
|
|
||||||
image_w_zoomed_in_mask_i_path = os.path.join(
|
image_w_zoomed_in_mask_i_path = os.path.join(
|
||||||
@@ -363,7 +363,7 @@ def agent_inference(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
|
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
|
||||||
)
|
)
|
||||||
print(f"Generated text for mask {i+1}: {checking_generated_text}")
|
print(f"Generated text for mask {i + 1}: {checking_generated_text}")
|
||||||
verdict = (
|
verdict = (
|
||||||
checking_generated_text.split("<verdict>")[-1]
|
checking_generated_text.split("<verdict>")[-1]
|
||||||
.split("</verdict>")[0]
|
.split("</verdict>")[0]
|
||||||
@@ -371,11 +371,11 @@ def agent_inference(
|
|||||||
)
|
)
|
||||||
if "Accept" in verdict:
|
if "Accept" in verdict:
|
||||||
assert not "Reject" in verdict
|
assert not "Reject" in verdict
|
||||||
print(f"Mask {i+1} accepted, keeping it in the outputs.")
|
print(f"Mask {i + 1} accepted, keeping it in the outputs.")
|
||||||
masks_to_keep.append(i)
|
masks_to_keep.append(i)
|
||||||
elif "Reject" in verdict:
|
elif "Reject" in verdict:
|
||||||
assert not "Accept" in verdict
|
assert not "Accept" in verdict
|
||||||
print(f"Mask {i+1} rejected, removing it from the outputs.")
|
print(f"Mask {i + 1} rejected, removing it from the outputs.")
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
|
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
|
||||||
@@ -397,7 +397,7 @@ def agent_inference(
|
|||||||
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
|
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
|
||||||
).replace(
|
).replace(
|
||||||
".png",
|
".png",
|
||||||
f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace(
|
f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
|
||||||
"/", "_"
|
"/", "_"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from sam3.model.box_ops import box_xyxy_to_xywh
|
from sam3.model.box_ops import box_xyxy_to_xywh
|
||||||
from sam3.train.masks_ops import rle_encode
|
from sam3.train.masks_ops import rle_encode
|
||||||
|
|
||||||
|
|||||||
@@ -84,9 +84,9 @@ class BoxMode(IntEnum):
|
|||||||
], "Relative mode not yet supported!"
|
], "Relative mode not yet supported!"
|
||||||
|
|
||||||
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
||||||
assert (
|
assert arr.shape[-1] == 5, (
|
||||||
arr.shape[-1] == 5
|
"The last dimension of input shape must be 5 for XYWHA format"
|
||||||
), "The last dimension of input shape must be 5 for XYWHA format"
|
)
|
||||||
original_dtype = arr.dtype
|
original_dtype = arr.dtype
|
||||||
arr = arr.double()
|
arr = arr.double()
|
||||||
|
|
||||||
@@ -244,9 +244,9 @@ class Boxes:
|
|||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return Boxes(self.tensor[item].view(1, -1))
|
return Boxes(self.tensor[item].view(1, -1))
|
||||||
b = self.tensor[item]
|
b = self.tensor[item]
|
||||||
assert (
|
assert b.dim() == 2, (
|
||||||
b.dim() == 2
|
"Indexing on Boxes with {} failed to return a matrix!".format(item)
|
||||||
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
|
)
|
||||||
return Boxes(b)
|
return Boxes(b)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -425,7 +425,7 @@ def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
|||||||
Tensor: iou, sized [N].
|
Tensor: iou, sized [N].
|
||||||
"""
|
"""
|
||||||
assert len(boxes1) == len(boxes2), (
|
assert len(boxes1) == len(boxes2), (
|
||||||
"boxlists should have the same" "number of entries, got {}, {}".format(
|
"boxlists should have the samenumber of entries, got {}, {}".format(
|
||||||
len(boxes1), len(boxes2)
|
len(boxes1), len(boxes2)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from torch import device
|
|||||||
|
|
||||||
from .boxes import Boxes
|
from .boxes import Boxes
|
||||||
from .memory import retry_if_cuda_oom
|
from .memory import retry_if_cuda_oom
|
||||||
|
|
||||||
from .roi_align import ROIAlign
|
from .roi_align import ROIAlign
|
||||||
|
|
||||||
|
|
||||||
@@ -142,10 +141,10 @@ class BitMasks:
|
|||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return BitMasks(self.tensor[item].unsqueeze(0))
|
return BitMasks(self.tensor[item].unsqueeze(0))
|
||||||
m = self.tensor[item]
|
m = self.tensor[item]
|
||||||
assert (
|
assert m.dim() == 3, (
|
||||||
m.dim() == 3
|
"Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
||||||
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
item, m.shape
|
||||||
item, m.shape
|
)
|
||||||
)
|
)
|
||||||
return BitMasks(m)
|
return BitMasks(m)
|
||||||
|
|
||||||
|
|||||||
@@ -363,9 +363,9 @@ class RotatedBoxes(Boxes):
|
|||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return RotatedBoxes(self.tensor[item].view(1, -1))
|
return RotatedBoxes(self.tensor[item].view(1, -1))
|
||||||
b = self.tensor[item]
|
b = self.tensor[item]
|
||||||
assert (
|
assert b.dim() == 2, (
|
||||||
b.dim() == 2
|
"Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
|
||||||
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
|
)
|
||||||
return RotatedBoxes(b)
|
return RotatedBoxes(b)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .boxes import Boxes, BoxMode
|
from .boxes import Boxes, BoxMode
|
||||||
|
|
||||||
from .color_map import random_color
|
from .color_map import random_color
|
||||||
from .keypoints import Keypoints
|
from .keypoints import Keypoints
|
||||||
from .masks import BitMasks, PolygonMasks
|
from .masks import BitMasks, PolygonMasks
|
||||||
@@ -222,9 +221,9 @@ class _PanopticPrediction:
|
|||||||
empty_ids.append(id)
|
empty_ids.append(id)
|
||||||
if len(empty_ids) == 0:
|
if len(empty_ids) == 0:
|
||||||
return np.zeros(self._seg.shape, dtype=np.uint8)
|
return np.zeros(self._seg.shape, dtype=np.uint8)
|
||||||
assert (
|
assert len(empty_ids) == 1, (
|
||||||
len(empty_ids) == 1
|
">1 ids corresponds to no labels. This is currently not supported"
|
||||||
), ">1 ids corresponds to no labels. This is currently not supported"
|
)
|
||||||
return (self._seg != empty_ids[0]).numpy().astype(np.bool)
|
return (self._seg != empty_ids[0]).numpy().astype(np.bool)
|
||||||
|
|
||||||
def semantic_masks(self):
|
def semantic_masks(self):
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def run_single_image_inference(
|
|||||||
print(f"Output JSON {output_json_path} already exists. Skipping.")
|
print(f"Output JSON {output_json_path} already exists. Skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ")
|
print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ")
|
||||||
agent_history, final_output_dict, rendered_final_output = agent_inference(
|
agent_history, final_output_dict, rendered_final_output = agent_inference(
|
||||||
image_path,
|
image_path,
|
||||||
text_prompt,
|
text_prompt,
|
||||||
@@ -50,7 +50,7 @@ def run_single_image_inference(
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
debug=debug,
|
debug=debug,
|
||||||
)
|
)
|
||||||
print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ")
|
print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ")
|
||||||
|
|
||||||
final_output_dict["text_prompt"] = text_prompt
|
final_output_dict["text_prompt"] = text_prompt
|
||||||
final_output_dict["image_path"] = image_path
|
final_output_dict["image_path"] = image_path
|
||||||
|
|||||||
@@ -73,7 +73,9 @@ def visualize(
|
|||||||
idx = int(zoom_in_index)
|
idx = int(zoom_in_index)
|
||||||
num_masks = len(input_json.get("pred_masks", []))
|
num_masks = len(input_json.get("pred_masks", []))
|
||||||
if idx < 0 or idx >= num_masks:
|
if idx < 0 or idx >= num_masks:
|
||||||
raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).")
|
raise ValueError(
|
||||||
|
f"zoom_in_index {idx} is out of range (0..{num_masks - 1})."
|
||||||
|
)
|
||||||
|
|
||||||
# (1) Replicate zoom_in_and_visualize
|
# (1) Replicate zoom_in_and_visualize
|
||||||
object_data = {
|
object_data = {
|
||||||
|
|||||||
@@ -126,9 +126,9 @@ class COCOCustom(COCO):
|
|||||||
# MODIFICATION: faster and cached subset check
|
# MODIFICATION: faster and cached subset check
|
||||||
if not hasattr(self, "img_id_set"):
|
if not hasattr(self, "img_id_set"):
|
||||||
self.img_id_set = set(self.getImgIds())
|
self.img_id_set = set(self.getImgIds())
|
||||||
assert set(annsImgIds).issubset(
|
assert set(annsImgIds).issubset(self.img_id_set), (
|
||||||
self.img_id_set
|
"Results do not correspond to current coco set"
|
||||||
), "Results do not correspond to current coco set"
|
)
|
||||||
# END MODIFICATION
|
# END MODIFICATION
|
||||||
if "caption" in anns[0]:
|
if "caption" in anns[0]:
|
||||||
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
||||||
@@ -301,9 +301,9 @@ class CGF1Eval(COCOeval):
|
|||||||
TP = (match_scores >= thresh).sum()
|
TP = (match_scores >= thresh).sum()
|
||||||
FP = len(dt) - TP
|
FP = len(dt) - TP
|
||||||
FN = len(gt) - TP
|
FN = len(gt) - TP
|
||||||
assert (
|
assert FP >= 0 and FN >= 0, (
|
||||||
FP >= 0 and FN >= 0
|
f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
||||||
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
)
|
||||||
TPs.append(TP)
|
TPs.append(TP)
|
||||||
FPs.append(FP)
|
FPs.append(FP)
|
||||||
FNs.append(FN)
|
FNs.append(FN)
|
||||||
@@ -599,9 +599,9 @@ class CGF1Evaluator:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
|
assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
|
||||||
assert len(self.coco_gts) == len(
|
assert len(self.coco_gts) == len(self.coco_evals), (
|
||||||
self.coco_evals
|
"Mismatch in number of ground truths and evaluators."
|
||||||
), "Mismatch in number of ground truths and evaluators."
|
)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"Loading predictions from {pred_file}")
|
print(f"Loading predictions from {pred_file}")
|
||||||
@@ -668,17 +668,17 @@ class CGF1Evaluator:
|
|||||||
if len(scorings) == 1:
|
if len(scorings) == 1:
|
||||||
return scorings[0]
|
return scorings[0]
|
||||||
|
|
||||||
assert (
|
assert scorings[0].ndim == 3, (
|
||||||
scorings[0].ndim == 3
|
f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
||||||
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
)
|
||||||
assert (
|
assert scorings[0].shape[0] == 1, (
|
||||||
scorings[0].shape[0] == 1
|
f"Expecting a single category, got {scorings[0].shape[0]}"
|
||||||
), f"Expecting a single category, got {scorings[0].shape[0]}"
|
)
|
||||||
|
|
||||||
for scoring in scorings:
|
for scoring in scorings:
|
||||||
assert (
|
assert scoring.shape == scorings[0].shape, (
|
||||||
scoring.shape == scorings[0].shape
|
f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
||||||
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
)
|
||||||
|
|
||||||
selected_imgs = []
|
selected_imgs = []
|
||||||
for img_id in range(scorings[0].shape[-1]):
|
for img_id in range(scorings[0].shape[-1]):
|
||||||
|
|||||||
@@ -18,19 +18,15 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import pycocotools.mask as mask_utils
|
import pycocotools.mask as mask_utils
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import g_pathmgr
|
from iopath.common.file_io import g_pathmgr
|
||||||
from pycocotools.coco import COCO
|
from pycocotools.coco import COCO
|
||||||
from pycocotools.cocoeval import COCOeval
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
from sam3.train.masks_ops import rle_encode
|
from sam3.train.masks_ops import rle_encode
|
||||||
|
|
||||||
from sam3.train.utils.distributed import (
|
from sam3.train.utils.distributed import (
|
||||||
all_gather,
|
all_gather,
|
||||||
gather_to_rank_0_via_filesys,
|
gather_to_rank_0_via_filesys,
|
||||||
@@ -755,9 +751,9 @@ def loadRes(self, resFile):
|
|||||||
anns = resFile
|
anns = resFile
|
||||||
assert type(anns) == list, "results in not an array of objects"
|
assert type(anns) == list, "results in not an array of objects"
|
||||||
annsImgIds = [ann["image_id"] for ann in anns]
|
annsImgIds = [ann["image_id"] for ann in anns]
|
||||||
assert set(annsImgIds) == (
|
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
|
||||||
set(annsImgIds) & set(self.getImgIds())
|
"Results do not correspond to current coco set"
|
||||||
), "Results do not correspond to current coco set"
|
)
|
||||||
if "caption" in anns[0]:
|
if "caption" in anns[0]:
|
||||||
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
||||||
[ann["image_id"] for ann in anns]
|
[ann["image_id"] for ann in anns]
|
||||||
|
|||||||
@@ -83,9 +83,9 @@ class PredictionDumper:
|
|||||||
self.merge_predictions = merge_predictions
|
self.merge_predictions = merge_predictions
|
||||||
self.pred_file_evaluators = pred_file_evaluators
|
self.pred_file_evaluators = pred_file_evaluators
|
||||||
if self.pred_file_evaluators is not None:
|
if self.pred_file_evaluators is not None:
|
||||||
assert (
|
assert merge_predictions, (
|
||||||
merge_predictions
|
"merge_predictions must be True if pred_file_evaluators are provided"
|
||||||
), "merge_predictions must be True if pred_file_evaluators are provided"
|
)
|
||||||
assert self.dump_dir is not None, "dump_dir must be provided"
|
assert self.dump_dir is not None, "dump_dir must be provided"
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
|
|||||||
@@ -13,11 +13,9 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pycocotools.mask as maskUtils
|
import pycocotools.mask as maskUtils
|
||||||
from pycocotools.cocoeval import COCOeval
|
from pycocotools.cocoeval import COCOeval
|
||||||
|
|
||||||
from sam3.eval.coco_eval import CocoEvaluator
|
from sam3.eval.coco_eval import CocoEvaluator
|
||||||
from sam3.train.masks_ops import compute_F_measure
|
from sam3.train.masks_ops import compute_F_measure
|
||||||
from sam3.train.utils.distributed import is_main_process
|
from sam3.train.utils.distributed import is_main_process
|
||||||
|
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
|
||||||
|
|
||||||
@@ -156,9 +154,9 @@ class DemoEval(COCOeval):
|
|||||||
TP = (match_scores >= thresh).sum()
|
TP = (match_scores >= thresh).sum()
|
||||||
FP = len(dt) - TP
|
FP = len(dt) - TP
|
||||||
FN = len(gt) - TP
|
FN = len(gt) - TP
|
||||||
assert (
|
assert FP >= 0 and FN >= 0, (
|
||||||
FP >= 0 and FN >= 0
|
f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
||||||
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
)
|
||||||
TPs.append(TP)
|
TPs.append(TP)
|
||||||
FPs.append(FP)
|
FPs.append(FP)
|
||||||
FNs.append(FN)
|
FNs.append(FN)
|
||||||
@@ -528,17 +526,17 @@ class DemoEvaluator(CocoEvaluator):
|
|||||||
if len(scorings) == 1:
|
if len(scorings) == 1:
|
||||||
return scorings[0]
|
return scorings[0]
|
||||||
|
|
||||||
assert (
|
assert scorings[0].ndim == 3, (
|
||||||
scorings[0].ndim == 3
|
f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
||||||
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
)
|
||||||
assert (
|
assert scorings[0].shape[0] == 1, (
|
||||||
scorings[0].shape[0] == 1
|
f"Expecting a single category, got {scorings[0].shape[0]}"
|
||||||
), f"Expecting a single category, got {scorings[0].shape[0]}"
|
)
|
||||||
|
|
||||||
for scoring in scorings:
|
for scoring in scorings:
|
||||||
assert (
|
assert scoring.shape == scorings[0].shape, (
|
||||||
scoring.shape == scorings[0].shape
|
f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
||||||
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
)
|
||||||
|
|
||||||
selected_imgs = []
|
selected_imgs = []
|
||||||
for img_id in range(scorings[0].shape[-1]):
|
for img_id in range(scorings[0].shape[-1]):
|
||||||
|
|||||||
@@ -255,9 +255,10 @@ class Evaluator:
|
|||||||
if show_progressbar and TQDM_IMPORTED:
|
if show_progressbar and TQDM_IMPORTED:
|
||||||
seq_list_sorted = sorted(seq_list)
|
seq_list_sorted = sorted(seq_list)
|
||||||
|
|
||||||
with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm(
|
with (
|
||||||
total=len(seq_list)
|
Pool(config["NUM_PARALLEL_CORES"]) as pool,
|
||||||
) as pbar:
|
tqdm.tqdm(total=len(seq_list)) as pbar,
|
||||||
|
):
|
||||||
_eval_sequence = partial(
|
_eval_sequence = partial(
|
||||||
eval_sequence,
|
eval_sequence,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|||||||
@@ -83,9 +83,9 @@ class PostProcessImage(nn.Module):
|
|||||||
ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
|
ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
|
||||||
"""
|
"""
|
||||||
if ret_tensordict:
|
if ret_tensordict:
|
||||||
assert (
|
assert consistent is True, (
|
||||||
consistent is True
|
"We don't support returning TensorDict if the outputs have different shapes"
|
||||||
), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it.
|
) # NOTE: It's possible but we don't support it.
|
||||||
assert self.detection_threshold <= 0.0, "TODO: implement?"
|
assert self.detection_threshold <= 0.0, "TODO: implement?"
|
||||||
try:
|
try:
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
@@ -118,7 +118,9 @@ class PostProcessImage(nn.Module):
|
|||||||
|
|
||||||
if boxes is None:
|
if boxes is None:
|
||||||
assert out_masks is not None
|
assert out_masks is not None
|
||||||
assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes"
|
assert not ret_tensordict, (
|
||||||
|
"We don't support returning TensorDict if the output does not contain boxes"
|
||||||
|
)
|
||||||
B = len(out_masks)
|
B = len(out_masks)
|
||||||
boxes = [None] * B
|
boxes = [None] * B
|
||||||
scores = [None] * B
|
scores = [None] * B
|
||||||
@@ -418,9 +420,9 @@ class PostProcessAPIVideo(PostProcessImage):
|
|||||||
if video_id == -1:
|
if video_id == -1:
|
||||||
video_id = unique_vid_id.item()
|
video_id = unique_vid_id.item()
|
||||||
else:
|
else:
|
||||||
assert (
|
assert video_id == unique_vid_id.item(), (
|
||||||
video_id == unique_vid_id.item()
|
"We can only postprocess one video per datapoint"
|
||||||
), "We can only postprocess one video per datapoint"
|
)
|
||||||
# keeping track of which objects appear in the current frame
|
# keeping track of which objects appear in the current frame
|
||||||
obj_ids_per_frame = frame_outs["pred_object_ids"]
|
obj_ids_per_frame = frame_outs["pred_object_ids"]
|
||||||
assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)
|
assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)
|
||||||
|
|||||||
@@ -95,9 +95,9 @@ class YTVIS(COCO):
|
|||||||
anns = resFile
|
anns = resFile
|
||||||
assert type(anns) == list, "results is not an array of objects"
|
assert type(anns) == list, "results is not an array of objects"
|
||||||
annsImgIds = [ann["image_id"] for ann in anns]
|
annsImgIds = [ann["image_id"] for ann in anns]
|
||||||
assert set(annsImgIds) == (
|
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
|
||||||
set(annsImgIds) & set(self.getImgIds())
|
"Results do not correspond to current coco set"
|
||||||
), "Results do not correspond to current coco set"
|
)
|
||||||
if "bboxes" in anns[0] and not anns[0]["bboxes"] == []:
|
if "bboxes" in anns[0] and not anns[0]["bboxes"] == []:
|
||||||
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
||||||
for id, ann in enumerate(anns):
|
for id, ann in enumerate(anns):
|
||||||
|
|||||||
@@ -109,9 +109,7 @@ class YTVISevalMixin:
|
|||||||
) # Num preds x Num GTS x Num frames
|
) # Num preds x Num GTS x Num frames
|
||||||
inter = inter.sum(-1)
|
inter = inter.sum(-1)
|
||||||
union = union.sum(-1)
|
union = union.sum(-1)
|
||||||
assert (
|
assert (union > 0).all(), (
|
||||||
union > 0
|
|
||||||
).all(), (
|
|
||||||
"There exists a tracklet with zero GTs across time. This is suspicious"
|
"There exists a tracklet with zero GTs across time. This is suspicious"
|
||||||
)
|
)
|
||||||
return inter / union
|
return inter / union
|
||||||
@@ -136,9 +134,9 @@ class YTVISevalMixin:
|
|||||||
iou = inter / union
|
iou = inter / union
|
||||||
assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation"
|
assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation"
|
||||||
else:
|
else:
|
||||||
assert np.isclose(inter, 0) and np.isclose(
|
assert np.isclose(inter, 0) and np.isclose(union, 0), (
|
||||||
union, 0
|
"Encountered an error in IoU computation"
|
||||||
), "Encountered an error in IoU computation"
|
)
|
||||||
iou = 1
|
iou = 1
|
||||||
return iou
|
return iou
|
||||||
|
|
||||||
@@ -206,16 +204,16 @@ class YTVISResultsWriter:
|
|||||||
if len(prediction) == 0:
|
if len(prediction) == 0:
|
||||||
continue
|
continue
|
||||||
for k in ["boxes", "scores", "labels"]:
|
for k in ["boxes", "scores", "labels"]:
|
||||||
assert (
|
assert k in prediction, (
|
||||||
k in prediction
|
f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
|
||||||
), f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
|
)
|
||||||
if self.save_per_frame_scores:
|
if self.save_per_frame_scores:
|
||||||
assert (
|
assert "per_frame_scores" in prediction, (
|
||||||
"per_frame_scores" in prediction
|
f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
|
||||||
), f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
|
)
|
||||||
assert xor(
|
assert xor("masks" in prediction, "masks_rle" in prediction), (
|
||||||
"masks" in prediction, "masks_rle" in prediction
|
f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
|
||||||
), f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
|
)
|
||||||
|
|
||||||
boxes = prediction["boxes"]
|
boxes = prediction["boxes"]
|
||||||
boxes = convert_to_xywh(boxes).tolist()
|
boxes = convert_to_xywh(boxes).tolist()
|
||||||
@@ -223,9 +221,9 @@ class YTVISResultsWriter:
|
|||||||
labels = prediction["labels"].tolist()
|
labels = prediction["labels"].tolist()
|
||||||
if "masks" in prediction:
|
if "masks" in prediction:
|
||||||
masks = prediction["masks"].squeeze(2)
|
masks = prediction["masks"].squeeze(2)
|
||||||
assert (
|
assert masks.ndim == 4, (
|
||||||
masks.ndim == 4
|
"Expected masks to be of shape(N_preds,T_frames,H,W)"
|
||||||
), "Expected masks to be of shape(N_preds,T_frames,H,W)"
|
)
|
||||||
|
|
||||||
areas = [mask.flatten(1).sum(1).tolist() for mask in masks]
|
areas = [mask.flatten(1).sum(1).tolist() for mask in masks]
|
||||||
rles = [rle_encode(masklet) for masklet in masks]
|
rles = [rle_encode(masklet) for masklet in masks]
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ def get_logger(name, level=logging.INFO):
|
|||||||
"""A command line logger."""
|
"""A command line logger."""
|
||||||
if "LOG_LEVEL" in os.environ:
|
if "LOG_LEVEL" in os.environ:
|
||||||
level = os.environ["LOG_LEVEL"].upper()
|
level = os.environ["LOG_LEVEL"].upper()
|
||||||
assert (
|
assert level in LOG_LEVELS, (
|
||||||
level in LOG_LEVELS
|
f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
|
||||||
), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
|
)
|
||||||
level = LOG_LEVELS[level]
|
level = LOG_LEVELS[level]
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ Misc functions, including distributed helpers.
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
||||||
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
|
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
|
||||||
|
|
||||||
@@ -29,9 +28,9 @@ def interpolate(
|
|||||||
input, size, scale_factor, mode, align_corners
|
input, size, scale_factor, mode, align_corners
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert input.shape[0] != 0 or input.shape[1] != 0, (
|
||||||
input.shape[0] != 0 or input.shape[1] != 0
|
"At least one of the two first dimensions must be non zero"
|
||||||
), "At least one of the two first dimensions must be non zero"
|
)
|
||||||
|
|
||||||
if input.shape[1] == 0:
|
if input.shape[1] == 0:
|
||||||
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
|
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
|
||||||
|
|||||||
@@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.sam.transformer import RoPEAttention
|
from sam3.sam.transformer import RoPEAttention
|
||||||
|
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
from torchvision.ops.roi_align import RoIAlign
|
from torchvision.ops.roi_align import RoIAlign
|
||||||
|
|
||||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||||
|
|
||||||
from .box_ops import box_cxcywh_to_xyxy
|
from .box_ops import box_cxcywh_to_xyxy
|
||||||
|
|
||||||
from .model_misc import (
|
from .model_misc import (
|
||||||
gen_sineembed_for_position,
|
gen_sineembed_for_position,
|
||||||
get_activation_fn,
|
get_activation_fn,
|
||||||
@@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module):
|
|||||||
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
||||||
"""
|
"""
|
||||||
if memory_mask is not None:
|
if memory_mask is not None:
|
||||||
assert (
|
assert self.boxRPB == "none", (
|
||||||
self.boxRPB == "none"
|
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||||
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
)
|
||||||
|
|
||||||
apply_dac = apply_dac if apply_dac is not None else self.dac
|
apply_dac = apply_dac if apply_dac is not None else self.dac
|
||||||
if apply_dac:
|
if apply_dac:
|
||||||
@@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module):
|
|||||||
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
|
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
|
||||||
|
|
||||||
if self.boxRPB != "none" and reference_boxes is not None:
|
if self.boxRPB != "none" and reference_boxes is not None:
|
||||||
assert (
|
assert spatial_shapes.shape[0] == 1, (
|
||||||
spatial_shapes.shape[0] == 1
|
"only single scale support implemented"
|
||||||
), "only single scale support implemented"
|
)
|
||||||
memory_mask = self._get_rpb_matrix(
|
memory_mask = self._get_rpb_matrix(
|
||||||
reference_boxes,
|
reference_boxes,
|
||||||
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
|
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
|
||||||
)
|
)
|
||||||
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
|
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
|
||||||
if self.training:
|
if self.training:
|
||||||
assert (
|
assert self.use_act_checkpoint, (
|
||||||
self.use_act_checkpoint
|
"Activation checkpointing not enabled in the decoder"
|
||||||
), "Activation checkpointing not enabled in the decoder"
|
)
|
||||||
output, presence_out = activation_ckpt_wrapper(layer)(
|
output, presence_out = activation_ckpt_wrapper(layer)(
|
||||||
tgt=output,
|
tgt=output,
|
||||||
tgt_query_pos=query_pos,
|
tgt_query_pos=query_pos,
|
||||||
@@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module):
|
|||||||
src_pos[0],
|
src_pos[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert src.shape[1] == prompt.shape[1], (
|
||||||
src.shape[1] == prompt.shape[1]
|
"Batch size must be the same for src and prompt"
|
||||||
), "Batch size must be the same for src and prompt"
|
)
|
||||||
|
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
|
|||||||
@@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module):
|
|||||||
return reference_points
|
return reference_points
|
||||||
|
|
||||||
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
|
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
|
||||||
assert (
|
assert len(srcs) == self.num_feature_levels, (
|
||||||
len(srcs) == self.num_feature_levels
|
"mismatch between expected and received # of feature levels"
|
||||||
), "mismatch between expected and received # of feature levels"
|
)
|
||||||
|
|
||||||
src_flatten = []
|
src_flatten = []
|
||||||
mask_flatten = []
|
mask_flatten = []
|
||||||
@@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module):
|
|||||||
- spatial_shapes: Spatial dimensions of each feature level
|
- spatial_shapes: Spatial dimensions of each feature level
|
||||||
- valid_ratios: Valid ratios for each feature level
|
- valid_ratios: Valid ratios for each feature level
|
||||||
"""
|
"""
|
||||||
assert (
|
assert len(src) == self.num_feature_levels, (
|
||||||
len(src) == self.num_feature_levels
|
"must be equal to num_feature_levels"
|
||||||
), "must be equal to num_feature_levels"
|
)
|
||||||
if src_key_padding_masks is not None:
|
if src_key_padding_masks is not None:
|
||||||
assert len(src_key_padding_masks) == self.num_feature_levels
|
assert len(src_key_padding_masks) == self.num_feature_levels
|
||||||
if pos is not None:
|
if pos is not None:
|
||||||
@@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert all(
|
assert all(x.dim == 4 for x in src), (
|
||||||
x.dim == 4 for x in src
|
"expected list of (bs, c, h, w) tensors"
|
||||||
), "expected list of (bs, c, h, w) tensors"
|
)
|
||||||
|
|
||||||
if self.add_pooled_text_to_img_feat:
|
if self.add_pooled_text_to_img_feat:
|
||||||
# Fusion: Add mean pooled text to image features
|
# Fusion: Add mean pooled text to image features
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||||
from .box_ops import box_cxcywh_to_xyxy
|
from .box_ops import box_cxcywh_to_xyxy
|
||||||
|
|
||||||
from .model_misc import get_clones
|
from .model_misc import get_clones
|
||||||
|
|
||||||
|
|
||||||
@@ -148,54 +147,42 @@ class Prompt:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Dimension checks
|
# Dimension checks
|
||||||
assert (
|
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
|
||||||
box_embeddings is not None
|
box_seq_len,
|
||||||
and list(box_embeddings.shape[:2])
|
bs,
|
||||||
== [
|
], (
|
||||||
box_seq_len,
|
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||||
bs,
|
)
|
||||||
]
|
assert box_mask is not None and list(box_mask.shape) == [
|
||||||
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
bs,
|
||||||
assert (
|
box_seq_len,
|
||||||
box_mask is not None
|
], (
|
||||||
and list(box_mask.shape)
|
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||||
== [
|
)
|
||||||
bs,
|
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
|
||||||
box_seq_len,
|
point_seq_len,
|
||||||
]
|
bs,
|
||||||
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
], (
|
||||||
assert (
|
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||||
point_embeddings is not None
|
)
|
||||||
and list(point_embeddings.shape[:2])
|
assert point_mask is not None and list(point_mask.shape) == [
|
||||||
== [
|
bs,
|
||||||
point_seq_len,
|
point_seq_len,
|
||||||
bs,
|
], (
|
||||||
]
|
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||||
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
)
|
||||||
assert (
|
assert box_labels is not None and list(box_labels.shape) == [
|
||||||
point_mask is not None
|
box_seq_len,
|
||||||
and list(point_mask.shape)
|
bs,
|
||||||
== [
|
], (
|
||||||
bs,
|
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||||
point_seq_len,
|
)
|
||||||
]
|
assert point_labels is not None and list(point_labels.shape) == [
|
||||||
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
point_seq_len,
|
||||||
assert (
|
bs,
|
||||||
box_labels is not None
|
], (
|
||||||
and list(box_labels.shape)
|
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||||
== [
|
)
|
||||||
box_seq_len,
|
|
||||||
bs,
|
|
||||||
]
|
|
||||||
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
|
||||||
assert (
|
|
||||||
point_labels is not None
|
|
||||||
and list(point_labels.shape)
|
|
||||||
== [
|
|
||||||
point_seq_len,
|
|
||||||
bs,
|
|
||||||
]
|
|
||||||
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
|
||||||
assert (
|
assert (
|
||||||
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
|
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
|
||||||
mask_embeddings is None
|
mask_embeddings is None
|
||||||
@@ -204,41 +191,41 @@ class Prompt:
|
|||||||
mask_seq_len,
|
mask_seq_len,
|
||||||
bs,
|
bs,
|
||||||
]
|
]
|
||||||
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
), (
|
||||||
assert (
|
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||||
mask_mask is None
|
)
|
||||||
or list(mask_mask.shape)
|
assert mask_mask is None or list(mask_mask.shape) == [
|
||||||
== [
|
bs,
|
||||||
bs,
|
mask_seq_len,
|
||||||
mask_seq_len,
|
], (
|
||||||
]
|
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||||
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
)
|
||||||
|
|
||||||
# Device checks
|
# Device checks
|
||||||
assert (
|
assert box_embeddings is not None and box_embeddings.device == device, (
|
||||||
box_embeddings is not None and box_embeddings.device == device
|
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||||
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
)
|
||||||
assert (
|
assert box_mask is not None and box_mask.device == device, (
|
||||||
box_mask is not None and box_mask.device == device
|
f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||||
), f"Expected box mask to be on device {device}, got {box_mask.device}"
|
)
|
||||||
assert (
|
assert box_labels is not None and box_labels.device == device, (
|
||||||
box_labels is not None and box_labels.device == device
|
f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||||
), f"Expected box labels to be on device {device}, got {box_labels.device}"
|
)
|
||||||
assert (
|
assert point_embeddings is not None and point_embeddings.device == device, (
|
||||||
point_embeddings is not None and point_embeddings.device == device
|
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||||
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
)
|
||||||
assert (
|
assert point_mask is not None and point_mask.device == device, (
|
||||||
point_mask is not None and point_mask.device == device
|
f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||||
), f"Expected point mask to be on device {device}, got {point_mask.device}"
|
)
|
||||||
assert (
|
assert point_labels is not None and point_labels.device == device, (
|
||||||
point_labels is not None and point_labels.device == device
|
f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||||
), f"Expected point labels to be on device {device}, got {point_labels.device}"
|
)
|
||||||
assert (
|
assert mask_embeddings is None or mask_embeddings.device == device, (
|
||||||
mask_embeddings is None or mask_embeddings.device == device
|
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||||
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
)
|
||||||
assert (
|
assert mask_mask is None or mask_mask.device == device, (
|
||||||
mask_mask is None or mask_mask.device == device
|
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||||
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
)
|
||||||
|
|
||||||
self.box_embeddings = box_embeddings
|
self.box_embeddings = box_embeddings
|
||||||
self.point_embeddings = point_embeddings
|
self.point_embeddings = point_embeddings
|
||||||
@@ -264,30 +251,30 @@ class Prompt:
|
|||||||
if point_embeddings is not None:
|
if point_embeddings is not None:
|
||||||
point_seq_len = point_embeddings.shape[0]
|
point_seq_len = point_embeddings.shape[0]
|
||||||
if bs is not None:
|
if bs is not None:
|
||||||
assert (
|
assert bs == point_embeddings.shape[1], (
|
||||||
bs == point_embeddings.shape[1]
|
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||||
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
)
|
||||||
else:
|
else:
|
||||||
bs = point_embeddings.shape[1]
|
bs = point_embeddings.shape[1]
|
||||||
if device is not None:
|
if device is not None:
|
||||||
assert (
|
assert device == point_embeddings.device, (
|
||||||
device == point_embeddings.device
|
"Device mismatch between box and point embeddings"
|
||||||
), "Device mismatch between box and point embeddings"
|
)
|
||||||
else:
|
else:
|
||||||
device = point_embeddings.device
|
device = point_embeddings.device
|
||||||
|
|
||||||
if mask_embeddings is not None:
|
if mask_embeddings is not None:
|
||||||
mask_seq_len = mask_embeddings.shape[0]
|
mask_seq_len = mask_embeddings.shape[0]
|
||||||
if bs is not None:
|
if bs is not None:
|
||||||
assert (
|
assert bs == mask_embeddings.shape[1], (
|
||||||
bs == mask_embeddings.shape[1]
|
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||||
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
)
|
||||||
else:
|
else:
|
||||||
bs = mask_embeddings.shape[1]
|
bs = mask_embeddings.shape[1]
|
||||||
if device is not None:
|
if device is not None:
|
||||||
assert (
|
assert device == mask_embeddings.device, (
|
||||||
device == mask_embeddings.device
|
"Device mismatch between box/point and mask embeddings."
|
||||||
), "Device mismatch between box/point and mask embeddings."
|
)
|
||||||
else:
|
else:
|
||||||
device = mask_embeddings.device
|
device = mask_embeddings.device
|
||||||
|
|
||||||
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
|
|||||||
if add_cls:
|
if add_cls:
|
||||||
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
||||||
|
|
||||||
assert (
|
assert points_direct_project or points_pos_enc or points_pool, (
|
||||||
points_direct_project or points_pos_enc or points_pool
|
"Error: need at least one way to encode points"
|
||||||
), "Error: need at least one way to encode points"
|
)
|
||||||
assert (
|
assert (
|
||||||
encode_boxes_as_points
|
encode_boxes_as_points
|
||||||
or boxes_direct_project
|
or boxes_direct_project
|
||||||
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
|
|||||||
|
|
||||||
self.encode = None
|
self.encode = None
|
||||||
if num_layers > 0:
|
if num_layers > 0:
|
||||||
assert (
|
assert add_cls, (
|
||||||
add_cls
|
"It's currently highly recommended to add a CLS when using a transformer"
|
||||||
), "It's currently highly recommended to add a CLS when using a transformer"
|
)
|
||||||
self.encode = get_clones(layer, num_layers)
|
self.encode = get_clones(layer, num_layers)
|
||||||
self.encode_norm = nn.LayerNorm(self.d_model)
|
self.encode_norm = nn.LayerNorm(self.d_model)
|
||||||
|
|
||||||
if mask_encoder is not None:
|
if mask_encoder is not None:
|
||||||
assert isinstance(
|
assert isinstance(mask_encoder, MaskEncoder), (
|
||||||
mask_encoder, MaskEncoder
|
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||||
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
)
|
||||||
if add_mask_label:
|
if add_mask_label:
|
||||||
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
|
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
|
||||||
self.add_mask_label = add_mask_label
|
self.add_mask_label = add_mask_label
|
||||||
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
|
|||||||
img_feats: torch.Tensor = None,
|
img_feats: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
n_masks, bs = masks.shape[:2]
|
n_masks, bs = masks.shape[:2]
|
||||||
assert (
|
assert n_masks == 1, (
|
||||||
n_masks == 1
|
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||||
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
)
|
||||||
assert (
|
assert list(attn_mask.shape) == [
|
||||||
list(attn_mask.shape)
|
bs,
|
||||||
== [
|
n_masks,
|
||||||
bs,
|
], (
|
||||||
n_masks,
|
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||||
]
|
)
|
||||||
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
|
||||||
masks, pos = self.mask_encoder(
|
masks, pos = self.mask_encoder(
|
||||||
masks=masks.flatten(0, 1).float(),
|
masks=masks.flatten(0, 1).float(),
|
||||||
pix_feat=img_feats,
|
pix_feat=img_feats,
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms.functional as TF
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from sam3.logger import get_logger
|
from sam3.logger import get_logger
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|||||||
@@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead):
|
|||||||
self.d_model = hidden_dim
|
self.d_model = hidden_dim
|
||||||
|
|
||||||
if dot_product_scorer is not None:
|
if dot_product_scorer is not None:
|
||||||
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
|
assert presence_head, (
|
||||||
|
"Specifying a dot product scorer without a presence head is likely a mistake"
|
||||||
|
)
|
||||||
|
|
||||||
self.presence_head = None
|
self.presence_head = None
|
||||||
if presence_head:
|
if presence_head:
|
||||||
|
|||||||
@@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module):
|
|||||||
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
||||||
self.interpol_size = interpol_size
|
self.interpol_size = interpol_size
|
||||||
if self.interpol_size is not None:
|
if self.interpol_size is not None:
|
||||||
assert isinstance(
|
assert isinstance(self.interpol_size, (list, tuple)), (
|
||||||
self.interpol_size, (list, tuple)
|
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
||||||
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
)
|
||||||
self.interpol_size = list(interpol_size)
|
self.interpol_size = list(interpol_size)
|
||||||
assert len(self.interpol_size) == 2
|
assert len(self.interpol_size) == 2
|
||||||
|
|
||||||
|
|||||||
@@ -330,9 +330,9 @@ class SAM3Output(list):
|
|||||||
self.output = output
|
self.output = output
|
||||||
else:
|
else:
|
||||||
self.output = []
|
self.output = []
|
||||||
assert isinstance(
|
assert isinstance(iter_mode, SAM3Output.IterMode), (
|
||||||
iter_mode, SAM3Output.IterMode
|
f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
|
||||||
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
|
)
|
||||||
|
|
||||||
self.iter_mode = iter_mode
|
self.iter_mode = iter_mode
|
||||||
# We create a weak reference to self to be used in the lambda functions.
|
# We create a weak reference to self to be used in the lambda functions.
|
||||||
@@ -411,9 +411,9 @@ class SAM3Output(list):
|
|||||||
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
|
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
|
||||||
|
|
||||||
def append(self, item: list):
|
def append(self, item: list):
|
||||||
assert isinstance(
|
assert isinstance(item, list), (
|
||||||
item, list
|
f"Only list items are supported. Got {type(item)}"
|
||||||
), f"Only list items are supported. Got {type(item)}"
|
)
|
||||||
self.output.append(item)
|
self.output.append(item)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from copy import deepcopy
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,15 +7,12 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from sam3.model.sam3_tracker_base import Sam3TrackerBase
|
from sam3.model.sam3_tracker_base import Sam3TrackerBase
|
||||||
from sam3.model.utils.sam1_utils import SAM2Transforms
|
from sam3.model.utils.sam1_utils import SAM2Transforms
|
||||||
|
|
||||||
@@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
|||||||
input_image = self._transforms(image)
|
input_image = self._transforms(image)
|
||||||
input_image = input_image[None, ...].to(self.device)
|
input_image = input_image[None, ...].to(self.device)
|
||||||
|
|
||||||
assert (
|
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
|
||||||
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||||
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
)
|
||||||
logging.info("Computing image embeddings for the provided image...")
|
logging.info("Computing image embeddings for the provided image...")
|
||||||
backbone_out = self.model.forward_image(input_image)
|
backbone_out = self.model.forward_image(input_image)
|
||||||
(
|
(
|
||||||
@@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
|||||||
assert isinstance(image_list, list)
|
assert isinstance(image_list, list)
|
||||||
self._orig_hw = []
|
self._orig_hw = []
|
||||||
for image in image_list:
|
for image in image_list:
|
||||||
assert isinstance(
|
assert isinstance(image, np.ndarray), (
|
||||||
image, np.ndarray
|
"Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
||||||
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
)
|
||||||
self._orig_hw.append(image.shape[:2])
|
self._orig_hw.append(image.shape[:2])
|
||||||
# Transform the image to the form expected by the model
|
# Transform the image to the form expected by the model
|
||||||
img_batch = self._transforms.forward_batch(image_list)
|
img_batch = self._transforms.forward_batch(image_list)
|
||||||
img_batch = img_batch.to(self.device)
|
img_batch = img_batch.to(self.device)
|
||||||
batch_size = img_batch.shape[0]
|
batch_size = img_batch.shape[0]
|
||||||
assert (
|
assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, (
|
||||||
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
||||||
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
)
|
||||||
logging.info("Computing image embeddings for the provided images...")
|
logging.info("Computing image embeddings for the provided images...")
|
||||||
backbone_out = self.model.forward_image(img_batch)
|
backbone_out = self.model.forward_image(img_batch)
|
||||||
(
|
(
|
||||||
@@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
|||||||
):
|
):
|
||||||
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
||||||
if point_coords is not None:
|
if point_coords is not None:
|
||||||
assert (
|
assert point_labels is not None, (
|
||||||
point_labels is not None
|
"point_labels must be supplied if point_coords is supplied."
|
||||||
), "point_labels must be supplied if point_coords is supplied."
|
)
|
||||||
point_coords = torch.as_tensor(
|
point_coords = torch.as_tensor(
|
||||||
point_coords, dtype=torch.float, device=self.device
|
point_coords, dtype=torch.float, device=self.device
|
||||||
)
|
)
|
||||||
@@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"An image must be set with .set_image(...) to generate an embedding."
|
"An image must be set with .set_image(...) to generate an embedding."
|
||||||
)
|
)
|
||||||
assert (
|
assert self._features is not None, (
|
||||||
self._features is not None
|
"Features must exist if an image has been set."
|
||||||
), "Features must exist if an image has been set."
|
)
|
||||||
return self._features["image_embed"]
|
return self._features["image_embed"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.model.model_misc import SAM3Output
|
from sam3.model.model_misc import SAM3Output
|
||||||
|
|
||||||
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
|
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
|
||||||
from sam3.model.vl_combiner import SAM3VLBackbone
|
from sam3.model.vl_combiner import SAM3VLBackbone
|
||||||
from sam3.perflib.nms import nms_masks
|
from sam3.perflib.nms import nms_masks
|
||||||
|
|
||||||
from sam3.train.data.collator import BatchedDatapoint
|
from sam3.train.data.collator import BatchedDatapoint
|
||||||
|
|
||||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||||
|
|
||||||
from .box_ops import box_cxcywh_to_xyxy
|
from .box_ops import box_cxcywh_to_xyxy
|
||||||
|
|
||||||
from .geometry_encoders import Prompt
|
from .geometry_encoders import Prompt
|
||||||
from .model_misc import inverse_sigmoid
|
from .model_misc import inverse_sigmoid
|
||||||
|
|
||||||
@@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module):
|
|||||||
inference_state["original_heights"],
|
inference_state["original_heights"],
|
||||||
inference_state["original_widths"],
|
inference_state["original_widths"],
|
||||||
)
|
)
|
||||||
assert (
|
assert batch_size == len(orig_heights) == len(orig_widths), (
|
||||||
batch_size == len(orig_heights) == len(orig_widths)
|
f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
|
||||||
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
|
)
|
||||||
feats = [
|
feats = [
|
||||||
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
||||||
for feat, feat_size in zip(
|
for feat, feat_size in zip(
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ from typing import Dict, List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.model import box_ops
|
from sam3.model import box_ops
|
||||||
|
|
||||||
from sam3.model.data_misc import FindStage, interpolate
|
from sam3.model.data_misc import FindStage, interpolate
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
|
||||||
@@ -83,9 +81,9 @@ class Sam3Processor:
|
|||||||
if not isinstance(images, list):
|
if not isinstance(images, list):
|
||||||
raise ValueError("Images must be a list of PIL images or tensors")
|
raise ValueError("Images must be a list of PIL images or tensors")
|
||||||
assert len(images) > 0, "Images list must not be empty"
|
assert len(images) > 0, "Images list must not be empty"
|
||||||
assert isinstance(
|
assert isinstance(images[0], PIL.Image.Image), (
|
||||||
images[0], PIL.Image.Image
|
"Images must be a list of PIL images"
|
||||||
), "Images must be a list of PIL images"
|
)
|
||||||
|
|
||||||
state["original_heights"] = [image.height for image in images]
|
state["original_heights"] = [image.height for image in images]
|
||||||
state["original_widths"] = [image.width for image in images]
|
state["original_widths"] = [image.width for image in images]
|
||||||
|
|||||||
@@ -6,11 +6,8 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sam3.model.memory import SimpleMaskEncoder
|
from sam3.model.memory import SimpleMaskEncoder
|
||||||
|
|
||||||
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
|
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
|
||||||
|
|
||||||
from sam3.sam.mask_decoder import MaskDecoder, MLP
|
from sam3.sam.mask_decoder import MaskDecoder, MLP
|
||||||
from sam3.sam.prompt_encoder import PromptEncoder
|
from sam3.sam.prompt_encoder import PromptEncoder
|
||||||
from sam3.sam.transformer import TwoWayTransformer
|
from sam3.sam.transformer import TwoWayTransformer
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from sam3.model.edt import edt_triton
|
from sam3.model.edt import edt_triton
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
|
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
|
||||||
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
|
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
|
||||||
from sam3.model.utils.sam2_utils import load_video_frames
|
from sam3.model.utils.sam2_utils import load_video_frames
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import numpy.typing as npt
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sam3 import perflib
|
from sam3 import perflib
|
||||||
from sam3.logger import get_logger
|
from sam3.logger import get_logger
|
||||||
from sam3.model.box_ops import fast_diag_box_iou
|
from sam3.model.box_ops import fast_diag_box_iou
|
||||||
@@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module):
|
|||||||
num_obj_dropped_due_to_limit,
|
num_obj_dropped_due_to_limit,
|
||||||
trk_id_to_max_iou_high_conf_det,
|
trk_id_to_max_iou_high_conf_det,
|
||||||
]
|
]
|
||||||
assert (
|
assert len(update_plan) == NUM_BROADCAST_ITEMS, (
|
||||||
len(update_plan) == NUM_BROADCAST_ITEMS
|
f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
|
||||||
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
|
)
|
||||||
self.broadcast_python_obj_cpu(update_plan, src=0)
|
self.broadcast_python_obj_cpu(update_plan, src=0)
|
||||||
elif self.rank > 0 and self.world_size > 1:
|
elif self.rank > 0 and self.world_size > 1:
|
||||||
update_plan = [
|
update_plan = [
|
||||||
@@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module):
|
|||||||
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
|
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
|
||||||
batch_size = tracker_low_res_masks_global.size(0)
|
batch_size = tracker_low_res_masks_global.size(0)
|
||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
assert (
|
assert len(obj_ids_global) == batch_size, (
|
||||||
len(obj_ids_global) == batch_size
|
f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
|
||||||
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
|
)
|
||||||
NEVER_OCCLUDED = -1
|
NEVER_OCCLUDED = -1
|
||||||
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
|
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
|
||||||
last_occluded_prev = torch.cat(
|
last_occluded_prev = torch.cat(
|
||||||
@@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module):
|
|||||||
reverse: bool = False,
|
reverse: bool = False,
|
||||||
):
|
):
|
||||||
# Suppress overlapping masks for objects that were most recently occluded
|
# Suppress overlapping masks for objects that were most recently occluded
|
||||||
assert (
|
assert binary_low_res_masks.dtype == torch.bool, (
|
||||||
binary_low_res_masks.dtype == torch.bool
|
f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
|
||||||
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
|
)
|
||||||
to_suppress = torch.zeros(
|
to_suppress = torch.zeros(
|
||||||
binary_low_res_masks.size(0),
|
binary_low_res_masks.size(0),
|
||||||
device=binary_low_res_masks.device,
|
device=binary_low_res_masks.device,
|
||||||
@@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module):
|
|||||||
num_frames_propagated += 1
|
num_frames_propagated += 1
|
||||||
|
|
||||||
# only 1 frames should be propagated
|
# only 1 frames should be propagated
|
||||||
assert (
|
assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
|
||||||
num_frames_propagated == 1 and out_frame_idx == frame_idx
|
f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
|
||||||
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
|
)
|
||||||
assert isinstance(out_obj_ids, list)
|
assert isinstance(out_obj_ids, list)
|
||||||
obj_ids_local.extend(out_obj_ids)
|
obj_ids_local.extend(out_obj_ids)
|
||||||
low_res_masks_list.append(out_low_res_masks.squeeze(1))
|
low_res_masks_list.append(out_low_res_masks.squeeze(1))
|
||||||
@@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module):
|
|||||||
|
|
||||||
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
||||||
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
||||||
assert (
|
assert trk_masks.size(0) == len(trk_obj_ids), (
|
||||||
trk_masks.size(0) == len(trk_obj_ids)
|
f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
|
||||||
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
|
)
|
||||||
if trk_masks.size(0) == 0:
|
if trk_masks.size(0) == 0:
|
||||||
# all detections are new
|
# all detections are new
|
||||||
new_det_fa_inds = np.arange(det_masks.size(0))
|
new_det_fa_inds = np.arange(det_masks.size(0))
|
||||||
@@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module):
|
|||||||
# a) first, expand "confirmation_data" to include new masklets added in this frame
|
# a) first, expand "confirmation_data" to include new masklets added in this frame
|
||||||
status_prev = confirmation_data["status"]
|
status_prev = confirmation_data["status"]
|
||||||
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
|
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
|
||||||
assert (
|
assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
|
||||||
status_prev.shape == obj_ids_all_gpu_prev.shape
|
f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
|
||||||
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
|
)
|
||||||
|
|
||||||
obj_id_to_updated_idx = {
|
obj_id_to_updated_idx = {
|
||||||
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
|
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sam3 import perflib
|
from sam3 import perflib
|
||||||
from sam3.logger import get_logger
|
from sam3.logger import get_logger
|
||||||
from sam3.model.act_ckpt_utils import clone_output_wrapper
|
from sam3.model.act_ckpt_utils import clone_output_wrapper
|
||||||
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
|||||||
assert (
|
assert (
|
||||||
"cached_frame_outputs" in inference_state
|
"cached_frame_outputs" in inference_state
|
||||||
and frame_idx in inference_state["cached_frame_outputs"]
|
and frame_idx in inference_state["cached_frame_outputs"]
|
||||||
), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
), (
|
||||||
|
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||||
|
)
|
||||||
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
|
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
|
||||||
|
|
||||||
obj_id_to_mask = cached_outputs.copy()
|
obj_id_to_mask = cached_outputs.copy()
|
||||||
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
|||||||
# Update with refined masks if provided
|
# Update with refined masks if provided
|
||||||
if refined_obj_id_to_mask is not None:
|
if refined_obj_id_to_mask is not None:
|
||||||
for obj_id, refined_mask in refined_obj_id_to_mask.items():
|
for obj_id, refined_mask in refined_obj_id_to_mask.items():
|
||||||
assert (
|
assert refined_mask is not None, (
|
||||||
refined_mask is not None
|
f"Refined mask data must be provided for obj_id {obj_id}"
|
||||||
), f"Refined mask data must be provided for obj_id {obj_id}"
|
)
|
||||||
obj_id_to_mask[obj_id] = refined_mask
|
obj_id_to_mask[obj_id] = refined_mask
|
||||||
|
|
||||||
return obj_id_to_mask
|
return obj_id_to_mask
|
||||||
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
|||||||
for i, thresh in enumerate(new_det_score_thresh_list):
|
for i, thresh in enumerate(new_det_score_thresh_list):
|
||||||
self.new_det_thresh = thresh
|
self.new_det_thresh = thresh
|
||||||
for num_objects in num_objects_list:
|
for num_objects in num_objects_list:
|
||||||
logger.info(f"{i+1}/{num_rounds} warming up model compilation")
|
logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
|
||||||
self.add_prompt(
|
self.add_prompt(
|
||||||
inference_state, frame_idx=start_frame_idx, text_str="cat"
|
inference_state, frame_idx=start_frame_idx, text_str="cat"
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||||
)
|
)
|
||||||
inference_state = self.add_fake_objects_to_inference_state(
|
inference_state = self.add_fake_objects_to_inference_state(
|
||||||
inference_state, num_objects, frame_idx=start_frame_idx
|
inference_state, num_objects, frame_idx=start_frame_idx
|
||||||
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
|
|||||||
pass
|
pass
|
||||||
self.reset_state(inference_state)
|
self.reset_state(inference_state)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
|
f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warm up Tracker memory encoder with varying input shapes
|
# Warm up Tracker memory encoder with varying input shapes
|
||||||
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
|||||||
logger.debug("Running add_prompt on frame %d", frame_idx)
|
logger.debug("Running add_prompt on frame %d", frame_idx)
|
||||||
|
|
||||||
num_frames = inference_state["num_frames"]
|
num_frames = inference_state["num_frames"]
|
||||||
assert (
|
assert text_str is not None or boxes_xywh is not None, (
|
||||||
text_str is not None or boxes_xywh is not None
|
"at least one type of prompt (text, boxes) must be provided"
|
||||||
), "at least one type of prompt (text, boxes) must be provided"
|
)
|
||||||
assert (
|
assert 0 <= frame_idx < num_frames, (
|
||||||
0 <= frame_idx < num_frames
|
f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||||
), f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
)
|
||||||
|
|
||||||
# since it's a semantic prompt, we start over
|
# since it's a semantic prompt, we start over
|
||||||
self.reset_state(inference_state)
|
self.reset_state(inference_state)
|
||||||
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
|||||||
"propagation_partial",
|
"propagation_partial",
|
||||||
"propagation_fetch",
|
"propagation_fetch",
|
||||||
]
|
]
|
||||||
assert (
|
assert action_type in instance_actions + propagation_actions, (
|
||||||
action_type in instance_actions + propagation_actions
|
f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||||
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
)
|
||||||
action = {
|
action = {
|
||||||
"type": action_type,
|
"type": action_type,
|
||||||
"frame_idx": frame_idx,
|
"frame_idx": frame_idx,
|
||||||
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
|||||||
):
|
):
|
||||||
if points is not None:
|
if points is not None:
|
||||||
# Tracker instance prompts
|
# Tracker instance prompts
|
||||||
assert (
|
assert text_str is None and boxes_xywh is None, (
|
||||||
text_str is None and boxes_xywh is None
|
"When points are provided, text_str and boxes_xywh must be None."
|
||||||
), "When points are provided, text_str and boxes_xywh must be None."
|
)
|
||||||
assert (
|
assert obj_id is not None, (
|
||||||
obj_id is not None
|
"When points are provided, obj_id must be provided."
|
||||||
), "When points are provided, obj_id must be provided."
|
)
|
||||||
return self.add_tracker_new_points(
|
return self.add_tracker_new_points(
|
||||||
inference_state,
|
inference_state,
|
||||||
frame_idx,
|
frame_idx,
|
||||||
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
|||||||
tracker_states = self._get_tracker_inference_states_by_obj_ids(
|
tracker_states = self._get_tracker_inference_states_by_obj_ids(
|
||||||
inference_state, [obj_id]
|
inference_state, [obj_id]
|
||||||
)
|
)
|
||||||
assert (
|
assert len(tracker_states) == 1, (
|
||||||
len(tracker_states) == 1
|
f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||||
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
)
|
||||||
tracker_state = tracker_states[0]
|
tracker_state = tracker_states[0]
|
||||||
|
|
||||||
# log
|
# log
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.logger import get_logger
|
from sam3.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -170,7 +169,7 @@ class Sam3VideoPredictor:
|
|||||||
):
|
):
|
||||||
"""Remove an object from tracking."""
|
"""Remove an object from tracking."""
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
|
f"remove object {obj_id} in session {session_id}: {is_user_action=}"
|
||||||
)
|
)
|
||||||
session = self._get_session(session_id)
|
session = self._get_session(session_id)
|
||||||
inference_state = session["state"]
|
inference_state = session["state"]
|
||||||
|
|||||||
@@ -318,9 +318,9 @@ class VETextEncoder(nn.Module):
|
|||||||
# The text is already encoded, use as is.
|
# The text is already encoded, use as is.
|
||||||
text_attention_mask, text_memory_resized, tokenized = text
|
text_attention_mask, text_memory_resized, tokenized = text
|
||||||
inputs_embeds = tokenized["inputs_embeds"]
|
inputs_embeds = tokenized["inputs_embeds"]
|
||||||
assert (
|
assert input_boxes is None or len(input_boxes) == 0, (
|
||||||
input_boxes is None or len(input_boxes) == 0
|
"Can't replace boxes in text if it's already encoded"
|
||||||
), "Can't replace boxes in text if it's already encoded"
|
)
|
||||||
|
|
||||||
# Note that the input_embeds are returned in pytorch's convention (sequence first)
|
# Note that the input_embeds are returned in pytorch's convention (sequence first)
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -708,9 +708,9 @@ class ViT(nn.Module):
|
|||||||
self.retain_cls_token = retain_cls_token
|
self.retain_cls_token = retain_cls_token
|
||||||
if self.retain_cls_token:
|
if self.retain_cls_token:
|
||||||
assert pretrain_use_cls_token
|
assert pretrain_use_cls_token
|
||||||
assert (
|
assert len(window_block_indexes) == 0, (
|
||||||
len(window_block_indexes) == 0
|
"windowing not supported with cls token"
|
||||||
), "windowing not supported with cls token"
|
)
|
||||||
|
|
||||||
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
|
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||||
|
|
||||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ def connected_components_cpu(input_tensor: torch.Tensor):
|
|||||||
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
||||||
input_tensor = input_tensor.squeeze(1)
|
input_tensor = input_tensor.squeeze(1)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert input_tensor.dim() == 3, (
|
||||||
input_tensor.dim() == 3
|
"Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
)
|
||||||
|
|
||||||
batch_size = input_tensor.shape[0]
|
batch_size = input_tensor.shape[0]
|
||||||
labels_list = []
|
labels_list = []
|
||||||
@@ -67,9 +67,9 @@ def connected_components(input_tensor: torch.Tensor):
|
|||||||
if input_tensor.dim() == 3:
|
if input_tensor.dim() == 3:
|
||||||
input_tensor = input_tensor.unsqueeze(1)
|
input_tensor = input_tensor.unsqueeze(1)
|
||||||
|
|
||||||
assert (
|
assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, (
|
||||||
input_tensor.dim() == 4 and input_tensor.shape[1] == 1
|
"Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
)
|
||||||
|
|
||||||
if input_tensor.is_cuda:
|
if input_tensor.is_cuda:
|
||||||
if HAS_CC_TORCH:
|
if HAS_CC_TORCH:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.perflib.masks_ops import mask_iou
|
from sam3.perflib.masks_ops import mask_iou
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -407,16 +407,16 @@ def connected_components_triton(input_tensor: torch.Tensor):
|
|||||||
- A BxHxW output tensor with dense labels. Background is 0.
|
- A BxHxW output tensor with dense labels. Background is 0.
|
||||||
- A BxHxW tensor with the size of the connected component for each pixel.
|
- A BxHxW tensor with the size of the connected component for each pixel.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert input_tensor.is_cuda and input_tensor.is_contiguous(), (
|
||||||
input_tensor.is_cuda and input_tensor.is_contiguous()
|
"Input tensor must be a contiguous CUDA tensor."
|
||||||
), "Input tensor must be a contiguous CUDA tensor."
|
)
|
||||||
out_shape = input_tensor.shape
|
out_shape = input_tensor.shape
|
||||||
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
||||||
input_tensor = input_tensor.squeeze(1)
|
input_tensor = input_tensor.squeeze(1)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert input_tensor.dim() == 3, (
|
||||||
input_tensor.dim() == 3
|
"Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
)
|
||||||
|
|
||||||
B, H, W = input_tensor.shape
|
B, H, W = input_tensor.shape
|
||||||
numel = B * H * W
|
numel = B * H * W
|
||||||
|
|||||||
@@ -202,9 +202,9 @@ class MaskDecoder(nn.Module):
|
|||||||
assert image_embeddings.shape[0] == tokens.shape[0]
|
assert image_embeddings.shape[0] == tokens.shape[0]
|
||||||
src = image_embeddings
|
src = image_embeddings
|
||||||
src = src + dense_prompt_embeddings
|
src = src + dense_prompt_embeddings
|
||||||
assert (
|
assert image_pe.size(0) == 1, (
|
||||||
image_pe.size(0) == 1
|
"image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||||
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
)
|
||||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||||
b, c, h, w = src.shape
|
b, c, h, w = src.shape
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from typing import Tuple, Type
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
|
from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
|
|
||||||
@@ -205,9 +204,9 @@ class Attention(nn.Module):
|
|||||||
self.internal_dim = embedding_dim // downsample_rate
|
self.internal_dim = embedding_dim // downsample_rate
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.use_fa3 = use_fa3
|
self.use_fa3 = use_fa3
|
||||||
assert (
|
assert self.internal_dim % num_heads == 0, (
|
||||||
self.internal_dim % num_heads == 0
|
"num_heads must divide embedding_dim."
|
||||||
), "num_heads must divide embedding_dim."
|
)
|
||||||
|
|
||||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||||
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||||
|
|||||||
@@ -142,9 +142,9 @@ class COCO_FROM_JSON:
|
|||||||
self.prompts = {}
|
self.prompts = {}
|
||||||
for loc_dict in prompts:
|
for loc_dict in prompts:
|
||||||
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
|
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
|
||||||
assert len(self.prompts) == len(
|
assert len(self.prompts) == len(self._sorted_cat_ids), (
|
||||||
self._sorted_cat_ids
|
"Number of prompts must match number of categories"
|
||||||
), "Number of prompts must match number of categories"
|
)
|
||||||
|
|
||||||
def getDatapointIds(self):
|
def getDatapointIds(self):
|
||||||
"""Return all datapoint indices for training."""
|
"""Return all datapoint indices for training."""
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data
|
|||||||
from typing import Any, get_args, get_origin, List, Union
|
from typing import Any, get_args, get_origin, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.model.data_misc import (
|
from sam3.model.data_misc import (
|
||||||
BatchedDatapoint,
|
BatchedDatapoint,
|
||||||
BatchedFindTarget,
|
BatchedFindTarget,
|
||||||
@@ -217,9 +216,9 @@ def collate_fn_api(
|
|||||||
text_batch.append(q.query_text)
|
text_batch.append(q.query_text)
|
||||||
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
||||||
|
|
||||||
assert (
|
assert q.inference_metadata is not None, (
|
||||||
q.inference_metadata is not None
|
"inference_metadata must be provided when FindQueryLoaded is created."
|
||||||
), "inference_metadata must be provided when FindQueryLoaded is created."
|
)
|
||||||
for f in fields(q.inference_metadata):
|
for f in fields(q.inference_metadata):
|
||||||
getattr(find_metadatas[stage_id], f.name).append(
|
getattr(find_metadatas[stage_id], f.name).append(
|
||||||
getattr(q.inference_metadata, f.name)
|
getattr(q.inference_metadata, f.name)
|
||||||
|
|||||||
@@ -19,10 +19,8 @@ import torch.utils.data
|
|||||||
import torchvision
|
import torchvision
|
||||||
from decord import cpu, VideoReader
|
from decord import cpu, VideoReader
|
||||||
from iopath.common.file_io import g_pathmgr
|
from iopath.common.file_io import g_pathmgr
|
||||||
|
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from PIL.Image import DecompressionBombError
|
from PIL.Image import DecompressionBombError
|
||||||
|
|
||||||
from sam3.model.box_ops import box_xywh_to_xyxy
|
from sam3.model.box_ops import box_xywh_to_xyxy
|
||||||
from torchvision.datasets.vision import VisionDataset
|
from torchvision.datasets.vision import VisionDataset
|
||||||
|
|
||||||
@@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset):
|
|||||||
if self.coco is not None:
|
if self.coco is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
assert g_pathmgr.isfile(
|
assert g_pathmgr.isfile(self.annFile), (
|
||||||
self.annFile
|
f"please provide valid annotation file. Missing: {self.annFile}"
|
||||||
), f"please provide valid annotation file. Missing: {self.annFile}"
|
)
|
||||||
annFile = g_pathmgr.get_local_path(self.annFile)
|
annFile = g_pathmgr.get_local_path(self.annFile)
|
||||||
|
|
||||||
if self.coco is not None:
|
if self.coco is not None:
|
||||||
@@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset):
|
|||||||
else:
|
else:
|
||||||
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
|
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
|
||||||
for stage, num_queries in stage2num_queries.items():
|
for stage, num_queries in stage2num_queries.items():
|
||||||
assert (
|
assert num_queries == num_queries_per_stage, (
|
||||||
num_queries == num_queries_per_stage
|
f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||||
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
)
|
||||||
|
|
||||||
for query_id, query in enumerate(queries):
|
for query_id, query in enumerate(queries):
|
||||||
h, w = id2imsize[query["image_id"]]
|
h, w = id2imsize[query["image_id"]]
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
# from decord import cpu, VideoReader
|
# from decord import cpu, VideoReader
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
@@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset):
|
|||||||
for query in filtered_queries:
|
for query in filtered_queries:
|
||||||
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
|
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
|
||||||
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
|
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
|
||||||
assert (
|
assert ptr_x_is_empty and ptr_y_is_empty, (
|
||||||
ptr_x_is_empty and ptr_y_is_empty
|
"Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||||
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
)
|
||||||
query["query_processing_order"] = stage_id_old2new[
|
query["query_processing_order"] = stage_id_old2new[
|
||||||
query["query_processing_order"]
|
query["query_processing_order"]
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -9,11 +9,8 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
|
|
||||||
from sam3.model import box_ops
|
from sam3.model import box_ops
|
||||||
|
|
||||||
from sam3.model.data_misc import interpolate
|
from sam3.model.data_misc import interpolate
|
||||||
|
|
||||||
from sam3.train.loss.sigmoid_focal_loss import (
|
from sam3.train.loss.sigmoid_focal_loss import (
|
||||||
triton_sigmoid_focal_loss,
|
triton_sigmoid_focal_loss,
|
||||||
triton_sigmoid_focal_loss_reduce,
|
triton_sigmoid_focal_loss_reduce,
|
||||||
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
|
|||||||
if num_det_queries is not None:
|
if num_det_queries is not None:
|
||||||
logging.warning("note: it's not needed to set num_det_queries anymore")
|
logging.warning("note: it's not needed to set num_det_queries anymore")
|
||||||
if self.use_separate_loss_for_det_and_trk:
|
if self.use_separate_loss_for_det_and_trk:
|
||||||
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
assert not self.weak_loss, (
|
||||||
|
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||||
|
)
|
||||||
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
|
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
|
||||||
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
|
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
|
||||||
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
|
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
|
||||||
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
|
|||||||
and det_non_exhaustive_loss_scale_neg == 1.0
|
and det_non_exhaustive_loss_scale_neg == 1.0
|
||||||
and trk_loss_scale_pos == 1.0
|
and trk_loss_scale_pos == 1.0
|
||||||
and trk_loss_scale_neg == 1.0
|
and trk_loss_scale_neg == 1.0
|
||||||
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
), (
|
||||||
|
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||||
|
)
|
||||||
|
|
||||||
def get_loss(self, outputs, targets, indices, num_boxes):
|
def get_loss(self, outputs, targets, indices, num_boxes):
|
||||||
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
|
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
|
||||||
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if self.weak_loss:
|
if self.weak_loss:
|
||||||
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
assert not self.use_separate_loss_for_det_and_trk, (
|
||||||
|
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||||
|
)
|
||||||
|
|
||||||
# nullify the negative loss for the non-exhaustive classes
|
# nullify the negative loss for the non-exhaustive classes
|
||||||
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
|
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
|
||||||
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
|
|||||||
loss_bce = loss_bce.mean()
|
loss_bce = loss_bce.mean()
|
||||||
else:
|
else:
|
||||||
assert isinstance(self.pad_n_queries, int)
|
assert isinstance(self.pad_n_queries, int)
|
||||||
assert (
|
assert loss_bce.size(1) < self.pad_n_queries, (
|
||||||
loss_bce.size(1) < self.pad_n_queries
|
f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||||
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
)
|
||||||
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
|
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
|
||||||
|
|
||||||
bce_f1 = torchmetrics.functional.f1_score(
|
bce_f1 = torchmetrics.functional.f1_score(
|
||||||
|
|||||||
@@ -3,9 +3,7 @@
|
|||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.model.model_misc import SAM3Output
|
from sam3.model.model_misc import SAM3Output
|
||||||
|
|
||||||
from sam3.train.utils.distributed import get_world_size
|
from sam3.train.utils.distributed import get_world_size
|
||||||
|
|
||||||
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
||||||
|
|||||||
@@ -103,9 +103,9 @@ def dilation(mask, kernel_size):
|
|||||||
|
|
||||||
assert mask.ndim == 3
|
assert mask.ndim == 3
|
||||||
kernel_size = int(kernel_size)
|
kernel_size = int(kernel_size)
|
||||||
assert (
|
assert kernel_size % 2 == 1, (
|
||||||
kernel_size % 2 == 1
|
f"Dilation expects a odd kernel size, got {kernel_size}"
|
||||||
), f"Dilation expects a odd kernel size, got {kernel_size}"
|
)
|
||||||
|
|
||||||
if mask.is_cuda:
|
if mask.is_cuda:
|
||||||
m = mask.unsqueeze(1).to(torch.float16)
|
m = mask.unsqueeze(1).to(torch.float16)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP.
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
|
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module):
|
|||||||
self.cost_bbox = cost_bbox
|
self.cost_bbox = cost_bbox
|
||||||
self.cost_giou = cost_giou
|
self.cost_giou = cost_giou
|
||||||
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
|
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
|
||||||
assert (
|
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
"all costs cant be 0"
|
||||||
), "all costs cant be 0"
|
)
|
||||||
self.focal_loss = focal_loss
|
self.focal_loss = focal_loss
|
||||||
self.focal_alpha = focal_alpha
|
self.focal_alpha = focal_alpha
|
||||||
self.focal_gamma = focal_gamma
|
self.focal_gamma = focal_gamma
|
||||||
@@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module):
|
|||||||
self.cost_bbox = cost_bbox
|
self.cost_bbox = cost_bbox
|
||||||
self.cost_giou = cost_giou
|
self.cost_giou = cost_giou
|
||||||
self.norm = nn.Sigmoid()
|
self.norm = nn.Sigmoid()
|
||||||
assert (
|
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
"all costs cant be 0"
|
||||||
), "all costs cant be 0"
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
|
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
|
||||||
@@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.stable = stable
|
self.stable = stable
|
||||||
assert (
|
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
"all costs cant be 0"
|
||||||
), "all costs cant be 0"
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
|
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
|
||||||
@@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module):
|
|||||||
self.cost_bbox = cost_bbox
|
self.cost_bbox = cost_bbox
|
||||||
self.cost_giou = cost_giou
|
self.cost_giou = cost_giou
|
||||||
self.norm = nn.Sigmoid()
|
self.norm = nn.Sigmoid()
|
||||||
assert (
|
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
"all costs cant be 0"
|
||||||
), "all costs cant be 0"
|
)
|
||||||
self.focal = focal
|
self.focal = focal
|
||||||
if focal:
|
if focal:
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
@@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names(
|
|||||||
"match any classes in the model"
|
"match any classes in the model"
|
||||||
)
|
)
|
||||||
matching_parameters = module_cls_to_param_names[module_cls]
|
matching_parameters = module_cls_to_param_names[module_cls]
|
||||||
assert (
|
assert len(matching_parameters) > 0, (
|
||||||
len(matching_parameters) > 0
|
f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||||
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
||||||
)
|
)
|
||||||
@@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names(
|
|||||||
allowed_parameter_names = []
|
allowed_parameter_names = []
|
||||||
for param_name in filter_param_names:
|
for param_name in filter_param_names:
|
||||||
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
||||||
assert (
|
assert len(matching_parameters) >= 1, (
|
||||||
len(matching_parameters) >= 1
|
f"param_name {param_name} does not match any parameters in the model"
|
||||||
), f"param_name {param_name} does not match any parameters in the model"
|
)
|
||||||
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
||||||
allowed_parameter_names.append(matching_parameters)
|
allowed_parameter_names.append(matching_parameters)
|
||||||
return set.union(*allowed_parameter_names)
|
return set.union(*allowed_parameter_names)
|
||||||
|
|||||||
@@ -12,13 +12,10 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import submitit
|
import submitit
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hydra import compose, initialize_config_module
|
from hydra import compose, initialize_config_module
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
from iopath.common.file_io import g_pathmgr
|
from iopath.common.file_io import g_pathmgr
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
|
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@@ -212,9 +209,9 @@ def main(args) -> None:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
if "include_nodes" in submitit_conf:
|
if "include_nodes" in submitit_conf:
|
||||||
assert (
|
assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
|
||||||
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
|
"Not enough nodes"
|
||||||
), "Not enough nodes"
|
)
|
||||||
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
|
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
|
||||||
submitit_conf["include_nodes"]
|
submitit_conf["include_nodes"]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from iopath.common.file_io import g_pathmgr
|
from iopath.common.file_io import g_pathmgr
|
||||||
|
|
||||||
from sam3.model.data_misc import BatchedDatapoint
|
from sam3.model.data_misc import BatchedDatapoint
|
||||||
from sam3.model.model_misc import SAM3Output
|
from sam3.model.model_misc import SAM3Output
|
||||||
from sam3.model.utils.misc import copy_data_to_device
|
from sam3.model.utils.misc import copy_data_to_device
|
||||||
|
|
||||||
from sam3.train.optim.optimizer import construct_optimizer
|
from sam3.train.optim.optimizer import construct_optimizer
|
||||||
|
|
||||||
from sam3.train.utils.checkpoint_utils import (
|
from sam3.train.utils.checkpoint_utils import (
|
||||||
assert_skipped_parameters_are_frozen,
|
assert_skipped_parameters_are_frozen,
|
||||||
exclude_params_matching_unix_pattern,
|
exclude_params_matching_unix_pattern,
|
||||||
load_state_dict_into_model,
|
load_state_dict_into_model,
|
||||||
with_check_parameter_frozen,
|
with_check_parameter_frozen,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
|
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
|
||||||
|
|
||||||
from sam3.train.utils.logger import Logger, setup_logging
|
from sam3.train.utils.logger import Logger, setup_logging
|
||||||
from sam3.train.utils.train_utils import (
|
from sam3.train.utils.train_utils import (
|
||||||
AverageMeter,
|
AverageMeter,
|
||||||
@@ -215,9 +209,9 @@ class Trainer:
|
|||||||
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
||||||
log_env_variables()
|
log_env_variables()
|
||||||
|
|
||||||
assert (
|
assert is_dist_avail_and_initialized(), (
|
||||||
is_dist_avail_and_initialized()
|
"Torch distributed needs to be initialized before calling the trainer."
|
||||||
), "Torch distributed needs to be initialized before calling the trainer."
|
)
|
||||||
|
|
||||||
self._setup_components() # Except Optimizer everything is setup here.
|
self._setup_components() # Except Optimizer everything is setup here.
|
||||||
self._move_to_device()
|
self._move_to_device()
|
||||||
@@ -227,9 +221,9 @@ class Trainer:
|
|||||||
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
||||||
|
|
||||||
if self.checkpoint_conf.resume_from is not None:
|
if self.checkpoint_conf.resume_from is not None:
|
||||||
assert os.path.exists(
|
assert os.path.exists(self.checkpoint_conf.resume_from), (
|
||||||
self.checkpoint_conf.resume_from
|
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||||
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
)
|
||||||
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
||||||
if self.distributed_rank == 0 and not os.path.exists(dst):
|
if self.distributed_rank == 0 and not os.path.exists(dst):
|
||||||
# Copy the "resume_from" checkpoint to the checkpoint folder
|
# Copy the "resume_from" checkpoint to the checkpoint folder
|
||||||
@@ -477,9 +471,9 @@ class Trainer:
|
|||||||
return self.loss[key]
|
return self.loss[key]
|
||||||
|
|
||||||
assert key != "all", "Loss must be specified for key='all'"
|
assert key != "all", "Loss must be specified for key='all'"
|
||||||
assert (
|
assert "default" in self.loss, (
|
||||||
"default" in self.loss
|
f"Key {key} not found in losss, and no default provided"
|
||||||
), f"Key {key} not found in losss, and no default provided"
|
)
|
||||||
return self.loss["default"]
|
return self.loss["default"]
|
||||||
|
|
||||||
def _find_meter(self, phase: str, key: str):
|
def _find_meter(self, phase: str, key: str):
|
||||||
@@ -922,12 +916,12 @@ class Trainer:
|
|||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
if self.gradient_accumulation_steps > 1:
|
if self.gradient_accumulation_steps > 1:
|
||||||
assert isinstance(
|
assert isinstance(batch, list), (
|
||||||
batch, list
|
f"Expected a list of batches, got {type(batch)}"
|
||||||
), f"Expected a list of batches, got {type(batch)}"
|
)
|
||||||
assert (
|
assert len(batch) == self.gradient_accumulation_steps, (
|
||||||
len(batch) == self.gradient_accumulation_steps
|
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||||
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
)
|
||||||
accum_steps = len(batch)
|
accum_steps = len(batch)
|
||||||
else:
|
else:
|
||||||
accum_steps = 1
|
accum_steps = 1
|
||||||
@@ -1039,9 +1033,9 @@ class Trainer:
|
|||||||
def _check_val_key_match(self, val_keys, phase):
|
def _check_val_key_match(self, val_keys, phase):
|
||||||
if val_keys is not None:
|
if val_keys is not None:
|
||||||
# Check if there are any duplicates
|
# Check if there are any duplicates
|
||||||
assert len(val_keys) == len(
|
assert len(val_keys) == len(set(val_keys)), (
|
||||||
set(val_keys)
|
f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||||
), f"Duplicate keys in val datasets, keys: {val_keys}"
|
)
|
||||||
|
|
||||||
# Check that the keys match the meter keys
|
# Check that the keys match the meter keys
|
||||||
if self.meters_conf is not None and phase in self.meters_conf:
|
if self.meters_conf is not None and phase in self.meters_conf:
|
||||||
@@ -1055,9 +1049,9 @@ class Trainer:
|
|||||||
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
||||||
if "default" not in loss_keys:
|
if "default" not in loss_keys:
|
||||||
for k in val_keys:
|
for k in val_keys:
|
||||||
assert (
|
assert k in loss_keys, (
|
||||||
k in loss_keys
|
f"Error: key {k} is not defined in the losses, and no default is set"
|
||||||
), f"Error: key {k} is not defined in the losses, and no default is set"
|
)
|
||||||
|
|
||||||
def _setup_components(self):
|
def _setup_components(self):
|
||||||
# Get the keys for all the val datasets, if any
|
# Get the keys for all the val datasets, if any
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import PIL
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
from sam3.model.box_ops import box_xyxy_to_cxcywh
|
from sam3.model.box_ops import box_xyxy_to_cxcywh
|
||||||
from sam3.model.data_misc import interpolate
|
from sam3.model.data_misc import interpolate
|
||||||
|
|
||||||
@@ -277,9 +276,9 @@ class RandomSizeCrop:
|
|||||||
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
|
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
|
||||||
)
|
)
|
||||||
result_img, result_target = crop(img, target, [j, i, h, w])
|
result_img, result_target = crop(img, target, [j, i, h, w])
|
||||||
assert (
|
assert len(result_target["boxes"]) == init_boxes, (
|
||||||
len(result_target["boxes"]) == init_boxes
|
f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||||
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
)
|
||||||
|
|
||||||
return result_img, result_target
|
return result_img, result_target
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
@@ -17,9 +16,7 @@ import torch
|
|||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
import torchvision.transforms.v2.functional as Fv2
|
import torchvision.transforms.v2.functional as Fv2
|
||||||
|
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
|
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
|
||||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||||
from torchvision.transforms import InterpolationMode
|
from torchvision.transforms import InterpolationMode
|
||||||
|
|||||||
@@ -4,12 +4,10 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
|
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
|
||||||
|
|
||||||
|
|
||||||
@@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries:
|
|||||||
if len(new_find_queries) == 0:
|
if len(new_find_queries) == 0:
|
||||||
start_with_zero_check = True
|
start_with_zero_check = True
|
||||||
|
|
||||||
assert (
|
assert start_with_zero_check, (
|
||||||
start_with_zero_check
|
"Invalid Find queries, they need to start at query_processing_order = 0"
|
||||||
), "Invalid Find queries, they need to start at query_processing_order = 0"
|
)
|
||||||
|
|
||||||
datapoint.find_queries = new_find_queries
|
datapoint.find_queries = new_find_queries
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from pycocotools import mask as mask_util
|
from pycocotools import mask as mask_util
|
||||||
|
|
||||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||||
from torchvision.ops import masks_to_boxes
|
from torchvision.ops import masks_to_boxes
|
||||||
|
|
||||||
@@ -250,9 +249,9 @@ class RandomGeometricInputsAPI:
|
|||||||
def _get_target_object(self, datapoint, query):
|
def _get_target_object(self, datapoint, query):
|
||||||
img = datapoint.images[query.image_id]
|
img = datapoint.images[query.image_id]
|
||||||
targets = query.object_ids_output
|
targets = query.object_ids_output
|
||||||
assert (
|
assert len(targets) == 1, (
|
||||||
len(targets) == 1
|
"Geometric queries only support a single target object."
|
||||||
), "Geometric queries only support a single target object."
|
)
|
||||||
target_idx = targets[0]
|
target_idx = targets[0]
|
||||||
return img.objects[target_idx]
|
return img.objects[target_idx]
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pycocotools.mask as mask_utils
|
import pycocotools.mask as mask_utils
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
from sam3.model.box_ops import masks_to_boxes
|
from sam3.model.box_ops import masks_to_boxes
|
||||||
|
|
||||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names(
|
|||||||
parameter_names = []
|
parameter_names = []
|
||||||
for param_name in constraints:
|
for param_name in constraints:
|
||||||
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
||||||
assert (
|
assert len(matching_parameters) > 0, (
|
||||||
len(matching_parameters) > 0
|
f"param_names {param_name} don't match any param in the given names."
|
||||||
), f"param_names {param_name} don't match any param in the given names."
|
)
|
||||||
parameter_names.append(matching_parameters)
|
parameter_names.append(matching_parameters)
|
||||||
return set.union(*parameter_names)
|
return set.union(*parameter_names)
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,8 @@ import uuid
|
|||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
from iopath.common.file_io import g_pathmgr
|
from iopath.common.file_io import g_pathmgr
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
|
|
||||||
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from datetime import timedelta
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import omegaconf
|
import omegaconf
|
||||||
import torch
|
import torch
|
||||||
@@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank():
|
|||||||
"""
|
"""
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
||||||
distributed_rank = int(os.environ.get("RANK", None))
|
distributed_rank = int(os.environ.get("RANK", None))
|
||||||
assert (
|
assert local_rank is not None and distributed_rank is not None, (
|
||||||
local_rank is not None and distributed_rank is not None
|
"Please the set the RANK and LOCAL_RANK environment variables."
|
||||||
), "Please the set the RANK and LOCAL_RANK environment variables."
|
)
|
||||||
return local_rank, distributed_rank
|
return local_rank, distributed_rank
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -158,22 +158,22 @@ def plot_mask(mask, color="r", ax=None):
|
|||||||
def normalize_bbox(bbox_xywh, img_w, img_h):
|
def normalize_bbox(bbox_xywh, img_w, img_h):
|
||||||
# Assumes bbox_xywh is in XYWH format
|
# Assumes bbox_xywh is in XYWH format
|
||||||
if isinstance(bbox_xywh, list):
|
if isinstance(bbox_xywh, list):
|
||||||
assert (
|
assert len(bbox_xywh) == 4, (
|
||||||
len(bbox_xywh) == 4
|
"bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
|
||||||
), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
|
)
|
||||||
normalized_bbox = bbox_xywh.copy()
|
normalized_bbox = bbox_xywh.copy()
|
||||||
normalized_bbox[0] /= img_w
|
normalized_bbox[0] /= img_w
|
||||||
normalized_bbox[1] /= img_h
|
normalized_bbox[1] /= img_h
|
||||||
normalized_bbox[2] /= img_w
|
normalized_bbox[2] /= img_w
|
||||||
normalized_bbox[3] /= img_h
|
normalized_bbox[3] /= img_h
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(bbox_xywh, torch.Tensor), (
|
||||||
bbox_xywh, torch.Tensor
|
"Only torch tensors are supported for batching."
|
||||||
), "Only torch tensors are supported for batching."
|
)
|
||||||
normalized_bbox = bbox_xywh.clone()
|
normalized_bbox = bbox_xywh.clone()
|
||||||
assert (
|
assert normalized_bbox.size(-1) == 4, (
|
||||||
normalized_bbox.size(-1) == 4
|
"bbox_xywh tensor must have last dimension of size 4."
|
||||||
), "bbox_xywh tensor must have last dimension of size 4."
|
)
|
||||||
normalized_bbox[..., 0] /= img_w
|
normalized_bbox[..., 0] /= img_w
|
||||||
normalized_bbox[..., 1] /= img_h
|
normalized_bbox[..., 1] /= img_h
|
||||||
normalized_bbox[..., 2] /= img_w
|
normalized_bbox[..., 2] /= img_w
|
||||||
@@ -244,10 +244,10 @@ def visualize_formatted_frame_output(
|
|||||||
|
|
||||||
num_outputs = len(outputs_list)
|
num_outputs = len(outputs_list)
|
||||||
if titles is None:
|
if titles is None:
|
||||||
titles = [f"Set {i+1}" for i in range(num_outputs)]
|
titles = [f"Set {i + 1}" for i in range(num_outputs)]
|
||||||
assert (
|
assert len(titles) == num_outputs, (
|
||||||
len(titles) == num_outputs
|
"length of `titles` should match that of `outputs_list` if not None."
|
||||||
), "length of `titles` should match that of `outputs_list` if not None."
|
)
|
||||||
|
|
||||||
_, axes = plt.subplots(1, num_outputs, figsize=figsize)
|
_, axes = plt.subplots(1, num_outputs, figsize=figsize)
|
||||||
if num_outputs == 1:
|
if num_outputs == 1:
|
||||||
@@ -703,9 +703,9 @@ def get_all_annotations_for_frame(
|
|||||||
|
|
||||||
# Get the frame
|
# Get the frame
|
||||||
video_df_current = video_df[video_df.id == video_id]
|
video_df_current = video_df[video_df.id == video_id]
|
||||||
assert (
|
assert len(video_df_current) == 1, (
|
||||||
len(video_df_current) == 1
|
f"Expected 1 video row, got {len(video_df_current)}"
|
||||||
), f"Expected 1 video row, got {len(video_df_current)}"
|
)
|
||||||
video_row = video_df_current.iloc[0]
|
video_row = video_df_current.iloc[0]
|
||||||
file_name = video_row.file_names[frame_idx]
|
file_name = video_row.file_names[frame_idx]
|
||||||
file_path = os.path.join(
|
file_path = os.path.join(
|
||||||
@@ -796,7 +796,7 @@ def visualize_prompt_overlay(
|
|||||||
ax.text(
|
ax.text(
|
||||||
x_img + 5,
|
x_img + 5,
|
||||||
y_img - 5,
|
y_img - 5,
|
||||||
f"P{i+1}",
|
f"P{i + 1}",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=10,
|
fontsize=10,
|
||||||
weight="bold",
|
weight="bold",
|
||||||
@@ -828,7 +828,7 @@ def visualize_prompt_overlay(
|
|||||||
ax.text(
|
ax.text(
|
||||||
x_img,
|
x_img,
|
||||||
y_img - 5,
|
y_img - 5,
|
||||||
f"B{i+1}",
|
f"B{i + 1}",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=10,
|
fontsize=10,
|
||||||
weight="bold",
|
weight="bold",
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from concurrent.futures import as_completed, ThreadPoolExecutor
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yt_dlp
|
import yt_dlp
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
annotation_files,
|
annotation_files,
|
||||||
config,
|
config,
|
||||||
@@ -244,9 +243,9 @@ def download_sav():
|
|||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) > 1, "You have to provide the name of the dataset"
|
assert len(sys.argv) > 1, "You have to provide the name of the dataset"
|
||||||
dataset_name = sys.argv[1]
|
dataset_name = sys.argv[1]
|
||||||
assert (
|
assert dataset_name in annotation_files, (
|
||||||
dataset_name in annotation_files
|
f"The dataset can be one of {list(annotation_files.keys())}"
|
||||||
), f"The dataset can be one of {list(annotation_files.keys())}"
|
)
|
||||||
|
|
||||||
if dataset_name == "yt1b":
|
if dataset_name == "yt1b":
|
||||||
download_youtube()
|
download_youtube()
|
||||||
|
|||||||
@@ -68,9 +68,9 @@ def process_image(args):
|
|||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) > 1, "You have to provide the name of the dataset"
|
assert len(sys.argv) > 1, "You have to provide the name of the dataset"
|
||||||
dataset_name = sys.argv[1]
|
dataset_name = sys.argv[1]
|
||||||
assert (
|
assert dataset_name in annotation_files, (
|
||||||
dataset_name in annotation_files
|
f"The dataset can be one of {list(annotation_files.keys())}"
|
||||||
), f"The dataset can be one of {list(annotation_files.keys())}"
|
)
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
for file in annotation_files[dataset_name]:
|
for file in annotation_files[dataset_name]:
|
||||||
with open(os.path.join(config["path_annotations"], file), "r") as f:
|
with open(os.path.join(config["path_annotations"], file), "r") as f:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def main(args, n_workers=20):
|
|||||||
paths = [
|
paths = [
|
||||||
(
|
(
|
||||||
raw_folder_food_images
|
raw_folder_food_images
|
||||||
/ f'{Path(each).stem.split("_")[-1]}{Path(each).suffix}',
|
/ f"{Path(each).stem.split('_')[-1]}{Path(each).suffix}",
|
||||||
processed_folder / each,
|
processed_folder / each,
|
||||||
)
|
)
|
||||||
for each in img_filenames
|
for each in img_filenames
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
@@ -58,9 +58,9 @@ class YtVideoPrep:
|
|||||||
df = self.yt1b_start_end_time_df[
|
df = self.yt1b_start_end_time_df[
|
||||||
self.yt1b_start_end_time_df.saco_yt1b_id == self.saco_yt1b_id
|
self.yt1b_start_end_time_df.saco_yt1b_id == self.saco_yt1b_id
|
||||||
]
|
]
|
||||||
assert (
|
assert len(df) == 1, (
|
||||||
len(df) == 1
|
f"Expected exactly 1 row for saco_yt1b_id: {self.saco_yt1b_id}, found {len(df)}"
|
||||||
), f"Expected exactly 1 row for saco_yt1b_id: {self.saco_yt1b_id}, found {len(df)}"
|
)
|
||||||
id_and_frame_map_row = df.iloc[0]
|
id_and_frame_map_row = df.iloc[0]
|
||||||
|
|
||||||
yt_video_id = (
|
yt_video_id = (
|
||||||
@@ -82,9 +82,9 @@ class YtVideoPrep:
|
|||||||
def download_youtube_video(self):
|
def download_youtube_video(self):
|
||||||
video_url = f"https://youtube.com/watch?v={self.yt_video_id}"
|
video_url = f"https://youtube.com/watch?v={self.yt_video_id}"
|
||||||
|
|
||||||
assert os.path.exists(
|
assert os.path.exists(self.cookies_file), (
|
||||||
self.cookies_file
|
f"Cookies file '{self.cookies_file}' not found. Must have it to download videos."
|
||||||
), f"Cookies file '{self.cookies_file}' not found. Must have it to download videos."
|
)
|
||||||
|
|
||||||
outtmpl = self.raw_video_path
|
outtmpl = self.raw_video_path
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user