apply Black 25.11.0 style in fbcode/deeplearning/projects (21/92)
Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: itamaro Differential Revision: D90476315 fbshipit-source-id: ee94c471788b8e7d067813d8b3e0311214d17f3f
This commit is contained in:
committed by
meta-codesync[bot]
parent
7b89b8fc3f
commit
11dec2936d
@@ -9,11 +9,8 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
|
||||
from sam3.model import box_ops
|
||||
|
||||
from sam3.model.data_misc import interpolate
|
||||
|
||||
from sam3.train.loss.sigmoid_focal_loss import (
|
||||
triton_sigmoid_focal_loss,
|
||||
triton_sigmoid_focal_loss_reduce,
|
||||
@@ -327,7 +324,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
if num_det_queries is not None:
|
||||
logging.warning("note: it's not needed to set num_det_queries anymore")
|
||||
if self.use_separate_loss_for_det_and_trk:
|
||||
assert not self.weak_loss, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
assert not self.weak_loss, (
|
||||
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
)
|
||||
self.det_exhaustive_loss_scale_pos = det_exhaustive_loss_scale_pos
|
||||
self.det_exhaustive_loss_scale_neg = det_exhaustive_loss_scale_neg
|
||||
self.det_non_exhaustive_loss_scale_pos = det_non_exhaustive_loss_scale_pos
|
||||
@@ -342,7 +341,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
and det_non_exhaustive_loss_scale_neg == 1.0
|
||||
and trk_loss_scale_pos == 1.0
|
||||
and trk_loss_scale_neg == 1.0
|
||||
), "If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||
), (
|
||||
"If not using separate loss for detection and tracking queries, separate detection and tracking loss scales should all be 1.0"
|
||||
)
|
||||
|
||||
def get_loss(self, outputs, targets, indices, num_boxes):
|
||||
assert len(outputs["pred_logits"].shape) > 2, "Incorrect predicted logits shape"
|
||||
@@ -443,7 +444,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
pass
|
||||
|
||||
if self.weak_loss:
|
||||
assert not self.use_separate_loss_for_det_and_trk, "Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
assert not self.use_separate_loss_for_det_and_trk, (
|
||||
"Do not use weak_loss in this case -- set separate loss for detection and tracking queries instead"
|
||||
)
|
||||
|
||||
# nullify the negative loss for the non-exhaustive classes
|
||||
assert loss_bce.shape[0] == targets["is_exhaustive"].shape[0]
|
||||
@@ -497,9 +500,9 @@ class IABCEMdetr(LossWithWeights):
|
||||
loss_bce = loss_bce.mean()
|
||||
else:
|
||||
assert isinstance(self.pad_n_queries, int)
|
||||
assert (
|
||||
loss_bce.size(1) < self.pad_n_queries
|
||||
), f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||
assert loss_bce.size(1) < self.pad_n_queries, (
|
||||
f"The number of predictions is more than the expected total after padding. Got {loss_bce.size(1)} predictions."
|
||||
)
|
||||
loss_bce = loss_bce.sum() / (self.pad_n_queries * loss_bce.size(0))
|
||||
|
||||
bce_f1 = torchmetrics.functional.f1_score(
|
||||
|
||||
Reference in New Issue
Block a user