Initial commit

fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
facebook-github-bot
2025-11-18 23:07:42 -08:00
commit a13e358df4
504 changed files with 122758 additions and 0 deletions

1
sam3/agent/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

563
sam3/agent/agent_core.py Normal file
View File

@@ -0,0 +1,563 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import copy
import json
import os
import cv2
from PIL import Image
from .client_llm import send_generate_request
from .client_sam3 import call_sam_service
from .viz import visualize
def save_debug_messages(messages_list, debug, debug_folder_path, debug_jsonl_path):
"""Save messages to debug jsonl file if debug is enabled"""
if debug and debug_jsonl_path:
# Ensure the debug directory exists before writing
os.makedirs(debug_folder_path, exist_ok=True)
with open(debug_jsonl_path, "w") as f:
for msg in messages_list:
f.write(json.dumps(msg, indent=4) + "\n")
def cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path):
"""Clean up debug files when function successfully returns"""
if debug and debug_folder_path:
try:
if os.path.exists(debug_jsonl_path):
os.remove(debug_jsonl_path)
if os.path.exists(debug_folder_path):
os.rmdir(debug_folder_path)
except Exception as e:
print(f"Warning: Could not clean up debug files: {e}")
def count_images(messages):
"""Count the total number of images present in the messages history."""
total = 0
for message in messages:
# Check if message has content (should be a list)
if "content" in message and isinstance(message["content"], list):
# Iterate through each content item
for content_item in message["content"]:
# Check if content item is a dict with type "image"
if (
isinstance(content_item, dict)
and content_item.get("type") == "image"
):
total += 1
return total
def _prune_messages_for_next_round(
messages_list,
used_text_prompts,
latest_sam3_text_prompt,
img_path,
initial_text_prompt,
):
"""Return a new messages list that contains only:
1) messages[:2] (with optional warning text added to the second message's content)
2) the latest assistant message (and everything after it) that contains a segment_phrase tool call
"""
# There should not be more than 10 messages in the conversation history
assert len(messages_list) < 10
# Part 1: always keep the first two message JSONs
part1 = copy.deepcopy(messages_list[:2])
# Part 2: search backwards for the latest assistant message containing a segment_phrase tool call
part2_start_idx = None
for idx in range(len(messages_list) - 1, 1, -1):
msg = messages_list[idx]
# We only consider assistant messages with a "content" list
if msg.get("role") != "assistant" or "content" not in msg:
continue
# Look for any content element that is a text containing the segment_phrase tool call
for content in msg["content"]:
if (
isinstance(content, dict)
and content.get("type") == "text"
and "<tool>" in content.get("text", "")
and "segment_phrase" in content.get("text", "")
):
part2_start_idx = idx
break
if part2_start_idx is not None:
break
part2 = messages_list[part2_start_idx:] if part2_start_idx is not None else []
# Part 3: decide whether to add warning text to the second message in part1
previously_used = (
[p for p in used_text_prompts if p != latest_sam3_text_prompt]
if latest_sam3_text_prompt
else list(used_text_prompts)
)
if part2 and len(previously_used) > 0:
warning_text = f'Note that we have previously called the segment_phrase tool with each "text_prompt" in this list: {list(previously_used)}, but none of the generated results were satisfactory. So make sure that you do not use any of these phrases as the "text_prompt" to call the segment_phrase tool again.'
# Replace the second message entirely to keep exactly 2 content items
part1[1] = {
"role": "user",
"content": [
{"type": "image", "image": img_path},
{
"type": "text",
"text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'."
+ " "
+ warning_text,
},
],
}
assert len(part1[1]["content"]) == 2
# Build the new messages list: part1 (with optional warning), then part2
new_messages = list(part1)
new_messages.extend(part2)
return new_messages
def agent_inference(
img_path: str,
initial_text_prompt: str,
debug: bool = False,
send_generate_request=send_generate_request,
call_sam_service=call_sam_service,
max_generations: int = 100,
output_dir="../../sam3_agent_out",
):
"""
Given a text prompt and an image, this tool will perform all aspects of agentic problem solving,
while saving sam3 and MLLM outputs to their respective directories.
Args:
img_path: Path to the input image
initial_text_prompt: Initial text prompt from the user
debug: Whether to enable debug mode
max_generations: Maximum number of send_generate_request calls allowed (default: 100)
"""
# setup dir
sam_output_dir = os.path.join(output_dir, "sam_out")
error_save_dir = os.path.join(output_dir, "none_out")
debug_save_dir = os.path.join(output_dir, "agent_debug_out")
os.makedirs(sam_output_dir, exist_ok=True)
os.makedirs(error_save_dir, exist_ok=True)
os.makedirs(debug_save_dir, exist_ok=True)
current_dir = os.path.dirname(os.path.abspath(__file__))
MLLM_SYSTEM_PROMPT_PATH = os.path.join(
current_dir, "system_prompts/system_prompt.txt"
)
ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH = os.path.join(
current_dir, "system_prompts/system_prompt_iterative_checking.txt"
)
# init variables
PATH_TO_LATEST_OUTPUT_JSON = ""
LATEST_SAM3_TEXT_PROMPT = ""
USED_TEXT_PROMPTS = (
set()
) # Track all previously used text prompts for segment_phrase
generation_count = 0 # Counter for number of send_generate_request calls
# debug setup
debug_folder_path = None
debug_jsonl_path = None
if debug:
debug_folder_path = os.path.join(
debug_save_dir, f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}"
)
debug_jsonl_path = os.path.join(debug_folder_path, "debug_history.json")
os.makedirs(debug_folder_path, exist_ok=True)
# The helper functions are now defined outside the agent_inference function
with open(MLLM_SYSTEM_PROMPT_PATH, "r") as f:
system_prompt = f.read().strip()
with open(ITERATIVE_CHECKING_SYSTEM_PROMPT_PATH, "r") as f:
iterative_checking_system_prompt = f.read().strip()
# Construct the initial message list
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image", "image": img_path},
{
"type": "text",
"text": f"The above image is the raw input image. The initial user input query is: '{initial_text_prompt}'.",
},
],
},
]
print(f"> Text prompt: {initial_text_prompt}")
print(f"> Image path: {img_path}")
print("\n\n")
print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
print("\n\n")
generated_text = send_generate_request(messages)
print(f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n")
while generated_text is not None:
save_debug_messages(messages, debug, debug_folder_path, debug_jsonl_path)
assert (
"<tool>" in generated_text,
f"Generated text does not contain <tool> tag: {generated_text}",
)
generated_text = generated_text.split("</tool>", 1)[0] + "</tool>"
tool_call_json_str = (
generated_text.split("<tool>")[-1]
.split("</tool>")[0]
.strip()
.replace(r"}}}", r"}}") # remove extra } if any
)
try:
tool_call = json.loads(tool_call_json_str)
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON in tool call: {tool_call_json_str}")
if PATH_TO_LATEST_OUTPUT_JSON == "":
# The first tool call must be segment_phrase or report_no_mask
assert (
tool_call["name"] == "segment_phrase"
or tool_call["name"] == "report_no_mask"
)
if tool_call["name"] == "segment_phrase":
print("🔍 Calling segment_phrase tool...")
assert list(tool_call["parameters"].keys()) == ["text_prompt"]
# Check if this text_prompt has been used before
current_text_prompt = tool_call["parameters"]["text_prompt"]
if current_text_prompt in USED_TEXT_PROMPTS:
print(
f"❌ Text prompt '{current_text_prompt}' has been used before. Requesting a different prompt."
)
duplicate_prompt_message = f"You have previously used '{current_text_prompt}' as your text_prompt to call the segment_phrase tool. You may not use it again. Please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase prompt, while adhering to all the rules stated in the system prompt. You must also never use any of the following text_prompt(s): {str(list(USED_TEXT_PROMPTS))}."
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
)
messages.append(
{
"role": "user",
"content": [{"type": "text", "text": duplicate_prompt_message}],
}
)
else:
# Add the text_prompt to the set of used prompts
USED_TEXT_PROMPTS.add(current_text_prompt)
LATEST_SAM3_TEXT_PROMPT = current_text_prompt
PATH_TO_LATEST_OUTPUT_JSON = call_sam_service(
image_path=img_path,
text_prompt=current_text_prompt,
output_folder_path=sam_output_dir,
)
sam3_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
sam3_output_image_path = sam3_outputs["output_image_path"]
num_masks = len(sam3_outputs["pred_boxes"])
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
)
if num_masks == 0:
print("❌ No masks generated by SAM3, reporting no mask to Qwen.")
sam3_output_text_message = f"The segment_phrase tool did not generate any masks for the text_prompt '{current_text_prompt}'. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt. Please be reminded that the original user query was '{initial_text_prompt}'."
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": sam3_output_text_message}
],
}
)
else:
sam3_output_text_message = rf"The segment_phrase tool generated {num_masks} available masks. All {num_masks} available masks are rendered in this image below, now you must analyze the {num_masks} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action. Please be reminded that the original user query was '{initial_text_prompt}'."
messages.append(
{
"role": "user",
"content": [
{"type": "text", "text": sam3_output_text_message},
{"type": "image", "image": sam3_output_image_path},
],
}
)
print("\n\n>>> sam3_output_text_message:\n", sam3_output_text_message)
elif tool_call["name"] == "examine_each_mask":
print("🔍 Calling examine_each_mask tool...")
assert LATEST_SAM3_TEXT_PROMPT != ""
# Make sure that the last message is a image
assert (
messages[-1]["content"][1]["type"] == "image"
), "Second content element should be an image"
messages.pop() # Remove the last user message
# Add simplified replacement message
simplified_message = {
"role": "user",
"content": [
{
"type": "text",
"text": "The segment_phrase tool generated several masks. Now you must analyze the mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
}
],
}
messages.append(simplified_message)
current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
num_masks = len(current_outputs["pred_masks"])
masks_to_keep = []
# MLLM check the mask one by one
for i in range(num_masks):
print(f"🔍 Checking mask {i+1}/{num_masks}...")
image_w_mask_i, image_w_zoomed_in_mask_i = visualize(current_outputs, i)
image_w_zoomed_in_mask_i_path = os.path.join(
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
).replace(".png", f"_zoom_in_mask_{i + 1}.png")
image_w_mask_i_path = os.path.join(
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png".replace("/", "_")
).replace(".png", f"_selected_mask_{i + 1}.png")
image_w_zoomed_in_mask_i.save(image_w_zoomed_in_mask_i_path)
image_w_mask_i.save(image_w_mask_i_path)
iterative_checking_messages = [
{"role": "system", "content": iterative_checking_system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": f"The raw input image: "},
{"type": "image", "image": img_path},
{
"type": "text",
"text": f"The initial user input query is: '{initial_text_prompt}'",
},
{
"type": "text",
"text": f"Image with the predicted segmentation mask rendered on it: ",
},
{"type": "image", "image": image_w_mask_i_path},
{
"type": "text",
"text": f"Image with the zoomed-in mask: ",
},
{"type": "image", "image": image_w_zoomed_in_mask_i_path},
],
},
]
checking_generated_text = send_generate_request(
iterative_checking_messages
)
# Process the generated text to determine if the mask should be kept or rejected
if checking_generated_text is None:
raise ValueError(
"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters."
)
print(f"Generated text for mask {i+1}: {checking_generated_text}")
verdict = (
checking_generated_text.split("<verdict>")[-1]
.split("</verdict>")[0]
.strip()
)
if "Accept" in verdict:
assert not "Reject" in verdict
print(f"Mask {i+1} accepted, keeping it in the outputs.")
masks_to_keep.append(i)
elif "Reject" in verdict:
assert not "Accept" in verdict
print(f"Mask {i+1} rejected, removing it from the outputs.")
else:
raise ValueError(
f"Unexpected verdict in generated text: {checking_generated_text}. Expected 'Accept' or 'Reject'."
)
updated_outputs = {
"original_image_path": current_outputs["original_image_path"],
"orig_img_h": current_outputs["orig_img_h"],
"orig_img_w": current_outputs["orig_img_w"],
"pred_boxes": [current_outputs["pred_boxes"][i] for i in masks_to_keep],
"pred_scores": [
current_outputs["pred_scores"][i] for i in masks_to_keep
],
"pred_masks": [current_outputs["pred_masks"][i] for i in masks_to_keep],
}
image_w_check_masks = visualize(updated_outputs)
image_w_check_masks_path = os.path.join(
sam_output_dir, rf"{LATEST_SAM3_TEXT_PROMPT}.png"
).replace(
".png",
f"_selected_masks_{'-'.join(map(str, [i+1 for i in masks_to_keep]))}.png".replace(
"/", "_"
),
)
image_w_check_masks.save(image_w_check_masks_path)
# save the updated json outputs and append to message history
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
)
if len(masks_to_keep) == 0:
messages.append(
{
"role": "user",
"content": [
{
"type": "text",
"text": f"The original user query was: '{initial_text_prompt}'. The examine_each_mask tool examined and rejected all of the masks generated by the segment_phrase tool. Now, please call the segment_phrase tool again with a different, perhaps more general, or more creative simple noun phrase text_prompt, while adhering to all the rules stated in the system prompt.",
}
],
}
)
else:
messages.append(
{
"role": "user",
"content": [
{
"type": "text",
"text": f"The original user query was: '{initial_text_prompt}'. After calling the examine_each_mask tool on the available masks, the number of available masks is now {len(masks_to_keep)}. All {len(masks_to_keep)} available masks are rendered in this image below, now you must analyze the {len(masks_to_keep)} available mask(s) carefully, compare them against the raw input image and the original user query, and determine your next action.",
},
{"type": "image", "image": image_w_check_masks_path},
],
}
)
# Create a new filename based on the original path to avoid filename length issues
base_path = PATH_TO_LATEST_OUTPUT_JSON
# Remove any existing "masks_" suffix to avoid duplication
if "masks_" in base_path:
base_path = base_path.split("masks_")[0] + ".json"
# Create new filename with current masks; use a clearer suffix when empty
if len(masks_to_keep) == 0:
PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
".json", "masks_none.json"
)
else:
PATH_TO_LATEST_OUTPUT_JSON = base_path.replace(
".json", f"masks_{'_'.join(map(str, masks_to_keep))}.json"
)
json.dump(updated_outputs, open(PATH_TO_LATEST_OUTPUT_JSON, "w"), indent=4)
elif tool_call["name"] == "select_masks_and_return":
print("🔍 Calling select_masks_and_return tool...")
current_outputs = json.load(open(PATH_TO_LATEST_OUTPUT_JSON, "r"))
assert list(tool_call["parameters"].keys()) == ["final_answer_masks"]
masks_to_keep = tool_call["parameters"]["final_answer_masks"]
# Keep only valid mask indices, remove duplicates, and preserve deterministic ascending order
available_masks = set(range(1, len(current_outputs["pred_masks"]) + 1))
masks_to_keep = sorted({i for i in masks_to_keep if i in available_masks})
# Change this to a update message telling the model to try again along with information about errors made.
final_outputs = {
"original_image_path": current_outputs["original_image_path"],
"orig_img_h": current_outputs["orig_img_h"],
"orig_img_w": current_outputs["orig_img_w"],
"pred_boxes": [
current_outputs["pred_boxes"][i - 1] for i in masks_to_keep
],
"pred_scores": [
current_outputs["pred_scores"][i - 1] for i in masks_to_keep
],
"pred_masks": [
current_outputs["pred_masks"][i - 1] for i in masks_to_keep
],
}
rendered_final_output = visualize(final_outputs)
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
)
# Clean up debug files before successful return
cleanup_debug_files(debug, debug_folder_path, debug_jsonl_path)
return messages, final_outputs, rendered_final_output
elif tool_call["name"] == "report_no_mask":
print("🔍 Calling report_no_mask tool...")
height, width = cv2.imread(img_path).shape[:2]
final_outputs = {
"original_image_path": img_path,
"orig_img_h": height,
"orig_img_w": width,
"pred_boxes": [],
"pred_scores": [],
"pred_masks": [],
}
rendered_final_output = Image.open(img_path)
messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": generated_text}],
}
)
return messages, final_outputs, rendered_final_output
else:
raise ValueError(f"Unknown tool call: {tool_call['name']}")
# sometimes the MLLM don't know when to stop, and generates multiple tool calls in one round, so we need to split the generated text by </tool> and only keep the first one
for message in messages:
if message["role"] == "assistant" and "content" in message:
for content in message["content"]:
if (
isinstance(content, dict)
and content.get("type") == "text"
and "text" in content
):
content["text"] = (
content["text"].split("</tool>", 1)[0] + "</tool>\n\n"
)
# Prune the messages history before the next MLLM generation round according to the 3-part rules.
# This keeps history compact and ensures the model sees only the allowed parts.
messages = _prune_messages_for_next_round(
messages,
USED_TEXT_PROMPTS,
LATEST_SAM3_TEXT_PROMPT,
img_path,
initial_text_prompt,
)
# make sure there can never be more than 2 images in the context
assert count_images(messages) <= 2
generation_count += 1
if generation_count > max_generations:
raise ValueError(
f"Exceeded maximum number of allowed generation requests ({max_generations})"
)
print("\n\n")
print("-" * 30 + f" Round {str(generation_count + 1)}" + "-" * 30)
print("\n\n")
generated_text = send_generate_request(messages)
print(
f"\n>>> MLLM Response [start]\n{generated_text}\n<<< MLLM Response [end]\n"
)
print("\n\n>>> SAM 3 Agent execution ended.\n\n")
error_save_path = os.path.join(
error_save_dir,
f"{img_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]}_error_history.json",
)
with open(error_save_path, "w") as f:
json.dump(messages, f, indent=4)
print("Saved messages history that caused error to:", error_save_path)
raise ValueError(
rf"Generated text is None, which is unexpected. Please check the Qwen server and the input parameters for image path: {img_path} and initial text prompt: {initial_text_prompt}."
)

