diff --git a/sam3/model/sam3_tracker_base.py b/sam3/model/sam3_tracker_base.py index 1591f32..90fbd69 100644 --- a/sam3/model/sam3_tracker_base.py +++ b/sam3/model/sam3_tracker_base.py @@ -900,8 +900,6 @@ class Sam3TrackerBase(torch.nn.Module): image=current_image, point_inputs=backbone_out["point_inputs_per_frame"].get(stage_id, None), mask_inputs=backbone_out["mask_inputs_per_frame"].get(stage_id, None), - gt_masks=backbone_out["gt_masks_per_frame"].get(stage_id, None), - frames_to_add_correction_pt=frames_to_add_correction_pt, output_dict=output_dict, num_frames=num_frames, ) diff --git a/sam3/model/sam3_tracking_predictor.py b/sam3/model/sam3_tracking_predictor.py index 28ab2bd..b2440ef 100644 --- a/sam3/model/sam3_tracking_predictor.py +++ b/sam3/model/sam3_tracking_predictor.py @@ -657,8 +657,6 @@ class Sam3TrackerPredictor(Sam3TrackerBase): image=image, point_inputs=None, mask_inputs=mask_inputs, - gt_masks=None, - frames_to_add_correction_pt=[], output_dict={ "cond_frame_outputs": {}, "non_cond_frame_outputs": {},