Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
This commit is contained in:
203
sam3/train/loss/sam3_loss.py
Normal file
203
sam3/train/loss/sam3_loss.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
||||
|
||||
import torch
|
||||
|
||||
from sam3.model.model_misc import SAM3Output
|
||||
|
||||
from sam3.train.utils.distributed import get_world_size
|
||||
|
||||
from .loss_fns import CORE_LOSS_KEY, Det2TrkAssoc, Masks
|
||||
|
||||
|
||||
class DummyLoss(torch.nn.Module):
|
||||
"""A dummy loss that always returns 0 (as a placeholder for eval)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
core_loss_key: str = CORE_LOSS_KEY,
|
||||
device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.core_loss_key = core_loss_key
|
||||
self.device = torch.device(device)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return {self.core_loss_key: torch.tensor(0.0, device=self.device)}
|
||||
|
||||
def accumulate(self, out_dict):
|
||||
"""
|
||||
Called by iterative losses.
|
||||
"""
|
||||
if self.core_loss_key not in out_dict:
|
||||
out_dict[self.core_loss_key] = torch.tensor(0.0, device=self.device)
|
||||
return out_dict
|
||||
|
||||
|
||||
class Sam3LossWrapper(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
loss_fns_find,
|
||||
normalization="global",
|
||||
matcher=None,
|
||||
o2m_matcher=None,
|
||||
o2m_weight=1.0,
|
||||
use_o2m_matcher_on_o2m_aux=True,
|
||||
loss_fn_semantic_seg=None,
|
||||
normalize_by_valid_object_num=False,
|
||||
normalize_by_stage_num=False,
|
||||
scale_by_find_batch_size=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_fns_find = loss_fns_find
|
||||
assert normalization in ["global", "local", "none"]
|
||||
self.normalization = normalization
|
||||
self.normalize_by_valid_object_num = normalize_by_valid_object_num
|
||||
self.normalize_by_stage_num = normalize_by_stage_num
|
||||
self.matcher = matcher
|
||||
self.o2m_matcher = o2m_matcher
|
||||
self.o2m_weight = o2m_weight
|
||||
# whether to use the o2m matcher on the o2m queries in auxiliary outputs
|
||||
self.use_o2m_matcher_on_o2m_aux = use_o2m_matcher_on_o2m_aux
|
||||
self.loss_fn_semantic_seg = loss_fn_semantic_seg
|
||||
self.scale_by_find_batch_size = scale_by_find_batch_size
|
||||
|
||||
def _get_num_boxes(self, targets):
|
||||
# the average number of target boxes for loss normalization
|
||||
if self.normalize_by_valid_object_num:
|
||||
# valid boxes are those with non-zero height and width
|
||||
# (while padded invisible boxes are )
|
||||
boxes_hw = targets["boxes"].view(-1, 4) # cx, cy, w, h
|
||||
num_boxes = (boxes_hw[:, 2:] > 0).all(dim=-1).sum().float()
|
||||
else:
|
||||
num_boxes = targets["num_boxes"].sum().float()
|
||||
if self.normalization == "global":
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1)
|
||||
elif self.normalization == "local":
|
||||
num_boxes = torch.clamp(num_boxes, min=1)
|
||||
elif self.normalization == "none":
|
||||
num_boxes = 1
|
||||
return num_boxes
|
||||
|
||||
def compute_loss(self, nested_out, targets):
|
||||
num_boxes = self._get_num_boxes(targets)
|
||||
o2m_out_is_valid = nested_out.get("o2m_out_is_valid", None)
|
||||
o2m_target_is_valid_padded = nested_out.get("o2m_target_is_valid_padded", None)
|
||||
|
||||
# Get a list of outputs, including auxiliary and first stage outputs
|
||||
output_list = [(nested_out, "", False)] # (out, suffix, is_aux)
|
||||
if "aux_outputs" in nested_out:
|
||||
output_list.extend(
|
||||
(aux_out, f"_aux_{i}", True)
|
||||
for i, aux_out in enumerate(nested_out["aux_outputs"])
|
||||
)
|
||||
if "first_stage" in nested_out:
|
||||
output_list.append((nested_out["first_stage"], "_fs", True))
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
total_core_loss = 0.0
|
||||
for out, suffix, is_aux in output_list:
|
||||
# o2o matcher indices need to be computed by the model (as the video model requires
|
||||
# a specific way of matching free and locked indices beyond just calling the matcher)
|
||||
indices = out["indices"]
|
||||
has_o2m_out = "pred_logits_o2m" in out
|
||||
if has_o2m_out:
|
||||
o2m_out = {
|
||||
k[: -len("_o2m")]: v for k, v in out.items() if k.endswith("_o2m")
|
||||
}
|
||||
# o2m targets are the same as the o2o targets (assuming repeat=1)
|
||||
o2m_targets = targets
|
||||
if self.use_o2m_matcher_on_o2m_aux or not is_aux:
|
||||
o2m_indices = self.o2m_matcher(
|
||||
o2m_out,
|
||||
o2m_targets,
|
||||
out_is_valid=o2m_out_is_valid,
|
||||
target_is_valid_padded=o2m_target_is_valid_padded,
|
||||
)
|
||||
else:
|
||||
o2m_indices = self.matcher(
|
||||
o2m_out,
|
||||
o2m_targets,
|
||||
out_is_valid=o2m_out_is_valid,
|
||||
target_is_valid_padded=o2m_target_is_valid_padded,
|
||||
)
|
||||
|
||||
for loss_fn in self.loss_fns_find:
|
||||
l_dict = loss_fn(
|
||||
outputs=out,
|
||||
targets=targets,
|
||||
indices=indices,
|
||||
num_boxes=num_boxes,
|
||||
is_aux=is_aux,
|
||||
)
|
||||
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
|
||||
losses.update({f"{k}{suffix}": v for k, v in l_dict.items()})
|
||||
|
||||
compute_o2m_loss = has_o2m_out
|
||||
# a special handling to allow turning off mask loss in o2m
|
||||
# (to be compatible with the original implementation)
|
||||
if isinstance(loss_fn, Masks):
|
||||
compute_o2m_loss = compute_o2m_loss and "pred_masks" in o2m_out
|
||||
if isinstance(loss_fn, Det2TrkAssoc):
|
||||
compute_o2m_loss = False # Det2TrkAssoc does not support o2m
|
||||
if compute_o2m_loss:
|
||||
l_dict = loss_fn(
|
||||
outputs=o2m_out,
|
||||
targets=o2m_targets,
|
||||
indices=o2m_indices,
|
||||
num_boxes=num_boxes,
|
||||
is_aux=is_aux,
|
||||
)
|
||||
for k in l_dict:
|
||||
l_dict[k] *= self.o2m_weight
|
||||
total_core_loss += l_dict.pop(CORE_LOSS_KEY)
|
||||
losses.update({f"{k}{suffix}_o2m": v for k, v in l_dict.items()})
|
||||
|
||||
losses[CORE_LOSS_KEY] = total_core_loss
|
||||
return losses
|
||||
|
||||
def forward(self, find_stages: SAM3Output, find_targets):
|
||||
if find_stages.loss_stages is not None:
|
||||
find_targets = [find_targets[i] for i in find_stages.loss_stages]
|
||||
with SAM3Output.iteration_mode(
|
||||
find_stages, iter_mode=SAM3Output.IterMode.ALL_STEPS_PER_STAGE
|
||||
) as find_stages:
|
||||
assert len(find_stages) == len(find_targets)
|
||||
total_losses = {}
|
||||
for stage_outputs, stage_targets in zip(find_stages, find_targets):
|
||||
stage_targets = [stage_targets] * len(stage_outputs)
|
||||
# If there are multiple steps within a stage, compute the loss for all of them (e.g. interactivity)
|
||||
for outputs, targets in zip(stage_outputs, stage_targets):
|
||||
cur_losses = self.compute_loss(outputs, targets)
|
||||
|
||||
if self.loss_fn_semantic_seg is not None:
|
||||
cur_losses_semantic = self.loss_fn_semantic_seg(
|
||||
outputs, targets
|
||||
)
|
||||
cur_losses[CORE_LOSS_KEY] += cur_losses_semantic.pop(
|
||||
CORE_LOSS_KEY
|
||||
)
|
||||
# make sure the semantic losses don't overlap with the find losses
|
||||
assert set(cur_losses).isdisjoint(set(cur_losses_semantic))
|
||||
cur_losses.update(cur_losses_semantic)
|
||||
|
||||
# Optionally, normalize the loss by the number of find stages (training video frames) so that
|
||||
# image batches and video batches have similar loss scales. (Otherwise video batches would
|
||||
# have a much higher loss scale due to summing the losses over all the find stages.)
|
||||
if self.normalize_by_stage_num:
|
||||
cur_losses[CORE_LOSS_KEY] /= len(find_stages)
|
||||
|
||||
if self.scale_by_find_batch_size:
|
||||
bs = targets["num_boxes"].shape[0]
|
||||
# sqrt scaling based on the "effective" batch size
|
||||
cur_losses[CORE_LOSS_KEY] *= bs**0.5
|
||||
|
||||
for k, v in cur_losses.items():
|
||||
if k not in total_losses:
|
||||
total_losses[k] = v
|
||||
else:
|
||||
total_losses[k] += v
|
||||
|
||||
return total_losses
|
||||
Reference in New Issue
Block a user