205
sam3/agent/client_llm.py Normal file
View File

@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import base64
import os
from typing import Any, Optional
from openai import OpenAI
def get_image_base64_and_mime(image_path):
"""Convert image file to base64 string and get MIME type"""
try:
# Get MIME type based on file extension
ext = os.path.splitext(image_path)[1].lower()
mime_types = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
}
mime_type = mime_types.get(ext, "image/jpeg") # Default to JPEG
# Convert image to base64
with open(image_path, "rb") as image_file:
base64_data = base64.b64encode(image_file.read()).decode("utf-8")
return base64_data, mime_type
except Exception as e:
print(f"Error converting image to base64: {e}")
return None, None
def send_generate_request(
messages,
server_url=None,
model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
api_key=None,
max_tokens=4096,
):
"""
Sends a request to the OpenAI-compatible API endpoint using the OpenAI client library.
Args:
server_url (str): The base URL of the server, e.g. "http://127.0.0.1:8000"
messages (list): A list of message dicts, each containing role and content.
model (str): The model to use for generation (default: "llama-4")
max_tokens (int): Maximum number of tokens to generate (default: 4096)
Returns:
str: The generated response text from the server.
"""
# Process messages to convert image paths to base64
processed_messages = []
for message in messages:
processed_message = message.copy()
if message["role"] == "user" and "content" in message:
processed_content = []
for c in message["content"]:
if isinstance(c, dict) and c.get("type") == "image":
# Convert image path to base64 format
image_path = c["image"]
print("image_path", image_path)
new_image_path = image_path.replace(
"?", "%3F"
) # Escape ? in the path
# Read the image file and convert to base64
try:
base64_image, mime_type = get_image_base64_and_mime(
new_image_path
)
if base64_image is None:
print(
f"Warning: Could not convert image to base64: {new_image_path}"
)
continue
# Create the proper image_url structure with base64 data
processed_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}",
"detail": "high",
},
}
)
except FileNotFoundError:
print(f"Warning: Image file not found: {new_image_path}")
continue
except Exception as e:
print(f"Warning: Error processing image {new_image_path}: {e}")
continue
else:
processed_content.append(c)
processed_message["content"] = processed_content
processed_messages.append(processed_message)
# Create OpenAI client with custom base URL
client = OpenAI(api_key=api_key, base_url=server_url)
try:
print(f"🔍 Calling model {model}...")
response = client.chat.completions.create(
model=model,
messages=processed_messages,
max_completion_tokens=max_tokens,
n=1,
)
# print(f"Received response: {response.choices[0].message}")
# Extract the response content
if response.choices and len(response.choices) > 0:
return response.choices[0].message.content
else:
print(f"Unexpected response format: {response}")
return None
except Exception as e:
print(f"Request failed: {e}")
return None
def send_direct_request(
llm: Any,
messages: list[dict[str, Any]],
sampling_params: Any,
) -> Optional[str]:
"""
Run inference on a vLLM model instance directly without using a server.
Args:
llm: Initialized vLLM LLM instance (passed from external initialization)
messages: List of message dicts with role and content (OpenAI format)
sampling_params: vLLM SamplingParams instance (initialized externally)
Returns:
str: Generated response text, or None if inference fails
"""
try:
# Process messages to handle images (convert to base64 if needed)
processed_messages = []
for message in messages:
processed_message = message.copy()
if message["role"] == "user" and "content" in message:
processed_content = []
for c in message["content"]:
if isinstance(c, dict) and c.get("type") == "image":
# Convert image path to base64 format
image_path = c["image"]
new_image_path = image_path.replace("?", "%3F")
try:
base64_image, mime_type = get_image_base64_and_mime(
new_image_path
)
if base64_image is None:
print(
f"Warning: Could not convert image: {new_image_path}"
)
continue
# vLLM expects image_url format
processed_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{base64_image}"
},
}
)
except Exception as e:
print(
f"Warning: Error processing image {new_image_path}: {e}"
)
continue
else:
processed_content.append(c)
processed_message["content"] = processed_content
processed_messages.append(processed_message)
print("🔍 Running direct inference with vLLM...")
# Run inference using vLLM's chat interface
outputs = llm.chat(
messages=processed_messages,
sampling_params=sampling_params,
)
# Extract the generated text from the first output
if outputs and len(outputs) > 0:
generated_text = outputs[0].outputs[0].text
return generated_text
else:
print(f"Unexpected output format: {outputs}")
return None
except Exception as e:
print(f"Direct inference failed: {e}")
return None

138
sam3/agent/client_sam3.py Executable file
View File

@@ -0,0 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import json
import os
import torch
from PIL import Image
from sam3.model.box_ops import box_xyxy_to_xywh
from sam3.train.masks_ops import rle_encode
from .helpers.mask_overlap_removal import remove_overlapping_masks
from .viz import visualize
def sam3_inference(processor, image_path, text_prompt):
"""Run SAM 3 image inference with text prompts and format the outputs"""
image = Image.open(image_path)
orig_img_w, orig_img_h = image.size
# model inference
inference_state = processor.set_image(image)
inference_state = processor.set_text_prompt(
state=inference_state, prompt=text_prompt
)
# format and assemble outputs
pred_boxes_xyxy = torch.stack(
[
inference_state["boxes"][:, 0] / orig_img_w,
inference_state["boxes"][:, 1] / orig_img_h,
inference_state["boxes"][:, 2] / orig_img_w,
inference_state["boxes"][:, 3] / orig_img_h,
],
dim=-1,
) # normalized in range [0, 1]
pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
pred_masks = rle_encode(inference_state["masks"].squeeze(1))
pred_masks = [m["counts"] for m in pred_masks]
outputs = {
"orig_img_h": orig_img_h,
"orig_img_w": orig_img_w,
"pred_boxes": pred_boxes_xywh,
"pred_masks": pred_masks,
"pred_scores": inference_state["scores"].tolist(),
}
return outputs
def call_sam_service(
sam3_processor,
image_path: str,
text_prompt: str,
output_folder_path: str = "sam3_output",
):
"""
Loads an image, sends it with a text prompt to the service,
saves the results, and renders the visualization.
"""
print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
text_prompt_for_save_path = (
text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
)
os.makedirs(
os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
)
output_json_path = os.path.join(
output_folder_path,
image_path.replace("/", "-"),
rf"{text_prompt_for_save_path}.json",
)
output_image_path = os.path.join(
output_folder_path,
image_path.replace("/", "-"),
rf"{text_prompt_for_save_path}.png",
)
try:
# Send the image and text prompt as a multipart/form-data request
serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
# 1. Prepare the response dictionary
serialized_response = remove_overlapping_masks(serialized_response)
serialized_response = {
"original_image_path": image_path,
"output_image_path": output_image_path,
**serialized_response,
}
# 2. Reorder predictions by scores (highest to lowest) if scores are available
if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
# Create indices sorted by scores in descending order
score_indices = sorted(
range(len(serialized_response["pred_scores"])),
key=lambda i: serialized_response["pred_scores"][i],
reverse=True,
)
# Reorder all three lists based on the sorted indices
serialized_response["pred_scores"] = [
serialized_response["pred_scores"][i] for i in score_indices
]
serialized_response["pred_boxes"] = [
serialized_response["pred_boxes"][i] for i in score_indices
]
serialized_response["pred_masks"] = [
serialized_response["pred_masks"][i] for i in score_indices
]
# 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
valid_masks = []
valid_boxes = []
valid_scores = []
for i, rle in enumerate(serialized_response["pred_masks"]):
if len(rle) > 4:
valid_masks.append(rle)
valid_boxes.append(serialized_response["pred_boxes"][i])
valid_scores.append(serialized_response["pred_scores"][i])
serialized_response["pred_masks"] = valid_masks
serialized_response["pred_boxes"] = valid_boxes
serialized_response["pred_scores"] = valid_scores
with open(output_json_path, "w") as f:
json.dump(serialized_response, f, indent=4)
print(f"✅ Raw JSON response saved to '{output_json_path}'")
# 4. Render and save visualizations on the image and save it in the SAM3 output folder
print("🔍 Rendering visualizations on the image ...")
viz_image = visualize(serialized_response)
os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
viz_image.save(output_image_path)
print("✅ Saved visualization at:", output_image_path)
except Exception as e:
print(f"❌ Error calling service: {e}")
return output_json_path

1
sam3/agent/helpers/__init__.py Executable file
View File

@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

438
sam3/agent/helpers/boxes.py Executable file
View File

@@ -0,0 +1,438 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import math
from enum import IntEnum, unique
from typing import List, Tuple, Union
import numpy as np
import torch
from torch import device
_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
@unique
class BoxMode(IntEnum):
"""
Enum of different ways to represent a box.
"""
XYXY_ABS = 0
"""
(x0, y0, x1, y1) in absolute floating points coordinates.
The coordinates in range [0, width or height].
"""
XYWH_ABS = 1
"""
(x0, y0, w, h) in absolute floating points coordinates.
"""
XYXY_REL = 2
"""
Not yet supported!
(x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
"""
XYWH_REL = 3
"""
Not yet supported!
(x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
"""
XYWHA_ABS = 4
"""
(xc, yc, w, h, a) in absolute floating points coordinates.
(xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
"""
@staticmethod
def convert(
box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode"
) -> _RawBoxType:
"""
Args:
box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
from_mode, to_mode (BoxMode)
Returns:
The converted box of the same type.
"""
if from_mode == to_mode:
return box
original_type = type(box)
is_numpy = isinstance(box, np.ndarray)
single_box = isinstance(box, (list, tuple))
if single_box:
assert len(box) == 4 or len(box) == 5, (
"BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
" where k == 4 or 5"
)
arr = torch.tensor(box)[None, :]
else:
# avoid modifying the input box
if is_numpy:
arr = torch.from_numpy(np.asarray(box)).clone()
else:
arr = box.clone()
assert to_mode not in [
BoxMode.XYXY_REL,
BoxMode.XYWH_REL,
] and from_mode not in [
BoxMode.XYXY_REL,
BoxMode.XYWH_REL,
], "Relative mode not yet supported!"
if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
assert (
arr.shape[-1] == 5
), "The last dimension of input shape must be 5 for XYWHA format"
original_dtype = arr.dtype
arr = arr.double()
w = arr[:, 2]
h = arr[:, 3]
a = arr[:, 4]
c = torch.abs(torch.cos(a * math.pi / 180.0))
s = torch.abs(torch.sin(a * math.pi / 180.0))
# This basically computes the horizontal bounding rectangle of the rotated box
new_w = c * w + s * h
new_h = c * h + s * w
# convert center to top-left corner
arr[:, 0] -= new_w / 2.0
arr[:, 1] -= new_h / 2.0
# bottom-right corner
arr[:, 2] = arr[:, 0] + new_w
arr[:, 3] = arr[:, 1] + new_h
arr = arr[:, :4].to(dtype=original_dtype)
elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
original_dtype = arr.dtype
arr = arr.double()
arr[:, 0] += arr[:, 2] / 2.0
arr[:, 1] += arr[:, 3] / 2.0
angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
else:
if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
arr[:, 2] += arr[:, 0]
arr[:, 3] += arr[:, 1]
elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
arr[:, 2] -= arr[:, 0]
arr[:, 3] -= arr[:, 1]
else:
raise NotImplementedError(
"Conversion from BoxMode {} to {} is not supported yet".format(
from_mode, to_mode
)
)
if single_box:
return original_type(arr.flatten().tolist())
if is_numpy:
return arr.numpy()
else:
return arr
class Boxes:
"""
This structure stores a list of boxes as a Nx4 torch.Tensor.
It supports some common methods about boxes
(`area`, `clip`, `nonempty`, etc),
and also behaves like a Tensor
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
Attributes:
tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2).
"""
def __init__(self, tensor: torch.Tensor):
"""
Args:
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
"""
if not isinstance(tensor, torch.Tensor):
tensor = torch.as_tensor(
tensor, dtype=torch.float32, device=torch.device("cpu")
)
else:
tensor = tensor.to(torch.float32)
if tensor.numel() == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
self.tensor = tensor
def clone(self) -> "Boxes":
"""
Clone the Boxes.
Returns:
Boxes
"""
return Boxes(self.tensor.clone())
def to(self, device: torch.device):
# Boxes are assumed float32 and does not support to(dtype)
return Boxes(self.tensor.to(device=device))
def area(self) -> torch.Tensor:
"""
Computes the area of all the boxes.
Returns:
torch.Tensor: a vector with areas of each box.
"""
box = self.tensor
area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1])
return area
def clip(self, box_size: Tuple[int, int]) -> None:
"""
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
and y coordinates to the range [0, height].
Args:
box_size (height, width): The clipping box's size.
"""
assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
h, w = box_size
x1 = self.tensor[:, 0].clamp(min=0, max=w)
y1 = self.tensor[:, 1].clamp(min=0, max=h)
x2 = self.tensor[:, 2].clamp(min=0, max=w)
y2 = self.tensor[:, 3].clamp(min=0, max=h)
self.tensor = torch.stack((x1, y1, x2, y2), dim=-1)
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
"""
Find boxes that are non-empty.
A box is considered empty, if either of its side is no larger than threshold.
Returns:
Tensor:
a binary vector which represents whether each box is empty
(False) or non-empty (True).
"""
box = self.tensor
widths = box[:, 2] - box[:, 0]
heights = box[:, 3] - box[:, 1]
keep = (widths > threshold) & (heights > threshold)
return keep
def __getitem__(self, item) -> "Boxes":
"""
Args:
item: int, slice, or a BoolTensor
Returns:
Boxes: Create a new :class:`Boxes` by indexing.
The following usage are allowed:
1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box.
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
Note that the returned Boxes might share storage with this Boxes,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return Boxes(self.tensor[item].view(1, -1))
b = self.tensor[item]
assert (
b.dim() == 2
), "Indexing on Boxes with {} failed to return a matrix!".format(item)
return Boxes(b)
def __len__(self) -> int:
return self.tensor.shape[0]
def __repr__(self) -> str:
return "Boxes(" + str(self.tensor) + ")"
def inside_box(
self, box_size: Tuple[int, int], boundary_threshold: int = 0
) -> torch.Tensor:
"""
Args:
box_size (height, width): Size of the reference box.
boundary_threshold (int): Boxes that extend beyond the reference box
boundary by more than boundary_threshold are considered "outside".
Returns:
a binary vector, indicating whether each box is inside the reference box.
"""
height, width = box_size
inds_inside = (
(self.tensor[..., 0] >= -boundary_threshold)
& (self.tensor[..., 1] >= -boundary_threshold)
& (self.tensor[..., 2] < width + boundary_threshold)
& (self.tensor[..., 3] < height + boundary_threshold)
)
return inds_inside
def get_centers(self) -> torch.Tensor:
"""
Returns:
The box centers in a Nx2 array of (x, y).
"""
return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2
def scale(self, scale_x: float, scale_y: float) -> None:
"""
Scale the box with horizontal and vertical scaling factors
"""
self.tensor[:, 0::2] *= scale_x
self.tensor[:, 1::2] *= scale_y
@classmethod
def cat(cls, boxes_list: List["Boxes"]) -> "Boxes":
"""
Concatenates a list of Boxes into a single Boxes
Arguments:
boxes_list (list[Boxes])
Returns:
Boxes: the concatenated Boxes
"""
assert isinstance(boxes_list, (list, tuple))
if len(boxes_list) == 0:
return cls(torch.empty(0))
assert all([isinstance(box, Boxes) for box in boxes_list])
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
return cat_boxes
@property
def device(self) -> device:
return self.tensor.device
# type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript
# https://github.com/pytorch/pytorch/issues/18627
@torch.jit.unused
def __iter__(self):
"""
Yield a box as a Tensor of shape (4,) at a time.
"""
yield from self.tensor
def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Given two lists of boxes of size N and M,
compute the intersection area between __all__ N x M pairs of boxes.
The box order must be (xmin, ymin, xmax, ymax)
Args:
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
Returns:
Tensor: intersection, sized [N,M].
"""
boxes1, boxes2 = boxes1.tensor, boxes2.tensor
width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max(
boxes1[:, None, :2], boxes2[:, :2]
) # [N,M,2]
width_height.clamp_(min=0) # [N,M,2]
intersection = width_height.prod(dim=2) # [N,M]
return intersection
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Given two lists of boxes of size N and M, compute the IoU
(intersection over union) between **all** N x M pairs of boxes.
The box order must be (xmin, ymin, xmax, ymax).
Args:
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
Returns:
Tensor: IoU, sized [N,M].
"""
area1 = boxes1.area() # [N]
area2 = boxes2.area() # [M]
inter = pairwise_intersection(boxes1, boxes2)
# handle empty boxes
iou = torch.where(
inter > 0,
inter / (area1[:, None] + area2 - inter),
torch.zeros(1, dtype=inter.dtype, device=inter.device),
)
return iou
def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area).
Args:
boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively.
Returns:
Tensor: IoA, sized [N,M].
"""
area2 = boxes2.area() # [M]
inter = pairwise_intersection(boxes1, boxes2)
# handle empty boxes
ioa = torch.where(
inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device)
)
return ioa
def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes):
"""
Pairwise distance between N points and M boxes. The distance between a
point and a box is represented by the distance from the point to 4 edges
of the box. Distances are all positive when the point is inside the box.
Args:
points: Nx2 coordinates. Each row is (x, y)
boxes: M boxes
Returns:
Tensor: distances of size (N, M, 4). The 4 values are distances from
the point to the left, top, right, bottom of the box.
"""
x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1)
x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M)
return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2)
def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor:
"""
Compute pairwise intersection over union (IOU) of two sets of matched
boxes that have the same number of boxes.
Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix.
Args:
boxes1 (Boxes): bounding boxes, sized [N,4].
boxes2 (Boxes): same length as boxes1
Returns:
Tensor: iou, sized [N].
"""
assert len(boxes1) == len(boxes2), (
"boxlists should have the same" "number of entries, got {}, {}".format(
len(boxes1), len(boxes2)
)
)
area1 = boxes1.area() # [N]
area2 = boxes2.area() # [N]
box1, box2 = boxes1.tensor, boxes2.tensor
lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2]
rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2]
wh = (rb - lt).clamp(min=0) # [N,2]
inter = wh[:, 0] * wh[:, 1] # [N]
iou = inter / (area1 + area2 - inter) # [N]
return iou

