Files
sam3_local/fastAPI_nocom.py
2026-02-15 13:48:11 +08:00

160 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import requests
from typing import Optional
from contextlib import asynccontextmanager
import torch
import matplotlib
# 关键:设置非交互式后端,避免服务器环境下报错
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from PIL import Image
# SAM3 相关导入 (请确保你的环境中已正确安装 sam3)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
# ------------------- 配置与路径 -------------------
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- 生命周期管理 -------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源
"""
print("正在加载 SAM3 模型到 GPU...")
# 1. 检测设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
# 2. 加载模型 (全局单例)
model = build_sam3_image_model()
model = model.to(device)
model.eval() # 切换到评估模式
# 3. 初始化 Processor
processor = Sam3Processor(model)
# 4. 存入 app.state 供全局访问
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield # 服务运行中...
# 清理资源 (如果需要)
print("正在清理资源...")
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
# 挂载静态文件目录,用于通过 URL 访问生成的图片
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ------------------- 辅助函数 -------------------
def load_image_from_url(url: str) -> Image.Image:
"""从网络 URL 下载图片"""
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def generate_and_save_result(image: Image.Image, inference_state) -> str:
"""生成可视化结果图并保存,返回文件名"""
# 生成唯一文件名防止冲突
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
# 绘图 (复用你提供的逻辑)
plot_results(image, inference_state)
# 保存
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close() # 务必关闭,防止内存泄漏
return filename
# ------------------- API 接口 -------------------
@app.post("/segment")
async def segment(
request: Request,
prompt: str = Form(...),
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None)
):
"""
接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。
"""
# 1. 校验输入
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# 2. 获取图片对象
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
# 3. 获取模型
processor = request.app.state.processor
# 4. 执行推理
try:
# 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 5. 生成可视化并保存
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
# 6. 构建返回 URL
# request.url_for 会自动根据当前域名生成正确的访问链接
file_url = request.url_for("static", path=f"results/{filename}")
return JSONResponse(content={
"status": "success",
"result_image_url": str(file_url),
"detected_count": len(masks),
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
if __name__ == "__main__":
import uvicorn
# 使用 Python 函数参数的方式传递配置
uvicorn.run(
"fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名
host="127.0.0.1",
port=55600,
proxy_headers=True, # 对应 --proxy-headers
forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*"
)