import os import uuid import requests from typing import Optional from contextlib import asynccontextmanager import torch import matplotlib # 关键:设置非交互式后端,避免服务器环境下报错 matplotlib.use('Agg') import matplotlib.pyplot as plt from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from PIL import Image # SAM3 相关导入 (请确保你的环境中已正确安装 sam3) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results # ------------------- 配置与路径 ------------------- STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) # ------------------- 核心修改:图片压缩函数 ------------------- def compress_image(image: Image.Image, max_size: int = 1920, quality: int = 85) -> Image.Image: """ 如果图片边长超过 max_size,则按比例压缩。 :param image: PIL Image 对象 :param max_size: 图片最大边长 (宽或高) :param quality: 仅用于保存时的参考,这里主要做尺寸压缩 :return: 压缩后的 PIL Image 对象 """ width, height = image.size # 如果图片本身就很小,直接返回 if width <= max_size and height <= max_size: return image # 计算缩放比例 if width > height: new_width = max_size new_height = int(height * (max_size / width)) else: new_height = max_size new_width = int(width * (max_size / height)) # 使用 LANCZOS 滤镜进行高质量下采样 print(f"压缩图片: {width}x{height} -> {new_width}x{new_height}") return image.resize((new_width, new_height), Image.Resampling.LANCZOS) # ------------------- 生命周期管理 ------------------- @asynccontextmanager async def lifespan(app: FastAPI): print("正在加载 SAM3 模型到 GPU...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") model = build_sam3_image_model() model = model.to(device) model.eval() processor = Sam3Processor(model) app.state.model = model app.state.processor = processor app.state.device = device print(f"模型加载完成,设备: {device}") yield print("正在清理资源...") # ------------------- FastAPI 初始化 ------------------- app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # ------------------- 辅助函数 ------------------- def load_image_from_url(url: str) -> Image.Image: try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) response.raise_for_status() image = Image.open(response.raw).convert("RGB") return image except Exception as e: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") def generate_and_save_result(image: Image.Image, inference_state) -> str: filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(RESULT_IMAGE_DIR, filename) plot_results(image, inference_state) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() return filename # ------------------- API 接口 ------------------- @app.post("/segment") async def segment( request: Request, prompt: str = Form(...), file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None) ): if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file 或 image_url") # 1. 获取图片对象 try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) # ========== 关键修改位置 ========== # 在送入模型前,强制压缩图片 image = compress_image(image, max_size=1920) # =================================== except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") # 2. 获取模型 processor = request.app.state.processor # 3. 执行推理 try: inference_state = processor.set_image(image) output = processor.set_text_prompt(state=inference_state, prompt=prompt) masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 4. 生成可视化并保存 try: filename = generate_and_save_result(image, inference_state) except Exception as e: raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") file_url = request.url_for("static", path=f"results/{filename}") return JSONResponse(content={ "status": "success", "result_image_url": str(file_url), "detected_count": len(masks), "scores": scores.tolist() if torch.is_tensor(scores) else scores }) if __name__ == "__main__": import uvicorn uvicorn.run( "fastAPI_main:app", host="127.0.0.1", port=55600, proxy_headers=True, forwarded_allow_ips="*" )