Files
sam3_local/fastAPI_main.py
2026-02-15 13:48:11 +08:00

159 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import requests
from typing import Optional
from contextlib import asynccontextmanager
import torch
import matplotlib
# 关键:设置非交互式后端,避免服务器环境下报错
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from PIL import Image
# SAM3 相关导入 (请确保你的环境中已正确安装 sam3)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
# ------------------- 配置与路径 -------------------
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- 核心修改:图片压缩函数 -------------------
def compress_image(image: Image.Image, max_size: int = 1920, quality: int = 85) -> Image.Image:
"""
如果图片边长超过 max_size则按比例压缩。
:param image: PIL Image 对象
:param max_size: 图片最大边长 (宽或高)
:param quality: 仅用于保存时的参考,这里主要做尺寸压缩
:return: 压缩后的 PIL Image 对象
"""
width, height = image.size
# 如果图片本身就很小,直接返回
if width <= max_size and height <= max_size:
return image
# 计算缩放比例
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
# 使用 LANCZOS 滤镜进行高质量下采样
print(f"压缩图片: {width}x{height} -> {new_width}x{new_height}")
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# ------------------- 生命周期管理 -------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
print("正在加载 SAM3 模型到 GPU...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
model = build_sam3_image_model()
model = model.to(device)
model.eval()
processor = Sam3Processor(model)
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield
print("正在清理资源...")
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ------------------- 辅助函数 -------------------
def load_image_from_url(url: str) -> Image.Image:
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def generate_and_save_result(image: Image.Image, inference_state) -> str:
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
plot_results(image, inference_state)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
return filename
# ------------------- API 接口 -------------------
@app.post("/segment")
async def segment(
request: Request,
prompt: str = Form(...),
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None)
):
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file 或 image_url")
# 1. 获取图片对象
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
# ========== 关键修改位置 ==========
# 在送入模型前,强制压缩图片
image = compress_image(image, max_size=1920)
# ===================================
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
# 2. 获取模型
processor = request.app.state.processor
# 3. 执行推理
try:
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 4. 生成可视化并保存
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
file_url = request.url_for("static", path=f"results/{filename}")
return JSONResponse(content={
"status": "success",
"result_image_url": str(file_url),
"detected_count": len(masks),
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"fastAPI_main:app",
host="127.0.0.1",
port=55600,
proxy_headers=True,
forwarded_allow_ips="*"
)