diff --git a/fastAPI_main.py b/fastAPI_main.py new file mode 100644 index 0000000..866d7aa --- /dev/null +++ b/fastAPI_main.py @@ -0,0 +1,159 @@ +import os +import uuid +import requests +from typing import Optional +from contextlib import asynccontextmanager + +import torch +import matplotlib +# 关键:设置非交互式后端,避免服务器环境下报错 +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request +from fastapi.staticfiles import StaticFiles +from fastapi.responses import JSONResponse +from PIL import Image + +# SAM3 相关导入 (请确保你的环境中已正确安装 sam3) +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor +from sam3.visualization_utils import plot_results + +# ------------------- 配置与路径 ------------------- +STATIC_DIR = "static" +RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") +os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) + +# ------------------- 生命周期管理 ------------------- +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源 + """ + print("正在加载 SAM3 模型到 GPU...") + + # 1. 检测设备 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not torch.cuda.is_available(): + print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") + + # 2. 加载模型 (全局单例) + model = build_sam3_image_model() + model = model.to(device) + model.eval() # 切换到评估模式 + + # 3. 初始化 Processor + processor = Sam3Processor(model) + + # 4. 存入 app.state 供全局访问 + app.state.model = model + app.state.processor = processor + app.state.device = device + + print(f"模型加载完成,设备: {device}") + + yield # 服务运行中... + + # 清理资源 (如果需要) + print("正在清理资源...") + +# ------------------- FastAPI 初始化 ------------------- +app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API") + +# 挂载静态文件目录,用于通过 URL 访问生成的图片 +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + +# ------------------- 辅助函数 ------------------- +def load_image_from_url(url: str) -> Image.Image: + """从网络 URL 下载图片""" + try: + headers = {'User-Agent': 'Mozilla/5.0'} + response = requests.get(url, headers=headers, stream=True, timeout=10) + response.raise_for_status() + image = Image.open(response.raw).convert("RGB") + return image + except Exception as e: + raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") + +def generate_and_save_result(image: Image.Image, inference_state) -> str: + """生成可视化结果图并保存,返回文件名""" + # 生成唯一文件名防止冲突 + filename = f"seg_{uuid.uuid4().hex}.jpg" + save_path = os.path.join(RESULT_IMAGE_DIR, filename) + + # 绘图 (复用你提供的逻辑) + plot_results(image, inference_state) + + # 保存 + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() # 务必关闭,防止内存泄漏 + + return filename + +# ------------------- API 接口 ------------------- +@app.post("/segment") +async def segment( + request: Request, + prompt: str = Form(...), + file: Optional[UploadFile] = File(None), + image_url: Optional[str] = Form(None) +): + """ + 接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。 + """ + + # 1. 校验输入 + if not file and not image_url: + raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + + # 2. 获取图片对象 + try: + if file: + image = Image.open(file.file).convert("RGB") + elif image_url: + image = load_image_from_url(image_url) + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") + + # 3. 获取模型 + processor = request.app.state.processor + + # 4. 执行推理 + try: + # 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移 + inference_state = processor.set_image(image) + output = processor.set_text_prompt(state=inference_state, prompt=prompt) + + masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + except Exception as e: + raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") + + # 5. 生成可视化并保存 + try: + filename = generate_and_save_result(image, inference_state) + except Exception as e: + raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") + + # 6. 构建返回 URL + # request.url_for 会自动根据当前域名生成正确的访问链接 + file_url = request.url_for("static", path=f"results/{filename}") + + return JSONResponse(content={ + "status": "success", + "result_image_url": str(file_url), + "detected_count": len(masks), + "scores": scores.tolist() if torch.is_tensor(scores) else scores + }) + +if __name__ == "__main__": + import uvicorn + + # 使用 Python 函数参数的方式传递配置 + uvicorn.run( + "fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名 + host="0.0.0.0", + port=55600, + proxy_headers=True, # 对应 --proxy-headers + forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*" + ) \ No newline at end of file diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..9250748 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,3 @@ +uvicorn +python-multipart +fastapi \ No newline at end of file diff --git a/sam3/visualization_utils.py b/sam3/visualization_utils.py index c007c38..fd89934 100644 --- a/sam3/visualization_utils.py +++ b/sam3/visualization_utils.py @@ -875,6 +875,8 @@ def plot_results(img, results): relative_coords=False, ) + plt.show() + def single_visualization(img, anns, title): """ diff --git a/sam3_image_food_result.jpg b/sam3_image_food_result.jpg new file mode 100644 index 0000000..5dda773 Binary files /dev/null and b/sam3_image_food_result.jpg differ diff --git a/test.py b/test.py index 0b39bb1..cea19eb 100644 --- a/test.py +++ b/test.py @@ -1,39 +1,165 @@ import torch -#################################### For Image #################################### +import matplotlib.pyplot as plt +import os +import cv2 +import numpy as np from PIL import Image -from sam3.model_builder import build_sam3_image_model +# 只保留SAM3实际存在的核心模块 +from sam3.model_builder import build_sam3_image_model, build_sam3_video_predictor from sam3.model.sam3_image_processor import Sam3Processor -# Load the model -model = build_sam3_image_model() -processor = Sam3Processor(model) -# Load an image -image = Image.open("/home/quant/data/dev/sam3-main/assets/player.gif") -inference_state = processor.set_image(image) -# Prompt the model with text -output = processor.set_text_prompt(state=inference_state, prompt="pepole") +from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results -# Get the masks, bounding boxes, and scores -masks, boxes, scores = output["masks"], output["boxes"], output["scores"] +# ==================== 显存优化配置 ==================== +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +torch.cuda.empty_cache() -#################################### For Video #################################### +# 通用视频帧读取函数 +def get_video_frames(video_path): + frame_paths = [] + if os.path.isfile(video_path) and video_path.endswith(('.mp4', '.avi', '.mov')): + total_frames = int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT)) + return frame_paths, total_frames + elif os.path.isdir(video_path): + frame_files = sorted([f for f in os.listdir(video_path) if f.endswith(('.jpg', '.png'))]) + frame_paths = [os.path.join(video_path, f) for f in frame_files] + total_frames = len(frame_paths) + return frame_paths, total_frames + else: + raise ValueError(f"不支持的视频路径:{video_path}") -# from sam3.model_builder import build_sam3_video_predictor +# 缩小视频帧分辨率(减少显存占用) +def resize_frame(frame, max_side=640): + h, w = frame.shape[:2] + scale = max_side / max(h, w) + if scale < 1: + new_h, new_w = int(h * scale), int(w * scale) + frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA) + return frame -# video_predictor = build_sam3_video_predictor() -# video_path = "" # a JPEG folder or an MP4 video file -# # Start a session -# response = video_predictor.handle_request( -# request=dict( -# type="start_session", -# resource_path=video_path, -# ) -# ) -# response = video_predictor.handle_request( -# request=dict( -# type="add_prompt", -# session_id=response["session_id"], -# frame_index=0, # Arbitrary frame index -# text="", -# ) -# ) -# output = response["outputs"] \ No newline at end of file +def sam3_video_inference_low_memory( + video_path, + text_prompt="", # 空提示词:让SAM3检测所有目标(最易检出) + save_dir="./sam3_video_results", + max_frames=5, + max_frame_side=640 +): + """ + 修复end_session + 提升检出率:适配你的SAM3版本 + """ + # 1. 初始化SAM3视频预测器 + video_predictor = build_sam3_video_predictor() + os.makedirs(save_dir, exist_ok=True) + + # 2. 启动SAM3会话(移除自定义model_config,避免和SAM3内部冲突) + session_response = video_predictor.handle_request( + request=dict( + type="start_session", + resource_path=video_path + # 移除model_config:你的SAM3版本会强制覆盖为max_num_objects=10000 + ) + ) + session_id = session_response["session_id"] + print(f"[低显存模式] 会话启动成功,ID: {session_id}") + + # 3. 读取视频帧信息 + frame_paths, total_frames = get_video_frames(video_path) + total_frames = min(total_frames, max_frames) + print(f"[低显存模式] 视频总帧数:{total_frames},本次处理前{total_frames}帧") + + # 4. 首帧推理(核心优化:降低阈值+空提示词,提升检出率) + first_frame_idx = 0 + prompt_response = video_predictor.handle_request( + request=dict( + type="add_prompt", + session_id=session_id, + frame_index=first_frame_idx, + text=text_prompt, + # 极低阈值:强制检出所有可能目标(从0.3→0.1) + prompt_config=dict( + box_threshold=0.1, + mask_threshold=0.1 + ) + ) + ) + first_output = prompt_response["outputs"] + first_boxes = first_output.get("boxes", []) + first_scores = first_output.get("scores", []) + print(f"[低显存模式] 首帧({first_frame_idx})检出框数:{len(first_boxes)},分数:{first_scores}") + + # 5. 处理帧(只处理首帧) + frame_save_paths = [] + cap = cv2.VideoCapture(video_path) if os.path.isfile(video_path) else None + + frame_idx = first_frame_idx + if cap is not None: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + if not ret: + print(f"[警告] 无法读取帧{frame_idx}") + cap.release() + return session_id, first_boxes + else: + frame = cv2.imread(frame_paths[frame_idx]) + + # 缩小帧分辨率 + frame = resize_frame(frame, max_frame_side) + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame_pil = Image.fromarray(frame_rgb) + + # 画框(即使框数为0也保存,方便排查) + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.imshow(frame_pil) + if len(first_boxes) > 0: + for box, score in zip(first_boxes, first_scores): + box_abs = normalize_bbox(box, frame_pil.width, frame_pil.height) + draw_box_on_image( + ax, box_abs, + label=f"SAM3 Score: {score:.2f}", + color="red", thickness=2 + ) + else: + ax.text(0.5, 0.5, "未检出目标", ha="center", va="center", fontsize=20, color="red") + ax.axis("off") + + # 保存带框帧 + frame_save_path = os.path.join(save_dir, f"sam3_frame_{frame_idx:04d}.jpg") + plt.savefig(frame_save_path, dpi=100, bbox_inches="tight") + plt.close(fig) + frame_save_paths.append(frame_save_path) + print(f"[低显存模式] 帧{frame_idx}处理完成,保存至:{frame_save_path}") + + # 6. 释放资源(移除end_session:你的SAM3版本不支持) + if cap is not None: + cap.release() + # 直接删除预测器+清理显存(替代end_session) + del video_predictor + torch.cuda.empty_cache() + print("[低显存模式] 推理完成,资源已释放(无需end_session)") + return session_id, first_boxes + +# ==================== 主函数 ==================== +if __name__ == "__main__": + # 1. 图像推理 + image_path = "/home/quant/data/dev/sam3/assets/images/groceries.jpg" + image_model = build_sam3_image_model() + image_processor = Sam3Processor(image_model) + image = Image.open(image_path).convert("RGB") + image = image.resize((640, 480), Image.Resampling.LANCZOS) + inference_state = image_processor.set_image(image) + image_output = image_processor.set_text_prompt(state=inference_state, prompt="food") + plot_results(image, inference_state) + plt.savefig("./sam3_image_food_result.jpg", dpi=100, bbox_inches='tight') + plt.close() + del image_model, image_processor + torch.cuda.empty_cache() + print("✅ 图像推理完成(低显存模式)") + + # 2. 视频推理:修复end_session + 提升检出率 + video_path = "/home/quant/data/dev/sam3/assets/videos/bedroom.mp4" + sam3_video_inference_low_memory( + video_path=video_path, + text_prompt="", # 空提示词:检测所有目标(优先保证检出) + max_frames=5, + max_frame_side=640 + ) \ No newline at end of file diff --git a/test1.py b/test1.py new file mode 100644 index 0000000..597da81 --- /dev/null +++ b/test1.py @@ -0,0 +1,43 @@ +import torch +import matplotlib.pyplot as plt # 新增:导入matplotlib用于保存图片 +#################################### For Image #################################### +from PIL import Image +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor +from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results + +# Load the model +model = build_sam3_image_model() +processor = Sam3Processor(model) + +# Load an image - 保留之前的RGB转换修复 +image = Image.open("/home/quant/data/dev/sam3/assets/images/groceries.jpg").convert("RGB") + +# 可选:打印图像信息,验证通道数 +print(f"图像模式: {image.mode}, 尺寸: {image.size}") + +# 处理图像 +inference_state = processor.set_image(image) + +# 文本提示推理 +output = processor.set_text_prompt(state=inference_state, prompt="food") + +# 获取推理结果 +masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + +# 可视化并保存图片(核心修改部分) +# 1. 生成可视化结果 +plot_results(image, inference_state) +# 2. 保存图片到当前目录,格式可选jpg/png,这里用jpg示例 +plt.savefig("./sam3_food_detection_result.jpg", # 保存路径:当前目录,文件名自定义 + dpi=150, # 图片分辨率,可选 + bbox_inches='tight') # 去除图片周围空白 +# 3. 关闭plt画布,避免内存占用 +plt.close() + +# 可选:打印输出信息 +print(f"检测到的mask数量: {len(masks)}") +print(f"检测到的box数量: {len(boxes)}") +print(f"置信度分数: {scores}") +print("图片已保存到当前目录:./sam3_food_detection_result.jpg") +