import torch import matplotlib.pyplot as plt import os import cv2 import numpy as np from PIL import Image # 只保留SAM3实际存在的核心模块 from sam3.model_builder import build_sam3_image_model, build_sam3_video_predictor from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results # ==================== 显存优化配置 ==================== os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["CUDA_VISIBLE_DEVICES"] = "0" torch.cuda.empty_cache() # 通用视频帧读取函数 def get_video_frames(video_path): frame_paths = [] if os.path.isfile(video_path) and video_path.endswith(('.mp4', '.avi', '.mov')): total_frames = int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)) return frame_paths, total_frames elif os.path.isdir(video_path): frame_files = sorted([f for f in os.listdir(video_path) if f.endswith(('.jpg', '.png'))]) frame_paths = [os.path.join(video_path, f) for f in frame_files] total_frames = len(frame_paths) return frame_paths, total_frames else: raise ValueError(f"不支持的视频路径:{video_path}") # 缩小视频帧分辨率(减少显存占用) def resize_frame(frame, max_side=640): h, w = frame.shape[:2] scale = max_side / max(h, w) if scale < 1: new_h, new_w = int(h * scale), int(w * scale) frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA) return frame def sam3_video_inference_low_memory( video_path, text_prompt="", # 空提示词:让SAM3检测所有目标(最易检出) save_dir="./sam3_video_results", max_frames=5, max_frame_side=640 ): """ 修复end_session + 提升检出率:适配你的SAM3版本 """ # 1. 初始化SAM3视频预测器 video_predictor = build_sam3_video_predictor() os.makedirs(save_dir, exist_ok=True) # 2. 启动SAM3会话(移除自定义model_config,避免和SAM3内部冲突) session_response = video_predictor.handle_request( request=dict( type="start_session", resource_path=video_path # 移除model_config:你的SAM3版本会强制覆盖为max_num_objects=10000 ) ) session_id = session_response["session_id"] print(f"[低显存模式] 会话启动成功,ID: {session_id}") # 3. 读取视频帧信息 frame_paths, total_frames = get_video_frames(video_path) total_frames = min(total_frames, max_frames) print(f"[低显存模式] 视频总帧数:{total_frames},本次处理前{total_frames}帧") # 4. 首帧推理(核心优化:降低阈值+空提示词,提升检出率) first_frame_idx = 0 prompt_response = video_predictor.handle_request( request=dict( type="add_prompt", session_id=session_id, frame_index=first_frame_idx, text=text_prompt, # 极低阈值:强制检出所有可能目标(从0.3→0.1) prompt_config=dict( box_threshold=0.1, mask_threshold=0.1 ) ) ) first_output = prompt_response["outputs"] first_boxes = first_output.get("boxes", []) first_scores = first_output.get("scores", []) print(f"[低显存模式] 首帧({first_frame_idx})检出框数:{len(first_boxes)},分数:{first_scores}") # 5. 处理帧(只处理首帧) frame_save_paths = [] cap = cv2.VideoCapture(video_path) if os.path.isfile(video_path) else None frame_idx = first_frame_idx if cap is not None: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() if not ret: print(f"[警告] 无法读取帧{frame_idx}") cap.release() return session_id, first_boxes else: frame = cv2.imread(frame_paths[frame_idx]) # 缩小帧分辨率 frame = resize_frame(frame, max_frame_side) frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame_rgb) # 画框(即使框数为0也保存,方便排查) fig, ax = plt.subplots(1, 1, figsize=(8, 6)) ax.imshow(frame_pil) if len(first_boxes) > 0: for box, score in zip(first_boxes, first_scores): box_abs = normalize_bbox(box, frame_pil.width, frame_pil.height) draw_box_on_image( ax, box_abs, label=f"SAM3 Score: {score:.2f}", color="red", thickness=2 ) else: ax.text(0.5, 0.5, "未检出目标", ha="center", va="center", fontsize=20, color="red") ax.axis("off") # 保存带框帧 frame_save_path = os.path.join(save_dir, f"sam3_frame_{frame_idx:04d}.jpg") plt.savefig(frame_save_path, dpi=100, bbox_inches="tight") plt.close(fig) frame_save_paths.append(frame_save_path) print(f"[低显存模式] 帧{frame_idx}处理完成,保存至:{frame_save_path}") # 6. 释放资源(移除end_session:你的SAM3版本不支持) if cap is not None: cap.release() # 直接删除预测器+清理显存(替代end_session) del video_predictor torch.cuda.empty_cache() print("[低显存模式] 推理完成,资源已释放(无需end_session)") return session_id, first_boxes # ==================== 主函数 ==================== if __name__ == "__main__": # 1. 图像推理 image_path = "/home/quant/data/dev/sam3/assets/images/groceries.jpg" image_model = build_sam3_image_model() image_processor = Sam3Processor(image_model) image = Image.open(image_path).convert("RGB") image = image.resize((640, 480), Image.Resampling.LANCZOS) inference_state = image_processor.set_image(image) image_output = image_processor.set_text_prompt(state=inference_state, prompt="food") plot_results(image, inference_state) plt.savefig("./sam3_image_food_result.jpg", dpi=100, bbox_inches='tight') plt.close() del image_model, image_processor torch.cuda.empty_cache() print("✅ 图像推理完成(低显存模式)") # 2. 视频推理:修复end_session + 提升检出率 video_path = "/home/quant/data/dev/sam3/assets/videos/bedroom.mp4" sam3_video_inference_low_memory( video_path=video_path, text_prompt="", # 空提示词:检测所有目标(优先保证检出) max_frames=5, max_frame_side=640 )