Files
sam3_local/test.py
2026-02-15 13:22:38 +08:00

165 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import matplotlib.pyplot as plt
import os
import cv2
import numpy as np
from PIL import Image
# 只保留SAM3实际存在的核心模块
from sam3.model_builder import build_sam3_image_model, build_sam3_video_predictor
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
# ==================== 显存优化配置 ====================
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.empty_cache()
# 通用视频帧读取函数
def get_video_frames(video_path):
frame_paths = []
if os.path.isfile(video_path) and video_path.endswith(('.mp4', '.avi', '.mov')):
total_frames = int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT))
return frame_paths, total_frames
elif os.path.isdir(video_path):
frame_files = sorted([f for f in os.listdir(video_path) if f.endswith(('.jpg', '.png'))])
frame_paths = [os.path.join(video_path, f) for f in frame_files]
total_frames = len(frame_paths)
return frame_paths, total_frames
else:
raise ValueError(f"不支持的视频路径:{video_path}")
# 缩小视频帧分辨率(减少显存占用)
def resize_frame(frame, max_side=640):
h, w = frame.shape[:2]
scale = max_side / max(h, w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
return frame
def sam3_video_inference_low_memory(
video_path,
text_prompt="", # 空提示词让SAM3检测所有目标最易检出
save_dir="./sam3_video_results",
max_frames=5,
max_frame_side=640
):
"""
修复end_session + 提升检出率适配你的SAM3版本
"""
# 1. 初始化SAM3视频预测器
video_predictor = build_sam3_video_predictor()
os.makedirs(save_dir, exist_ok=True)
# 2. 启动SAM3会话移除自定义model_config避免和SAM3内部冲突
session_response = video_predictor.handle_request(
request=dict(
type="start_session",
resource_path=video_path
# 移除model_config你的SAM3版本会强制覆盖为max_num_objects=10000
)
)
session_id = session_response["session_id"]
print(f"[低显存模式] 会话启动成功ID: {session_id}")
# 3. 读取视频帧信息
frame_paths, total_frames = get_video_frames(video_path)
total_frames = min(total_frames, max_frames)
print(f"[低显存模式] 视频总帧数:{total_frames},本次处理前{total_frames}")
# 4. 首帧推理(核心优化:降低阈值+空提示词,提升检出率)
first_frame_idx = 0
prompt_response = video_predictor.handle_request(
request=dict(
type="add_prompt",
session_id=session_id,
frame_index=first_frame_idx,
text=text_prompt,
# 极低阈值强制检出所有可能目标从0.3→0.1
prompt_config=dict(
box_threshold=0.1,
mask_threshold=0.1
)
)
)
first_output = prompt_response["outputs"]
first_boxes = first_output.get("boxes", [])
first_scores = first_output.get("scores", [])
print(f"[低显存模式] 首帧({first_frame_idx})检出框数:{len(first_boxes)},分数:{first_scores}")
# 5. 处理帧(只处理首帧)
frame_save_paths = []
cap = cv2.VideoCapture(video_path) if os.path.isfile(video_path) else None
frame_idx = first_frame_idx
if cap is not None:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if not ret:
print(f"[警告] 无法读取帧{frame_idx}")
cap.release()
return session_id, first_boxes
else:
frame = cv2.imread(frame_paths[frame_idx])
# 缩小帧分辨率
frame = resize_frame(frame, max_frame_side)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
# 画框即使框数为0也保存方便排查
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.imshow(frame_pil)
if len(first_boxes) > 0:
for box, score in zip(first_boxes, first_scores):
box_abs = normalize_bbox(box, frame_pil.width, frame_pil.height)
draw_box_on_image(
ax, box_abs,
label=f"SAM3 Score: {score:.2f}",
color="red", thickness=2
)
else:
ax.text(0.5, 0.5, "未检出目标", ha="center", va="center", fontsize=20, color="red")
ax.axis("off")
# 保存带框帧
frame_save_path = os.path.join(save_dir, f"sam3_frame_{frame_idx:04d}.jpg")
plt.savefig(frame_save_path, dpi=100, bbox_inches="tight")
plt.close(fig)
frame_save_paths.append(frame_save_path)
print(f"[低显存模式] 帧{frame_idx}处理完成,保存至:{frame_save_path}")
# 6. 释放资源移除end_session你的SAM3版本不支持
if cap is not None:
cap.release()
# 直接删除预测器+清理显存替代end_session
del video_predictor
torch.cuda.empty_cache()
print("[低显存模式] 推理完成资源已释放无需end_session")
return session_id, first_boxes
# ==================== 主函数 ====================
if __name__ == "__main__":
# 1. 图像推理
image_path = "/home/quant/data/dev/sam3/assets/images/groceries.jpg"
image_model = build_sam3_image_model()
image_processor = Sam3Processor(image_model)
image = Image.open(image_path).convert("RGB")
image = image.resize((640, 480), Image.Resampling.LANCZOS)
inference_state = image_processor.set_image(image)
image_output = image_processor.set_text_prompt(state=inference_state, prompt="food")
plot_results(image, inference_state)
plt.savefig("./sam3_image_food_result.jpg", dpi=100, bbox_inches='tight')
plt.close()
del image_model, image_processor
torch.cuda.empty_cache()
print("✅ 图像推理完成(低显存模式)")
# 2. 视频推理修复end_session + 提升检出率
video_path = "/home/quant/data/dev/sam3/assets/videos/bedroom.mp4"
sam3_video_inference_low_memory(
video_path=video_path,
text_prompt="", # 空提示词:检测所有目标(优先保证检出)
max_frames=5,
max_frame_side=640
)