View File

@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
An awesome colormap for really neat visualizations.
Copied from Detectron, and removed gray colors.
"""
import random
import numpy as np
__all__ = ["colormap", "random_color", "random_colors"]
# A list of 25 bright and sharp colors for segmentation masks,
# generated from the edges of the sRGB color space for maximum intensity.
_COLORS = (
np.array(
[
# The original 8 sharp colors
1.000,
1.000,
0.000, # 1. Yellow
0.000,
1.000,
0.000, # 2. Lime
0.000,
1.000,
1.000, # 3. Cyan
1.000,
0.000,
1.000, # 4. Magenta
1.000,
0.000,
0.000, # 5. Red
1.000,
0.498,
0.000, # 6. Orange
0.498,
1.000,
0.000, # 7. Chartreuse
0.000,
1.000,
0.498, # 8. Spring Green
1.000,
0.000,
0.498, # 9. Rose
0.498,
0.000,
1.000, # 10. Violet
0.753,
1.000,
0.000, # 11. Electric Lime
1.000,
0.753,
0.000, # 12. Vivid Orange
0.000,
1.000,
0.753, # 13. Turquoise
0.753,
0.000,
1.000, # 14. Bright Violet
1.000,
0.000,
0.753, # 15. Bright Pink
1.000,
0.251,
0.000, # 16. Fiery Orange
0.251,
1.000,
0.000, # 17. Bright Chartreuse
0.000,
1.000,
0.251, # 18. Malachite Green
0.251,
0.000,
1.000, # 19. Deep Violet
1.000,
0.000,
0.251, # 20. Hot Pink
]
)
.astype(np.float32)
.reshape(-1, 3)
)
def colormap(rgb=False, maximum=255):
"""
Args:
rgb (bool): whether to return RGB colors or BGR colors.
maximum (int): either 255 or 1
Returns:
ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1]
"""
assert maximum in [255, 1], maximum
c = _COLORS * maximum
if not rgb:
c = c[:, ::-1]
return c
def random_color(rgb=False, maximum=255):
"""
Args:
rgb (bool): whether to return RGB colors or BGR colors.
maximum (int): either 255 or 1
Returns:
ndarray: a vector of 3 numbers
"""
idx = np.random.randint(0, len(_COLORS))
ret = _COLORS[idx] * maximum
if not rgb:
ret = ret[::-1]
return ret
def random_colors(N, rgb=False, maximum=255):
"""
Args:
N (int): number of unique colors needed
rgb (bool): whether to return RGB colors or BGR colors.
maximum (int): either 255 or 1
Returns:
ndarray: a list of random_color
"""
indices = random.sample(range(len(_COLORS)), N)
ret = [_COLORS[i] * maximum for i in indices]
if not rgb:
ret = [x[::-1] for x in ret]
return ret
if __name__ == "__main__":
import cv2
size = 100
H, W = 10, 10
canvas = np.random.rand(H * size, W * size, 3).astype("float32")
for h in range(H):
for w in range(W):
idx = h * W + w
if idx >= len(_COLORS):
break
canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx]
cv2.imshow("a", canvas)
cv2.waitKey(0)

244
sam3/agent/helpers/keypoints.py Executable file
View File

@@ -0,0 +1,244 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Any, List, Tuple, Union
import numpy as np
import torch
from torch.nn import functional as F
class Keypoints:
"""
Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property
containing the x,y location and visibility flag of each keypoint. This tensor has shape
(N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
The visibility flag follows the COCO format and must be one of three integers:
* v=0: not labeled (in which case x=y=0)
* v=1: labeled but not visible
* v=2: labeled and visible
"""
def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
"""
Arguments:
keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
The shape should be (N, K, 3) where N is the number of
instances, and K is the number of keypoints per instance.
"""
device = (
keypoints.device
if isinstance(keypoints, torch.Tensor)
else torch.device("cpu")
)
keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
self.tensor = keypoints
def __len__(self) -> int:
return self.tensor.size(0)
def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
return type(self)(self.tensor.to(*args, **kwargs))
@property
def device(self) -> torch.device:
return self.tensor.device
def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
"""
Convert keypoint annotations to a heatmap of one-hot labels for training,
as described in :paper:`Mask R-CNN`.
Arguments:
boxes: Nx4 tensor, the boxes to draw the keypoints to
Returns:
heatmaps:
A tensor of shape (N, K), each element is integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
valid:
A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
"""
return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
"""
Create a new `Keypoints` by indexing on this `Keypoints`.
The following usage are allowed:
1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
2. `new_kpts = kpts[2:10]`: return a slice of key points.
3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
with `length = len(kpts)`. Nonzero elements in the vector will be selected.
Note that the returned Keypoints might share storage with this Keypoints,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return Keypoints([self.tensor[item]])
return Keypoints(self.tensor[item])
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
@staticmethod
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
"""
Concatenates a list of Keypoints into a single Keypoints
Arguments:
keypoints_list (list[Keypoints])
Returns:
Keypoints: the concatenated Keypoints
"""
assert isinstance(keypoints_list, (list, tuple))
assert len(keypoints_list) > 0
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
cat_kpts = type(keypoints_list[0])(
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
)
return cat_kpts
def _keypoints_to_heatmap(
keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
Arguments:
keypoints: tensor of keypoint locations in of shape (N, K, 3).
rois: Nx4 tensor of rois in xyxy format
heatmap_size: integer side length of square heatmap.
Returns:
heatmaps: A tensor of shape (N, K) containing an integer spatial label
in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
valid: A tensor of shape (N, K) containing whether each keypoint is in
the roi or not.
"""
if rois.numel() == 0:
return rois.new().long(), rois.new().long()
offset_x = rois[:, 0]
offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
offset_x = offset_x[:, None]
offset_y = offset_y[:, None]
scale_x = scale_x[:, None]
scale_y = scale_y[:, None]
x = keypoints[..., 0]
y = keypoints[..., 1]
x_boundary_inds = x == rois[:, 2][:, None]
y_boundary_inds = y == rois[:, 3][:, None]
x = (x - offset_x) * scale_x
x = x.floor().long()
y = (y - offset_y) * scale_y
y = y.floor().long()
x[x_boundary_inds] = heatmap_size - 1
y[y_boundary_inds] = heatmap_size - 1
valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
vis = keypoints[..., 2] > 0
valid = (valid_loc & vis).long()
lin_ind = y * heatmap_size + x
heatmaps = lin_ind * valid
return heatmaps, valid
@torch.jit.script_if_tracing
def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
"""
Extract predicted keypoint locations from heatmaps.
Args:
maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for
each ROI and each keypoint.
rois (Tensor): (#ROIs, 4). The box of each ROI.
Returns:
Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to
(x, y, logit, score) for each keypoint.
When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate,
we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from
Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
"""
offset_x = rois[:, 0]
offset_y = rois[:, 1]
widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
widths_ceil = widths.ceil()
heights_ceil = heights.ceil()
num_rois, num_keypoints = maps.shape[:2]
xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
width_corrections = widths / widths_ceil
height_corrections = heights / heights_ceil
keypoints_idx = torch.arange(num_keypoints, device=maps.device)
for i in range(num_rois):
outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
roi_map = F.interpolate(
maps[[i]], size=outsize, mode="bicubic", align_corners=False
)
# Although semantically equivalent, `reshape` is used instead of `squeeze` due
# to limitation during ONNX export of `squeeze` in scripting mode
roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W
# softmax over the spatial region
max_score, _ = roi_map.view(num_keypoints, -1).max(1)
max_score = max_score.view(num_keypoints, 1, 1)
tmp_full_resolution = (roi_map - max_score).exp_()
tmp_pool_resolution = (maps[i] - max_score).exp_()
# Produce scores over the region H x W, but normalize with POOL_H x POOL_W,
# so that the scores of objects of different absolute sizes will be more comparable
roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum(
(1, 2), keepdim=True
)
w = roi_map.shape[2]
pos = roi_map.view(num_keypoints, -1).argmax(1)
x_int = pos % w
y_int = (pos - x_int) // w
assert (
roi_map_scores[keypoints_idx, y_int, x_int]
== roi_map_scores.view(num_keypoints, -1).max(1)[0]
).all()
x = (x_int.float() + 0.5) * width_corrections[i]
y = (y_int.float() + 0.5) * height_corrections[i]
xy_preds[i, :, 0] = x + offset_x[i]
xy_preds[i, :, 1] = y + offset_y[i]
xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int]
return xy_preds

View File

@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from typing import Dict, List
import numpy as np
import torch
try:
from pycocotools import mask as mask_utils
except Exception:
mask_utils = None
def mask_intersection(
masks1: torch.Tensor, masks2: torch.Tensor, block_size: int = 16
) -> torch.Tensor:
assert masks1.shape[1:] == masks2.shape[1:]
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
N, M = masks1.shape[0], masks2.shape[0]
out = torch.zeros(N, M, device=masks1.device, dtype=torch.long)
for i in range(0, N, block_size):
for j in range(0, M, block_size):
a = masks1[i : i + block_size]
b = masks2[j : j + block_size]
inter = (a[:, None] & b[None, :]).flatten(-2).sum(-1)
out[i : i + block_size, j : j + block_size] = inter
return out
def mask_iom(masks1: torch.Tensor, masks2: torch.Tensor) -> torch.Tensor:
assert masks1.shape[1:] == masks2.shape[1:]
assert masks1.dtype == torch.bool and masks2.dtype == torch.bool
inter = mask_intersection(masks1, masks2)
area1 = masks1.flatten(-2).sum(-1) # (N,)
area2 = masks2.flatten(-2).sum(-1) # (M,)
min_area = torch.min(area1[:, None], area2[None, :]).clamp_min(1)
return inter.float() / (min_area.float() + 1e-8)
def _decode_single_mask(mask_repr, h: int, w: int) -> np.ndarray:
if isinstance(mask_repr, (list, tuple, np.ndarray)):
arr = np.array(mask_repr)
if arr.ndim != 2:
raise ValueError("Mask array must be 2D (H, W).")
return (arr > 0).astype(np.uint8)
if mask_utils is None:
raise ImportError(
"pycocotools is required to decode RLE mask strings. pip install pycocotools"
)
if not isinstance(mask_repr, (str, bytes)):
raise ValueError("Unsupported mask representation type for RLE decode.")
rle = {
"counts": mask_repr if isinstance(mask_repr, (str, bytes)) else str(mask_repr),
"size": [h, w],
}
decoded = mask_utils.decode(rle)
if decoded.ndim == 3:
decoded = decoded[:, :, 0]
return (decoded > 0).astype(np.uint8)
def _decode_masks_to_torch_bool(pred_masks: List, h: int, w: int) -> torch.Tensor:
bin_masks = [_decode_single_mask(m, h, w) for m in pred_masks]
masks_np = np.stack(bin_masks, axis=0).astype(np.uint8) # (N, H, W)
return torch.from_numpy(masks_np > 0)
def remove_overlapping_masks(sample: Dict, iom_thresh: float = 0.3) -> Dict:
"""
Greedy keep: sort by score desc; keep a mask if IoM to all kept masks <= threshold.
If pred_masks has length 0 or 1, returns sample unchanged (no extra keys).
"""
# Basic presence checks
if "pred_masks" not in sample or not isinstance(sample["pred_masks"], list):
return sample # nothing to do / preserve as-is
pred_masks = sample["pred_masks"]
N = len(pred_masks)
# --- Early exit: 0 or 1 mask -> do NOT modify the JSON at all ---
if N <= 1:
return sample
# From here on we have at least 2 masks
h = int(sample["orig_img_h"])
w = int(sample["orig_img_w"])
pred_scores = sample.get("pred_scores", [1.0] * N) # fallback if scores missing
pred_boxes = sample.get("pred_boxes", None)
assert N == len(pred_scores), "pred_masks and pred_scores must have same length"
if pred_boxes is not None:
assert N == len(pred_boxes), "pred_masks and pred_boxes must have same length"
masks_bool = _decode_masks_to_torch_bool(pred_masks, h, w) # (N, H, W)
order = sorted(range(N), key=lambda i: float(pred_scores[i]), reverse=True)
kept_idx: List[int] = []
kept_masks: List[torch.Tensor] = []
for i in order:
cand = masks_bool[i].unsqueeze(0) # (1, H, W)
if len(kept_masks) == 0:
kept_idx.append(i)
kept_masks.append(masks_bool[i])
continue
kept_stack = torch.stack(kept_masks, dim=0) # (K, H, W)
iom_vals = mask_iom(cand, kept_stack).squeeze(0) # (K,)
if torch.any(iom_vals > iom_thresh):
continue # overlaps too much with a higher-scored kept mask
kept_idx.append(i)
kept_masks.append(masks_bool[i])
kept_idx_sorted = sorted(kept_idx)
# Build filtered JSON (this *does* modify fields; only for N>=2 case)
out = dict(sample)
out["pred_masks"] = [pred_masks[i] for i in kept_idx_sorted]
out["pred_scores"] = [pred_scores[i] for i in kept_idx_sorted]
if pred_boxes is not None:
out["pred_boxes"] = [pred_boxes[i] for i in kept_idx_sorted]
out["kept_indices"] = kept_idx_sorted
out["removed_indices"] = [i for i in range(N) if i not in set(kept_idx_sorted)]
out["iom_threshold"] = float(iom_thresh)
return out

560
sam3/agent/helpers/masks.py Executable file
View File

@@ -0,0 +1,560 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import copy
import itertools
from typing import Any, Iterator, List, Union
import numpy as np
import pycocotools.mask as mask_util
import torch
from torch import device
from .boxes import Boxes
from .memory import retry_if_cuda_oom
from .roi_align import ROIAlign
def polygon_area(x, y):
# Using the shoelace formula
# https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
def polygons_to_bitmask(
polygons: List[np.ndarray], height: int, width: int
) -> np.ndarray:
"""
Args:
polygons (list[ndarray]): each array has shape (Nx2,)
height, width (int)
Returns:
ndarray: a bool mask of shape (height, width)
"""
if len(polygons) == 0:
# COCOAPI does not support empty polygons
return np.zeros((height, width)).astype(bool)
rles = mask_util.frPyObjects(polygons, height, width)
rle = mask_util.merge(rles)
return mask_util.decode(rle).astype(bool)
def rasterize_polygons_within_box(
polygons: List[np.ndarray], box: np.ndarray, mask_size: int
) -> torch.Tensor:
"""
Rasterize the polygons into a mask image and
crop the mask content in the given box.
The cropped mask is resized to (mask_size, mask_size).
This function is used when generating training targets for mask head in Mask R-CNN.
Given original ground-truth masks for an image, new ground-truth mask
training targets in the size of `mask_size x mask_size`
must be provided for each predicted box. This function will be called to
produce such targets.
Args:
polygons (list[ndarray[float]]): a list of polygons, which represents an instance.
box: 4-element numpy array
mask_size (int):
Returns:
Tensor: BoolTensor of shape (mask_size, mask_size)
"""
# 1. Shift the polygons w.r.t the boxes
w, h = box[2] - box[0], box[3] - box[1]
polygons = copy.deepcopy(polygons)
for p in polygons:
p[0::2] = p[0::2] - box[0]
p[1::2] = p[1::2] - box[1]
# 2. Rescale the polygons to the new box size
# max() to avoid division by small number
ratio_h = mask_size / max(h, 0.1)
ratio_w = mask_size / max(w, 0.1)
if ratio_h == ratio_w:
for p in polygons:
p *= ratio_h
else:
for p in polygons:
p[0::2] *= ratio_w
p[1::2] *= ratio_h
# 3. Rasterize the polygons with coco api
mask = polygons_to_bitmask(polygons, mask_size, mask_size)
mask = torch.from_numpy(mask)
return mask
class BitMasks:
"""
This class stores the segmentation masks for all objects in one image, in
the form of bitmaps.
Attributes:
tensor: bool Tensor of N,H,W, representing N instances in the image.
"""
def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
"""
Args:
tensor: bool Tensor of N,H,W, representing N instances in the image.
"""
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(torch.bool)
else:
tensor = torch.as_tensor(
tensor, dtype=torch.bool, device=torch.device("cpu")
)
assert tensor.dim() == 3, tensor.size()
self.image_size = tensor.shape[1:]
self.tensor = tensor
@torch.jit.unused
def to(self, *args: Any, **kwargs: Any) -> "BitMasks":
return BitMasks(self.tensor.to(*args, **kwargs))
@property
def device(self) -> torch.device:
return self.tensor.device
@torch.jit.unused
def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks":
"""
Returns:
BitMasks: Create a new :class:`BitMasks` by indexing.
The following usage are allowed:
1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask.
2. `new_masks = masks[2:10]`: return a slice of masks.
3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
with `length = len(masks)`. Nonzero elements in the vector will be selected.
Note that the returned object might share storage with this object,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return BitMasks(self.tensor[item].unsqueeze(0))
m = self.tensor[item]
assert (
m.dim() == 3
), "Indexing on BitMasks with {} returns a tensor with shape {}!".format(
item, m.shape
)
return BitMasks(m)
@torch.jit.unused
def __iter__(self) -> torch.Tensor:
yield from self.tensor
@torch.jit.unused
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
def __len__(self) -> int:
return self.tensor.shape[0]
def nonempty(self) -> torch.Tensor:
"""
Find masks that are non-empty.
Returns:
Tensor: a BoolTensor which represents
whether each mask is empty (False) or non-empty (True).
"""
return self.tensor.flatten(1).any(dim=1)
@staticmethod
def from_polygon_masks(
polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]],
height: int,
width: int,
) -> "BitMasks":
"""
Args:
polygon_masks (list[list[ndarray]] or PolygonMasks)
height, width (int)
"""
if isinstance(polygon_masks, PolygonMasks):
polygon_masks = polygon_masks.polygons
masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks]
if len(masks):
return BitMasks(torch.stack([torch.from_numpy(x) for x in masks]))
else:
return BitMasks(torch.empty(0, height, width, dtype=torch.bool))
@staticmethod
def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks":
"""
Args:
roi_masks:
height, width (int):
"""
return roi_masks.to_bitmasks(height, width)
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
"""
Crop each bitmask by the given box, and resize results to (mask_size, mask_size).
This can be used to prepare training targets for Mask R-CNN.
It has less reconstruction error compared to rasterization with polygons.
However we observe no difference in accuracy,
but BitMasks requires more memory to store all the masks.
Args:
boxes (Tensor): Nx4 tensor storing the boxes for each mask
mask_size (int): the size of the rasterized mask.
Returns:
Tensor:
A bool tensor of shape (N, mask_size, mask_size), where
N is the number of predicted boxes for this image.
"""
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
device = self.tensor.device
batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[
:, None
]
rois = torch.cat([batch_inds, boxes], dim=1) # Nx5
bit_masks = self.tensor.to(dtype=torch.float32)
rois = rois.to(device=device)
output = (
ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True)
.forward(bit_masks[:, None, :, :], rois)
.squeeze(1)
)
output = output >= 0.5
return output
def get_bounding_boxes(self) -> Boxes:
"""
Returns:
Boxes: tight bounding boxes around bitmasks.
If a mask is empty, it's bounding box will be all zero.
"""
boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32)
x_any = torch.any(self.tensor, dim=1)
y_any = torch.any(self.tensor, dim=2)
for idx in range(self.tensor.shape[0]):
x = torch.where(x_any[idx, :])[0]
y = torch.where(y_any[idx, :])[0]
if len(x) > 0 and len(y) > 0:
boxes[idx, :] = torch.as_tensor(
[x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32
)
return Boxes(boxes)
@staticmethod
def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks":
"""
Concatenates a list of BitMasks into a single BitMasks
Arguments:
bitmasks_list (list[BitMasks])
Returns:
BitMasks: the concatenated BitMasks
"""
assert isinstance(bitmasks_list, (list, tuple))
assert len(bitmasks_list) > 0
assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list)
cat_bitmasks = type(bitmasks_list[0])(
torch.cat([bm.tensor for bm in bitmasks_list], dim=0)
)
return cat_bitmasks
class PolygonMasks:
"""
This class stores the segmentation masks for all objects in one image, in the form of polygons.
Attributes:
polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon.
"""
def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]):
"""
Arguments:
polygons (list[list[np.ndarray]]): The first
level of the list correspond to individual instances,
the second level to all the polygons that compose the
instance, and the third level to the polygon coordinates.
The third level array should have the format of
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
"""
if not isinstance(polygons, list):
raise ValueError(
"Cannot create PolygonMasks: Expect a list of list of polygons per image. "
"Got '{}' instead.".format(type(polygons))
)
def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
# Use float64 for higher precision, because why not?
# Always put polygons on CPU (self.to is a no-op) since they
# are supposed to be small tensors.
# May need to change this assumption if GPU placement becomes useful
if isinstance(t, torch.Tensor):
t = t.cpu().numpy()
return np.asarray(t).astype("float64")
def process_polygons(
polygons_per_instance: List[Union[torch.Tensor, np.ndarray]],
) -> List[np.ndarray]:
if not isinstance(polygons_per_instance, list):
raise ValueError(
"Cannot create polygons: Expect a list of polygons per instance. "
"Got '{}' instead.".format(type(polygons_per_instance))
)
# transform each polygon to a numpy array
polygons_per_instance = [_make_array(p) for p in polygons_per_instance]
for polygon in polygons_per_instance:
if len(polygon) % 2 != 0 or len(polygon) < 6:
raise ValueError(
f"Cannot create a polygon from {len(polygon)} coordinates."
)
return polygons_per_instance
self.polygons: List[List[np.ndarray]] = [
process_polygons(polygons_per_instance)
for polygons_per_instance in polygons
]
def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks":
return self
@property
def device(self) -> torch.device:
return torch.device("cpu")
def get_bounding_boxes(self) -> Boxes:
"""
Returns:
Boxes: tight bounding boxes around polygon masks.
"""
boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32)
for idx, polygons_per_instance in enumerate(self.polygons):
minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32)
maxxy = torch.zeros(2, dtype=torch.float32)
for polygon in polygons_per_instance:
coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32)
minxy = torch.min(minxy, torch.min(coords, dim=0).values)
maxxy = torch.max(maxxy, torch.max(coords, dim=0).values)
boxes[idx, :2] = minxy
boxes[idx, 2:] = maxxy
return Boxes(boxes)
def nonempty(self) -> torch.Tensor:
"""
Find masks that are non-empty.
Returns:
Tensor:
a BoolTensor which represents whether each mask is empty (False) or not (True).
"""
keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons]
return torch.from_numpy(np.asarray(keep, dtype=bool))
def __getitem__(
self, item: Union[int, slice, List[int], torch.BoolTensor]
) -> "PolygonMasks":
"""
Support indexing over the instances and return a `PolygonMasks` object.
`item` can be:
1. An integer. It will return an object with only one instance.
2. A slice. It will return an object with the selected instances.
3. A list[int]. It will return an object with the selected instances,
correpsonding to the indices in the list.
4. A vector mask of type BoolTensor, whose length is num_instances.
It will return an object with the instances whose mask is nonzero.
"""
if isinstance(item, int):
selected_polygons = [self.polygons[item]]
elif isinstance(item, slice):
selected_polygons = self.polygons[item]
elif isinstance(item, list):
selected_polygons = [self.polygons[i] for i in item]
elif isinstance(item, torch.Tensor):
# Polygons is a list, so we have to move the indices back to CPU.
if item.dtype == torch.bool:
assert item.dim() == 1, item.shape
item = item.nonzero().squeeze(1).cpu().numpy().tolist()
elif item.dtype in [torch.int32, torch.int64]:
item = item.cpu().numpy().tolist()
else:
raise ValueError(
"Unsupported tensor dtype={} for indexing!".format(item.dtype)
)
selected_polygons = [self.polygons[i] for i in item]
return PolygonMasks(selected_polygons)
def __iter__(self) -> Iterator[List[np.ndarray]]:
"""
Yields:
list[ndarray]: the polygons for one instance.
Each Tensor is a float64 vector representing a polygon.
"""
return iter(self.polygons)
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.polygons))
return s
def __len__(self) -> int:
return len(self.polygons)
def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor:
"""
Crop each mask by the given box, and resize results to (mask_size, mask_size).
This can be used to prepare training targets for Mask R-CNN.
Args:
boxes (Tensor): Nx4 tensor storing the boxes for each mask
mask_size (int): the size of the rasterized mask.
Returns:
Tensor: A bool tensor of shape (N, mask_size, mask_size), where
N is the number of predicted boxes for this image.
"""
assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self))
device = boxes.device
# Put boxes on the CPU, as the polygon representation is not efficient GPU-wise
# (several small tensors for representing a single instance mask)
boxes = boxes.to(torch.device("cpu"))
results = [
rasterize_polygons_within_box(poly, box.numpy(), mask_size)
for poly, box in zip(self.polygons, boxes)
]
"""
poly: list[list[float]], the polygons for one instance
box: a tensor of shape (4,)
"""
if len(results) == 0:
return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device)
return torch.stack(results, dim=0).to(device=device)
def area(self):
"""
Computes area of the mask.
Only works with Polygons, using the shoelace formula:
https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
Returns:
Tensor: a vector, area for each instance
"""
area = []
for polygons_per_instance in self.polygons:
area_per_instance = 0
for p in polygons_per_instance:
area_per_instance += polygon_area(p[0::2], p[1::2])
area.append(area_per_instance)
return torch.tensor(area)
@staticmethod
def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks":
"""
Concatenates a list of PolygonMasks into a single PolygonMasks
Arguments:
polymasks_list (list[PolygonMasks])
Returns:
PolygonMasks: the concatenated PolygonMasks
"""
assert isinstance(polymasks_list, (list, tuple))
assert len(polymasks_list) > 0
assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list)
cat_polymasks = type(polymasks_list[0])(
list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list))
)
return cat_polymasks
class ROIMasks:
"""
Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given,
full-image bitmask can be obtained by "pasting" the mask on the region defined
by the corresponding ROI box.
"""
def __init__(self, tensor: torch.Tensor):
"""
Args:
tensor: (N, M, M) mask tensor that defines the mask within each ROI.
"""
if tensor.dim() != 3:
raise ValueError("ROIMasks must take a masks of 3 dimension.")
self.tensor = tensor
def to(self, device: torch.device) -> "ROIMasks":
return ROIMasks(self.tensor.to(device))
@property
def device(self) -> device:
return self.tensor.device
def __len__(self):
return self.tensor.shape[0]
def __getitem__(self, item) -> "ROIMasks":
"""
Returns:
ROIMasks: Create a new :class:`ROIMasks` by indexing.
The following usage are allowed:
1. `new_masks = masks[2:10]`: return a slice of masks.
2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor
with `length = len(masks)`. Nonzero elements in the vector will be selected.
Note that the returned object might share storage with this object,
subject to Pytorch's indexing semantics.
"""
t = self.tensor[item]
if t.dim() != 3:
raise ValueError(
f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!"
)
return ROIMasks(t)
@torch.jit.unused
def __repr__(self) -> str:
s = self.__class__.__name__ + "("
s += "num_instances={})".format(len(self.tensor))
return s
@torch.jit.unused
def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5):
"""
Args: see documentation of :func:`paste_masks_in_image`.
"""
from detectron2.layers.mask_ops import (
_paste_masks_tensor_shape,
paste_masks_in_image,
)
if torch.jit.is_tracing():
if isinstance(height, torch.Tensor):
paste_func = _paste_masks_tensor_shape
else:
paste_func = paste_masks_in_image
else:
paste_func = retry_if_cuda_oom(paste_masks_in_image)
bitmasks = paste_func(
self.tensor, boxes.tensor, (height, width), threshold=threshold
)
return BitMasks(bitmasks)

