Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
138
sam3/agent/client_sam3.py
Executable file
138
sam3/agent/client_sam3.py
Executable file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sam3.model.box_ops import box_xyxy_to_xywh
|
||||
from sam3.train.masks_ops import rle_encode
|
||||
|
||||
from .helpers.mask_overlap_removal import remove_overlapping_masks
|
||||
from .viz import visualize
|
||||
|
||||
|
||||
def sam3_inference(processor, image_path, text_prompt):
|
||||
"""Run SAM 3 image inference with text prompts and format the outputs"""
|
||||
image = Image.open(image_path)
|
||||
orig_img_w, orig_img_h = image.size
|
||||
|
||||
# model inference
|
||||
inference_state = processor.set_image(image)
|
||||
inference_state = processor.set_text_prompt(
|
||||
state=inference_state, prompt=text_prompt
|
||||
)
|
||||
|
||||
# format and assemble outputs
|
||||
pred_boxes_xyxy = torch.stack(
|
||||
[
|
||||
inference_state["boxes"][:, 0] / orig_img_w,
|
||||
inference_state["boxes"][:, 1] / orig_img_h,
|
||||
inference_state["boxes"][:, 2] / orig_img_w,
|
||||
inference_state["boxes"][:, 3] / orig_img_h,
|
||||
],
|
||||
dim=-1,
|
||||
) # normalized in range [0, 1]
|
||||
pred_boxes_xywh = box_xyxy_to_xywh(pred_boxes_xyxy).tolist()
|
||||
pred_masks = rle_encode(inference_state["masks"].squeeze(1))
|
||||
pred_masks = [m["counts"] for m in pred_masks]
|
||||
outputs = {
|
||||
"orig_img_h": orig_img_h,
|
||||
"orig_img_w": orig_img_w,
|
||||
"pred_boxes": pred_boxes_xywh,
|
||||
"pred_masks": pred_masks,
|
||||
"pred_scores": inference_state["scores"].tolist(),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
def call_sam_service(
|
||||
sam3_processor,
|
||||
image_path: str,
|
||||
text_prompt: str,
|
||||
output_folder_path: str = "sam3_output",
|
||||
):
|
||||
"""
|
||||
Loads an image, sends it with a text prompt to the service,
|
||||
saves the results, and renders the visualization.
|
||||
"""
|
||||
print(f"📞 Loading image '{image_path}' and sending with prompt '{text_prompt}'...")
|
||||
|
||||
text_prompt_for_save_path = (
|
||||
text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
|
||||
)
|
||||
|
||||
os.makedirs(
|
||||
os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
|
||||
)
|
||||
output_json_path = os.path.join(
|
||||
output_folder_path,
|
||||
image_path.replace("/", "-"),
|
||||
rf"{text_prompt_for_save_path}.json",
|
||||
)
|
||||
output_image_path = os.path.join(
|
||||
output_folder_path,
|
||||
image_path.replace("/", "-"),
|
||||
rf"{text_prompt_for_save_path}.png",
|
||||
)
|
||||
|
||||
try:
|
||||
# Send the image and text prompt as a multipart/form-data request
|
||||
serialized_response = sam3_inference(sam3_processor, image_path, text_prompt)
|
||||
|
||||
# 1. Prepare the response dictionary
|
||||
serialized_response = remove_overlapping_masks(serialized_response)
|
||||
serialized_response = {
|
||||
"original_image_path": image_path,
|
||||
"output_image_path": output_image_path,
|
||||
**serialized_response,
|
||||
}
|
||||
|
||||
# 2. Reorder predictions by scores (highest to lowest) if scores are available
|
||||
if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
|
||||
# Create indices sorted by scores in descending order
|
||||
score_indices = sorted(
|
||||
range(len(serialized_response["pred_scores"])),
|
||||
key=lambda i: serialized_response["pred_scores"][i],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Reorder all three lists based on the sorted indices
|
||||
serialized_response["pred_scores"] = [
|
||||
serialized_response["pred_scores"][i] for i in score_indices
|
||||
]
|
||||
serialized_response["pred_boxes"] = [
|
||||
serialized_response["pred_boxes"][i] for i in score_indices
|
||||
]
|
||||
serialized_response["pred_masks"] = [
|
||||
serialized_response["pred_masks"][i] for i in score_indices
|
||||
]
|
||||
|
||||
# 3. Remove any invalid RLE masks that is too short (shorter than 5 characters)
|
||||
valid_masks = []
|
||||
valid_boxes = []
|
||||
valid_scores = []
|
||||
for i, rle in enumerate(serialized_response["pred_masks"]):
|
||||
if len(rle) > 4:
|
||||
valid_masks.append(rle)
|
||||
valid_boxes.append(serialized_response["pred_boxes"][i])
|
||||
valid_scores.append(serialized_response["pred_scores"][i])
|
||||
serialized_response["pred_masks"] = valid_masks
|
||||
serialized_response["pred_boxes"] = valid_boxes
|
||||
serialized_response["pred_scores"] = valid_scores
|
||||
|
||||
with open(output_json_path, "w") as f:
|
||||
json.dump(serialized_response, f, indent=4)
|
||||
print(f"✅ Raw JSON response saved to '{output_json_path}'")
|
||||
|
||||
# 4. Render and save visualizations on the image and save it in the SAM3 output folder
|
||||
print("🔍 Rendering visualizations on the image ...")
|
||||
viz_image = visualize(serialized_response)
|
||||
os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
|
||||
viz_image.save(output_image_path)
|
||||
print("✅ Saved visualization at:", output_image_path)
|
||||
except Exception as e:
|
||||
print(f"❌ Error calling service: {e}")
|
||||
|
||||
return output_json_path
|
||||
Reference in New Issue
Block a user