This commit is contained in:
2026-02-15 13:22:38 +08:00
parent 8bb00ac928
commit 3b5461371c
6 changed files with 365 additions and 32 deletions

190
test.py
View File

@@ -1,39 +1,165 @@
import torch
#################################### For Image ####################################
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
from PIL import Image
from sam3.model_builder import build_sam3_image_model
# 只保留SAM3实际存在的核心模块
from sam3.model_builder import build_sam3_image_model, build_sam3_video_predictor
from sam3.model.sam3_image_processor import Sam3Processor
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)
# Load an image
image = Image.open("/home/quant/data/dev/sam3-main/assets/player.gif")
inference_state = processor.set_image(image)
# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="pepole")
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
# ==================== 显存优化配置 ====================
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.empty_cache()
#################################### For Video ####################################
# 通用视频帧读取函数
def get_video_frames(video_path):
frame_paths = []
if os.path.isfile(video_path) and video_path.endswith(('.mp4', '.avi', '.mov')):
total_frames = int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT))
return frame_paths, total_frames
elif os.path.isdir(video_path):
frame_files = sorted([f for f in os.listdir(video_path) if f.endswith(('.jpg', '.png'))])
frame_paths = [os.path.join(video_path, f) for f in frame_files]
total_frames = len(frame_paths)
return frame_paths, total_frames
else:
raise ValueError(f"不支持的视频路径:{video_path}")
# from sam3.model_builder import build_sam3_video_predictor
# 缩小视频帧分辨率(减少显存占用)
def resize_frame(frame, max_side=640):
h, w = frame.shape[:2]
scale = max_side / max(h, w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
return frame
# video_predictor = build_sam3_video_predictor()
# video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
# # Start a session
# response = video_predictor.handle_request(
# request=dict(
# type="start_session",
# resource_path=video_path,
# )
# )
# response = video_predictor.handle_request(
# request=dict(
# type="add_prompt",
# session_id=response["session_id"],
# frame_index=0, # Arbitrary frame index
# text="<YOUR_TEXT_PROMPT>",
# )
# )
# output = response["outputs"]
def sam3_video_inference_low_memory(
video_path,
text_prompt="", # 空提示词让SAM3检测所有目标最易检出
save_dir="./sam3_video_results",
max_frames=5,
max_frame_side=640
):
"""
修复end_session + 提升检出率适配你的SAM3版本
"""
# 1. 初始化SAM3视频预测器
video_predictor = build_sam3_video_predictor()
os.makedirs(save_dir, exist_ok=True)
# 2. 启动SAM3会话移除自定义model_config避免和SAM3内部冲突
session_response = video_predictor.handle_request(
request=dict(
type="start_session",
resource_path=video_path
# 移除model_config你的SAM3版本会强制覆盖为max_num_objects=10000
)
)
session_id = session_response["session_id"]
print(f"[低显存模式] 会话启动成功ID: {session_id}")
# 3. 读取视频帧信息
frame_paths, total_frames = get_video_frames(video_path)
total_frames = min(total_frames, max_frames)
print(f"[低显存模式] 视频总帧数:{total_frames},本次处理前{total_frames}")
# 4. 首帧推理(核心优化:降低阈值+空提示词,提升检出率)
first_frame_idx = 0
prompt_response = video_predictor.handle_request(
request=dict(
type="add_prompt",
session_id=session_id,
frame_index=first_frame_idx,
text=text_prompt,
# 极低阈值强制检出所有可能目标从0.3→0.1
prompt_config=dict(
box_threshold=0.1,
mask_threshold=0.1
)
)
)
first_output = prompt_response["outputs"]
first_boxes = first_output.get("boxes", [])
first_scores = first_output.get("scores", [])
print(f"[低显存模式] 首帧({first_frame_idx})检出框数:{len(first_boxes)},分数:{first_scores}")
# 5. 处理帧(只处理首帧)
frame_save_paths = []
cap = cv2.VideoCapture(video_path) if os.path.isfile(video_path) else None
frame_idx = first_frame_idx
if cap is not None:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
print(f"[警告] 无法读取帧{frame_idx}")
cap.release()
return session_id, first_boxes
else:
frame = cv2.imread(frame_paths[frame_idx])
# 缩小帧分辨率
frame = resize_frame(frame, max_frame_side)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
# 画框即使框数为0也保存方便排查
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.imshow(frame_pil)
if len(first_boxes) > 0:
for box, score in zip(first_boxes, first_scores):
box_abs = normalize_bbox(box, frame_pil.width, frame_pil.height)
draw_box_on_image(
ax, box_abs,
label=f"SAM3 Score: {score:.2f}",
color="red", thickness=2
)
else:
ax.text(0.5, 0.5, "未检出目标", ha="center", va="center", fontsize=20, color="red")
ax.axis("off")
# 保存带框帧
frame_save_path = os.path.join(save_dir, f"sam3_frame_{frame_idx:04d}.jpg")
plt.savefig(frame_save_path, dpi=100, bbox_inches="tight")
plt.close(fig)
frame_save_paths.append(frame_save_path)
print(f"[低显存模式] 帧{frame_idx}处理完成,保存至:{frame_save_path}")
# 6. 释放资源移除end_session你的SAM3版本不支持
if cap is not None:
cap.release()
# 直接删除预测器+清理显存替代end_session
del video_predictor
torch.cuda.empty_cache()
print("[低显存模式] 推理完成资源已释放无需end_session")
return session_id, first_boxes
# ==================== 主函数 ====================
if __name__ == "__main__":
# 1. 图像推理
image_path = "/home/quant/data/dev/sam3/assets/images/groceries.jpg"
image_model = build_sam3_image_model()
image_processor = Sam3Processor(image_model)
image = Image.open(image_path).convert("RGB")
image = image.resize((640, 480), Image.Resampling.LANCZOS)
inference_state = image_processor.set_image(image)
image_output = image_processor.set_text_prompt(state=inference_state, prompt="food")
plot_results(image, inference_state)
plt.savefig("./sam3_image_food_result.jpg", dpi=100, bbox_inches='tight')
plt.close()
del image_model, image_processor
torch.cuda.empty_cache()
print("✅ 图像推理完成(低显存模式)")
# 2. 视频推理修复end_session + 提升检出率
video_path = "/home/quant/data/dev/sam3/assets/videos/bedroom.mp4"
sam3_video_inference_low_memory(
video_path=video_path,
text_prompt="", # 空提示词:检测所有目标(优先保证检出)
max_frames=5,
max_frame_side=640
)