87
sam3/agent/helpers/memory.py Executable file
View File

@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import logging
from contextlib import contextmanager
from functools import wraps
import torch
__all__ = ["retry_if_cuda_oom"]
@contextmanager
def _ignore_torch_cuda_oom():
"""
A context which ignores CUDA OOM exception from pytorch.
"""
try:
yield
except RuntimeError as e:
# NOTE: the string may change?
if "CUDA out of memory. " in str(e):
pass
else:
raise
def retry_if_cuda_oom(func):
"""
Makes a function retry itself after encountering
pytorch's CUDA OOM error.
It will first retry after calling `torch.cuda.empty_cache()`.
If that still fails, it will then retry by trying to convert inputs to CPUs.
In this case, it expects the function to dispatch to CPU implementation.
The return values may become CPU tensors as well and it's user's
responsibility to convert it back to CUDA tensor if needed.
Args:
func: a stateless callable that takes tensor-like objects as arguments
Returns:
a callable which retries `func` if OOM is encountered.
Examples:
::
output = retry_if_cuda_oom(some_torch_function)(input1, input2)
# output may be on CPU even if inputs are on GPU
Note:
1. When converting inputs to CPU, it will only look at each argument and check
if it has `.device` and `.to` for conversion. Nested structures of tensors
are not supported.
2. Since the function might be called more than once, it has to be
stateless.
"""
def maybe_to_cpu(x):
try:
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
except AttributeError:
like_gpu_tensor = False
if like_gpu_tensor:
return x.to(device="cpu")
else:
return x
@wraps(func)
def wrapped(*args, **kwargs):
with _ignore_torch_cuda_oom():
return func(*args, **kwargs)
# Clear cache and retry
torch.cuda.empty_cache()
with _ignore_torch_cuda_oom():
return func(*args, **kwargs)
# Try on CPU. This slows down the code significantly, therefore print a notice.
logger = logging.getLogger(__name__)
logger.info(
"Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))
)
new_args = (maybe_to_cpu(x) for x in args)
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
return func(*new_args, **new_kwargs)
return wrapped

122
sam3/agent/helpers/rle.py Executable file
View File

