local
This commit is contained in:
@@ -561,8 +561,8 @@ def build_sam3_image_model(
|
|||||||
bpe_path=None,
|
bpe_path=None,
|
||||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
eval_mode=True,
|
eval_mode=True,
|
||||||
checkpoint_path=None,
|
checkpoint_path="/home/quant/data/dev/sam3/sam3.pt",
|
||||||
load_from_HF=True,
|
load_from_HF=False,
|
||||||
enable_segmentation=True,
|
enable_segmentation=True,
|
||||||
enable_inst_interactivity=False,
|
enable_inst_interactivity=False,
|
||||||
compile=False,
|
compile=False,
|
||||||
@@ -651,8 +651,8 @@ def download_ckpt_from_hf():
|
|||||||
|
|
||||||
|
|
||||||
def build_sam3_video_model(
|
def build_sam3_video_model(
|
||||||
checkpoint_path: Optional[str] = None,
|
checkpoint_path: Optional[str] = "/home/quant/data/dev/sam3/sam3.pt",
|
||||||
load_from_HF=True,
|
load_from_HF=False,
|
||||||
bpe_path: Optional[str] = None,
|
bpe_path: Optional[str] = None,
|
||||||
has_presence_token: bool = True,
|
has_presence_token: bool = True,
|
||||||
geo_encoder_use_img_cross_attn: bool = True,
|
geo_encoder_use_img_cross_attn: bool = True,
|
||||||
|
|||||||
39
test.py
Normal file
39
test.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
#################################### For Image ####################################
|
||||||
|
from PIL import Image
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
# Load the model
|
||||||
|
model = build_sam3_image_model()
|
||||||
|
processor = Sam3Processor(model)
|
||||||
|
# Load an image
|
||||||
|
image = Image.open("/home/quant/data/dev/sam3-main/assets/player.gif")
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
# Prompt the model with text
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt="pepole")
|
||||||
|
|
||||||
|
# Get the masks, bounding boxes, and scores
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
|
||||||
|
#################################### For Video ####################################
|
||||||
|
|
||||||
|
# from sam3.model_builder import build_sam3_video_predictor
|
||||||
|
|
||||||
|
# video_predictor = build_sam3_video_predictor()
|
||||||
|
# video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
|
||||||
|
# # Start a session
|
||||||
|
# response = video_predictor.handle_request(
|
||||||
|
# request=dict(
|
||||||
|
# type="start_session",
|
||||||
|
# resource_path=video_path,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# response = video_predictor.handle_request(
|
||||||
|
# request=dict(
|
||||||
|
# type="add_prompt",
|
||||||
|
# session_id=response["session_id"],
|
||||||
|
# frame_index=0, # Arbitrary frame index
|
||||||
|
# text="<YOUR_TEXT_PROMPT>",
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# output = response["outputs"]
|
||||||
Reference in New Issue
Block a user