Files
sam3_local/sam3/agent/helpers/som_utils.py
generatedunixname89002005307016 7b89b8fc3f Add missing Pyre mode headers] [batch:11/N] [shard:17/N]
Differential Revision: D90237984

fbshipit-source-id: 526fd760f303bf31be4f743bdcd77760496de0de
2026-01-07 05:16:41 -08:00

409 lines
12 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
import colorsys
from dataclasses import dataclass
from typing import List, Tuple
import cv2
import matplotlib as mpl
import matplotlib.colors as mplc
import numpy as np
import pycocotools.mask as mask_utils
def rgb_to_hex(rgb_color):
"""
Convert a rgb color to hex color.
Args:
rgb_color (tuple/list of ints): RGB color in tuple or list format.
Returns:
str: Hex color.
Example:
```
>>> rgb_to_hex((255, 0, 244))
'#ff00ff'
```
"""
return "#" + "".join([hex(c)[2:].zfill(2) for c in rgb_color])
# DEFAULT_COLOR_HEX_TO_NAME = {
# rgb_to_hex((255, 0, 0)): "red",
# rgb_to_hex((0, 255, 0)): "lime",
# rgb_to_hex((0, 0, 255)): "blue",
# rgb_to_hex((255, 255, 0)): "yellow",
# rgb_to_hex((255, 0, 255)): "fuchsia",
# rgb_to_hex((0, 255, 255)): "aqua",
# rgb_to_hex((255, 165, 0)): "orange",
# rgb_to_hex((128, 0, 128)): "purple",
# rgb_to_hex((255, 215, 0)): "gold",
# }
# Assuming rgb_to_hex is a function that converts an (R, G, B) tuple to a hex string.
# For example: def rgb_to_hex(rgb): return '#%02x%02x%02x' % rgb
DEFAULT_COLOR_HEX_TO_NAME = {
# The top 20 approved colors
rgb_to_hex((255, 255, 0)): "yellow",
rgb_to_hex((0, 255, 0)): "lime",
rgb_to_hex((0, 255, 255)): "cyan",
rgb_to_hex((255, 0, 255)): "magenta",
rgb_to_hex((255, 0, 0)): "red",
rgb_to_hex((255, 127, 0)): "orange",
rgb_to_hex((127, 255, 0)): "chartreuse",
rgb_to_hex((0, 255, 127)): "spring green",
rgb_to_hex((255, 0, 127)): "rose",
rgb_to_hex((127, 0, 255)): "violet",
rgb_to_hex((192, 255, 0)): "electric lime",
rgb_to_hex((255, 192, 0)): "vivid orange",
rgb_to_hex((0, 255, 192)): "turquoise",
rgb_to_hex((192, 0, 255)): "bright violet",
rgb_to_hex((255, 0, 192)): "bright pink",
rgb_to_hex((255, 64, 0)): "fiery orange",
rgb_to_hex((64, 255, 0)): "bright chartreuse",
rgb_to_hex((0, 255, 64)): "malachite",
rgb_to_hex((64, 0, 255)): "deep violet",
rgb_to_hex((255, 0, 64)): "hot pink",
}
DEFAULT_COLOR_PALETTE = list(DEFAULT_COLOR_HEX_TO_NAME.keys())
def _validate_color_hex(color_hex: str):
color_hex = color_hex.lstrip("#")
if not all(c in "0123456789abcdefABCDEF" for c in color_hex):
raise ValueError("Invalid characters in color hash")
if len(color_hex) not in (3, 6):
raise ValueError("Invalid length of color hash")
# copied from https://github.com/roboflow/supervision/blob/c8f557af0c61b5c03392bad2cc36c8835598b1e1/supervision/draw/color.py
@dataclass
class Color:
"""
Represents a color in RGB format.
Attributes:
r (int): Red channel.
g (int): Green channel.
b (int): Blue channel.
"""
r: int
g: int
b: int
@classmethod
def from_hex(cls, color_hex: str):
"""
Create a Color instance from a hex string.
Args:
color_hex (str): Hex string of the color.
Returns:
Color: Instance representing the color.
Example:
```
>>> Color.from_hex('#ff00ff')
Color(r=255, g=0, b=255)
```
"""
_validate_color_hex(color_hex)
color_hex = color_hex.lstrip("#")
if len(color_hex) == 3:
color_hex = "".join(c * 2 for c in color_hex)
r, g, b = (int(color_hex[i : i + 2], 16) for i in range(0, 6, 2))
return cls(r, g, b)
@classmethod
def to_hex(cls, color):
"""
Convert a Color instance to a hex string.
Args:
color (Color): Color instance of color.
Returns:
Color: a hex string.
"""
return rgb_to_hex((color.r, color.g, color.b))
def as_rgb(self) -> Tuple[int, int, int]:
"""
Returns the color as an RGB tuple.
Returns:
Tuple[int, int, int]: RGB tuple.
Example:
```
>>> color.as_rgb()
(255, 0, 255)
```
"""
return self.r, self.g, self.b
def as_bgr(self) -> Tuple[int, int, int]:
"""
Returns the color as a BGR tuple.
Returns:
Tuple[int, int, int]: BGR tuple.
Example:
```
>>> color.as_bgr()
(255, 0, 255)
```
"""
return self.b, self.g, self.r
@classmethod
def white(cls):
return Color.from_hex(color_hex="#ffffff")
@classmethod
def black(cls):
return Color.from_hex(color_hex="#000000")
@classmethod
def red(cls):
return Color.from_hex(color_hex="#ff0000")
@classmethod
def green(cls):
return Color.from_hex(color_hex="#00ff00")
@classmethod
def blue(cls):
return Color.from_hex(color_hex="#0000ff")
@dataclass
class ColorPalette:
colors: List[Color]
@classmethod
def default(cls):
"""
Returns a default color palette.
Returns:
ColorPalette: A ColorPalette instance with default colors.
Example:
```
>>> ColorPalette.default()
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
```
"""
return ColorPalette.from_hex(color_hex_list=DEFAULT_COLOR_PALETTE)
@classmethod
def from_hex(cls, color_hex_list: List[str]):
"""
Create a ColorPalette instance from a list of hex strings.
Args:
color_hex_list (List[str]): List of color hex strings.
Returns:
ColorPalette: A ColorPalette instance.
Example:
```
>>> ColorPalette.from_hex(['#ff0000', '#00ff00', '#0000ff'])
ColorPalette(colors=[Color(r=255, g=0, b=0), Color(r=0, g=255, b=0), ...])
```
"""
colors = [Color.from_hex(color_hex) for color_hex in color_hex_list]
return cls(colors)
def by_idx(self, idx: int) -> Color:
"""
Return the color at a given index in the palette.
Args:
idx (int): Index of the color in the palette.
Returns:
Color: Color at the given index.
Example:
```
>>> color_palette.by_idx(1)
Color(r=0, g=255, b=0)
```
"""
if idx < 0:
raise ValueError("idx argument should not be negative")
idx = idx % len(self.colors)
return self.colors[idx]
def find_farthest_color(self, img_array):
"""
Return the color that is the farthest from the given color.
Args:
img_array (np array): any *x3 np array, 3 is the RGB color channel.
Returns:
Color: Farthest color.
"""
# Reshape the image array for broadcasting
img_array = img_array.reshape((-1, 3))
# Convert colors dictionary to a NumPy array
color_values = np.array([[c.r, c.g, c.b] for c in self.colors])
# Calculate the Euclidean distance between the colors and each pixel in the image
# Broadcasting happens here: img_array shape is (num_pixels, 3), color_values shape is (num_colors, 3)
distances = np.sqrt(
np.sum((img_array[:, np.newaxis, :] - color_values) ** 2, axis=2)
)
# Average the distances for each color
mean_distances = np.mean(distances, axis=0)
# return the farthest color
farthest_idx = np.argmax(mean_distances)
farthest_color = self.colors[farthest_idx]
farthest_color_hex = Color.to_hex(farthest_color)
if farthest_color_hex in DEFAULT_COLOR_HEX_TO_NAME:
farthest_color_name = DEFAULT_COLOR_HEX_TO_NAME[farthest_color_hex]
else:
farthest_color_name = "unknown"
return farthest_color, farthest_color_name
def draw_box(ax, box_coord, alpha=0.8, edge_color="g", line_style="-", linewidth=2.0):
x0, y0, width, height = box_coord
ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=edge_color,
linewidth=linewidth,
alpha=alpha,
linestyle=line_style,
)
)
def draw_text(
ax,
text,
position,
font_size=None,
color="g",
horizontal_alignment="left",
rotation=0,
):
if not font_size:
font_size = mpl.rcParams["font.size"]
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
ax.text(
x,
y,
text,
size=font_size,
family="sans-serif",
bbox={"facecolor": "none", "alpha": 0.5, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
rotation=rotation,
)
def draw_mask(
ax, rle, color, show_holes=True, alpha=0.15, upsample_factor=1.0, rle_upsampled=None
):
if isinstance(rle, dict):
mask = mask_utils.decode(rle)
elif isinstance(rle, np.ndarray):
mask = rle
else:
raise ValueError(f"Unsupported type for rle: {type(rle)}")
mask_upsampled = None
if upsample_factor > 1.0 and show_holes:
assert rle_upsampled is not None
if isinstance(rle_upsampled, dict):
mask_upsampled = mask_utils.decode(rle_upsampled)
elif isinstance(rle_upsampled, np.ndarray):
mask_upsampled = rle_upsampled
else:
raise ValueError(f"Unsupported type for rle: {type(rle)}")
if show_holes:
if mask_upsampled is None:
mask_upsampled = mask
h, w = mask_upsampled.shape
mask_img = np.zeros((h, w, 4))
mask_img[:, :, :-1] = color[np.newaxis, np.newaxis, :]
mask_img[:, :, -1] = mask_upsampled * alpha
ax.imshow(mask_img)
*_, contours, _ = cv2.findContours(
mask.astype(np.uint8).copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
upsampled_contours = [(cont + 0.5) * upsample_factor - 0.5 for cont in contours]
facecolor = (0, 0, 0, 0) if show_holes else color
if alpha > 0.8:
edge_color = _change_color_brightness(color, brightness_factor=-0.7)
else:
edge_color = color
for cont in upsampled_contours:
polygon = mpl.patches.Polygon(
[el[0] for el in cont],
edgecolor=edge_color,
linewidth=2.0,
facecolor=facecolor,
)
ax.add_patch(polygon)
def _change_color_brightness(color, brightness_factor):
"""
Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
less or more saturation than the original color.
Args:
color: color of the polygon. Refer to `matplotlib.colors` for a full list of
formats that are accepted.
brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
0 will correspond to no change, a factor in [-1.0, 0) range will result in
a darker color and a factor in (0, 1.0] range will result in a lighter color.
Returns:
modified_color (tuple[double]): a tuple containing the RGB values of the
modified color. Each value in the tuple is in the [0.0, 1.0] range.
"""
assert brightness_factor >= -1.0 and brightness_factor <= 1.0
color = mplc.to_rgb(color)
polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
modified_color = colorsys.hls_to_rgb(
polygon_color[0], modified_lightness, polygon_color[2]
)
return modified_color