159 lines
5.4 KiB
Python
159 lines
5.4 KiB
Python
import os
|
||
import uuid
|
||
import requests
|
||
from typing import Optional
|
||
from contextlib import asynccontextmanager
|
||
|
||
import torch
|
||
import matplotlib
|
||
# 关键:设置非交互式后端,避免服务器环境下报错
|
||
matplotlib.use('Agg')
|
||
import matplotlib.pyplot as plt
|
||
|
||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import JSONResponse
|
||
from PIL import Image
|
||
|
||
# SAM3 相关导入 (请确保你的环境中已正确安装 sam3)
|
||
from sam3.model_builder import build_sam3_image_model
|
||
from sam3.model.sam3_image_processor import Sam3Processor
|
||
from sam3.visualization_utils import plot_results
|
||
|
||
# ------------------- 配置与路径 -------------------
|
||
STATIC_DIR = "static"
|
||
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||
|
||
# ------------------- 核心修改:图片压缩函数 -------------------
|
||
def compress_image(image: Image.Image, max_size: int = 1920, quality: int = 85) -> Image.Image:
|
||
"""
|
||
如果图片边长超过 max_size,则按比例压缩。
|
||
:param image: PIL Image 对象
|
||
:param max_size: 图片最大边长 (宽或高)
|
||
:param quality: 仅用于保存时的参考,这里主要做尺寸压缩
|
||
:return: 压缩后的 PIL Image 对象
|
||
"""
|
||
width, height = image.size
|
||
|
||
# 如果图片本身就很小,直接返回
|
||
if width <= max_size and height <= max_size:
|
||
return image
|
||
|
||
# 计算缩放比例
|
||
if width > height:
|
||
new_width = max_size
|
||
new_height = int(height * (max_size / width))
|
||
else:
|
||
new_height = max_size
|
||
new_width = int(width * (max_size / height))
|
||
|
||
# 使用 LANCZOS 滤镜进行高质量下采样
|
||
print(f"压缩图片: {width}x{height} -> {new_width}x{new_height}")
|
||
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||
|
||
# ------------------- 生命周期管理 -------------------
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
print("正在加载 SAM3 模型到 GPU...")
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
if not torch.cuda.is_available():
|
||
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||
|
||
model = build_sam3_image_model()
|
||
model = model.to(device)
|
||
model.eval()
|
||
processor = Sam3Processor(model)
|
||
|
||
app.state.model = model
|
||
app.state.processor = processor
|
||
app.state.device = device
|
||
|
||
print(f"模型加载完成,设备: {device}")
|
||
yield
|
||
print("正在清理资源...")
|
||
|
||
# ------------------- FastAPI 初始化 -------------------
|
||
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
|
||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||
|
||
# ------------------- 辅助函数 -------------------
|
||
def load_image_from_url(url: str) -> Image.Image:
|
||
try:
|
||
headers = {'User-Agent': 'Mozilla/5.0'}
|
||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||
response.raise_for_status()
|
||
image = Image.open(response.raw).convert("RGB")
|
||
return image
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
||
|
||
def generate_and_save_result(image: Image.Image, inference_state) -> str:
|
||
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
|
||
plot_results(image, inference_state)
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
return filename
|
||
|
||
# ------------------- API 接口 -------------------
|
||
@app.post("/segment")
|
||
async def segment(
|
||
request: Request,
|
||
prompt: str = Form(...),
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None)
|
||
):
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file 或 image_url")
|
||
|
||
# 1. 获取图片对象
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
|
||
# ========== 关键修改位置 ==========
|
||
# 在送入模型前,强制压缩图片
|
||
image = compress_image(image, max_size=1920)
|
||
# ===================================
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
# 2. 获取模型
|
||
processor = request.app.state.processor
|
||
|
||
# 3. 执行推理
|
||
try:
|
||
inference_state = processor.set_image(image)
|
||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
# 4. 生成可视化并保存
|
||
try:
|
||
filename = generate_and_save_result(image, inference_state)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
||
|
||
file_url = request.url_for("static", path=f"results/{filename}")
|
||
|
||
return JSONResponse(content={
|
||
"status": "success",
|
||
"result_image_url": str(file_url),
|
||
"detected_count": len(masks),
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
})
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(
|
||
"fastAPI_main:app",
|
||
host="127.0.0.1",
|
||
port=55600,
|
||
proxy_headers=True,
|
||
forwarded_allow_ips="*"
|
||
) |