apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -296,9 +296,9 @@ def agent_inference(
|
||||
assert LATEST_SAM3_TEXT_PROMPT != ""
|
||||
|
||||
# Make sure that the last message is a image
|
||||
assert (
|
||||
messages[-1]["content"][1]["type"] == "image"
|
||||
), "Second content element should be an image"
|
||||
assert messages[-1]["content"][1]["type"] == "image", (
|
||||
"Second content element should be an image"
|
||||
)
|
||||
messages.pop() # Remove the last user message
|
||||
# Add simplified replacement message
|
||||
simplified_message = {
|
||||
@@ -318,7 +318,7 @@ def agent_inference(
|
||||
|
||||
# MLLM check the mask one by one
|
||||
for i in range(num_masks):
|
||||
print(f"🔍 Checking mask {i+1}/{num_masks}...")
|
||||
print(f"🔍 Checking mask {i + 1}/{num_masks}...")
|
||||
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
|
||||
|
||||
image_w_zoomed_in_mask_i_path = os.path.join(
|
||||
@@ -363,7 +363,7 @@ def agent_inference(
|
||||
raise ValueError(
|
||||
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
|
||||
)
|
||||
print(f"Generated text for mask {i+1}: {checking_generated_text}")
|
||||
print(f"Generated text for mask {i + 1}: {checking_generated_text}")
|
||||
verdict = (
|
||||
checking_generated_text.split("<verdict>")[-1]
|
||||
.split("</verdict>")[0]
|
||||
@@ -371,11 +371,11 @@ def agent_inference(
|
||||
)
|
||||
if "Accept" in verdict:
|
||||
assert not "Reject" in verdict
|
||||
print(f"Mask {i+1} accepted, keeping it in the outputs.")
|
||||
print(f"Mask {i + 1} accepted, keeping it in the outputs.")
|
||||
masks_to_keep.append(i)
|
||||
elif "Reject" in verdict:
|
||||
assert not "Accept" in verdict
|
||||
print(f"Mask {i+1} rejected, removing it from the outputs.")
|
||||
print(f"Mask {i + 1} rejected, removing it from the outputs.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
|
||||
@@ -397,7 +397,7 @@ def agent_inference(
|
||||
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
|
||||
).replace(
|
||||
".png",
|
||||
f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace(
|
||||
f"_selected_masks_{'-'.join(map(str, [i + 1 for i in masks_to_keep]))}.png".replace(
|
||||
"/", "_"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_xywh
|
||||
from sam3.train.masks_ops import rle_encode
|
||||
|
||||
|
||||
@@ -84,9 +84,9 @@ class BoxMode(IntEnum):
|
||||
], "Relative mode not yet supported!"
|
||||
|
||||
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
|
||||
assert (
|
||||
arr.shape[-1] == 5
|
||||
), "The last dimension of input shape must be 5 for XYWHA format"
|
||||
assert arr.shape[-1] == 5, (
|
||||
"The last dimension of input shape must be 5 for XYWHA format"
|
||||
)
|
||||
original_dtype = arr.dtype
|
||||
arr = arr.double()
|
||||
|
||||
@@ -244,9 +244,9 @@ class Boxes:
|
||||
if isinstance(item, int):
|
||||
return Boxes(self.tensor[item].view(1, -1))
|
||||
b = self.tensor[item]
|
||||
assert (
|
||||
b.dim() == 2
|
||||
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
|
||||
assert b.dim() == 2, (
|
||||
"Indexing on Boxes with {} failed to return a matrix!".format(item)
|
||||
)
|
||||
return Boxes(b)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -425,7 +425,7 @@ def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
|
||||
Tensor: iou, sized [N].
|
||||
"""
|
||||
assert len(boxes1) == len(boxes2), (
|
||||
"boxlists should have the same" "number of entries, got {}, {}".format(
|
||||
"boxlists should have the samenumber of entries, got {}, {}".format(
|
||||
len(boxes1), len(boxes2)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from torch import device
|
||||
|
||||
from .boxes import Boxes
|
||||
from .memory import retry_if_cuda_oom
|
||||
|
||||
from .roi_align import ROIAlign
|
||||
|
||||
|
||||
@@ -142,10 +141,10 @@ class BitMasks:
|
||||
if isinstance(item, int):
|
||||
return BitMasks(self.tensor[item].unsqueeze(0))
|
||||
m = self.tensor[item]
|
||||
assert (
|
||||
m.dim() == 3
|
||||
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
||||
item, m.shape
|
||||
assert m.dim() == 3, (
|
||||
"Indexing on BitMasks with {} returns a tensor with shape {}!".format(
|
||||
item, m.shape
|
||||
)
|
||||
)
|
||||
return BitMasks(m)
|
||||
|
||||
|
||||
@@ -363,9 +363,9 @@ class RotatedBoxes(Boxes):
|
||||
if isinstance(item, int):
|
||||
return RotatedBoxes(self.tensor[item].view(1, -1))
|
||||
b = self.tensor[item]
|
||||
assert (
|
||||
b.dim() == 2
|
||||
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
|
||||
assert b.dim() == 2, (
|
||||
"Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
|
||||
)
|
||||
return RotatedBoxes(b)
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
||||
@@ -20,7 +20,6 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
from PIL import Image
|
||||
|
||||
from .boxes import Boxes, BoxMode
|
||||
|
||||
from .color_map import random_color
|
||||
from .keypoints import Keypoints
|
||||
from .masks import BitMasks, PolygonMasks
|
||||
@@ -222,9 +221,9 @@ class _PanopticPrediction:
|
||||
empty_ids.append(id)
|
||||
if len(empty_ids) == 0:
|
||||
return np.zeros(self._seg.shape, dtype=np.uint8)
|
||||
assert (
|
||||
len(empty_ids) == 1
|
||||
), ">1 ids corresponds to no labels. This is currently not supported"
|
||||
assert len(empty_ids) == 1, (
|
||||
">1 ids corresponds to no labels. This is currently not supported"
|
||||
)
|
||||
return (self._seg != empty_ids[0]).numpy().astype(np.bool)
|
||||
|
||||
def semantic_masks(self):
|
||||
|
||||
@@ -41,7 +41,7 @@ def run_single_image_inference(
|
||||
print(f"Output JSON {output_json_path} already exists. Skipping.")
|
||||
return
|
||||
|
||||
print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ")
|
||||
print(f"{'-' * 30} Starting SAM 3 Agent Session... {'-' * 30} ")
|
||||
agent_history, final_output_dict, rendered_final_output = agent_inference(
|
||||
image_path,
|
||||
text_prompt,
|
||||
@@ -50,7 +50,7 @@ def run_single_image_inference(
|
||||
output_dir=output_dir,
|
||||
debug=debug,
|
||||
)
|
||||
print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ")
|
||||
print(f"{'-' * 30} End of SAM 3 Agent Session... {'-' * 30} ")
|
||||
|
||||
final_output_dict["text_prompt"] = text_prompt
|
||||
final_output_dict["image_path"] = image_path
|
||||
|
||||
@@ -73,7 +73,9 @@ def visualize(
|
||||
idx = int(zoom_in_index)
|
||||
num_masks = len(input_json.get("pred_masks", []))
|
||||
if idx < 0 or idx >= num_masks:
|
||||
raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).")
|
||||
raise ValueError(
|
||||
f"zoom_in_index {idx} is out of range (0..{num_masks - 1})."
|
||||
)
|
||||
|
||||
# (1) Replicate zoom_in_and_visualize
|
||||
object_data = {
|
||||
|
||||
@@ -126,9 +126,9 @@ class COCOCustom(COCO):
|
||||
# MODIFICATION: faster and cached subset check
|
||||
if not hasattr(self, "img_id_set"):
|
||||
self.img_id_set = set(self.getImgIds())
|
||||
assert set(annsImgIds).issubset(
|
||||
self.img_id_set
|
||||
), "Results do not correspond to current coco set"
|
||||
assert set(annsImgIds).issubset(self.img_id_set), (
|
||||
"Results do not correspond to current coco set"
|
||||
)
|
||||
# END MODIFICATION
|
||||
if "caption" in anns[0]:
|
||||
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
||||
@@ -301,9 +301,9 @@ class CGF1Eval(COCOeval):
|
||||
TP = (match_scores >= thresh).sum()
|
||||
FP = len(dt) - TP
|
||||
FN = len(gt) - TP
|
||||
assert (
|
||||
FP >= 0 and FN >= 0
|
||||
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
||||
assert FP >= 0 and FN >= 0, (
|
||||
f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
||||
)
|
||||
TPs.append(TP)
|
||||
FPs.append(FP)
|
||||
FNs.append(FN)
|
||||
@@ -599,9 +599,9 @@ class CGF1Evaluator:
|
||||
|
||||
"""
|
||||
assert len(self.coco_gts) > 0, "No ground truth provided for evaluation."
|
||||
assert len(self.coco_gts) == len(
|
||||
self.coco_evals
|
||||
), "Mismatch in number of ground truths and evaluators."
|
||||
assert len(self.coco_gts) == len(self.coco_evals), (
|
||||
"Mismatch in number of ground truths and evaluators."
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(f"Loading predictions from {pred_file}")
|
||||
@@ -668,17 +668,17 @@ class CGF1Evaluator:
|
||||
if len(scorings) == 1:
|
||||
return scorings[0]
|
||||
|
||||
assert (
|
||||
scorings[0].ndim == 3
|
||||
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
||||
assert (
|
||||
scorings[0].shape[0] == 1
|
||||
), f"Expecting a single category, got {scorings[0].shape[0]}"
|
||||
assert scorings[0].ndim == 3, (
|
||||
f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
||||
)
|
||||
assert scorings[0].shape[0] == 1, (
|
||||
f"Expecting a single category, got {scorings[0].shape[0]}"
|
||||
)
|
||||
|
||||
for scoring in scorings:
|
||||
assert (
|
||||
scoring.shape == scorings[0].shape
|
||||
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
||||
assert scoring.shape == scorings[0].shape, (
|
||||
f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
||||
)
|
||||
|
||||
selected_imgs = []
|
||||
for img_id in range(scorings[0].shape[-1]):
|
||||
|
||||
@@ -18,19 +18,15 @@ import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pycocotools.mask as mask_utils
|
||||
import torch
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
from sam3.train.masks_ops import rle_encode
|
||||
|
||||
from sam3.train.utils.distributed import (
|
||||
all_gather,
|
||||
gather_to_rank_0_via_filesys,
|
||||
@@ -755,9 +751,9 @@ def loadRes(self, resFile):
|
||||
anns = resFile
|
||||
assert type(anns) == list, "results in not an array of objects"
|
||||
annsImgIds = [ann["image_id"] for ann in anns]
|
||||
assert set(annsImgIds) == (
|
||||
set(annsImgIds) & set(self.getImgIds())
|
||||
), "Results do not correspond to current coco set"
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
|
||||
"Results do not correspond to current coco set"
|
||||
)
|
||||
if "caption" in anns[0]:
|
||||
imgIds = set([img["id"] for img in res.dataset["images"]]) & set(
|
||||
[ann["image_id"] for ann in anns]
|
||||
|
||||
@@ -83,9 +83,9 @@ class PredictionDumper:
|
||||
self.merge_predictions = merge_predictions
|
||||
self.pred_file_evaluators = pred_file_evaluators
|
||||
if self.pred_file_evaluators is not None:
|
||||
assert (
|
||||
merge_predictions
|
||||
), "merge_predictions must be True if pred_file_evaluators are provided"
|
||||
assert merge_predictions, (
|
||||
"merge_predictions must be True if pred_file_evaluators are provided"
|
||||
)
|
||||
assert self.dump_dir is not None, "dump_dir must be provided"
|
||||
|
||||
if is_main_process():
|
||||
|
||||
@@ -13,11 +13,9 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import pycocotools.mask as maskUtils
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
|
||||
from sam3.eval.coco_eval import CocoEvaluator
|
||||
from sam3.train.masks_ops import compute_F_measure
|
||||
from sam3.train.utils.distributed import is_main_process
|
||||
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
@@ -156,9 +154,9 @@ class DemoEval(COCOeval):
|
||||
TP = (match_scores >= thresh).sum()
|
||||
FP = len(dt) - TP
|
||||
FN = len(gt) - TP
|
||||
assert (
|
||||
FP >= 0 and FN >= 0
|
||||
), f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
||||
assert FP >= 0 and FN >= 0, (
|
||||
f"FP: {FP}, FN: {FN}, TP: {TP}, match_scores: {match_scores}, len(dt): {len(dt)}, len(gt): {len(gt)}, ious: {ious}"
|
||||
)
|
||||
TPs.append(TP)
|
||||
FPs.append(FP)
|
||||
FNs.append(FN)
|
||||
@@ -528,17 +526,17 @@ class DemoEvaluator(CocoEvaluator):
|
||||
if len(scorings) == 1:
|
||||
return scorings[0]
|
||||
|
||||
assert (
|
||||
scorings[0].ndim == 3
|
||||
), f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
||||
assert (
|
||||
scorings[0].shape[0] == 1
|
||||
), f"Expecting a single category, got {scorings[0].shape[0]}"
|
||||
assert scorings[0].ndim == 3, (
|
||||
f"Expecting results in [numCats, numAreas, numImgs] format, got {scorings[0].shape}"
|
||||
)
|
||||
assert scorings[0].shape[0] == 1, (
|
||||
f"Expecting a single category, got {scorings[0].shape[0]}"
|
||||
)
|
||||
|
||||
for scoring in scorings:
|
||||
assert (
|
||||
scoring.shape == scorings[0].shape
|
||||
), f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
||||
assert scoring.shape == scorings[0].shape, (
|
||||
f"Shape mismatch: {scoring.shape}, {scorings[0].shape}"
|
||||
)
|
||||
|
||||
selected_imgs = []
|
||||
for img_id in range(scorings[0].shape[-1]):
|
||||
|
||||
@@ -255,9 +255,10 @@ class Evaluator:
|
||||
if show_progressbar and TQDM_IMPORTED:
|
||||
seq_list_sorted = sorted(seq_list)
|
||||
|
||||
with Pool(config["NUM_PARALLEL_CORES"]) as pool, tqdm.tqdm(
|
||||
total=len(seq_list)
|
||||
) as pbar:
|
||||
with (
|
||||
Pool(config["NUM_PARALLEL_CORES"]) as pool,
|
||||
tqdm.tqdm(total=len(seq_list)) as pbar,
|
||||
):
|
||||
_eval_sequence = partial(
|
||||
eval_sequence,
|
||||
dataset=dataset,
|
||||
|
||||
@@ -83,9 +83,9 @@ class PostProcessImage(nn.Module):
|
||||
ret_tensordict: Experimental argument. If true, return a tensordict.TensorDict instead of a list of dictionaries for easier manipulation.
|
||||
"""
|
||||
if ret_tensordict:
|
||||
assert (
|
||||
consistent is True
|
||||
), "We don't support returning TensorDict if the outputs have different shapes" # NOTE: It's possible but we don't support it.
|
||||
assert consistent is True, (
|
||||
"We don't support returning TensorDict if the outputs have different shapes"
|
||||
) # NOTE: It's possible but we don't support it.
|
||||
assert self.detection_threshold <= 0.0, "TODO: implement?"
|
||||
try:
|
||||
from tensordict import TensorDict
|
||||
@@ -118,7 +118,9 @@ class PostProcessImage(nn.Module):
|
||||
|
||||
if boxes is None:
|
||||
assert out_masks is not None
|
||||
assert not ret_tensordict, "We don't support returning TensorDict if the output does not contain boxes"
|
||||
assert not ret_tensordict, (
|
||||
"We don't support returning TensorDict if the output does not contain boxes"
|
||||
)
|
||||
B = len(out_masks)
|
||||
boxes = [None] * B
|
||||
scores = [None] * B
|
||||
@@ -418,9 +420,9 @@ class PostProcessAPIVideo(PostProcessImage):
|
||||
if video_id == -1:
|
||||
video_id = unique_vid_id.item()
|
||||
else:
|
||||
assert (
|
||||
video_id == unique_vid_id.item()
|
||||
), "We can only postprocess one video per datapoint"
|
||||
assert video_id == unique_vid_id.item(), (
|
||||
"We can only postprocess one video per datapoint"
|
||||
)
|
||||
# keeping track of which objects appear in the current frame
|
||||
obj_ids_per_frame = frame_outs["pred_object_ids"]
|
||||
assert obj_ids_per_frame.size(-1) == frame_outs["pred_logits"].size(-2)
|
||||
|
||||
@@ -95,9 +95,9 @@ class YTVIS(COCO):
|
||||
anns = resFile
|
||||
assert type(anns) == list, "results is not an array of objects"
|
||||
annsImgIds = [ann["image_id"] for ann in anns]
|
||||
assert set(annsImgIds) == (
|
||||
set(annsImgIds) & set(self.getImgIds())
|
||||
), "Results do not correspond to current coco set"
|
||||
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), (
|
||||
"Results do not correspond to current coco set"
|
||||
)
|
||||
if "bboxes" in anns[0] and not anns[0]["bboxes"] == []:
|
||||
res.dataset["categories"] = copy.deepcopy(self.dataset["categories"])
|
||||
for id, ann in enumerate(anns):
|
||||
|
||||
@@ -109,9 +109,7 @@ class YTVISevalMixin:
|
||||
) # Num preds x Num GTS x Num frames
|
||||
inter = inter.sum(-1)
|
||||
union = union.sum(-1)
|
||||
assert (
|
||||
union > 0
|
||||
).all(), (
|
||||
assert (union > 0).all(), (
|
||||
"There exists a tracklet with zero GTs across time. This is suspicious"
|
||||
)
|
||||
return inter / union
|
||||
@@ -136,9 +134,9 @@ class YTVISevalMixin:
|
||||
iou = inter / union
|
||||
assert iou >= 0 and iou <= 1, "Encountered an error in IoU computation"
|
||||
else:
|
||||
assert np.isclose(inter, 0) and np.isclose(
|
||||
union, 0
|
||||
), "Encountered an error in IoU computation"
|
||||
assert np.isclose(inter, 0) and np.isclose(union, 0), (
|
||||
"Encountered an error in IoU computation"
|
||||
)
|
||||
iou = 1
|
||||
return iou
|
||||
|
||||
@@ -206,16 +204,16 @@ class YTVISResultsWriter:
|
||||
if len(prediction) == 0:
|
||||
continue
|
||||
for k in ["boxes", "scores", "labels"]:
|
||||
assert (
|
||||
k in prediction
|
||||
), f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
|
||||
assert k in prediction, (
|
||||
f"Expected predictions to have `{k}` key, available keys are {prediction.keys()}"
|
||||
)
|
||||
if self.save_per_frame_scores:
|
||||
assert (
|
||||
"per_frame_scores" in prediction
|
||||
), f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
|
||||
assert xor(
|
||||
"masks" in prediction, "masks_rle" in prediction
|
||||
), f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
|
||||
assert "per_frame_scores" in prediction, (
|
||||
f"Expected predictions to have `per_frame_scores` key, available keys are {prediction.keys()}"
|
||||
)
|
||||
assert xor("masks" in prediction, "masks_rle" in prediction), (
|
||||
f"Expected predictions to have either `masks` key or `masks_rle` key, available keys are {prediction.keys()}"
|
||||
)
|
||||
|
||||
boxes = prediction["boxes"]
|
||||
boxes = convert_to_xywh(boxes).tolist()
|
||||
@@ -223,9 +221,9 @@ class YTVISResultsWriter:
|
||||
labels = prediction["labels"].tolist()
|
||||
if "masks" in prediction:
|
||||
masks = prediction["masks"].squeeze(2)
|
||||
assert (
|
||||
masks.ndim == 4
|
||||
), "Expected masks to be of shape(N_preds,T_frames,H,W)"
|
||||
assert masks.ndim == 4, (
|
||||
"Expected masks to be of shape(N_preds,T_frames,H,W)"
|
||||
)
|
||||
|
||||
areas = [mask.flatten(1).sum(1).tolist() for mask in masks]
|
||||
rles = [rle_encode(masklet) for masklet in masks]
|
||||
|
||||
@@ -42,9 +42,9 @@ def get_logger(name, level=logging.INFO):
|
||||
"""A command line logger."""
|
||||
if "LOG_LEVEL" in os.environ:
|
||||
level = os.environ["LOG_LEVEL"].upper()
|
||||
assert (
|
||||
level in LOG_LEVELS
|
||||
), f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
|
||||
assert level in LOG_LEVELS, (
|
||||
f"Invalid LOG_LEVEL: {level}, must be one of {list(LOG_LEVELS.keys())}"
|
||||
)
|
||||
level = LOG_LEVELS[level]
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
@@ -7,7 +7,6 @@ Misc functions, including distributed helpers.
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
|
||||
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
|
||||
|
||||
@@ -29,9 +28,9 @@ def interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
assert (
|
||||
input.shape[0] != 0 or input.shape[1] != 0
|
||||
), "At least one of the two first dimensions must be non zero"
|
||||
assert input.shape[0] != 0 or input.shape[1] != 0, (
|
||||
"At least one of the two first dimensions must be non zero"
|
||||
)
|
||||
|
||||
if input.shape[1] == 0:
|
||||
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
|
||||
|
||||
@@ -9,18 +9,13 @@ Inspired from Pytorch's version, adds the pre-norm variant
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.sam.transformer import RoPEAttention
|
||||
|
||||
from torch import nn, Tensor
|
||||
from torchvision.ops.roi_align import RoIAlign
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .model_misc import (
|
||||
gen_sineembed_for_position,
|
||||
get_activation_fn,
|
||||
@@ -444,9 +439,9 @@ class TransformerDecoder(nn.Module):
|
||||
- valid_ratios/spatial_shapes: bs, nlevel, 2
|
||||
"""
|
||||
if memory_mask is not None:
|
||||
assert (
|
||||
self.boxRPB == "none"
|
||||
), "inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||
assert self.boxRPB == "none", (
|
||||
"inputting a memory_mask in the presence of boxRPB is unexpected/not implemented"
|
||||
)
|
||||
|
||||
apply_dac = apply_dac if apply_dac is not None else self.dac
|
||||
if apply_dac:
|
||||
@@ -516,18 +511,18 @@ class TransformerDecoder(nn.Module):
|
||||
query_pos = self.ref_point_head(query_sine_embed) # nq, bs, d_model
|
||||
|
||||
if self.boxRPB != "none" and reference_boxes is not None:
|
||||
assert (
|
||||
spatial_shapes.shape[0] == 1
|
||||
), "only single scale support implemented"
|
||||
assert spatial_shapes.shape[0] == 1, (
|
||||
"only single scale support implemented"
|
||||
)
|
||||
memory_mask = self._get_rpb_matrix(
|
||||
reference_boxes,
|
||||
(spatial_shapes[0, 0], spatial_shapes[0, 1]),
|
||||
)
|
||||
memory_mask = memory_mask.flatten(0, 1) # (bs*n_heads, nq, H*W)
|
||||
if self.training:
|
||||
assert (
|
||||
self.use_act_checkpoint
|
||||
), "Activation checkpointing not enabled in the decoder"
|
||||
assert self.use_act_checkpoint, (
|
||||
"Activation checkpointing not enabled in the decoder"
|
||||
)
|
||||
output, presence_out = activation_ckpt_wrapper(layer)(
|
||||
tgt=output,
|
||||
tgt_query_pos=query_pos,
|
||||
@@ -676,9 +671,9 @@ class TransformerEncoderCrossAttention(nn.Module):
|
||||
src_pos[0],
|
||||
)
|
||||
|
||||
assert (
|
||||
src.shape[1] == prompt.shape[1]
|
||||
), "Batch size must be the same for src and prompt"
|
||||
assert src.shape[1] == prompt.shape[1], (
|
||||
"Batch size must be the same for src and prompt"
|
||||
)
|
||||
|
||||
output = src
|
||||
|
||||
|
||||
@@ -322,9 +322,9 @@ class TransformerEncoder(nn.Module):
|
||||
return reference_points
|
||||
|
||||
def _prepare_multilevel_features(self, srcs, masks, pos_embeds):
|
||||
assert (
|
||||
len(srcs) == self.num_feature_levels
|
||||
), "mismatch between expected and received # of feature levels"
|
||||
assert len(srcs) == self.num_feature_levels, (
|
||||
"mismatch between expected and received # of feature levels"
|
||||
)
|
||||
|
||||
src_flatten = []
|
||||
mask_flatten = []
|
||||
@@ -406,9 +406,9 @@ class TransformerEncoder(nn.Module):
|
||||
- spatial_shapes: Spatial dimensions of each feature level
|
||||
- valid_ratios: Valid ratios for each feature level
|
||||
"""
|
||||
assert (
|
||||
len(src) == self.num_feature_levels
|
||||
), "must be equal to num_feature_levels"
|
||||
assert len(src) == self.num_feature_levels, (
|
||||
"must be equal to num_feature_levels"
|
||||
)
|
||||
if src_key_padding_masks is not None:
|
||||
assert len(src_key_padding_masks) == self.num_feature_levels
|
||||
if pos is not None:
|
||||
@@ -538,9 +538,9 @@ class TransformerEncoderFusion(TransformerEncoder):
|
||||
else None
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
x.dim == 4 for x in src
|
||||
), "expected list of (bs, c, h, w) tensors"
|
||||
assert all(x.dim == 4 for x in src), (
|
||||
"expected list of (bs, c, h, w) tensors"
|
||||
)
|
||||
|
||||
if self.add_pooled_text_to_img_feat:
|
||||
# Fusion: Add mean pooled text to image features
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing_extensions import override
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .model_misc import get_clones
|
||||
|
||||
|
||||
@@ -148,54 +147,42 @@ class Prompt:
|
||||
)
|
||||
|
||||
# Dimension checks
|
||||
assert (
|
||||
box_embeddings is not None
|
||||
and list(box_embeddings.shape[:2])
|
||||
== [
|
||||
box_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
assert (
|
||||
box_mask is not None
|
||||
and list(box_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
box_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
assert (
|
||||
point_embeddings is not None
|
||||
and list(point_embeddings.shape[:2])
|
||||
== [
|
||||
point_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||
assert (
|
||||
point_mask is not None
|
||||
and list(point_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
point_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||
assert (
|
||||
box_labels is not None
|
||||
and list(box_labels.shape)
|
||||
== [
|
||||
box_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
assert (
|
||||
point_labels is not None
|
||||
and list(point_labels.shape)
|
||||
== [
|
||||
point_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||
assert box_embeddings is not None and list(box_embeddings.shape[:2]) == [
|
||||
box_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for box embeddings. Expected [{box_seq_len}, {bs}, *] got {box_embeddings.shape}"
|
||||
)
|
||||
assert box_mask is not None and list(box_mask.shape) == [
|
||||
bs,
|
||||
box_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for box mask. Expected [{bs}, {box_seq_len}] got {box_mask.shape}"
|
||||
)
|
||||
assert point_embeddings is not None and list(point_embeddings.shape[:2]) == [
|
||||
point_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for point embeddings. Expected [{point_seq_len}, {bs}, *] got {point_embeddings.shape}"
|
||||
)
|
||||
assert point_mask is not None and list(point_mask.shape) == [
|
||||
bs,
|
||||
point_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for point mask. Expected [{bs}, {point_seq_len}] got {point_mask.shape}"
|
||||
)
|
||||
assert box_labels is not None and list(box_labels.shape) == [
|
||||
box_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for box labels. Expected [{box_seq_len}, {bs}] got {box_labels.shape}"
|
||||
)
|
||||
assert point_labels is not None and list(point_labels.shape) == [
|
||||
point_seq_len,
|
||||
bs,
|
||||
], (
|
||||
f"Wrong dimension for point labels. Expected [{point_seq_len}, {bs}] got {point_labels.shape}"
|
||||
)
|
||||
assert (
|
||||
# Allowed to be None, we leave it to the encoder to check for validity before encoding.
|
||||
mask_embeddings is None
|
||||
@@ -204,41 +191,41 @@ class Prompt:
|
||||
mask_seq_len,
|
||||
bs,
|
||||
]
|
||||
), f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||
assert (
|
||||
mask_mask is None
|
||||
or list(mask_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
mask_seq_len,
|
||||
]
|
||||
), f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||
), (
|
||||
f"Wrong dimension for mask embeddings. Expected [{mask_seq_len}, {bs}, *] got {mask_embeddings.shape}"
|
||||
)
|
||||
assert mask_mask is None or list(mask_mask.shape) == [
|
||||
bs,
|
||||
mask_seq_len,
|
||||
], (
|
||||
f"Wrong dimension for mask attn. mask. Expected [{bs}, {mask_seq_len}] got {mask_mask.shape}"
|
||||
)
|
||||
|
||||
# Device checks
|
||||
assert (
|
||||
box_embeddings is not None and box_embeddings.device == device
|
||||
), f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
assert (
|
||||
box_mask is not None and box_mask.device == device
|
||||
), f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
assert (
|
||||
box_labels is not None and box_labels.device == device
|
||||
), f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
assert (
|
||||
point_embeddings is not None and point_embeddings.device == device
|
||||
), f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||
assert (
|
||||
point_mask is not None and point_mask.device == device
|
||||
), f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||
assert (
|
||||
point_labels is not None and point_labels.device == device
|
||||
), f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||
assert (
|
||||
mask_embeddings is None or mask_embeddings.device == device
|
||||
), f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||
assert (
|
||||
mask_mask is None or mask_mask.device == device
|
||||
), f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||
assert box_embeddings is not None and box_embeddings.device == device, (
|
||||
f"Expected box embeddings to be on device {device}, got {box_embeddings.device}"
|
||||
)
|
||||
assert box_mask is not None and box_mask.device == device, (
|
||||
f"Expected box mask to be on device {device}, got {box_mask.device}"
|
||||
)
|
||||
assert box_labels is not None and box_labels.device == device, (
|
||||
f"Expected box labels to be on device {device}, got {box_labels.device}"
|
||||
)
|
||||
assert point_embeddings is not None and point_embeddings.device == device, (
|
||||
f"Expected point embeddings to be on device {device}, got {point_embeddings.device}"
|
||||
)
|
||||
assert point_mask is not None and point_mask.device == device, (
|
||||
f"Expected point mask to be on device {device}, got {point_mask.device}"
|
||||
)
|
||||
assert point_labels is not None and point_labels.device == device, (
|
||||
f"Expected point labels to be on device {device}, got {point_labels.device}"
|
||||
)
|
||||
assert mask_embeddings is None or mask_embeddings.device == device, (
|
||||
f"Expected mask embeddings to be on device {device}, got {mask_embeddings.device}"
|
||||
)
|
||||
assert mask_mask is None or mask_mask.device == device, (
|
||||
f"Expected mask attn. mask to be on device {device}, got {mask_mask.device}"
|
||||
)
|
||||
|
||||
self.box_embeddings = box_embeddings
|
||||
self.point_embeddings = point_embeddings
|
||||
@@ -264,30 +251,30 @@ class Prompt:
|
||||
if point_embeddings is not None:
|
||||
point_seq_len = point_embeddings.shape[0]
|
||||
if bs is not None:
|
||||
assert (
|
||||
bs == point_embeddings.shape[1]
|
||||
), f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||
assert bs == point_embeddings.shape[1], (
|
||||
f"Batch size mismatch between box and point embeddings. Got {bs} and {point_embeddings.shape[1]}."
|
||||
)
|
||||
else:
|
||||
bs = point_embeddings.shape[1]
|
||||
if device is not None:
|
||||
assert (
|
||||
device == point_embeddings.device
|
||||
), "Device mismatch between box and point embeddings"
|
||||
assert device == point_embeddings.device, (
|
||||
"Device mismatch between box and point embeddings"
|
||||
)
|
||||
else:
|
||||
device = point_embeddings.device
|
||||
|
||||
if mask_embeddings is not None:
|
||||
mask_seq_len = mask_embeddings.shape[0]
|
||||
if bs is not None:
|
||||
assert (
|
||||
bs == mask_embeddings.shape[1]
|
||||
), f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||
assert bs == mask_embeddings.shape[1], (
|
||||
f"Batch size mismatch between box/point and mask embedding. Got {bs} and {mask_embeddings.shape[1]}"
|
||||
)
|
||||
else:
|
||||
bs = mask_embeddings.shape[1]
|
||||
if device is not None:
|
||||
assert (
|
||||
device == mask_embeddings.device
|
||||
), "Device mismatch between box/point and mask embeddings."
|
||||
assert device == mask_embeddings.device, (
|
||||
"Device mismatch between box/point and mask embeddings."
|
||||
)
|
||||
else:
|
||||
device = mask_embeddings.device
|
||||
|
||||
@@ -539,9 +526,9 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
if add_cls:
|
||||
self.cls_embed = torch.nn.Embedding(1, self.d_model)
|
||||
|
||||
assert (
|
||||
points_direct_project or points_pos_enc or points_pool
|
||||
), "Error: need at least one way to encode points"
|
||||
assert points_direct_project or points_pos_enc or points_pool, (
|
||||
"Error: need at least one way to encode points"
|
||||
)
|
||||
assert (
|
||||
encode_boxes_as_points
|
||||
or boxes_direct_project
|
||||
@@ -583,16 +570,16 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
|
||||
self.encode = None
|
||||
if num_layers > 0:
|
||||
assert (
|
||||
add_cls
|
||||
), "It's currently highly recommended to add a CLS when using a transformer"
|
||||
assert add_cls, (
|
||||
"It's currently highly recommended to add a CLS when using a transformer"
|
||||
)
|
||||
self.encode = get_clones(layer, num_layers)
|
||||
self.encode_norm = nn.LayerNorm(self.d_model)
|
||||
|
||||
if mask_encoder is not None:
|
||||
assert isinstance(
|
||||
mask_encoder, MaskEncoder
|
||||
), f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||
assert isinstance(mask_encoder, MaskEncoder), (
|
||||
f"Expected mask_encoder of type MaskEncoder. Got {type(mask_encoder)}."
|
||||
)
|
||||
if add_mask_label:
|
||||
self.mask_label_embed = torch.nn.Embedding(2, self.d_model)
|
||||
self.add_mask_label = add_mask_label
|
||||
@@ -701,16 +688,15 @@ class SequenceGeometryEncoder(nn.Module):
|
||||
img_feats: torch.Tensor = None,
|
||||
):
|
||||
n_masks, bs = masks.shape[:2]
|
||||
assert (
|
||||
n_masks == 1
|
||||
), "We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||
assert (
|
||||
list(attn_mask.shape)
|
||||
== [
|
||||
bs,
|
||||
n_masks,
|
||||
]
|
||||
), f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||
assert n_masks == 1, (
|
||||
"We assume one mask per prompt for now. Code should still be functional if this assertion is removed."
|
||||
)
|
||||
assert list(attn_mask.shape) == [
|
||||
bs,
|
||||
n_masks,
|
||||
], (
|
||||
f"Expected attn_mask to be of shape {bs}x{n_masks}. Got {list(attn_mask.shape)}."
|
||||
)
|
||||
masks, pos = self.mask_encoder(
|
||||
masks=masks.flatten(0, 1).float(),
|
||||
pix_feat=img_feats,
|
||||
|
||||
@@ -13,9 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from sam3.logger import get_logger
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@@ -248,7 +248,9 @@ class UniversalSegmentationHead(SegmentationHead):
|
||||
self.d_model = hidden_dim
|
||||
|
||||
if dot_product_scorer is not None:
|
||||
assert presence_head, "Specifying a dot product scorer without a presence head is likely a mistake"
|
||||
assert presence_head, (
|
||||
"Specifying a dot product scorer without a presence head is likely a mistake"
|
||||
)
|
||||
|
||||
self.presence_head = None
|
||||
if presence_head:
|
||||
|
||||
@@ -62,9 +62,9 @@ class SimpleMaskDownSampler(nn.Module):
|
||||
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
||||
self.interpol_size = interpol_size
|
||||
if self.interpol_size is not None:
|
||||
assert isinstance(
|
||||
self.interpol_size, (list, tuple)
|
||||
), f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
||||
assert isinstance(self.interpol_size, (list, tuple)), (
|
||||
f"Unsupported type {type(self.interpol_size)}. Should be a list or tuple."
|
||||
)
|
||||
self.interpol_size = list(interpol_size)
|
||||
assert len(self.interpol_size) == 2
|
||||
|
||||
|
||||
@@ -330,9 +330,9 @@ class SAM3Output(list):
|
||||
self.output = output
|
||||
else:
|
||||
self.output = []
|
||||
assert isinstance(
|
||||
iter_mode, SAM3Output.IterMode
|
||||
), f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
|
||||
assert isinstance(iter_mode, SAM3Output.IterMode), (
|
||||
f"iter_mode shoulf be of enum type 'SAM3Output.IterMode'. Got {type(iter_mode)}"
|
||||
)
|
||||
|
||||
self.iter_mode = iter_mode
|
||||
# We create a weak reference to self to be used in the lambda functions.
|
||||
@@ -411,9 +411,9 @@ class SAM3Output(list):
|
||||
return SAM3Output._IterationMode(model_output=model_output, iter_mode=iter_mode)
|
||||
|
||||
def append(self, item: list):
|
||||
assert isinstance(
|
||||
item, list
|
||||
), f"Only list items are supported. Got {type(item)}"
|
||||
assert isinstance(item, list), (
|
||||
f"Only list items are supported. Got {type(item)}"
|
||||
)
|
||||
self.output.append(item)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -8,7 +8,6 @@ from copy import deepcopy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
|
||||
@@ -7,15 +7,12 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from PIL.Image import Image
|
||||
|
||||
from sam3.model.sam3_tracker_base import Sam3TrackerBase
|
||||
from sam3.model.utils.sam1_utils import SAM2Transforms
|
||||
|
||||
@@ -97,9 +94,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
input_image = self._transforms(image)
|
||||
input_image = input_image[None, ...].to(self.device)
|
||||
|
||||
assert (
|
||||
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
||||
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||
assert len(input_image.shape) == 4 and input_image.shape[1] == 3, (
|
||||
f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
||||
)
|
||||
logging.info("Computing image embeddings for the provided image...")
|
||||
backbone_out = self.model.forward_image(input_image)
|
||||
(
|
||||
@@ -136,17 +133,17 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
assert isinstance(image_list, list)
|
||||
self._orig_hw = []
|
||||
for image in image_list:
|
||||
assert isinstance(
|
||||
image, np.ndarray
|
||||
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
||||
assert isinstance(image, np.ndarray), (
|
||||
"Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
||||
)
|
||||
self._orig_hw.append(image.shape[:2])
|
||||
# Transform the image to the form expected by the model
|
||||
img_batch = self._transforms.forward_batch(image_list)
|
||||
img_batch = img_batch.to(self.device)
|
||||
batch_size = img_batch.shape[0]
|
||||
assert (
|
||||
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
||||
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
||||
assert len(img_batch.shape) == 4 and img_batch.shape[1] == 3, (
|
||||
f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
||||
)
|
||||
logging.info("Computing image embeddings for the provided images...")
|
||||
backbone_out = self.model.forward_image(img_batch)
|
||||
(
|
||||
@@ -302,9 +299,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
):
|
||||
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), "point_labels must be supplied if point_coords is supplied."
|
||||
assert point_labels is not None, (
|
||||
"point_labels must be supplied if point_coords is supplied."
|
||||
)
|
||||
point_coords = torch.as_tensor(
|
||||
point_coords, dtype=torch.float, device=self.device
|
||||
)
|
||||
@@ -441,9 +438,9 @@ class SAM3InteractiveImagePredictor(nn.Module):
|
||||
raise RuntimeError(
|
||||
"An image must be set with .set_image(...) to generate an embedding."
|
||||
)
|
||||
assert (
|
||||
self._features is not None
|
||||
), "Features must exist if an image has been set."
|
||||
assert self._features is not None, (
|
||||
"Features must exist if an image has been set."
|
||||
)
|
||||
return self._features["image_embed"]
|
||||
|
||||
@property
|
||||
|
||||
@@ -8,19 +8,14 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.model.sam1_task_predictor import SAM3InteractiveImagePredictor
|
||||
from sam3.model.vl_combiner import SAM3VLBackbone
|
||||
from sam3.perflib.nms import nms_masks
|
||||
|
||||
from sam3.train.data.collator import BatchedDatapoint
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
from .box_ops import box_cxcywh_to_xyxy
|
||||
|
||||
from .geometry_encoders import Prompt
|
||||
from .model_misc import inverse_sigmoid
|
||||
|
||||
@@ -661,9 +656,9 @@ class Sam3Image(torch.nn.Module):
|
||||
inference_state["original_heights"],
|
||||
inference_state["original_widths"],
|
||||
)
|
||||
assert (
|
||||
batch_size == len(orig_heights) == len(orig_widths)
|
||||
), f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
|
||||
assert batch_size == len(orig_heights) == len(orig_widths), (
|
||||
f"Batch size mismatch in predict_inst_batch. Got {batch_size}, {len(orig_heights)}, {len(orig_widths)}"
|
||||
)
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
||||
for feat, feat_size in zip(
|
||||
|
||||
@@ -6,9 +6,7 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from sam3.model import box_ops
|
||||
|
||||
from sam3.model.data_misc import FindStage, interpolate
|
||||
from torchvision.transforms import v2
|
||||
|
||||
@@ -83,9 +81,9 @@ class Sam3Processor:
|
||||
if not isinstance(images, list):
|
||||
raise ValueError("Images must be a list of PIL images or tensors")
|
||||
assert len(images) > 0, "Images list must not be empty"
|
||||
assert isinstance(
|
||||
images[0], PIL.Image.Image
|
||||
), "Images must be a list of PIL images"
|
||||
assert isinstance(images[0], PIL.Image.Image), (
|
||||
"Images must be a list of PIL images"
|
||||
)
|
||||
|
||||
state["original_heights"] = [image.height for image in images]
|
||||
state["original_widths"] = [image.width for image in images]
|
||||
|
||||
@@ -6,11 +6,8 @@ import logging
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3.model.memory import SimpleMaskEncoder
|
||||
|
||||
from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames
|
||||
|
||||
from sam3.sam.mask_decoder import MaskDecoder, MLP
|
||||
from sam3.sam.prompt_encoder import PromptEncoder
|
||||
from sam3.sam.transformer import TwoWayTransformer
|
||||
|
||||
@@ -6,7 +6,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from sam3.model.edt import edt_triton
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.sam3_tracker_base import concat_points, NO_OBJ_SCORE, Sam3TrackerBase
|
||||
from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores
|
||||
from sam3.model.utils.sam2_utils import load_video_frames
|
||||
|
||||
@@ -16,7 +16,6 @@ import numpy.typing as npt
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3 import perflib
|
||||
from sam3.logger import get_logger
|
||||
from sam3.model.box_ops import fast_diag_box_iou
|
||||
@@ -620,9 +619,9 @@ class Sam3VideoBase(nn.Module):
|
||||
num_obj_dropped_due_to_limit,
|
||||
trk_id_to_max_iou_high_conf_det,
|
||||
]
|
||||
assert (
|
||||
len(update_plan) == NUM_BROADCAST_ITEMS
|
||||
), f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
|
||||
assert len(update_plan) == NUM_BROADCAST_ITEMS, (
|
||||
f"Manually update NUM_BROADCAST_ITEMS to be: {len(update_plan)}"
|
||||
)
|
||||
self.broadcast_python_obj_cpu(update_plan, src=0)
|
||||
elif self.rank > 0 and self.world_size > 1:
|
||||
update_plan = [
|
||||
@@ -842,9 +841,9 @@ class Sam3VideoBase(nn.Module):
|
||||
binary_tracker_low_res_masks_global = tracker_low_res_masks_global > 0
|
||||
batch_size = tracker_low_res_masks_global.size(0)
|
||||
if batch_size > 0:
|
||||
assert (
|
||||
len(obj_ids_global) == batch_size
|
||||
), f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
|
||||
assert len(obj_ids_global) == batch_size, (
|
||||
f"Mismatch in number of objects: {len(obj_ids_global)} vs {batch_size}"
|
||||
)
|
||||
NEVER_OCCLUDED = -1
|
||||
ALWAYS_OCCLUDED = 100000 # This value should be larger than any possible frame index, indicates that the object was removed by hotstart logic
|
||||
last_occluded_prev = torch.cat(
|
||||
@@ -1023,9 +1022,9 @@ class Sam3VideoBase(nn.Module):
|
||||
reverse: bool = False,
|
||||
):
|
||||
# Suppress overlapping masks for objects that were most recently occluded
|
||||
assert (
|
||||
binary_low_res_masks.dtype == torch.bool
|
||||
), f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
|
||||
assert binary_low_res_masks.dtype == torch.bool, (
|
||||
f"Expected boolean tensor, got {binary_low_res_masks.dtype}"
|
||||
)
|
||||
to_suppress = torch.zeros(
|
||||
binary_low_res_masks.size(0),
|
||||
device=binary_low_res_masks.device,
|
||||
@@ -1130,9 +1129,9 @@ class Sam3VideoBase(nn.Module):
|
||||
num_frames_propagated += 1
|
||||
|
||||
# only 1 frames should be propagated
|
||||
assert (
|
||||
num_frames_propagated == 1 and out_frame_idx == frame_idx
|
||||
), f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
|
||||
assert num_frames_propagated == 1 and out_frame_idx == frame_idx, (
|
||||
f"num_frames_propagated: {num_frames_propagated}, out_frame_idx: {out_frame_idx}, frame_idx: {frame_idx}"
|
||||
)
|
||||
assert isinstance(out_obj_ids, list)
|
||||
obj_ids_local.extend(out_obj_ids)
|
||||
low_res_masks_list.append(out_low_res_masks.squeeze(1))
|
||||
@@ -1189,9 +1188,9 @@ class Sam3VideoBase(nn.Module):
|
||||
|
||||
assert det_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
||||
assert trk_masks.is_floating_point(), "float tensor expected (do not binarize)"
|
||||
assert (
|
||||
trk_masks.size(0) == len(trk_obj_ids)
|
||||
), f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
|
||||
assert trk_masks.size(0) == len(trk_obj_ids), (
|
||||
f"trk_masks and trk_obj_ids should have the same length, {trk_masks.size(0)} vs {len(trk_obj_ids)}"
|
||||
)
|
||||
if trk_masks.size(0) == 0:
|
||||
# all detections are new
|
||||
new_det_fa_inds = np.arange(det_masks.size(0))
|
||||
@@ -1655,9 +1654,9 @@ class Sam3VideoBase(nn.Module):
|
||||
# a) first, expand "confirmation_data" to include new masklets added in this frame
|
||||
status_prev = confirmation_data["status"]
|
||||
consecutive_det_num_prev = confirmation_data["consecutive_det_num"]
|
||||
assert (
|
||||
status_prev.shape == obj_ids_all_gpu_prev.shape
|
||||
), f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
|
||||
assert status_prev.shape == obj_ids_all_gpu_prev.shape, (
|
||||
f"Got {status_prev.shape} vs {obj_ids_all_gpu_prev.shape}"
|
||||
)
|
||||
|
||||
obj_id_to_updated_idx = {
|
||||
obj_id: idx for idx, obj_id in enumerate(obj_ids_all_gpu_updated)
|
||||
|
||||
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3 import perflib
|
||||
from sam3.logger import get_logger
|
||||
from sam3.model.act_ckpt_utils import clone_output_wrapper
|
||||
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
assert (
|
||||
"cached_frame_outputs" in inference_state
|
||||
and frame_idx in inference_state["cached_frame_outputs"]
|
||||
), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||
), (
|
||||
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||
)
|
||||
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
|
||||
|
||||
obj_id_to_mask = cached_outputs.copy()
|
||||
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
# Update with refined masks if provided
|
||||
if refined_obj_id_to_mask is not None:
|
||||
for obj_id, refined_mask in refined_obj_id_to_mask.items():
|
||||
assert (
|
||||
refined_mask is not None
|
||||
), f"Refined mask data must be provided for obj_id {obj_id}"
|
||||
assert refined_mask is not None, (
|
||||
f"Refined mask data must be provided for obj_id {obj_id}"
|
||||
)
|
||||
obj_id_to_mask[obj_id] = refined_mask
|
||||
|
||||
return obj_id_to_mask
|
||||
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
for i, thresh in enumerate(new_det_score_thresh_list):
|
||||
self.new_det_thresh = thresh
|
||||
for num_objects in num_objects_list:
|
||||
logger.info(f"{i+1}/{num_rounds} warming up model compilation")
|
||||
logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
|
||||
self.add_prompt(
|
||||
inference_state, frame_idx=start_frame_idx, text_str="cat"
|
||||
)
|
||||
logger.info(
|
||||
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||
f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||
)
|
||||
inference_state = self.add_fake_objects_to_inference_state(
|
||||
inference_state, num_objects, frame_idx=start_frame_idx
|
||||
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
pass
|
||||
self.reset_state(inference_state)
|
||||
logger.info(
|
||||
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
|
||||
f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
|
||||
)
|
||||
|
||||
# Warm up Tracker memory encoder with varying input shapes
|
||||
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
logger.debug("Running add_prompt on frame %d", frame_idx)
|
||||
|
||||
num_frames = inference_state["num_frames"]
|
||||
assert (
|
||||
text_str is not None or boxes_xywh is not None
|
||||
), "at least one type of prompt (text, boxes) must be provided"
|
||||
assert (
|
||||
0 <= frame_idx < num_frames
|
||||
), f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||
assert text_str is not None or boxes_xywh is not None, (
|
||||
"at least one type of prompt (text, boxes) must be provided"
|
||||
)
|
||||
assert 0 <= frame_idx < num_frames, (
|
||||
f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||
)
|
||||
|
||||
# since it's a semantic prompt, we start over
|
||||
self.reset_state(inference_state)
|
||||
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
"propagation_partial",
|
||||
"propagation_fetch",
|
||||
]
|
||||
assert (
|
||||
action_type in instance_actions + propagation_actions
|
||||
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||
assert action_type in instance_actions + propagation_actions, (
|
||||
f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||
)
|
||||
action = {
|
||||
"type": action_type,
|
||||
"frame_idx": frame_idx,
|
||||
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
):
|
||||
if points is not None:
|
||||
# Tracker instance prompts
|
||||
assert (
|
||||
text_str is None and boxes_xywh is None
|
||||
), "When points are provided, text_str and boxes_xywh must be None."
|
||||
assert (
|
||||
obj_id is not None
|
||||
), "When points are provided, obj_id must be provided."
|
||||
assert text_str is None and boxes_xywh is None, (
|
||||
"When points are provided, text_str and boxes_xywh must be None."
|
||||
)
|
||||
assert obj_id is not None, (
|
||||
"When points are provided, obj_id must be provided."
|
||||
)
|
||||
return self.add_tracker_new_points(
|
||||
inference_state,
|
||||
frame_idx,
|
||||
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
tracker_states = self._get_tracker_inference_states_by_obj_ids(
|
||||
inference_state, [obj_id]
|
||||
)
|
||||
assert (
|
||||
len(tracker_states) == 1
|
||||
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||
assert len(tracker_states) == 1, (
|
||||
f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||
)
|
||||
tracker_state = tracker_states[0]
|
||||
|
||||
# log
|
||||
|
||||
@@ -16,7 +16,6 @@ from typing import List, Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from sam3.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -170,7 +169,7 @@ class Sam3VideoPredictor:
|
||||
):
|
||||
"""Remove an object from tracking."""
|
||||
logger.debug(
|
||||
f"remove object {obj_id} in session {session_id}: " f"{is_user_action=}"
|
||||
f"remove object {obj_id} in session {session_id}: {is_user_action=}"
|
||||
)
|
||||
session = self._get_session(session_id)
|
||||
inference_state = session["state"]
|
||||
|
||||
@@ -318,9 +318,9 @@ class VETextEncoder(nn.Module):
|
||||
# The text is already encoded, use as is.
|
||||
text_attention_mask, text_memory_resized, tokenized = text
|
||||
inputs_embeds = tokenized["inputs_embeds"]
|
||||
assert (
|
||||
input_boxes is None or len(input_boxes) == 0
|
||||
), "Can't replace boxes in text if it's already encoded"
|
||||
assert input_boxes is None or len(input_boxes) == 0, (
|
||||
"Can't replace boxes in text if it's already encoded"
|
||||
)
|
||||
|
||||
# Note that the input_embeds are returned in pytorch's convention (sequence first)
|
||||
return (
|
||||
|
||||
@@ -708,9 +708,9 @@ class ViT(nn.Module):
|
||||
self.retain_cls_token = retain_cls_token
|
||||
if self.retain_cls_token:
|
||||
assert pretrain_use_cls_token
|
||||
assert (
|
||||
len(window_block_indexes) == 0
|
||||
), "windowing not supported with cls token"
|
||||
assert len(window_block_indexes) == 0, (
|
||||
"windowing not supported with cls token"
|
||||
)
|
||||
|
||||
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
from .act_ckpt_utils import activation_ckpt_wrapper
|
||||
|
||||
@@ -6,7 +6,6 @@ import os
|
||||
from typing import Optional
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
@@ -36,9 +36,9 @@ def connected_components_cpu(input_tensor: torch.Tensor):
|
||||
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
||||
input_tensor = input_tensor.squeeze(1)
|
||||
else:
|
||||
assert (
|
||||
input_tensor.dim() == 3
|
||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
assert input_tensor.dim() == 3, (
|
||||
"Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
)
|
||||
|
||||
batch_size = input_tensor.shape[0]
|
||||
labels_list = []
|
||||
@@ -67,9 +67,9 @@ def connected_components(input_tensor: torch.Tensor):
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.unsqueeze(1)
|
||||
|
||||
assert (
|
||||
input_tensor.dim() == 4 and input_tensor.shape[1] == 1
|
||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
assert input_tensor.dim() == 4 and input_tensor.shape[1] == 1, (
|
||||
"Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
)
|
||||
|
||||
if input_tensor.is_cuda:
|
||||
if HAS_CC_TORCH:
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.perflib.masks_ops import mask_iou
|
||||
|
||||
|
||||
|
||||
@@ -407,16 +407,16 @@ def connected_components_triton(input_tensor: torch.Tensor):
|
||||
- A BxHxW output tensor with dense labels. Background is 0.
|
||||
- A BxHxW tensor with the size of the connected component for each pixel.
|
||||
"""
|
||||
assert (
|
||||
input_tensor.is_cuda and input_tensor.is_contiguous()
|
||||
), "Input tensor must be a contiguous CUDA tensor."
|
||||
assert input_tensor.is_cuda and input_tensor.is_contiguous(), (
|
||||
"Input tensor must be a contiguous CUDA tensor."
|
||||
)
|
||||
out_shape = input_tensor.shape
|
||||
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
|
||||
input_tensor = input_tensor.squeeze(1)
|
||||
else:
|
||||
assert (
|
||||
input_tensor.dim() == 3
|
||||
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
assert input_tensor.dim() == 3, (
|
||||
"Input tensor must be (B, H, W) or (B, 1, H, W)."
|
||||
)
|
||||
|
||||
B, H, W = input_tensor.shape
|
||||
numel = B * H * W
|
||||
|
||||
@@ -202,9 +202,9 @@ class MaskDecoder(nn.Module):
|
||||
assert image_embeddings.shape[0] == tokens.shape[0]
|
||||
src = image_embeddings
|
||||
src = src + dense_prompt_embeddings
|
||||
assert (
|
||||
image_pe.size(0) == 1
|
||||
), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||
assert image_pe.size(0) == 1, (
|
||||
"image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||
)
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3.sam.rope import apply_rotary_enc, apply_rotary_enc_real, compute_axial_cis
|
||||
from torch import nn, Tensor
|
||||
|
||||
@@ -205,9 +204,9 @@ class Attention(nn.Module):
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
self.use_fa3 = use_fa3
|
||||
assert (
|
||||
self.internal_dim % num_heads == 0
|
||||
), "num_heads must divide embedding_dim."
|
||||
assert self.internal_dim % num_heads == 0, (
|
||||
"num_heads must divide embedding_dim."
|
||||
)
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
|
||||
@@ -142,9 +142,9 @@ class COCO_FROM_JSON:
|
||||
self.prompts = {}
|
||||
for loc_dict in prompts:
|
||||
self.prompts[int(loc_dict["id"])] = loc_dict["name"]
|
||||
assert len(self.prompts) == len(
|
||||
self._sorted_cat_ids
|
||||
), "Number of prompts must match number of categories"
|
||||
assert len(self.prompts) == len(self._sorted_cat_ids), (
|
||||
"Number of prompts must match number of categories"
|
||||
)
|
||||
|
||||
def getDatapointIds(self):
|
||||
"""Return all datapoint indices for training."""
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_data
|
||||
from typing import Any, get_args, get_origin, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.data_misc import (
|
||||
BatchedDatapoint,
|
||||
BatchedFindTarget,
|
||||
@@ -217,9 +216,9 @@ def collate_fn_api(
|
||||
text_batch.append(q.query_text)
|
||||
stages[stage_id].text_ids.append(text_batch.index(q.query_text))
|
||||
|
||||
assert (
|
||||
q.inference_metadata is not None
|
||||
), "inference_metadata must be provided when FindQueryLoaded is created."
|
||||
assert q.inference_metadata is not None, (
|
||||
"inference_metadata must be provided when FindQueryLoaded is created."
|
||||
)
|
||||
for f in fields(q.inference_metadata):
|
||||
getattr(find_metadatas[stage_id], f.name).append(
|
||||
getattr(q.inference_metadata, f.name)
|
||||
|
||||
@@ -19,10 +19,8 @@ import torch.utils.data
|
||||
import torchvision
|
||||
from decord import cpu, VideoReader
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from PIL.Image import DecompressionBombError
|
||||
|
||||
from sam3.model.box_ops import box_xywh_to_xyxy
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
|
||||
@@ -234,9 +232,9 @@ class CustomCocoDetectionAPI(VisionDataset):
|
||||
if self.coco is not None:
|
||||
return
|
||||
|
||||
assert g_pathmgr.isfile(
|
||||
self.annFile
|
||||
), f"please provide valid annotation file. Missing: {self.annFile}"
|
||||
assert g_pathmgr.isfile(self.annFile), (
|
||||
f"please provide valid annotation file. Missing: {self.annFile}"
|
||||
)
|
||||
annFile = g_pathmgr.get_local_path(self.annFile)
|
||||
|
||||
if self.coco is not None:
|
||||
@@ -326,9 +324,9 @@ class CustomCocoDetectionAPI(VisionDataset):
|
||||
else:
|
||||
num_queries_per_stage = stage2num_queries.most_common(1)[0][1]
|
||||
for stage, num_queries in stage2num_queries.items():
|
||||
assert (
|
||||
num_queries == num_queries_per_stage
|
||||
), f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||
assert num_queries == num_queries_per_stage, (
|
||||
f"Number of queries in stage {stage} is {num_queries}, expected {num_queries_per_stage}"
|
||||
)
|
||||
|
||||
for query_id, query in enumerate(queries):
|
||||
h, w = id2imsize[query["image_id"]]
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# pyre-unsafe
|
||||
|
||||
import copy
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
@@ -16,7 +15,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
# from decord import cpu, VideoReader
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
@@ -220,9 +218,9 @@ class VideoGroundingDataset(Sam3ImageDataset):
|
||||
for query in filtered_queries:
|
||||
ptr_x_is_empty = query["ptr_x_query_id"] in [None, -1]
|
||||
ptr_y_is_empty = query["ptr_y_query_id"] in [None, -1]
|
||||
assert (
|
||||
ptr_x_is_empty and ptr_y_is_empty
|
||||
), "Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||
assert ptr_x_is_empty and ptr_y_is_empty, (
|
||||
"Remapping stage ids is not supported for queries with non-empty ptr_x or ptr_y pointers"
|
||||
)
|
||||
query["query_processing_order"] = stage_id_old2new[
|
||||
query["query_processing_order"]
|
||||
]
|
||||
|
||||
@@ -9,11 +9,8 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
|
||||
from sam3.model import box_ops
|
||||
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
from sam3.train.loss.sigmoid_focal_loss import (
|
||||
triton_sigmoid_focal_loss,
|
||||
triton_sigmoid_focal_loss_reduce,
|
||||
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
if num_det_queries is not None:
|
||||
logging.warning("note: it's not needed to set num_det_queries anymore")
|
||||
if self.use_separate_loss_for_det_and_trk:
|
||||
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
assert not self.weak_loss, (
|
||||
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
)
|
||||
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
|
||||
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
|
||||
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
|
||||
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
and det_non_exhaustive_loss_scale_neg == 1.0
|
||||
and trk_loss_scale_pos == 1.0
|
||||
and trk_loss_scale_neg == 1.0
|
||||
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||
), (
|
||||
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||
)
|
||||
|
||||
def get_loss(self, outputs, targets, indices, num_boxes):
|
||||
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
|
||||
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
pass
|
||||
|
||||
if self.weak_loss:
|
||||
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
assert not self.use_separate_loss_for_det_and_trk, (
|
||||
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
)
|
||||
|
||||
# nullify the negative loss for the non-exhaustive classes
|
||||
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
|
||||
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
loss_bce = loss_bce.mean()
|
||||
else:
|
||||
assert isinstance(self.pad_n_queries, int)
|
||||
assert (
|
||||
loss_bce.size(1) < self.pad_n_queries
|
||||
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||
assert loss_bce.size(1) < self.pad_n_queries, (
|
||||
f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||
)
|
||||
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
|
||||
|
||||
bce_f1 = torchmetrics.functional.f1_score(
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
# pyre-unsafe
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.train.utils.distributed import get_world_size
|
||||
|
||||
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
||||
|
||||
@@ -103,9 +103,9 @@ def dilation(mask, kernel_size):
|
||||
|
||||
assert mask.ndim == 3
|
||||
kernel_size = int(kernel_size)
|
||||
assert (
|
||||
kernel_size % 2 == 1
|
||||
), f"Dilation expects a odd kernel size, got {kernel_size}"
|
||||
assert kernel_size % 2 == 1, (
|
||||
f"Dilation expects a odd kernel size, got {kernel_size}"
|
||||
)
|
||||
|
||||
if mask.is_cuda:
|
||||
m = mask.unsqueeze(1).to(torch.float16)
|
||||
|
||||
@@ -8,7 +8,6 @@ Modules to compute the matching cost and solve the corresponding LSAP.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sam3.model.box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from torch import nn
|
||||
@@ -60,9 +59,9 @@ class HungarianMatcher(nn.Module):
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid() if focal_loss else nn.Softmax(-1)
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
self.focal_loss = focal_loss
|
||||
self.focal_alpha = focal_alpha
|
||||
self.focal_gamma = focal_gamma
|
||||
@@ -197,9 +196,9 @@ class BinaryHungarianMatcher(nn.Module):
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets, repeats=0, repeat_batch=1):
|
||||
@@ -322,9 +321,9 @@ class BinaryFocalHungarianMatcher(nn.Module):
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.stable = stable
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, batched_targets, repeats=1, repeat_batch=1):
|
||||
@@ -470,9 +469,9 @@ class BinaryHungarianMatcherV2(nn.Module):
|
||||
self.cost_bbox = cost_bbox
|
||||
self.cost_giou = cost_giou
|
||||
self.norm = nn.Sigmoid()
|
||||
assert (
|
||||
cost_class != 0 or cost_bbox != 0 or cost_giou != 0
|
||||
), "all costs cant be 0"
|
||||
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
||||
"all costs cant be 0"
|
||||
)
|
||||
self.focal = focal
|
||||
if focal:
|
||||
self.alpha = alpha
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import hydra
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig
|
||||
@@ -212,9 +211,9 @@ def unix_module_cls_pattern_to_parameter_names(
|
||||
"match any classes in the model"
|
||||
)
|
||||
matching_parameters = module_cls_to_param_names[module_cls]
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
assert len(matching_parameters) > 0, (
|
||||
f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
)
|
||||
logging.info(
|
||||
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
||||
)
|
||||
@@ -240,9 +239,9 @@ def unix_param_pattern_to_parameter_names(
|
||||
allowed_parameter_names = []
|
||||
for param_name in filter_param_names:
|
||||
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) >= 1
|
||||
), f"param_name {param_name} does not match any parameters in the model"
|
||||
assert len(matching_parameters) >= 1, (
|
||||
f"param_name {param_name} does not match any parameters in the model"
|
||||
)
|
||||
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
@@ -12,13 +12,10 @@ from copy import deepcopy
|
||||
|
||||
import submitit
|
||||
import torch
|
||||
|
||||
from hydra import compose, initialize_config_module
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from sam3.train.utils.train_utils import makedir, register_omegaconf_resolvers
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -212,9 +209,9 @@ def main(args) -> None:
|
||||
},
|
||||
}
|
||||
if "include_nodes" in submitit_conf:
|
||||
assert (
|
||||
len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes
|
||||
), "Not enough nodes"
|
||||
assert len(submitit_conf["include_nodes"]) >= cfg.launcher.num_nodes, (
|
||||
"Not enough nodes"
|
||||
)
|
||||
job_kwargs["slurm_additional_parameters"]["nodelist"] = " ".join(
|
||||
submitit_conf["include_nodes"]
|
||||
)
|
||||
|
||||
@@ -15,28 +15,22 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from hydra.utils import instantiate
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
|
||||
from sam3.model.data_misc import BatchedDatapoint
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
from sam3.model.utils.misc import copy_data_to_device
|
||||
|
||||
from sam3.train.optim.optimizer import construct_optimizer
|
||||
|
||||
from sam3.train.utils.checkpoint_utils import (
|
||||
assert_skipped_parameters_are_frozen,
|
||||
exclude_params_matching_unix_pattern,
|
||||
load_state_dict_into_model,
|
||||
with_check_parameter_frozen,
|
||||
)
|
||||
|
||||
from sam3.train.utils.distributed import all_reduce_max, barrier, get_rank
|
||||
|
||||
from sam3.train.utils.logger import Logger, setup_logging
|
||||
from sam3.train.utils.train_utils import (
|
||||
AverageMeter,
|
||||
@@ -215,9 +209,9 @@ class Trainer:
|
||||
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
|
||||
log_env_variables()
|
||||
|
||||
assert (
|
||||
is_dist_avail_and_initialized()
|
||||
), "Torch distributed needs to be initialized before calling the trainer."
|
||||
assert is_dist_avail_and_initialized(), (
|
||||
"Torch distributed needs to be initialized before calling the trainer."
|
||||
)
|
||||
|
||||
self._setup_components() # Except Optimizer everything is setup here.
|
||||
self._move_to_device()
|
||||
@@ -227,9 +221,9 @@ class Trainer:
|
||||
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
|
||||
|
||||
if self.checkpoint_conf.resume_from is not None:
|
||||
assert os.path.exists(
|
||||
self.checkpoint_conf.resume_from
|
||||
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||
assert os.path.exists(self.checkpoint_conf.resume_from), (
|
||||
f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
|
||||
)
|
||||
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
|
||||
if self.distributed_rank == 0 and not os.path.exists(dst):
|
||||
# Copy the "resume_from" checkpoint to the checkpoint folder
|
||||
@@ -477,9 +471,9 @@ class Trainer:
|
||||
return self.loss[key]
|
||||
|
||||
assert key != "all", "Loss must be specified for key='all'"
|
||||
assert (
|
||||
"default" in self.loss
|
||||
), f"Key {key} not found in losss, and no default provided"
|
||||
assert "default" in self.loss, (
|
||||
f"Key {key} not found in losss, and no default provided"
|
||||
)
|
||||
return self.loss["default"]
|
||||
|
||||
def _find_meter(self, phase: str, key: str):
|
||||
@@ -922,12 +916,12 @@ class Trainer:
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
if self.gradient_accumulation_steps > 1:
|
||||
assert isinstance(
|
||||
batch, list
|
||||
), f"Expected a list of batches, got {type(batch)}"
|
||||
assert (
|
||||
len(batch) == self.gradient_accumulation_steps
|
||||
), f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||
assert isinstance(batch, list), (
|
||||
f"Expected a list of batches, got {type(batch)}"
|
||||
)
|
||||
assert len(batch) == self.gradient_accumulation_steps, (
|
||||
f"Expected {self.gradient_accumulation_steps} batches, got {len(batch)}"
|
||||
)
|
||||
accum_steps = len(batch)
|
||||
else:
|
||||
accum_steps = 1
|
||||
@@ -1039,9 +1033,9 @@ class Trainer:
|
||||
def _check_val_key_match(self, val_keys, phase):
|
||||
if val_keys is not None:
|
||||
# Check if there are any duplicates
|
||||
assert len(val_keys) == len(
|
||||
set(val_keys)
|
||||
), f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||
assert len(val_keys) == len(set(val_keys)), (
|
||||
f"Duplicate keys in val datasets, keys: {val_keys}"
|
||||
)
|
||||
|
||||
# Check that the keys match the meter keys
|
||||
if self.meters_conf is not None and phase in self.meters_conf:
|
||||
@@ -1055,9 +1049,9 @@ class Trainer:
|
||||
loss_keys = set(self.loss_conf.keys()) - set(["all"])
|
||||
if "default" not in loss_keys:
|
||||
for k in val_keys:
|
||||
assert (
|
||||
k in loss_keys
|
||||
), f"Error: key {k} is not defined in the losses, and no default is set"
|
||||
assert k in loss_keys, (
|
||||
f"Error: key {k} is not defined in the losses, and no default is set"
|
||||
)
|
||||
|
||||
def _setup_components(self):
|
||||
# Get the keys for all the val datasets, if any
|
||||
|
||||
@@ -14,7 +14,6 @@ import PIL
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_cxcywh
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
@@ -277,9 +276,9 @@ class RandomSizeCrop:
|
||||
max(0, minY - h + 1), max(maxY - 1, max(0, minY - h + 1))
|
||||
)
|
||||
result_img, result_target = crop(img, target, [j, i, h, w])
|
||||
assert (
|
||||
len(result_target["boxes"]) == init_boxes
|
||||
), f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||
assert len(result_target["boxes"]) == init_boxes, (
|
||||
f"img_w={img.width}\timg_h={img.height}\tminX={minX}\tminY={minY}\tmaxX={maxX}\tmaxY={maxY}\tminW={minW}\tminH={minH}\tmaxW={maxW}\tmaxH={maxH}\tw={w}\th={h}\ti={i}\tj={j}\tinit_boxes={init_boxes_tensor}\tresults={result_target['boxes']}"
|
||||
)
|
||||
|
||||
return result_img, result_target
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,6 @@ Transforms and data augmentation for both image + bbox.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import numbers
|
||||
import random
|
||||
from collections.abc import Sequence
|
||||
@@ -17,9 +16,7 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms.v2.functional as Fv2
|
||||
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_cxcywh, masks_to_boxes
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint, FindQuery, Object
|
||||
|
||||
|
||||
@@ -381,9 +379,9 @@ class FlexibleFilterFindGetQueries:
|
||||
if len(new_find_queries) == 0:
|
||||
start_with_zero_check = True
|
||||
|
||||
assert (
|
||||
start_with_zero_check
|
||||
), "Invalid Find queries, they need to start at query_processing_order = 0"
|
||||
assert start_with_zero_check, (
|
||||
"Invalid Find queries, they need to start at query_processing_order = 0"
|
||||
)
|
||||
|
||||
datapoint.find_queries = new_find_queries
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
from pycocotools import mask as mask_util
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
from torchvision.ops import masks_to_boxes
|
||||
|
||||
@@ -250,9 +249,9 @@ class RandomGeometricInputsAPI:
|
||||
def _get_target_object(self, datapoint, query):
|
||||
img = datapoint.images[query.image_id]
|
||||
targets = query.object_ids_output
|
||||
assert (
|
||||
len(targets) == 1
|
||||
), "Geometric queries only support a single target object."
|
||||
assert len(targets) == 1, (
|
||||
"Geometric queries only support a single target object."
|
||||
)
|
||||
target_idx = targets[0]
|
||||
return img.objects[target_idx]
|
||||
|
||||
|
||||
@@ -5,12 +5,9 @@
|
||||
import numpy as np
|
||||
import pycocotools.mask as mask_utils
|
||||
import torch
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from sam3.model.box_ops import masks_to_boxes
|
||||
|
||||
from sam3.train.data.sam3_image_dataset import Datapoint
|
||||
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ def unix_pattern_to_parameter_names(
|
||||
parameter_names = []
|
||||
for param_name in constraints:
|
||||
matching_parameters = set(fnmatch.filter(all_parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"param_names {param_name} don't match any param in the given names."
|
||||
assert len(matching_parameters) > 0, (
|
||||
f"param_names {param_name} don't match any param in the given names."
|
||||
)
|
||||
parameter_names.append(matching_parameters)
|
||||
return set.union(*parameter_names)
|
||||
|
||||
|
||||
@@ -10,10 +10,8 @@ import uuid
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from hydra.utils import instantiate
|
||||
|
||||
from iopath.common.file_io import g_pathmgr
|
||||
from numpy import ndarray
|
||||
|
||||
from sam3.train.utils.train_utils import get_machine_local_and_dist_rank, makedir
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@@ -11,7 +11,6 @@ from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import hydra
|
||||
|
||||
import numpy as np
|
||||
import omegaconf
|
||||
import torch
|
||||
@@ -83,9 +82,9 @@ def get_machine_local_and_dist_rank():
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", None))
|
||||
distributed_rank = int(os.environ.get("RANK", None))
|
||||
assert (
|
||||
local_rank is not None and distributed_rank is not None
|
||||
), "Please the set the RANK and LOCAL_RANK environment variables."
|
||||
assert local_rank is not None and distributed_rank is not None, (
|
||||
"Please the set the RANK and LOCAL_RANK environment variables."
|
||||
)
|
||||
return local_rank, distributed_rank
|
||||
|
||||
|
||||
|
||||
@@ -158,22 +158,22 @@ def plot_mask(mask, color="r", ax=None):
|
||||
def normalize_bbox(bbox_xywh, img_w, img_h):
|
||||
# Assumes bbox_xywh is in XYWH format
|
||||
if isinstance(bbox_xywh, list):
|
||||
assert (
|
||||
len(bbox_xywh) == 4
|
||||
), "bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
|
||||
assert len(bbox_xywh) == 4, (
|
||||
"bbox_xywh list must have 4 elements. Batching not support except for torch tensors."
|
||||
)
|
||||
normalized_bbox = bbox_xywh.copy()
|
||||
normalized_bbox[0] /= img_w
|
||||
normalized_bbox[1] /= img_h
|
||||
normalized_bbox[2] /= img_w
|
||||
normalized_bbox[3] /= img_h
|
||||
else:
|
||||
assert isinstance(
|
||||
bbox_xywh, torch.Tensor
|
||||
), "Only torch tensors are supported for batching."
|
||||
assert isinstance(bbox_xywh, torch.Tensor), (
|
||||
"Only torch tensors are supported for batching."
|
||||
)
|
||||
normalized_bbox = bbox_xywh.clone()
|
||||
assert (
|
||||
normalized_bbox.size(-1) == 4
|
||||
), "bbox_xywh tensor must have last dimension of size 4."
|
||||
assert normalized_bbox.size(-1) == 4, (
|
||||
"bbox_xywh tensor must have last dimension of size 4."
|
||||
)
|
||||
normalized_bbox[..., 0] /= img_w
|
||||
normalized_bbox[..., 1] /= img_h
|
||||
normalized_bbox[..., 2] /= img_w
|
||||
@@ -244,10 +244,10 @@ def visualize_formatted_frame_output(
|
||||
|
||||
num_outputs = len(outputs_list)
|
||||
if titles is None:
|
||||
titles = [f"Set {i+1}" for i in range(num_outputs)]
|
||||
assert (
|
||||
len(titles) == num_outputs
|
||||
), "length of `titles` should match that of `outputs_list` if not None."
|
||||
titles = [f"Set {i + 1}" for i in range(num_outputs)]
|
||||
assert len(titles) == num_outputs, (
|
||||
"length of `titles` should match that of `outputs_list` if not None."
|
||||
)
|
||||
|
||||
_, axes = plt.subplots(1, num_outputs, figsize=figsize)
|
||||
if num_outputs == 1:
|
||||
@@ -703,9 +703,9 @@ def get_all_annotations_for_frame(
|
||||
|
||||
# Get the frame
|
||||
video_df_current = video_df[video_df.id == video_id]
|
||||
assert (
|
||||
len(video_df_current) == 1
|
||||
), f"Expected 1 video row, got {len(video_df_current)}"
|
||||
assert len(video_df_current) == 1, (
|
||||
f"Expected 1 video row, got {len(video_df_current)}"
|
||||
)
|
||||
video_row = video_df_current.iloc[0]
|
||||
file_name = video_row.file_names[frame_idx]
|
||||
file_path = os.path.join(
|
||||
@@ -796,7 +796,7 @@ def visualize_prompt_overlay(
|
||||
ax.text(
|
||||
x_img + 5,
|
||||
y_img - 5,
|
||||
f"P{i+1}",
|
||||
f"P{i + 1}",
|
||||
color=color,
|
||||
fontsize=10,
|
||||
weight="bold",
|
||||
@@ -828,7 +828,7 @@ def visualize_prompt_overlay(
|
||||
ax.text(
|
||||
x_img,
|
||||
y_img - 5,
|
||||
f"B{i+1}",
|
||||
f"B{i + 1}",
|
||||
color=color,
|
||||
fontsize=10,
|
||||
weight="bold",
|
||||
|
||||
Reference in New Issue
Block a user