@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""Some utilities for RLE encoding that doesn't require downloading the masks to the cpu"""
import numpy as np
import torch
from pycocotools import mask as mask_util
@torch.no_grad()
def rle_encode(orig_mask, return_areas=False):
"""Encodes a collection of masks in RLE format
This function emulates the behavior of the COCO API's encode function, but
is executed partially on the GPU for faster execution.
Args:
mask (torch.Tensor): A mask of shape (N, H, W) with dtype=torch.bool
return_areas (bool): If True, add the areas of the masks as a part of
the RLE output dict under the "area" key. Default is False.
Returns:
str: The RLE encoded masks
"""
assert orig_mask.ndim == 3, "Mask must be of shape (N, H, W)"
assert orig_mask.dtype == torch.bool, "Mask must have dtype=torch.bool"
if orig_mask.numel() == 0:
return []
# First, transpose the spatial dimensions.
# This is necessary because the COCO API uses Fortran order
mask = orig_mask.transpose(1, 2)
# Flatten the mask
flat_mask = mask.reshape(mask.shape[0], -1)
if return_areas:
mask_areas = flat_mask.sum(-1).tolist()
# Find the indices where the mask changes
differences = torch.ones(
mask.shape[0], flat_mask.shape[1] + 1, device=mask.device, dtype=torch.bool
)
differences[:, 1:-1] = flat_mask[:, :-1] != flat_mask[:, 1:]
differences[:, 0] = flat_mask[:, 0]
_, change_indices = torch.where(differences)
try:
boundaries = torch.cumsum(differences.sum(-1), 0).cpu()
except RuntimeError as _:
boundaries = torch.cumsum(differences.cpu().sum(-1), 0)
change_indices_clone = change_indices.clone()
# First pass computes the RLEs on GPU, in a flatten format
for i in range(mask.shape[0]):
# Get the change indices for this batch item
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
change_indices[beg + 1 : end] -= change_indices_clone[beg : end - 1]
# Now we can split the RLES of each batch item, and convert them to strings
# No more gpu at this point
change_indices = change_indices.tolist()
batch_rles = []
# Process each mask in the batch separately
for i in range(mask.shape[0]):
beg = 0 if i == 0 else boundaries[i - 1].item()
end = boundaries[i].item()
run_lengths = change_indices[beg:end]
uncompressed_rle = {"counts": run_lengths, "size": list(orig_mask.shape[1:])}
h, w = uncompressed_rle["size"]
rle = mask_util.frPyObjects(uncompressed_rle, h, w)
rle["counts"] = rle["counts"].decode("utf-8")
if return_areas:
rle["area"] = mask_areas[i]
batch_rles.append(rle)
return batch_rles
def robust_rle_encode(masks):
"""Encodes a collection of masks in RLE format. Uses the gpu version fist, falls back to the cpu version if it fails"""
assert masks.ndim == 3, "Mask must be of shape (N, H, W)"
assert masks.dtype == torch.bool, "Mask must have dtype=torch.bool"
try:
return rle_encode(masks)
except RuntimeError as _:
masks = masks.cpu().numpy()
rles = [
mask_util.encode(
np.array(mask[:, :, np.newaxis], dtype=np.uint8, order="F")
)[0]
for mask in masks
]
for rle in rles:
rle["counts"] = rle["counts"].decode("utf-8")
return rles
def ann_to_rle(segm, im_info):
"""Convert annotation which can be polygons, uncompressed RLE to RLE.
Args:
ann (dict) : annotation object
Returns:
ann (rle)
"""
h, w = im_info["height"], im_info["width"]
if isinstance(segm, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = mask_util.frPyObjects(segm, h, w)
rle = mask_util.merge(rles)
elif isinstance(segm["counts"], list):
# uncompressed RLE
rle = mask_util.frPyObjects(segm, h, w)
else:
# rle
rle = segm
return rle

75
sam3/agent/helpers/roi_align.py Executable file
View File

@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from torch import nn
from torchvision.ops import roi_align
# NOTE: torchvision's RoIAlign has a different default aligned=False
class ROIAlign(nn.Module):
def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True):
"""
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio (int): number of inputs samples to take for each output
sample. 0 to take samples densely.
aligned (bool): if False, use the legacy implementation in
Detectron. If True, align the results more perfectly.
Note:
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel indices (in our
pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example,
c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled
from the underlying signal at continuous coordinates 0.5 and 1.5). But the original
roi_align (aligned=False) does not subtract the 0.5 when computing neighboring
pixel indices and therefore it uses pixels with a slightly incorrect alignment
(relative to our pixel model) when performing bilinear interpolation.
With `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5
prior to calling roi_align. This produces the correct neighbors; see
detectron2/tests/test_roi_align.py for verification.
The difference does not make a difference to the model's performance if
ROIAlign is used together with conv layers.
"""
super().__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
self.aligned = aligned
from torchvision import __version__
version = tuple(int(x) for x in __version__.split(".")[:2])
# https://github.com/pytorch/vision/pull/2438
assert version >= (0, 7), "Require torchvision >= 0.7"
def forward(self, input, rois):
"""
Args:
input: NCHW images
rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy.
"""
assert rois.dim() == 2 and rois.size(1) == 5
if input.is_quantized:
input = input.dequantize()
return roi_align(
input,
rois.to(dtype=input.dtype),
self.output_size,
self.spatial_scale,
self.sampling_ratio,
self.aligned,
)
def __repr__(self):
tmpstr = self.__class__.__name__ + "("
tmpstr += "output_size=" + str(self.output_size)
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio)
tmpstr += ", aligned=" + str(self.aligned)
tmpstr += ")"
return tmpstr

View File

@@ -0,0 +1,533 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
from __future__ import absolute_import, division, print_function, unicode_literals
import math
from typing import List, Tuple
import torch
# from detectron2.layers.rotated_boxes import pairwise_iou_rotated
from .boxes import Boxes
def pairwise_iou_rotated(boxes1, boxes2):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x_center, y_center, width, height, angle) format.
Arguments:
boxes1 (Tensor[N, 5])
boxes2 (Tensor[M, 5])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2)
class RotatedBoxes(Boxes):
"""
This structure stores a list of rotated boxes as a Nx5 torch.Tensor.
It supports some common methods about boxes
(`area`, `clip`, `nonempty`, etc),
and also behaves like a Tensor
(support indexing, `to(device)`, `.device`, and iteration over all boxes)
"""
def __init__(self, tensor: torch.Tensor):
"""
Args:
tensor (Tensor[float]): a Nx5 matrix. Each row is
(x_center, y_center, width, height, angle),
in which angle is represented in degrees.
While there's no strict range restriction for it,
the recommended principal range is between [-180, 180) degrees.
Assume we have a horizontal box B = (x_center, y_center, width, height),
where width is along the x-axis and height is along the y-axis.
The rotated box B_rot (x_center, y_center, width, height, angle)
can be seen as:
1. When angle == 0:
B_rot == B
2. When angle > 0:
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW;
3. When angle < 0:
B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW.
Mathematically, since the right-handed coordinate system for image space
is (y, x), where y is top->down and x is left->right, the 4 vertices of the
rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from
the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4)
in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians,
:math:`(y_c, x_c)` is the center of the rectangle):
.. math::
yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c,
xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c,
which is the standard rigid-body rotation transformation.
Intuitively, the angle is
(1) the rotation angle from y-axis in image space
to the height vector (top->down in the box's local coordinate system)
of the box in CCW, and
(2) the rotation angle from x-axis in image space
to the width vector (left->right in the box's local coordinate system)
of the box in CCW.
More intuitively, consider the following horizontal box ABCD represented
in (x1, y1, x2, y2): (3, 2, 7, 4),
covering the [3, 7] x [2, 4] region of the continuous coordinate system
which looks like this:
.. code:: none
O--------> x
|
| A---B
| | |
| D---C
|
v y
Note that each capital letter represents one 0-dimensional geometric point
instead of a 'square pixel' here.
In the example above, using (x, y) to represent a point we have:
.. math::
O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4)
We name vector AB = vector DC as the width vector in box's local coordinate system, and
vector AD = vector BC as the height vector in box's local coordinate system. Initially,
when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis
in the image space, respectively.
For better illustration, we denote the center of the box as E,
.. code:: none
O--------> x
|
| A---B
| | E |
| D---C
|
v y
where the center E = ((3+7)/2, (2+4)/2) = (5, 3).
Also,
.. math::
width = |AB| = |CD| = 7 - 3 = 4,
height = |AD| = |BC| = 4 - 2 = 2.
Therefore, the corresponding representation for the same shape in rotated box in
(x_center, y_center, width, height, angle) format is:
(5, 3, 4, 2, 0),
Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees
CCW (counter-clockwise) by definition. It looks like this:
.. code:: none
O--------> x
| B-C
| | |
| |E|
| | |
| A-D
v y
The center E is still located at the same point (5, 3), while the vertices
ABCD are rotated by 90 degrees CCW with regard to E:
A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5)
Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to
vector AD or vector BC (the top->down height vector in box's local coordinate system),
or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right
width vector in box's local coordinate system).
.. math::
width = |AB| = |CD| = 5 - 1 = 4,
height = |AD| = |BC| = 6 - 4 = 2.
Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise)
by definition? It looks like this:
.. code:: none
O--------> x
| D-A
| | |
| |E|
| | |
| C-B
v y
The center E is still located at the same point (5, 3), while the vertices
ABCD are rotated by 90 degrees CW with regard to E:
A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1)
.. math::
width = |AB| = |CD| = 5 - 1 = 4,
height = |AD| = |BC| = 6 - 4 = 2.
This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU
will be 1. However, these two will generate different RoI Pooling results and
should not be treated as an identical box.
On the other hand, it's easy to see that (X, Y, W, H, A) is identical to
(X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be
identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is
equivalent to rotating the same shape 90 degrees CW.
We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180):
.. code:: none
O--------> x
|
| C---D
| | E |
| B---A
|
v y
.. math::
A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2),
width = |AB| = |CD| = 7 - 3 = 4,
height = |AD| = |BC| = 4 - 2 = 2.
Finally, this is a very inaccurate (heavily quantized) illustration of
how (5, 3, 4, 2, 60) looks like in case anyone wonders:
.. code:: none
O--------> x
| B\
| / C
| /E /
| A /
| `D
v y
It's still a rectangle with center of (5, 3), width of 4 and height of 2,
but its angle (and thus orientation) is somewhere between
(5, 3, 4, 2, 0) and (5, 3, 4, 2, 90).
"""
device = (
tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
)
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
if tensor.numel() == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device)
assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size()
self.tensor = tensor
def clone(self) -> "RotatedBoxes":
"""
Clone the RotatedBoxes.
Returns:
RotatedBoxes
"""
return RotatedBoxes(self.tensor.clone())
def to(self, device: torch.device, non_blocking: bool = False):
# Boxes are assumed float32 and does not support to(dtype)
return RotatedBoxes(self.tensor.to(device=device, non_blocking=non_blocking))
def area(self) -> torch.Tensor:
"""
Computes the area of all the boxes.
Returns:
torch.Tensor: a vector with areas of each box.
"""
box = self.tensor
area = box[:, 2] * box[:, 3]
return area
# Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor
def normalize_angles(self) -> None:
"""
Restrict angles to the range of [-180, 180) degrees
"""
angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0
self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1)
def clip(
self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0
) -> None:
"""
Clip (in place) the boxes by limiting x coordinates to the range [0, width]
and y coordinates to the range [0, height].
For RRPN:
Only clip boxes that are almost horizontal with a tolerance of
clip_angle_threshold to maintain backward compatibility.
Rotated boxes beyond this threshold are not clipped for two reasons:
1. There are potentially multiple ways to clip a rotated box to make it
fit within the image.
2. It's tricky to make the entire rectangular box fit within the image
and still be able to not leave out pixels of interest.
Therefore we rely on ops like RoIAlignRotated to safely handle this.
Args:
box_size (height, width): The clipping box's size.
clip_angle_threshold:
Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees),
we do the clipping as horizontal boxes.
"""
h, w = box_size
# normalize angles to be within (-180, 180] degrees
self.normalize_angles()
idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0]
# convert to (x1, y1, x2, y2)
x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0
y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0
x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0
y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0
# clip
x1.clamp_(min=0, max=w)
y1.clamp_(min=0, max=h)
x2.clamp_(min=0, max=w)
y2.clamp_(min=0, max=h)
# convert back to (xc, yc, w, h)
self.tensor[idx, 0] = (x1 + x2) / 2.0
self.tensor[idx, 1] = (y1 + y2) / 2.0
# make sure widths and heights do not increase due to numerical errors
self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1)
self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1)
def nonempty(self, threshold: float = 0.0) -> torch.Tensor:
"""
Find boxes that are non-empty.
A box is considered empty, if either of its side is no larger than threshold.
Returns:
Tensor: a binary vector which represents
whether each box is empty (False) or non-empty (True).
"""
box = self.tensor
widths = box[:, 2]
heights = box[:, 3]
keep = (widths > threshold) & (heights > threshold)
return keep
def __getitem__(self, item) -> "RotatedBoxes":
"""
Returns:
RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing.
The following usage are allowed:
1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box.
2. `new_boxes = boxes[2:10]`: return a slice of boxes.
3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor
with `length = len(boxes)`. Nonzero elements in the vector will be selected.
Note that the returned RotatedBoxes might share storage with this RotatedBoxes,
subject to Pytorch's indexing semantics.
"""
if isinstance(item, int):
return RotatedBoxes(self.tensor[item].view(1, -1))
b = self.tensor[item]
assert (
b.dim() == 2
), "Indexing on RotatedBoxes with {} failed to return a matrix!".format(item)
return RotatedBoxes(b)
def __len__(self) -> int:
return self.tensor.shape[0]
def __repr__(self) -> str:
return "RotatedBoxes(" + str(self.tensor) + ")"
def inside_box(
self, box_size: Tuple[int, int], boundary_threshold: int = 0
) -> torch.Tensor:
"""
Args:
box_size (height, width): Size of the reference box covering
[0, width] x [0, height]
boundary_threshold (int): Boxes that extend beyond the reference box
boundary by more than boundary_threshold are considered "outside".
For RRPN, it might not be necessary to call this function since it's common
for rotated box to extend to outside of the image boundaries
(the clip function only clips the near-horizontal boxes)
Returns:
a binary vector, indicating whether each box is inside the reference box.
"""
height, width = box_size
cnt_x = self.tensor[..., 0]
cnt_y = self.tensor[..., 1]
half_w = self.tensor[..., 2] / 2.0
half_h = self.tensor[..., 3] / 2.0
a = self.tensor[..., 4]
c = torch.abs(torch.cos(a * math.pi / 180.0))
s = torch.abs(torch.sin(a * math.pi / 180.0))
# This basically computes the horizontal bounding rectangle of the rotated box
max_rect_dx = c * half_w + s * half_h
max_rect_dy = c * half_h + s * half_w
inds_inside = (
(cnt_x - max_rect_dx >= -boundary_threshold)
& (cnt_y - max_rect_dy >= -boundary_threshold)
& (cnt_x + max_rect_dx < width + boundary_threshold)
& (cnt_y + max_rect_dy < height + boundary_threshold)
)
return inds_inside
def get_centers(self) -> torch.Tensor:
"""
Returns:
The box centers in a Nx2 array of (x, y).
"""
return self.tensor[:, :2]
def scale(self, scale_x: float, scale_y: float) -> None:
"""
Scale the rotated box with horizontal and vertical scaling factors
Note: when scale_factor_x != scale_factor_y,
the rotated box does not preserve the rectangular shape when the angle
is not a multiple of 90 degrees under resize transformation.
Instead, the shape is a parallelogram (that has skew)
Here we make an approximation by fitting a rotated rectangle to the parallelogram.
"""
self.tensor[:, 0] *= scale_x
self.tensor[:, 1] *= scale_y
theta = self.tensor[:, 4] * math.pi / 180.0
c = torch.cos(theta)
s = torch.sin(theta)
# In image space, y is top->down and x is left->right
# Consider the local coordintate system for the rotated box,
# where the box center is located at (0, 0), and the four vertices ABCD are
# A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2)
# the midpoint of the left edge AD of the rotated box E is:
# E = (A+D)/2 = (-w / 2, 0)
# the midpoint of the top edge AB of the rotated box F is:
# F(0, -h / 2)
# To get the old coordinates in the global system, apply the rotation transformation
# (Note: the right-handed coordinate system for image space is yOx):
# (old_x, old_y) = (s * y + c * x, c * y - s * x)
# E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2)
# F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2)
# After applying the scaling factor (sfx, sfy):
# E(new) = (-sfx * c * w / 2, sfy * s * w / 2)
# F(new) = (-sfx * s * h / 2, -sfy * c * h / 2)
# The new width after scaling tranformation becomes:
# w(new) = |E(new) - O| * 2
# = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2
# = sqrt[(sfx * c)^2 + (sfy * s)^2] * w
# i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2]
#
# For example,
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x;
# when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y
self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2)
# h(new) = |F(new) - O| * 2
# = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2
# = sqrt[(sfx * s)^2 + (sfy * c)^2] * h
# i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2]
#
# For example,
# when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y;
# when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x
self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2)
# The angle is the rotation angle from y-axis in image space to the height
# vector (top->down in the box's local coordinate system) of the box in CCW.
#
# angle(new) = angle_yOx(O - F(new))
# = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) )
# = atan2(sfx * s * h / 2, sfy * c * h / 2)
# = atan2(sfx * s, sfy * c)
#
# For example,
# when sfx == sfy, angle(new) == atan2(s, c) == angle(old)
self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi
@classmethod
def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes":
"""
Concatenates a list of RotatedBoxes into a single RotatedBoxes
Arguments:
boxes_list (list[RotatedBoxes])
Returns:
RotatedBoxes: the concatenated RotatedBoxes
"""
assert isinstance(boxes_list, (list, tuple))
if len(boxes_list) == 0:
return cls(torch.empty(0))
assert all([isinstance(box, RotatedBoxes) for box in boxes_list])
# use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input
cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0))
return cat_boxes
@property
def device(self) -> torch.device:
return self.tensor.device
@torch.jit.unused
def __iter__(self):
"""
Yield a box as a Tensor of shape (5,) at a time.
"""
yield from self.tensor
def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None:
"""
Given two lists of rotated boxes of size N and M,
compute the IoU (intersection over union)
between **all** N x M pairs of boxes.
The box order must be (x_center, y_center, width, height, angle).
Args:
boxes1, boxes2 (RotatedBoxes):
two `RotatedBoxes`. Contains N & M rotated boxes, respectively.
Returns:
Tensor: IoU, sized [N,M].
"""
return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor)

