apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam3 import perflib
|
||||
from sam3.logger import get_logger
|
||||
from sam3.model.act_ckpt_utils import clone_output_wrapper
|
||||
@@ -555,7 +554,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
assert (
|
||||
"cached_frame_outputs" in inference_state
|
||||
and frame_idx in inference_state["cached_frame_outputs"]
|
||||
), "No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||
), (
|
||||
"No cached outputs found. Ensure normal propagation has run first to populate the cache."
|
||||
)
|
||||
cached_outputs = inference_state["cached_frame_outputs"][frame_idx]
|
||||
|
||||
obj_id_to_mask = cached_outputs.copy()
|
||||
@@ -563,9 +564,9 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
# Update with refined masks if provided
|
||||
if refined_obj_id_to_mask is not None:
|
||||
for obj_id, refined_mask in refined_obj_id_to_mask.items():
|
||||
assert (
|
||||
refined_mask is not None
|
||||
), f"Refined mask data must be provided for obj_id {obj_id}"
|
||||
assert refined_mask is not None, (
|
||||
f"Refined mask data must be provided for obj_id {obj_id}"
|
||||
)
|
||||
obj_id_to_mask[obj_id] = refined_mask
|
||||
|
||||
return obj_id_to_mask
|
||||
@@ -660,12 +661,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
for i, thresh in enumerate(new_det_score_thresh_list):
|
||||
self.new_det_thresh = thresh
|
||||
for num_objects in num_objects_list:
|
||||
logger.info(f"{i+1}/{num_rounds} warming up model compilation")
|
||||
logger.info(f"{i + 1}/{num_rounds} warming up model compilation")
|
||||
self.add_prompt(
|
||||
inference_state, frame_idx=start_frame_idx, text_str="cat"
|
||||
)
|
||||
logger.info(
|
||||
f"{i+1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||
f"{i + 1}/{num_rounds} warming up model compilation -- simulating {num_objects}/{self.num_obj_for_compile} objects"
|
||||
)
|
||||
inference_state = self.add_fake_objects_to_inference_state(
|
||||
inference_state, num_objects, frame_idx=start_frame_idx
|
||||
@@ -690,7 +691,7 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
pass
|
||||
self.reset_state(inference_state)
|
||||
logger.info(
|
||||
f"{i+1}/{num_rounds} warming up model compilation -- completed round {i+1} out of {num_rounds}"
|
||||
f"{i + 1}/{num_rounds} warming up model compilation -- completed round {i + 1} out of {num_rounds}"
|
||||
)
|
||||
|
||||
# Warm up Tracker memory encoder with varying input shapes
|
||||
@@ -854,12 +855,12 @@ class Sam3VideoInference(Sam3VideoBase):
|
||||
logger.debug("Running add_prompt on frame %d", frame_idx)
|
||||
|
||||
num_frames = inference_state["num_frames"]
|
||||
assert (
|
||||
text_str is not None or boxes_xywh is not None
|
||||
), "at least one type of prompt (text, boxes) must be provided"
|
||||
assert (
|
||||
0 <= frame_idx < num_frames
|
||||
), f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||
assert text_str is not None or boxes_xywh is not None, (
|
||||
"at least one type of prompt (text, boxes) must be provided"
|
||||
)
|
||||
assert 0 <= frame_idx < num_frames, (
|
||||
f"{frame_idx=} is out of range for a total of {num_frames} frames"
|
||||
)
|
||||
|
||||
# since it's a semantic prompt, we start over
|
||||
self.reset_state(inference_state)
|
||||
@@ -1200,9 +1201,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
"propagation_partial",
|
||||
"propagation_fetch",
|
||||
]
|
||||
assert (
|
||||
action_type in instance_actions + propagation_actions
|
||||
), f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||
assert action_type in instance_actions + propagation_actions, (
|
||||
f"Invalid action type: {action_type}, must be one of {instance_actions + propagation_actions}"
|
||||
)
|
||||
action = {
|
||||
"type": action_type,
|
||||
"frame_idx": frame_idx,
|
||||
@@ -1370,12 +1371,12 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
):
|
||||
if points is not None:
|
||||
# Tracker instance prompts
|
||||
assert (
|
||||
text_str is None and boxes_xywh is None
|
||||
), "When points are provided, text_str and boxes_xywh must be None."
|
||||
assert (
|
||||
obj_id is not None
|
||||
), "When points are provided, obj_id must be provided."
|
||||
assert text_str is None and boxes_xywh is None, (
|
||||
"When points are provided, text_str and boxes_xywh must be None."
|
||||
)
|
||||
assert obj_id is not None, (
|
||||
"When points are provided, obj_id must be provided."
|
||||
)
|
||||
return self.add_tracker_new_points(
|
||||
inference_state,
|
||||
frame_idx,
|
||||
@@ -1491,9 +1492,9 @@ class Sam3VideoInferenceWithInstanceInteractivity(Sam3VideoInference):
|
||||
tracker_states = self._get_tracker_inference_states_by_obj_ids(
|
||||
inference_state, [obj_id]
|
||||
)
|
||||
assert (
|
||||
len(tracker_states) == 1
|
||||
), f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||
assert len(tracker_states) == 1, (
|
||||
f"[rank={self.rank}] Multiple Tracker inference states found for the same object id."
|
||||
)
|
||||
tracker_state = tracker_states[0]
|
||||
|
||||
# log
|
||||
|
||||
Reference in New Issue
Block a user