diff --git a/sam3/model_builder.py b/sam3/model_builder.py index 103b324..7662b68 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -561,8 +561,8 @@ def build_sam3_image_model( bpe_path=None, device="cuda" if torch.cuda.is_available() else "cpu", eval_mode=True, - checkpoint_path=None, - load_from_HF=True, + checkpoint_path="/home/quant/data/dev/sam3/sam3.pt", + load_from_HF=False, enable_segmentation=True, enable_inst_interactivity=False, compile=False, @@ -651,8 +651,8 @@ def download_ckpt_from_hf(): def build_sam3_video_model( - checkpoint_path: Optional[str] = None, - load_from_HF=True, + checkpoint_path: Optional[str] = "/home/quant/data/dev/sam3/sam3.pt", + load_from_HF=False, bpe_path: Optional[str] = None, has_presence_token: bool = True, geo_encoder_use_img_cross_attn: bool = True, diff --git a/test.py b/test.py new file mode 100644 index 0000000..0b39bb1 --- /dev/null +++ b/test.py @@ -0,0 +1,39 @@ +import torch +#################################### For Image #################################### +from PIL import Image +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor +# Load the model +model = build_sam3_image_model() +processor = Sam3Processor(model) +# Load an image +image = Image.open("/home/quant/data/dev/sam3-main/assets/player.gif") +inference_state = processor.set_image(image) +# Prompt the model with text +output = processor.set_text_prompt(state=inference_state, prompt="pepole") + +# Get the masks, bounding boxes, and scores +masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + +#################################### For Video #################################### + +# from sam3.model_builder import build_sam3_video_predictor + +# video_predictor = build_sam3_video_predictor() +# video_path = "" # a JPEG folder or an MP4 video file +# # Start a session +# response = video_predictor.handle_request( +# request=dict( +# type="start_session", +# resource_path=video_path, +# ) +# ) +# response = video_predictor.handle_request( +# request=dict( +# type="add_prompt", +# session_id=response["session_id"], +# frame_index=0, # Arbitrary frame index +# text="", +# ) +# ) +# output = response["outputs"] \ No newline at end of file