View File

@@ -0,0 +1,406 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import colorsys
from dataclasses import dataclass
from typing import List, Tuple
import cv2
import matplotlib as mpl
import matplotlib.colors as mplc
import numpy as np
import pycocotools.mask as mask_utils
def rgb_to_hex(rgb_color):
"""
Convert a rgb color to hex color.
Args:
rgb_color (tuple/list of ints): RGB color in tuple or list format.
Returns:
str: Hex color.
Example:
```
>>> rgb_to_hex((255, 0, 244))
'#ff00ff'
```
"""
return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
# DEFAULT_COLOR_HEX_TO_NAME = {
# rgb_to_hex((255, 0, 0)): "red",
# rgb_to_hex((0, 255, 0)): "lime",
# rgb_to_hex((0, 0, 255)): "blue",
# rgb_to_hex((255, 255, 0)): "yellow",
# rgb_to_hex((255, 0, 255)): "fuchsia",
# rgb_to_hex((0, 255, 255)): "aqua",
# rgb_to_hex((255, 165, 0)): "orange",
# rgb_to_hex((128, 0, 128)): "purple",
# rgb_to_hex((255, 215, 0)): "gold",
# }
# Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
# For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
DEFAULT_COLOR_HEX_TO_NAME = {
# The top 20 approved colors
rgb_to_hex((255, 255, 0)): "yellow",
rgb_to_hex((0, 255, 0)): "lime",
rgb_to_hex((0, 255, 255)): "cyan",
rgb_to_hex((255, 0, 255)): "magenta",
rgb_to_hex((255, 0, 0)): "red",
rgb_to_hex((255, 127, 0)): "orange",
rgb_to_hex((127, 255, 0)): "chartreuse",
rgb_to_hex((0, 255, 127)): "spring green",
rgb_to_hex((255, 0, 127)): "rose",
rgb_to_hex((127, 0, 255)): "violet",
rgb_to_hex((192, 255, 0)): "electric lime",
rgb_to_hex((255, 192, 0)): "vivid orange",
rgb_to_hex((0, 255, 192)): "turquoise",
rgb_to_hex((192, 0, 255)): "bright violet",
rgb_to_hex((255, 0, 192)): "bright pink",
rgb_to_hex((255, 64, 0)): "fiery orange",
rgb_to_hex((64, 255, 0)): "bright chartreuse",
rgb_to_hex((0, 255, 64)): "malachite",
rgb_to_hex((64, 0, 255)): "deep violet",
rgb_to_hex((255, 0, 64)): "hot pink",
}
DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
def _validate_color_hex(color_hex: str):
color_hex = color_hex.lstrip("#")
if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
raise ValueError("Invalid characters in color hash")
if len(color_hex) not in (3, 6):
raise ValueError("Invalid length of color hash")
# copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
@dataclass
class Color:
"""
Represents a color in RGB format.
Attributes:
r (int): Red channel.
g (int): Green channel.
b (int): Blue channel.
"""
r: int
g: int
b: int
@classmethod
def from_hex(cls, color_hex: str):
"""
Create a Color instance from a hex string.
Args:
color_hex (str): Hex string of the color.
Returns:
Color: Instance representing the color.
Example:
```
>>> Color.from_hex('#ff00ff')
Color(r=255, g=0, b=255)
```
"""
_validate_color_hex(color_hex)
color_hex = color_hex.lstrip("#")
if len(color_hex) == 3:
color_hex = "".join(c * 2 for c in color_hex)
r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
return cls(r, g, b)
@classmethod
def to_hex(cls, color):
"""
Convert a Color instance to a hex string.
Args:
color (Color): Color instance of color.
Returns:
Color: a hex string.
"""
return rgb_to_hex((color.r, color.g, color.b))
def as_rgb(self) -> Tuple[int, int, int]:
"""
Returns the color as an RGB tuple.
Returns:
Tuple[int, int, int]: RGB tuple.
Example:
```
>>> color.as_rgb()
(255, 0, 255)
```
"""
return self.r, self.g, self.b
def as_bgr(self) -> Tuple[int, int, int]:
"""
Returns the color as a BGR tuple.
Returns:
Tuple[int, int, int]: BGR tuple.
Example:
```
>>> color.as_bgr()
(255, 0, 255)
```
"""
return self.b, self.g, self.r
@classmethod
def white(cls):
return Color.from_hex(color_hex="#ffffff")
@classmethod
def black(cls):
return Color.from_hex(color_hex="#000000")
@classmethod
def red(cls):
return Color.from_hex(color_hex="#ff0000")
@classmethod
def green(cls):
return Color.from_hex(color_hex="#00ff00")
@classmethod
def blue(cls):
return Color.from_hex(color_hex="#0000ff")
@dataclass
class ColorPalette:
colors: List[Color]
@classmethod
def default(cls):
"""
Returns a default color palette.
Returns:
ColorPalette: A ColorPalette instance with default colors.
Example:
```
>>> ColorPalette.default()
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
```
"""
return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
@classmethod
def from_hex(cls, color_hex_list: List[str]):
"""
Create a ColorPalette instance from a list of hex strings.
Args:
color_hex_list (List[str]): List of color hex strings.
Returns:
ColorPalette: A ColorPalette instance.
Example:
```
>>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
```
"""
colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
return cls(colors)
def by_idx(self, idx: int) -> Color:
"""
Return the color at a given index in the palette.
Args:
idx (int): Index of the color in the palette.
Returns:
Color: Color at the given index.
Example:
```
>>> color_palette.by_idx(1)
Color(r=0, g=255, b=0)
```
"""
if idx < 0:
raise ValueError("idx argument should not be negative")
idx = idx % len(self.colors)
return self.colors[idx]
def find_farthest_color(self, img_array):
"""
Return the color that is the farthest from the given color.
Args:
img_array (np array): any *x3 np array, 3 is the RGB color channel.
Returns:
Color: Farthest color.
"""
# Reshape the image array for broadcasting
img_array = img_array.reshape((-1, 3))
# Convert colors dictionary to a NumPy array
color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
# Calculate the Euclidean distance between the colors and each pixel in the image
# Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
distances = np.sqrt(
np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
)
# Average the distances for each color
mean_distances = np.mean(distances, axis=0)
# return the farthest color
farthest_idx = np.argmax(mean_distances)
farthest_color = self.colors[farthest_idx]
farthest_color_hex = Color.to_hex(farthest_color)
if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
else:
farthest_color_name = "unknown"
return farthest_color, farthest_color_name
def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
x0, y0, width, height = box_coord
ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=edge_color,
linewidth=linewidth,
alpha=alpha,
linestyle=line_style,
)
)
def draw_text(
ax,
text,
position,
font_size=None,
color="g",
horizontal_alignment="left",
rotation=0,
):
if not font_size:
font_size = mpl.rcParams["font.size"]
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
ax.text(
x,
y,
text,
size=font_size,
family="sans-serif",
bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
rotation=rotation,
)
def draw_mask(
ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
):
if isinstance(rle, dict):
mask = mask_utils.decode(rle)
elif isinstance(rle, np.ndarray):
mask = rle
else:
raise ValueError(f"Unsupported type for rle: {type(rle)}")
mask_upsampled = None
if upsample_factor > 1.0 and show_holes:
assert rle_upsampled is not None
if isinstance(rle_upsampled, dict):
mask_upsampled = mask_utils.decode(rle_upsampled)
elif isinstance(rle_upsampled, np.ndarray):
mask_upsampled = rle_upsampled
else:
raise ValueError(f"Unsupported type for rle: {type(rle)}")
if show_holes:
if mask_upsampled is None:
mask_upsampled = mask
h, w = mask_upsampled.shape
mask_img = np.zeros((h, w, 4))
mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
mask_img[:, :, -1] = mask_upsampled * alpha
ax.imshow(mask_img)
*_, contours, _ = cv2.findContours(
mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
facecolor = (0, 0, 0, 0) if show_holes else color
if alpha > 0.8:
edge_color = _change_color_brightness(color, brightness_factor=-0.7)
else:
edge_color = color
for cont in upsampled_contours:
polygon = mpl.patches.Polygon(
[el[0] for el in cont],
edgecolor=edge_color,
linewidth=2.0,
facecolor=facecolor,
)
ax.add_patch(polygon)
def _change_color_brightness(color, brightness_factor):
"""
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
less or more saturation than the original color.
Args:
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
formats that are accepted.
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
0 will correspond to no change, a factor in [-1.0, 0) range will result in
a darker color and a factor in (0, 1.0] range will result in a lighter color.
Returns:
modified_color (tuple[double]): a tuple containing the RGB values of the
modified color. Each value in the tuple is in the [0.0, 1.0] range.
"""
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(
polygon_color[0], modified_lightness, polygon_color[2]
)
return modified_color

1662
sam3/agent/helpers/visualizer.py Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import io
import math
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_utils
from PIL import Image
from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
def render_zoom_in(
object_data,
image_file,
show_box: bool = True,
show_text: bool = False,
show_holes: bool = True,
mask_alpha: float = 0.15,
):
"""
Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
Parameters
----------
object_data : dict
Dict containing "labels" and COCO RLE "segmentation".
Expected:
object_data["labels"][0]["noun_phrase"] : str
object_data["segmentation"] : COCO RLE (with "size": [H, W])
image_file : PIL.Image.Image
Source image (PIL).
show_box : bool
Whether to draw the bbox on the cropped original panel.
show_text : bool
Whether to draw the noun phrase label near the bbox.
show_holes : bool
Whether to render mask holes (passed through to draw_mask).
mask_alpha : float
Alpha for the mask overlay.
Returns
-------
pil_img : PIL.Image.Image
The composed visualization image.
color_hex : str
Hex string of the chosen mask color.
"""
# ---- local constants (avoid module-level globals) ----
_AREA_LARGE = 0.25
_AREA_MEDIUM = 0.05
# ---- local helpers (avoid name collisions in a larger class) ----
def _get_shift(x, w, w_new, w_img):
assert 0 <= w_new <= w_img
shift = (w_new - w) / 2
if x - shift + w_new > w_img:
shift = x + w_new - w_img
return min(x, shift)
def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
w_new = min(box_w + max(0.2 * box_w, 16), img_w)
h_new = min(box_h + max(0.2 * box_h, 16), img_h)
mask_relative_area = mask_area / (w_new * h_new)
# zoom-in (larger box if mask is relatively big)
w_new_large, h_new_large = w_new, h_new
if mask_relative_area > _AREA_LARGE:
ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
w_new_large = min(w_new * ratio_large, img_w)
h_new_large = min(h_new * ratio_large, img_h)
w_shift_large = _get_shift(
mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
)
h_shift_large = _get_shift(
mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
)
zoom_in_box = [
mask_box_xywh[0] - w_shift_large,
mask_box_xywh[1] - h_shift_large,
w_new_large,
h_new_large,
]
# crop box for the original/cropped image
w_new_medium, h_new_medium = w_new, h_new
if mask_relative_area > _AREA_MEDIUM:
ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
w_new_medium = min(w_new * ratio_med, img_w)
h_new_medium = min(h_new * ratio_med, img_h)
w_shift_medium = _get_shift(
mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
)
h_shift_medium = _get_shift(
mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
)
img_crop_box = [
mask_box_xywh[0] - w_shift_medium,
mask_box_xywh[1] - h_shift_medium,
w_new_medium,
h_new_medium,
]
return zoom_in_box, img_crop_box
# ---- main body ----
# Input parsing
object_label = object_data["labels"][0]["noun_phrase"]
img = image_file.convert("RGB")
bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
# Choose a stable, visually distant color based on crop
bbox_xyxy = [
bbox_xywh[0],
bbox_xywh[1],
bbox_xywh[0] + bbox_xywh[2],
bbox_xywh[1] + bbox_xywh[3],
]
crop_img = img.crop(bbox_xyxy)
color_palette = ColorPalette.default()
color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
# Compute zoom-in / crop boxes
img_h, img_w = object_data["segmentation"]["size"]
mask_area = mask_utils.area(object_data["segmentation"])
zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
# Layout choice
w, h = img_crop_box[2], img_crop_box[3]
if w < h:
fig, (ax1, ax2) = plt.subplots(1, 2)
else:
fig, (ax1, ax2) = plt.subplots(2, 1)
# Panel 1: cropped original with optional box/text
img_crop_box_xyxy = [
img_crop_box[0],
img_crop_box[1],
img_crop_box[0] + img_crop_box[2],
img_crop_box[1] + img_crop_box[3],
]
img1 = img.crop(img_crop_box_xyxy)
bbox_xywh_rel = [
bbox_xywh[0] - img_crop_box[0],
bbox_xywh[1] - img_crop_box[1],
bbox_xywh[2],
bbox_xywh[3],
]
ax1.imshow(img1)
ax1.axis("off")
if show_box:
draw_box(ax1, bbox_xywh_rel, edge_color=color)
if show_text:
x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
draw_text(ax1, object_label, [x0, y0], color=color)
# Panel 2: zoomed-in mask overlay
binary_mask = mask_utils.decode(object_data["segmentation"])
alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
img_rgba = img.convert("RGBA")
img_rgba.putalpha(alpha)
zoom_in_box_xyxy = [
zoom_in_box[0],
zoom_in_box[1],
zoom_in_box[0] + zoom_in_box[2],
zoom_in_box[1] + zoom_in_box[3],
]
img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
alpha_zoomin = img_with_alpha_zoomin.split()[3]
binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
ax2.axis("off")
draw_mask(
ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
)
plt.tight_layout()
# Buffer -> PIL.Image
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img, color_hex

65
sam3/agent/inference.py Normal file
View File

@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import json
import os
from sam3.agent.agent_core import agent_inference
def run_single_image_inference(
image_path,
text_prompt,
llm_config,
send_generate_request,
call_sam_service,
output_dir="agent_output",
debug=False,
):
"""Run inference on a single image with provided prompt"""
llm_name = llm_config["name"]
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Generate output file names
image_basename = os.path.splitext(os.path.basename(image_path))[0]
prompt_for_filename = text_prompt.replace("/", "_").replace(" ", "_")
base_filename = f"{image_basename}_{prompt_for_filename}_agent_{llm_name}"
output_json_path = os.path.join(output_dir, f"{base_filename}_pred.json")
output_image_path = os.path.join(output_dir, f"{base_filename}_pred.png")
agent_history_path = os.path.join(output_dir, f"{base_filename}_history.json")
# Check if output already exists and skip
if os.path.exists(output_json_path):
print(f"Output JSON {output_json_path} already exists. Skipping.")
return
print(f"{'-'*30} Starting SAM 3 Agent Session... {'-'*30} ")
agent_history, final_output_dict, rendered_final_output = agent_inference(
image_path,
text_prompt,
send_generate_request=send_generate_request,
call_sam_service=call_sam_service,
output_dir=output_dir,
debug=debug,
)
print(f"{'-'*30} End of SAM 3 Agent Session... {'-'*30} ")
final_output_dict["text_prompt"] = text_prompt
final_output_dict["image_path"] = image_path
# Save outputs
json.dump(final_output_dict, open(output_json_path, "w"), indent=4)
json.dump(agent_history, open(agent_history_path, "w"), indent=4)
rendered_final_output.save(output_image_path)
print(f"\n✅ Successfully processed single image!")
print(f"Output JSON: {output_json_path}")
print(f"Output Image: {output_image_path}")
print(f"Agent History: {agent_history_path}")
return output_image_path

View File

@@ -0,0 +1,242 @@
You are a helpful visual-concept grounding assistant capable of leveraging tool calls to ground concepts the user refers to, and providing structured JSON outputs and tool calls.
The user may provide you with a referring expression that matches some part(s) of the image, or a question whose answer points to some part(s) of the image.
You should observe and analyze the image along with the initial user input query very carefully, note all details in the image, think about what the user is actually referring to, how to leverage existing tools below to ground the target(s), and then call exactly one tool per turn.
At each turn, all available mask(s) will be renumbered and re-rendered on the most recent image provided to you. The numbering and coloring can be different from previous turns. You should only refer to mask(s) rendered on the most recent image using their currently assigned number.
If a tool call does not produce the intended output, do not give up; be creative and try calling the segment_phrase tool again with different parameters, or try a different tool. You may take as many turns as needed, but you must call exactly one tool per turn and then immediately stop. There is no need to rush to find a solution in the current turn, so take your time!
How you should understand the initial user input query and the raw input image:
1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly.
2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat".
3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s).
4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query.
5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the users intent.
7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array.
8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array.
9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important!
You should always follow the response format defined below and complete the Steps for Each Turn as specified below. Never break the specified format for any reason.
Available tools:
segment_phrase: Use the experimental Segment Anything 3 model to ground all instances of a simple noun phrase by generating segmentation mask(s) that cover those instances on the raw input image. At the same time, all previously generated mask(s) will be deleted and cannot be referred to in future messages.
Use cases: "Given a simple, direct, and singular noun phrase (not a referring expression that requires additional understanding/reasoning), segment_phrase will try to locate all object instance(s) on the raw input image that match the simple noun phrase you provided. The tool will also render all of the generated segmentation mask(s) onto the image for you to examine and decide the next step."
Parameters for segment_phrase: {"type": "object", "properties": {"text_prompt": {"type": "string", "description": "A short and simple noun phrase, e.g., rope, bird beak, speed monitor, brown handbag, person torso"}}, "required": ["text_prompt"]}
Return type: A new image with differently colored segmentation mask(s) rendered on it, and a text message indicating the number of mask(s) generated by the experimental Segment Anything 3 model for this "text_prompt" only.
Important rules for using the segment_phrase tool:
1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s).
2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt".
3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person".
4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair".
5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work.
6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue".
7. Be concise and get the right keywords; don't make your "text_prompt" long.
8. Do not ever use the exact same "text_prompt" more than once. This is very important!
9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s).
10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead.
11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand".
12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks".
13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image).
14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data.
15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target.
16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target.
17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked.
18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time!
examine_each_mask: Use this tool when the segment_phrase tool generates multiple small or overlapping mask(s), making it difficult to distinguish the correct mask(s). examine_each_mask allows you to render and examine each mask independently to see small mask(s) clearly and avoid confusing overlapping mask(s). (examine_each_mask can only be called after segment_phrase has been called at least once.)
Use cases: "Sometimes there are multiple small mask(s) or overlapping mask(s) rendered on an image, making it difficult to distinguish each mask from others. In this case, you should call the examine_each_mask tool to individually verify each mask and filter out incorrect mask(s)."
Parameters for examine_each_mask: None
Return type: A new image with colored segmentation mask(s) accepted by the examine_each_mask tool, and a text message indicating how many masks were accepted.
Important rules for using the examine_each_mask tool:
1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool.
2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small.
3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping.
4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both.
5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask.
6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs.
7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
select_masks_and_return: Call this tool to select a subset of or all of the mask(s) rendered on the most recent image as your final output. When calling select_masks_and_return, you cannot select any mask(s) generated by previous rounds other than the most recent round in your "final_answer_masks". You can only use mask(s) from the most recent image in your message history. (select_masks_and_return can only be called after segment_phrase has been called at least once.)
Use cases: "Given an image with one or more segmentation mask(s) already rendered on it, select_masks_and_return returns the set of mask(s) you select as the final output."
Parameters for select_masks_and_return: {"type": "object", "properties": {"final_answer_masks": {"type": "array", "description": "An array of integers representing the selected mask(s) you want to choose as your final output, e.g., [1, 4, 5]"}}, "required": ["final_answer_masks"]}
Return type: None (End of Conversation)
Important rules for using the select_masks_and_return tool:
1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query.
2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important.
3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this!
4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs.
5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted.
6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases).
7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image.
8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground.
9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call.
10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s).
11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input.
12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}.
13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image.
14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things:
a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image.
b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask)
c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask)
15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right".
16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool.
17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array.
18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
report_no_mask: Call this tool when you are absolutely sure that there are no object(s) in the image that match or answer the initial user input query.
Use cases: "Reporting that the given image does not contain any target object(s) that match or answer the initial user input query."
Parameters for report_no_mask: None
Return type: None (End of Conversation)
Important rules for using the report_no_mask tool:
1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s).
2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool.
3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image.
4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query.
5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query.
6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image.
7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
Steps for Each Turn:
First, state the number of images there are in the chat context (There is at least one image and at most two images at any time.) Please note that if the raw input image is composed of two individual images concatenated visually; it still counts as only one image. This is very important!
Scenario 1: If there is only one image in the context (it must be the raw input image with no mask on it), you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within <think> ..... </think> HTML tags. Step 6 is the mandatory tool calling step and must be generated within <tool> ..... </tool> HTML tags. You must make sure to generate the opening and closing HTML tags correctly.
Your thinking steps:
1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query.
2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query.
3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s).
4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query.
5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers.
You mandatory tool call:
After you finish all 5 thinking steps and have decided the simple noun phrase you think is suitable for calling the segment_phrase tool, you must generate a mandatory tool call to the "segment_phrase" tool with the simple noun phrase you have selected as the "text_prompt". Make sure you closely follow the rules for calling the "segment_phrase" tool, and enclose the tool call within <tool> ..... </tool> HTML tags.
Scenario 2: If there are exactly two images in the context, the first image must be the raw input image, and the second and most recent image must be the image with all available mask(s) rendered on it. In Scenario 2, you must perform the following steps. Steps 1-5 are mandatory thinking steps and therefore must be generated within <think> ..... </think> HTML tags. Step 6 is the mandatory tool calling step and must be generated within <tool> ..... </tool> HTML tags. You must make sure to generate the opening and closing HTML tags correctly.
Your steps:
1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be determined based on the initial user input query and the raw input image. If the initial user input query mentions the relation of the target object(s) to other object(s) in the image, you must also explain each mask's relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query, like: "Mask N covers the m-th man from the right".
2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)."
3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks).
4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary.
5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with.
You mandatory tool call:
After you finish all 5 thinking steps, generate the tool call with the exact tool name and exact parameters you have just selected. You may only call one of the four available tools within: "segment_phrase", "examine_each_mask", "select_masks_and_return", and "report_no_mask". Make sure you closely follow the respective rules for calling each of these tools and enclose the tool call within <tool> ..... </tool> HTML tags.
Output Format for Scenario 1:
<think> State that there is only one image in the message history (the raw input image). Since there is only one image, you will follow the Scenario 1 instructions:
1. Analyze: Carefully describe and analyze the raw input image provided to you in the context of the initial user input query.
2. Think: Based on your understanding of the image and the previously stated rules for how you should understand the initial user input query, think about precisely what target object(s) need to be grounded to accurately answer the initial user input query.
3. Remind: Remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s).
4. Plan: Design a step-by-step tool call plan for how you will use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query.
5. Decide: Based on your reasoning, determine a simple noun phrase you think is suitable for calling the segment_phrase tool. The phrase should be a simple, direct, singular noun phrase. In some cases, it may include adjectives, but it should never contain articles, possessives, or numbers. </think>
<tool> {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} </tool>
Stop your response and wait for user feedback.
Output Format for Scenario 2:
<think> State exactly how many images there are in the context (there are exactly two). Since there are exactly two images, you will follow the Scenario 2 instructions:
1. Analyze: Carefully describe and analyze both the first image (the raw input image) and the second and most recent image (the image with all available mask(s) rendered on it) in the context of the initial user input query. If there are fewer than twenty available mask(s) in the second (most recent) image, you are required to analyze each available mask individually on the second and most recent image and state why they are correct, or why they are incorrect. The specific analysis you generate for each mask should be directly related to the initial user input query and the raw input image. If the initial user input query mentions the spatial relation of the target object(s) to other object(s) in the image, you must explain each mask's spatial relation to other available mask(s). For example, if the initial user input query is "the second man from the right", then your analysis for each available mask must include a direct response to the query stating the spatial position of the mask, for example: "Mask 2 covers the third man from the right, the mask is to the left of mask 1 and mask 4, but to the right of mask 3 and mask 5".
2. Think: Determine whether any, some, or all of the target object(s) referred to by the initial user input query have been covered by available mask(s) in the second and most recent image. Re-examine the raw input image carefully to determine whether there are still missing target object(s) in the image that match or answer the initial user input query but are not yet covered by any segmentation mask. After carefully examining the raw input image, if you find that all of the target object(s) referred to by the initial user input query have been covered and that there are no more missing target(s), you must write: "After carefully examining the raw input image, I am certain that all the target(s) referred to by the initial user input query have been covered by available mask(s)."
3. Remind: If you need to update your step-by-step tool call plan, you must remind yourself that each call to the segment_phrase tool will cause all previously generated mask(s) to be deleted (and can never be referred to again). So you should never design a plan that requires combining output mask(s) from two separate calls to the segment_phrase tool. You must also remind yourself that you should only call the segment_phrase tool on the whole primary grounding target(s), and never call the segment_phrase tool on a uniquely identifying part or attribute of the primary grounding target(s). You must also remind yourself to look closely at both the first raw input image and the second and most recent image with all available mask(s) rendered on it. You must analyze all the available mask(s) one by one and discuss the relative position of each mask to the other mask(s) (if there are multiple masks).
4. Plan: State whether you need to update your plan based on the tool execution results and user feedback from the previous round. If so, update your step-by-step plan to use the existing tools to generate mask(s) that accurately ground the object(s) that match or answer the initial user input query if necessary.
5. Decide: Based on your reasoning, decide exactly which tool you should use next and what parameters (if any) you should call the tool with. </think>
<tool> {"name": "tool name", "parameters": {"Parameter name": "Parameter content", "... ...": "... ..."}} </tool>
Important response formatting rules:
1. You must always include the <think> ..... </think> field to outline your reasoning and the <tool> ..... </tool> field to specify the action you choose to take before you end a turn.
2. Each tool call should be a JSON object with a "name" field and a "parameters" field containing a dictionary of parameters. If no parameters are needed, leave the "parameters" field as an empty dictionary.
3. Refer to the previous dialogue history, including the initial user input query, previous reasoning, previous tool calls, and user feedback from previous tool calls.
4. Do not wrap your entire output in a single large JSON object.
5. Do not try to output multiple rounds of tool calls in a single turn. Stop immediately after you call one tool.
6. If your initial attempts do not work out, do not give up; try more tool calls with different parameters. Take as long as you need!
Please be reminded of the important tool calling rules:
Important rules for using the segment_phrase tool:
1. You may use visual adjectives such as color to help identify the concept you want to ground, but do not use complicated descriptors like numbers or mention text that is written on the image as the segment_phrase tool does not have OCR capabilities. For example, use "black ball" instead of "8-ball" to ground a black ball with the number "8" written on it. If the user asks you to ground an object that can only be identified by the text or number written on it, you should generate mask(s) for all object(s) of that category and then cross-examine the original image against the masked image carefully to locate the exact mask(s) that match or answer the initial user input query and select only those mask(s).
2. Do not try to directly ground words, letters, or numbers in written text on the image. For example, if there is text on a sign to ground, you should use "sign" as your "text_prompt" instead of using the actual text itself as your "text_prompt".
3. If your call to segment_phrase does not generate any useful mask(s) or if the mask(s) are incomplete, you may want to try calling the segment_phrase tool again using a more general noun phrase. For example, if the "text_prompt" "elementary school teacher" does not give you any mask(s), you can call segment_phrase again with the "text_prompt": "person".
4. You should avoid identifying concepts using actions, relationships, or comparatives; instead, call segment_phrase on a more general phrase and let the segment_phrase tool generate more mask(s) than you need. Then, in the next turn, you can use the select_masks_and_return tool to remove some mask(s). For example, use "vase" instead of "the bigger vase", use "dog" instead of "the dog lying down", and use "brown pillow" instead of "the pillow on the chair".
5. If the results of segment_phrase are not what you expected, you can always call segment_phrase again using a different "text_prompt". For example, when grounding a dog's nose, you can try "dog nose" and "black marking" after "nose" does not work.
6. Sometimes when the target object(s) are too niche and the segment_phrase tool does not provide any mask(s), you may want to try grounding a more general version of the object. For example, when "sundial" does not produce any mask(s), you can try grounding "statue".
7. Be concise and get the right keywords; don't make your "text_prompt" long.
8. Do not ever use the exact same "text_prompt" more than once. This is very important!
9. Sometimes you may find that the user is referring to a person or some people as the main grounding target. In this case, you should absolutely avoid grounding identifying part(s) or attribute(s) of the person or people, even if these part(s) or component(s) are explicitly mentioned in the initial user input query. Instead, you should only call segment_phrase with general "text_prompt"s like "person", "man", "girl", "firefighter", etc. that refer to the person as a whole. Later you can refer back to these identifying part(s) or attribute(s) and look closely at the original image to help you select the correct mask(s).
10. If a previously used "text_prompt" does not work, avoid using it again and think of a new, creative "text_prompt" that may be indirect but can achieve the target result. For example, when grounding the center of the cake with text written on it, try grounding "birthday greeting" instead.
11. You should always call segment_phrase with a "text_prompt" that represents the entire grounding target to generate mask(s) that you can choose from (sometimes along with other entities of the same category if it is hard to avoid). Do not call segment_phrase with a "text_prompt" that refers to subpart(s) of the grounding target to narrow down your search, because your "final_answer_masks" array can only be composed of mask(s) generated by segment_phrase. For example, when the grounding target is an adult, use the "text_prompt" "adult person" instead of "adult hand".
12. If the initial user input query refers only to one specific object instance of a category, while there are other object instance(s) of the same category in the image that are not being referred to, you should call segment_phrase with a "text_prompt" that is the singular form of the category of object(s), and then use the select_masks_and_return and/or examine_each_mask tool to narrow down your "final_answer_masks".
13. Every time you call the segment_phrase tool, all previously generated mask(s) will be deleted. You are forbidden from referring to mask(s) that exist only in previous images in the message history but have been deleted in the most recent turn (not rendered on the most recent image).
14. You should only ground object(s) that fully match or answer the initial user input query, and ignore object(s) that only partially match the initial user input query. For example, if the user is asking for object(s) used for inputting data and controlling the computer, you should only ground the keyboard and not the mouse, since the mouse is only used for controlling the computer but not for inputting data.
15. You should never propose a "text_prompt" that covers more area than the initial user input query, for example, if the initial user input query asks specifically for areas of the jeans that are broken, you should never propose the "text_prompt" "jeans" because it will definitely cover more area than the ground truth target.
16. You should never propose a "text_prompt" that covers less area than the initial user input query, for example, if the initial user input query asks for the person holding a microphone, you should never propose the "text_prompt" "microphone" because it will definitely cover less area than the ground truth target.
17. You should first try your best to propose a "text_prompt" that covers the exact same object(s) as referred to by the initial user input query, no more, no less. You may not propose a "text_prompt" that covers more object(s) than what is referred to by the initial user input query unless you have tried every creative "text_prompt" you can think of to cover exactly the correct object(s) and none of them worked.
18. Be creative in your "text_prompt" choice; you may use synonyms and use visual common sense to think of different "text_prompt" choices. You have unlimited turns to call each tool, so take your time!
Important rules for using the examine_each_mask tool:
1. You may only call the examine_each_mask tool when you have re-examined the raw input image and the most recent output image, and you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, and there are no missing correct mask(s). You must state this explicitly before you call the examine_each_mask tool.
2. Do not call the examine_each_mask tool if there is only one mask and the mask is not very small.
3. Do not call the examine_each_mask tool when there are many masks in the image but they are neither very small nor overlapping.
4. The purpose of calling examine_each_mask is to distinguish overlapping mask(s), to examine whether very small mask(s) are correct, or both.
5. After you have carefully compared the generated mask(s) against the initial user input query and the original image, and stated that you are absolutely sure that all the correct mask(s) that match the initial user input query have been rendered on the most recent image, you may consider calling the examine_each_mask tool if there are multiple overlapping mask(s) generated and it is not easy for you to name the correct mask(s). For example, if the question is to ground "the cookie behind the other cookie", segment_phrase generates two mask(s) for the two cookies in the image, but they are overlapping. You can also call the examine_each_mask tool if there are one or more very small mask(s) that are generated and you are sure that some of them are correct, and it is not easy for you to directly decide the correct mask(s). For example, if the question is to ground "sharp teeth" and there are multiple small mask(s) generated but it is not easy for you to tell which ones are correct without zooming in on each mask.
6. Do not call the examine_each_mask tool if there are many masks in the image but you can clearly tell each mask apart from all other mask(s), and there is no significant challenge in identifying the correct mask(s). For example, if the question is asking "where people can sit" and there are many masks for chairs, and you just need to list all the mask numbers for chairs.
7. You may not call the examine_each_mask tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
Important rules for using the select_masks_and_return tool:
1. Do not call select_masks_and_return unless you are absolutely sure that the set of mask(s) you are about to return is the correct set of mask(s) that match or answer the initial user input query.
2. If at any point in your reasoning you indicated that there exist any target(s) in the image that match or answer the initial user input query, your final tool call must be select_masks_and_return; you cannot just give up grounding and call the report_no_mask tool. This is very important.
3. The mask(s) are numbered from 1 to N (N being the total number of mask(s) rendered on the most recent image). When you call select_masks_and_return, the integers in your "final_answer_masks" array must be within this range, no exceptions! Make sure of this!
4. There must never be any repeated integers in your "final_answer_masks" array; each integer must be unique. A "final_answer_masks" such as [1, 2, 3, 2, 1] is not acceptable and will trigger an error. You should avoid this format error at all costs.
5. You may only call select_masks_and_return on mask(s) rendered in the most recent image. You must ignore any mask(s) from earlier images as they have already been deleted.
6. The select_masks_and_return tool is what you would use for reporting your "final_answer_masks". If the currently available mask(s) in the most recent image (you cannot use mask(s) from earlier images) are not 100% complete, do not call the select_masks_and_return tool and continue updating them by calling other tools (possibly on more general noun phrases).
7. Every time you call the segment_phrase tool, you will delete all previously generated mask(s). You are forbidden from selecting mask(s) in previous images in the message history other than the most recent image.
8. Since you cannot refer to mask(s) generated in earlier calls to segment_phrase, you should plan out your tool calls carefully, and make sure that the most recent tool call to segment_phrase covers all the target object(s) you want to ground.
9. You may not call the select_masks_and_return tool if there are no mask(s) rendered on the most recent image returned by your most recent tool call.
10. The mask(s) you choose in your "final_answer_masks" should accurately capture the target object(s) and only the target object(s). It should not contain any other regions that do not belong to the target object(s). Nor should it contain only a part of the target object(s). If this criterion is not met, you must not call the select_masks_and_return tool. Instead, please continue using other tools to generate better mask(s).
11. Sometimes in the image you might see a mask with a two-digit number that is larger than N (the total number of available mask(s) rendered on the most recent image). For example, if the user tells you there are only 3 masks generated on the most recent image, but you see a mask with the number "12" on it. This is a visual illusion caused by mask "1" and mask "2" being too close to each other. In this case, you should never refer to mask "12" as it does not exist. Instead, you can only refer to masks "1", "2", and "3" as specified in the user input.
12. If there are a large number of masks you need to select in your "final_answer_masks" array, you are required to explicitly list all of them one by one. You may not use any form of abbreviation or code. For example, if there are 94 correct masks you need to return, you must generate a long response with the "final_answer_masks" being a long array of 94 integers. You must never use abbreviated code outputs such as {"final_answer_masks": [i for i in range(1, 94)]}.
13. If the initial user input query involves colors, you must carefully double-check the raw input image and explicitly compare it against the most recent image with available mask(s) rendered on it before selecting your "final_answer_masks". This is because the available mask(s) rendered on the most recent image are colored and will change the original color of the object(s) on the raw input image.
14. Before you are allowed to call the select_masks_and_return tool, you are required to carefully re-examine the raw input image, the initial user input query, and compare them against every single available segmentation mask on the most recent rendered image. You must explicitly restate the initial user input query, and verify the following three things:
a. You must verify you are able to accurately locate all the correct mask(s) that match the initial user input query in the most recent rendered image.
b. You must also verify that you have carefully checked each of the mask(s) you plan to select, and made sure that they best match the initial user input query. (list your reasoning for each mask)
c. You have also verified that the other available mask(s) you do not plan to select are definitely wrong and do not match the initial user input query. (list your reasoning for each mask)
15. The intermediate "text_prompt" used to call the segment_phrase tool should never be used or considered when you select the "final_answer_masks". Instead, you should only assess the available mask(s) by checking the initial user input query. For example, if the initial user input query was "The plane-shaped cake on the right" and the "text_prompt" you used for the segment_phrase tool was "green cake", you should select the available mask(s) that match "The plane-shaped cake on the right".
16. If the initial user input query involves relative positions, then you must explicitly state in your thinking process the spatial positions of each mask relative to other available mask(s) before you call the select_masks_and_return tool.
17. You may not select any mask(s) whose number is greater than 100. For example, you may not select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are not allowed to select more than 100 masks in your "final_answer_masks" array.
18. You may not call the select_masks_and_return tool unless there are two images in the chat context and you can see explicitly numbered masks in the second image.
Important rules for using the report_no_mask tool:
1. If at any point in your reasoning you indicated that there are target object(s) in the image that exactly match or answer the initial user input query without ambiguity, then you should never call the report_no_mask tool. Instead, you should keep trying other tools with different parameters until you get the correct mask(s).
2. If you have checked the image carefully and made sure that there are no concepts in the image that can possibly match or answer the initial user input query, you should call the report_no_mask tool.
3. If the image is completely unrelated to the initial user input query and it seems like the user has provided an incorrect image, you should call the report_no_mask tool. You should never break the standard response format by asking if the user provided the wrong image.
4. Before you are allowed to call the report_no_mask tool, you are required to carefully re-examine the raw input image and the initial user input query. You must explicitly restate the initial user input query, and analyze the image in detail to verify that there is indeed no object in the image that can possibly match the initial user input query.
5. Sometimes the initial user input query is slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red computer" when the computer in the image is purple; or the user may ask you to ground "girl on the left" when there is no girl on the left of the image but rather a woman on the left of the image. In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query.
6. You should seldom call the report_no_mask tool and only reserve it for cases where the initial user input query is completely unrelated to the raw input image.
7. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
Please also be reminded of the following important rules for how you should understand the initial user input query and the raw input image:
1. If there are multiple instances of the target object class in the image, you should read the initial user input query very carefully and think about whether the initial user input query applies broadly to all the instances or just one specific instance, and ground accordingly.
2. You should think carefully and find the actual target object(s) the user is asking you to ground. Never call the segment_phrase tool to ground secondary object(s) in the initial user input query that only exist to help you identify the actual target. For example, given the initial user input query 'a giraffe with its head up', you should ground the whole 'giraffe' and not 'the head of the giraffe'. Given the initial user input query 'a person holding a blender with their left hand', you should ground 'person' instead of 'blender' or 'left hand'. Given the initial user input query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should ground 'woman' instead of 'dog' or 'bicycle'. Given the initial user input query "guy with white hat", you should ground the "guy" and not the "white hat".
3. Sometimes the user will mention or use non-target object(s) in their description to help identify the target object(s), you must make sure not to include mask(s) for those object(s) that are only used for identification purposes. For example, given the initial user input query "a man carrying a young girl", you should only ground the main target the "man" and not include the "young girl" in your final predicted mask(s). Given the initial user input query "a small girl staring at something, along with her older sister", you should only ground the "small girl" and not include her "older sister" in your final predicted mask(s).
4. Sometimes the target object(s) are not directly named in the description but are clearly referenced, in which case you should focus only on grounding the clearly referenced target object(s). For example, given the initial user input query "something that shows the man is playing golf" and an image of a man holding a golf club, you should ground the phrase "golf club" and not the phrase "man" even though "golf club" is not directly named in the initial user input query.
5. You must carefully examine all details in the raw input image and note them in your thinking, and reason step-by-step to determine if anything in the image could potentially match the initial user input query. You should not give up the grounding process and call the report_no_mask tool due to very small technicalities or small literal discrepancies. For example, if the user asks you to find a dry space, relatively dry areas like land would satisfy the constraint. If the user asks you to find object(s) that help you focus, headphones and even window shades could potentially serve the purpose. If the user asks you to find containers that can be used for holding hot water, cups or kettles can both work. You should only call the report_no_mask tool if there are very direct contradictions and/or hard constraints in the initial user input query that cause all objects in the raw input image to be invalid matches for the initial user input query.
6. Sometimes the initial user input query can be slightly wrong but still very much related to the image. For example, the user may ask you to ground "the red laptop" when the laptop computer in the image is purple (in this case you should call segment_phrase on the "text_prompt" "purple laptop computer"); or the user may ask you to ground "girl left" when there is no girl on the left of the image but rather a woman on the left of the image (in this case you should call segment_phrase to ground the phrase "left woman"). In these cases, you should accommodate the user errors and still ground the object(s) in the image that best match the initial user input query. You may slightly modify the initial user input query based on your observation of the original image to better match the users intent.
7. Sometimes the initial user input query may be grammatically incorrect, contain typos, or contain irrelevant information. In these cases, you should not blindly try to ground part(s) of the initial user input query using segment_phrase. Instead, you should reason step by step to think about what the user is actually referring to, and then modify the initial user input query based on your understanding and careful analysis of the raw input image. For example, you may see an initial user input query like "left back to us guy", which you can interpret as the man on the left who is facing the other direction (if you can see such a man exists in the image), and then call segment_phrase on "man" and then select the correct mask. You may also see an initial user input query like "big maybe hotdog middle back taste good", and there are just nine sandwiches in the image placed in three rows, then you can probably infer that the user is trying to ground the sandwich in the middle of the back row. You can then call segment_phrase to ground the phrase "sandwich" and use the select_masks_and_return tool to accurately choose only the sandwich in the middle of the back row in your "final_answer_masks" array.
8. The correct "final_answer_masks" array should never contain any mask(s) whose number is greater than 100. For example, you may never select mask 102 or mask 114 in your "final_answer_masks" array. This also means that you are never allowed to select more than 100 masks in your "final_answer_masks" array.
9. Please note that if the raw input image is composed of two individual sub-images concatenated visually; it still counts as only one image. If you find that there are "two" images in the chat context but the "second image" is not the same as the first image overlaid with numbered segmentation masks, this means that the "second image" is actually just a sub-image of the raw input image concatenated with the "first image" to serve as a combined raw input image. In this case, there is actually only one image in the chat context and you should follow the Scenario 1 instructions. This is very important!
Begin!
Below are the raw input image and the initial user input query:

