Files
sam3_local/sam3/agent/helpers/zoom_in.py
generatedunixname89002005307016 7b89b8fc3f Add missing Pyre mode headers] [batch:11/N] [shard:17/N]
Differential Revision: D90237984

fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
2026-01-07 05:16:41 -08:00

198 lines
6.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
import io
import math
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_utils
from PIL import Image
from .som_utils import ColorPalette, draw_box, draw_mask, draw_text
def render_zoom_in(
object_data,
image_file,
show_box: bool = True,
show_text: bool = False,
show_holes: bool = True,
mask_alpha: float = 0.15,
):
"""
Render a two-panel visualization with a cropped original view (left/upper) and a zoomed-in
mask overlay (right/lower), then return it as a PIL.Image along with the chosen mask color (hex).
Parameters
----------
object_data : dict
Dict containing "labels" and COCO RLE "segmentation".
Expected:
object_data["labels"][0]["noun_phrase"] : str
object_data["segmentation"] : COCO RLE (with "size": [H, W])
image_file : PIL.Image.Image
Source image (PIL).
show_box : bool
Whether to draw the bbox on the cropped original panel.
show_text : bool
Whether to draw the noun phrase label near the bbox.
show_holes : bool
Whether to render mask holes (passed through to draw_mask).
mask_alpha : float
Alpha for the mask overlay.
Returns
-------
pil_img : PIL.Image.Image
The composed visualization image.
color_hex : str
Hex string of the chosen mask color.
"""
# ---- local constants (avoid module-level globals) ----
_AREA_LARGE = 0.25
_AREA_MEDIUM = 0.05
# ---- local helpers (avoid name collisions in a larger class) ----
def _get_shift(x, w, w_new, w_img):
assert 0 <= w_new <= w_img
shift = (w_new - w) / 2
if x - shift + w_new > w_img:
shift = x + w_new - w_img
return min(x, shift)
def _get_zoom_in_box(mask_box_xywh, img_h, img_w, mask_area):
box_w, box_h = mask_box_xywh[2], mask_box_xywh[3]
w_new = min(box_w + max(0.2 * box_w, 16), img_w)
h_new = min(box_h + max(0.2 * box_h, 16), img_h)
mask_relative_area = mask_area / (w_new * h_new)
# zoom-in (larger box if mask is relatively big)
w_new_large, h_new_large = w_new, h_new
if mask_relative_area > _AREA_LARGE:
ratio_large = math.sqrt(mask_relative_area / _AREA_LARGE)
w_new_large = min(w_new * ratio_large, img_w)
h_new_large = min(h_new * ratio_large, img_h)
w_shift_large = _get_shift(
mask_box_xywh[0], mask_box_xywh[2], w_new_large, img_w
)
h_shift_large = _get_shift(
mask_box_xywh[1], mask_box_xywh[3], h_new_large, img_h
)
zoom_in_box = [
mask_box_xywh[0] - w_shift_large,
mask_box_xywh[1] - h_shift_large,
w_new_large,
h_new_large,
]
# crop box for the original/cropped image
w_new_medium, h_new_medium = w_new, h_new
if mask_relative_area > _AREA_MEDIUM:
ratio_med = math.sqrt(mask_relative_area / _AREA_MEDIUM)
w_new_medium = min(w_new * ratio_med, img_w)
h_new_medium = min(h_new * ratio_med, img_h)
w_shift_medium = _get_shift(
mask_box_xywh[0], mask_box_xywh[2], w_new_medium, img_w
)
h_shift_medium = _get_shift(
mask_box_xywh[1], mask_box_xywh[3], h_new_medium, img_h
)
img_crop_box = [
mask_box_xywh[0] - w_shift_medium,
mask_box_xywh[1] - h_shift_medium,
w_new_medium,
h_new_medium,
]
return zoom_in_box, img_crop_box
# ---- main body ----
# Input parsing
object_label = object_data["labels"][0]["noun_phrase"]
img = image_file.convert("RGB")
bbox_xywh = mask_utils.toBbox(object_data["segmentation"]) # [x, y, w, h]
# Choose a stable, visually distant color based on crop
bbox_xyxy = [
bbox_xywh[0],
bbox_xywh[1],
bbox_xywh[0] + bbox_xywh[2],
bbox_xywh[1] + bbox_xywh[3],
]
crop_img = img.crop(bbox_xyxy)
color_palette = ColorPalette.default()
color_obj, _ = color_palette.find_farthest_color(np.array(crop_img))
color = np.array([color_obj.r / 255, color_obj.g / 255, color_obj.b / 255])
color_hex = f"#{color_obj.r:02x}{color_obj.g:02x}{color_obj.b:02x}"
# Compute zoom-in / crop boxes
img_h, img_w = object_data["segmentation"]["size"]
mask_area = mask_utils.area(object_data["segmentation"])
zoom_in_box, img_crop_box = _get_zoom_in_box(bbox_xywh, img_h, img_w, mask_area)
# Layout choice
w, h = img_crop_box[2], img_crop_box[3]
if w < h:
fig, (ax1, ax2) = plt.subplots(1, 2)
else:
fig, (ax1, ax2) = plt.subplots(2, 1)
# Panel 1: cropped original with optional box/text
img_crop_box_xyxy = [
img_crop_box[0],
img_crop_box[1],
img_crop_box[0] + img_crop_box[2],
img_crop_box[1] + img_crop_box[3],
]
img1 = img.crop(img_crop_box_xyxy)
bbox_xywh_rel = [
bbox_xywh[0] - img_crop_box[0],
bbox_xywh[1] - img_crop_box[1],
bbox_xywh[2],
bbox_xywh[3],
]
ax1.imshow(img1)
ax1.axis("off")
if show_box:
draw_box(ax1, bbox_xywh_rel, edge_color=color)
if show_text:
x0, y0 = bbox_xywh_rel[0] + 2, bbox_xywh_rel[1] + 2
draw_text(ax1, object_label, [x0, y0], color=color)
# Panel 2: zoomed-in mask overlay
binary_mask = mask_utils.decode(object_data["segmentation"])
alpha = Image.fromarray((binary_mask * 255).astype("uint8"))
img_rgba = img.convert("RGBA")
img_rgba.putalpha(alpha)
zoom_in_box_xyxy = [
zoom_in_box[0],
zoom_in_box[1],
zoom_in_box[0] + zoom_in_box[2],
zoom_in_box[1] + zoom_in_box[3],
]
img_with_alpha_zoomin = img_rgba.crop(zoom_in_box_xyxy)
alpha_zoomin = img_with_alpha_zoomin.split()[3]
binary_mask_zoomin = np.array(alpha_zoomin).astype(bool)
ax2.imshow(img_with_alpha_zoomin.convert("RGB"))
ax2.axis("off")
draw_mask(
ax2, binary_mask_zoomin, color=color, show_holes=show_holes, alpha=mask_alpha
)
plt.tight_layout()
# Buffer -> PIL.Image
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0, dpi=100)
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img, color_hex