Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
157 lines
6.7 KiB
Python
157 lines
6.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
|
|
# pyre-unsafe
|
|
|
|
import numpy as np
|
|
import pycocotools.mask as mask_utils
|
|
import torch
|
|
import torchvision.transforms.functional as F
|
|
from PIL import Image as PILImage
|
|
from sam3.model.box_ops import masks_to_boxes
|
|
from sam3.train.data.sam3_image_dataset import Datapoint
|
|
|
|
|
|
class InstanceToSemantic(object):
|
|
"""Convert instance segmentation to semantic segmentation."""
|
|
|
|
def __init__(self, delete_instance=True, use_rle=False):
|
|
self.delete_instance = delete_instance
|
|
self.use_rle = use_rle
|
|
|
|
def __call__(self, datapoint: Datapoint, **kwargs):
|
|
for fquery in datapoint.find_queries:
|
|
h, w = datapoint.images[fquery.image_id].size
|
|
|
|
if self.use_rle:
|
|
all_segs = [
|
|
datapoint.images[fquery.image_id].objects[obj_id].segment
|
|
for obj_id in fquery.object_ids_output
|
|
]
|
|
if len(all_segs) > 0:
|
|
# we need to double check that all rles are the correct size
|
|
# Otherwise cocotools will fail silently to an empty [0,0] mask
|
|
for seg in all_segs:
|
|
assert seg["size"] == all_segs[0]["size"], (
|
|
"Instance segments have inconsistent sizes. "
|
|
f"Found sizes {seg['size']} and {all_segs[0]['size']}"
|
|
)
|
|
fquery.semantic_target = mask_utils.merge(all_segs)
|
|
else:
|
|
# There is no good way to create an empty RLE of the correct size
|
|
# We resort to converting an empty box to RLE
|
|
fquery.semantic_target = mask_utils.frPyObjects(
|
|
np.array([[0, 0, 0, 0]], dtype=np.float64), h, w
|
|
)[0]
|
|
|
|
else:
|
|
# `semantic_target` is uint8 and remains uint8 throughout the transforms
|
|
# (it contains binary 0 and 1 values just like `segment` for each object)
|
|
fquery.semantic_target = torch.zeros((h, w), dtype=torch.uint8)
|
|
for obj_id in fquery.object_ids_output:
|
|
segment = datapoint.images[fquery.image_id].objects[obj_id].segment
|
|
if segment is not None:
|
|
assert (
|
|
isinstance(segment, torch.Tensor)
|
|
and segment.dtype == torch.uint8
|
|
)
|
|
fquery.semantic_target |= segment
|
|
|
|
if self.delete_instance:
|
|
for img in datapoint.images:
|
|
for obj in img.objects:
|
|
del obj.segment
|
|
obj.segment = None
|
|
|
|
return datapoint
|
|
|
|
|
|
class RecomputeBoxesFromMasks:
|
|
"""Recompute bounding boxes from masks."""
|
|
|
|
def __call__(self, datapoint: Datapoint, **kwargs):
|
|
for img in datapoint.images:
|
|
for obj in img.objects:
|
|
# Note: if the mask is empty, the bounding box will be undefined
|
|
# The empty targets should be subsequently filtered
|
|
obj.bbox = masks_to_boxes(obj.segment)
|
|
obj.area = obj.segment.sum().item()
|
|
|
|
return datapoint
|
|
|
|
|
|
class DecodeRle:
|
|
"""This transform decodes RLEs into binary segments.
|
|
Implementing it as a transforms allows lazy loading. Some transforms (eg query filters)
|
|
may be deleting masks, so decoding them from the beginning is wasteful.
|
|
|
|
This transforms needs to be called before any kind of geometric manipulation
|
|
"""
|
|
|
|
def __call__(self, datapoint: Datapoint, **kwargs):
|
|
imgId2size = {}
|
|
warning_shown = False
|
|
for imgId, img in enumerate(datapoint.images):
|
|
if isinstance(img.data, PILImage.Image):
|
|
img_w, img_h = img.data.size
|
|
elif isinstance(img.data, torch.Tensor):
|
|
img_w, img_h = img.data.shape[-2:]
|
|
else:
|
|
raise RuntimeError(f"Unexpected image type {type(img.data)}")
|
|
|
|
imgId2size[imgId] = (img_h, img_w)
|
|
|
|
for obj in img.objects:
|
|
if obj.segment is not None and not isinstance(
|
|
obj.segment, torch.Tensor
|
|
):
|
|
if mask_utils.area(obj.segment) == 0:
|
|
print("Warning, empty mask found, approximating from box")
|
|
obj.segment = torch.zeros(img_h, img_w, dtype=torch.uint8)
|
|
x1, y1, x2, y2 = obj.bbox.int().tolist()
|
|
obj.segment[y1 : max(y2, y1 + 1), x1 : max(x1 + 1, x2)] = 1
|
|
else:
|
|
obj.segment = mask_utils.decode(obj.segment)
|
|
# segment is uint8 and remains uint8 throughout the transforms
|
|
obj.segment = torch.tensor(obj.segment).to(torch.uint8)
|
|
|
|
if list(obj.segment.shape) != [img_h, img_w]:
|
|
# Should not happen often, but adding for security
|
|
if not warning_shown:
|
|
print(
|
|
f"Warning expected instance segmentation size to be {[img_h, img_w]} but found {list(obj.segment.shape)}"
|
|
)
|
|
# Printing only once per datapoint to avoid spam
|
|
warning_shown = True
|
|
|
|
obj.segment = F.resize(
|
|
obj.segment[None], (img_h, img_w)
|
|
).squeeze(0)
|
|
|
|
assert list(obj.segment.shape) == [img_h, img_w]
|
|
|
|
warning_shown = False
|
|
for query in datapoint.find_queries:
|
|
if query.semantic_target is not None and not isinstance(
|
|
query.semantic_target, torch.Tensor
|
|
):
|
|
query.semantic_target = mask_utils.decode(query.semantic_target)
|
|
# segment is uint8 and remains uint8 throughout the transforms
|
|
query.semantic_target = torch.tensor(query.semantic_target).to(
|
|
torch.uint8
|
|
)
|
|
if tuple(query.semantic_target.shape) != imgId2size[query.image_id]:
|
|
if not warning_shown:
|
|
print(
|
|
f"Warning expected semantic segmentation size to be {imgId2size[query.image_id]} but found {tuple(query.semantic_target.shape)}"
|
|
)
|
|
# Printing only once per datapoint to avoid spam
|
|
warning_shown = True
|
|
|
|
query.semantic_target = F.resize(
|
|
query.semantic_target[None], imgId2size[query.image_id]
|
|
).squeeze(0)
|
|
|
|
assert tuple(query.semantic_target.shape) == imgId2size[query.image_id]
|
|
|
|
return datapoint
|