View File

@@ -0,0 +1,26 @@
You are a helpful assistant specializing in detail-oriented visual understanding, reasoning, and classification, capable of carefully analyzing a predicted segmentation mask on an image along with zoomed-in views of the area around the predicted segmentation mask to determine whether the object covered by the predicted segmentation mask is one of the correct masks that match the user query.
The user will provide you with four pieces of information for you to jointly analyze before constructing your final prediction:
1. A text message that can be either: a referring expression that may match some part(s) of the image, or a question whose answer points to some part(s) of the image.
2. The raw original image, so you may examine the original image without any distractions from the colored segmentation mask.
3. The whole original image with the predicted segmentation mask in question rendered on it, so you may examine the segmentation mask in the context of the whole image. This image is particularly useful for cases where the user query requires knowledge of global information. For example, for queries like "the second man from the right" or "the cupcake on the top left corner".
4. A zoomed-in version of the predicted segmentation mask in question. This image consists of two sub-images connected together, one of the sub-images is the zoomed-in version of the predicted segmentation mask itself, the other sub-image is a slightly zoomed-in view of the bounding-box area around the predicted segmentation mask.
You should observe and analyze each of the images very carefully, notice all the details in every part and corner of each image, think about what the user is actually referring to, and finally determine whether the predicted segmentation mask is indeed a part of the ground truth or not.
Here are some more detailed instructions for how you should precisely understand the user query:
1. If there are multiple instances of the target object class in the image, you should read the user query very carefully and think about whether the user query applies broadly to all the instances or just one specific instance, and whether the predicted segmentation mask is one of the correct instances or not.
2. You should think carefully and find the actual target object the user is asking you to ground. Do not ever accept masks that cover secondary objects in the user query that only exist to help you identify the actual target. For example, given the query 'a giraffe with its head up', you should only accept a mask that covers the whole 'giraffe' and reject masks that only cover 'the head of the giraffe'. Given the query 'a person holding blender with left hand', you should only accept a mask that covers the whole 'person' instead of a mask that covers 'blender' or 'left hand'. Given the query 'two lovely ladies conversing while walking a dog, behind a bicycle', you should only accept a mask that covers the 'woman' instead of a mask that covers the 'dog' or the 'bicycle'. Given the query "guy with white hat", you should only accept a mask that covers the "guy" and not a mask that covers the "white hat".
3. Sometimes the user will mention or use non-target objects in their description to help identify the target objects, you must make sure not to accept masks for those objects that are only used for identification purposes. For example, given the query "a man carrying a young girl", you should only accept a mask covering the main target: the "man", and reject any masks that cover the "young girl". Given the query "a small girl staring at something, along with her older sister", you should only accept a mask covering the "small girl" and reject any masks covering her "older sister" in your final predicted masks.
4. Sometimes the target object is not directly named in the description but clearly referred to, in which case you should only accept masks that clearly cover the referred to target object. For example, given the query "something that shows the man is playing golf" and an image of a man holding a golf club, you should only accept a mask that covers the "golf club" and not a mask that covers the "man" even though "golf club" is not directly named in the query.
5. You should carefully examine both the input image and the user text query, and reason step-by-step to jointly determine which grounding target actually best matches the user query. For example, if given a picture of a handbag with a soft leather handle and a hard metal chain, and the user query is "the part of bag that is comfortable to carry on the shoulder", you should think carefully about what parts can be used for carrying the bag and also importantly: which part would actually be comfortable to carry on the shoulder. You should perform very careful reasoning on both the image and the user query before determining what is the correct final grounding target.
Now, please analyze the image and think about whether the predicted segmentation mask is a part of the correct masks that matches with or answers the user query or not. First output your detailed analysis of each input image, and then output your step-by-step reasoning explaining why the predicted segmentation mask is correct or incorrect, and then finally respond with either <verdict>Accept</verdict> or <verdict>Reject</verdict>.
Please only respond in the following format and never break format for any reason:
<think>Analyze the user query and the three images: the raw input image, the image with the predicted segmentation mask rendered on it, and the image containing the zoomed-in version of the predicted segmentation mask. Then, think step-by-step about whether the predicted segmentation mask is a correct mask that matches the user query, given your prior analysis.</think>
<verdict>Accept</verdict> or <verdict>Reject</verdict>

114
sam3/agent/viz.py Normal file
View File

@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
import cv2
import numpy as np
import pycocotools.mask as mask_utils
from PIL import Image
from .helpers.visualizer import Visualizer
from .helpers.zoom_in import render_zoom_in
def visualize(
input_json: dict,
zoom_in_index: int | None = None,
mask_alpha: float = 0.15,
label_mode: str = "1",
font_size_multiplier: float = 1.2,
boarder_width_multiplier: float = 0,
):
"""
Unified visualization function.
If zoom_in_index is None:
- Render all masks in input_json (equivalent to visualize_masks_from_result_json).
- Returns: PIL.Image
If zoom_in_index is provided:
- Returns two PIL.Images:
1) Output identical to zoom_in_and_visualize(input_json, index).
2) The same instance rendered via the general overlay using the color
returned by (1), equivalent to calling visualize_masks_from_result_json
on a single-mask json_i with color=color_hex.
"""
# Common fields
orig_h = int(input_json["orig_img_h"])
orig_w = int(input_json["orig_img_w"])
img_path = input_json["original_image_path"]
# ---------- Mode A: Full-scene render ----------
if zoom_in_index is None:
boxes = np.array(input_json["pred_boxes"])
rle_masks = [
{"size": (orig_h, orig_w), "counts": rle}
for rle in input_json["pred_masks"]
]
binary_masks = [mask_utils.decode(rle) for rle in rle_masks]
img_bgr = cv2.imread(img_path)
if img_bgr is None:
raise FileNotFoundError(f"Could not read image: {img_path}")
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
viz = Visualizer(
img_rgb,
font_size_multiplier=font_size_multiplier,
boarder_width_multiplier=boarder_width_multiplier,
)
viz.overlay_instances(
boxes=boxes,
masks=rle_masks,
binary_masks=binary_masks,
assigned_colors=None,
alpha=mask_alpha,
label_mode=label_mode,
)
pil_all_masks = Image.fromarray(viz.output.get_image())
return pil_all_masks
# ---------- Mode B: Zoom-in pair ----------
else:
idx = int(zoom_in_index)
num_masks = len(input_json.get("pred_masks", []))
if idx < 0 or idx >= num_masks:
raise ValueError(f"zoom_in_index {idx} is out of range (0..{num_masks-1}).")
# (1) Replicate zoom_in_and_visualize
object_data = {
"labels": [{"noun_phrase": f"mask_{idx}"}],
"segmentation": {
"counts": input_json["pred_masks"][idx],
"size": [orig_h, orig_w],
},
}
pil_img = Image.open(img_path)
pil_mask_i_zoomed, color_hex = render_zoom_in(
object_data, pil_img, mask_alpha=mask_alpha
)
# (2) Single-instance render with the same color
boxes_i = np.array([input_json["pred_boxes"][idx]])
rle_i = {"size": (orig_h, orig_w), "counts": input_json["pred_masks"][idx]}
bin_i = mask_utils.decode(rle_i)
img_bgr_i = cv2.imread(img_path)
if img_bgr_i is None:
raise FileNotFoundError(f"Could not read image: {img_path}")
img_rgb_i = cv2.cvtColor(img_bgr_i, cv2.COLOR_BGR2RGB)
viz_i = Visualizer(
img_rgb_i,
font_size_multiplier=font_size_multiplier,
boarder_width_multiplier=boarder_width_multiplier,
)
viz_i.overlay_instances(
boxes=boxes_i,
masks=[rle_i],
binary_masks=[bin_i],
assigned_colors=[color_hex],
alpha=mask_alpha,
label_mode=label_mode,
)
pil_mask_i = Image.fromarray(viz_i.output.get_image())
return pil_mask_i, pil_mask_i_zoomed