This commit is contained in:
2026-02-15 13:48:11 +08:00
parent 8065619f5e
commit 6dc5d17a43
5 changed files with 203 additions and 44 deletions

View File

@@ -25,48 +25,60 @@ STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- 核心修改:图片压缩函数 -------------------
def compress_image(image: Image.Image, max_size: int = 1920, quality: int = 85) -> Image.Image:
"""
如果图片边长超过 max_size则按比例压缩。
:param image: PIL Image 对象
:param max_size: 图片最大边长 (宽或高)
:param quality: 仅用于保存时的参考,这里主要做尺寸压缩
:return: 压缩后的 PIL Image 对象
"""
width, height = image.size
# 如果图片本身就很小,直接返回
if width <= max_size and height <= max_size:
return image
# 计算缩放比例
if width > height:
new_width = max_size
new_height = int(height * (max_size / width))
else:
new_height = max_size
new_width = int(width * (max_size / height))
# 使用 LANCZOS 滤镜进行高质量下采样
print(f"压缩图片: {width}x{height} -> {new_width}x{new_height}")
return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# ------------------- 生命周期管理 -------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源
"""
print("正在加载 SAM3 模型到 GPU...")
# 1. 检测设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
# 2. 加载模型 (全局单例)
model = build_sam3_image_model()
model = model.to(device)
model.eval() # 切换到评估模式
# 3. 初始化 Processor
model.eval()
processor = Sam3Processor(model)
# 4. 存入 app.state 供全局访问
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield # 服务运行中...
# 清理资源 (如果需要)
yield
print("正在清理资源...")
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
# 挂载静态文件目录,用于通过 URL 访问生成的图片
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ------------------- 辅助函数 -------------------
def load_image_from_url(url: str) -> Image.Image:
"""从网络 URL 下载图片"""
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
@@ -77,18 +89,11 @@ def load_image_from_url(url: str) -> Image.Image:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def generate_and_save_result(image: Image.Image, inference_state) -> str:
"""生成可视化结果图并保存,返回文件名"""
# 生成唯一文件名防止冲突
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
# 绘图 (复用你提供的逻辑)
plot_results(image, inference_state)
# 保存
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close() # 务必关闭,防止内存泄漏
plt.close()
return filename
# ------------------- API 接口 -------------------
@@ -99,44 +104,41 @@ async def segment(
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None)
):
"""
接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。
"""
# 1. 校验输入
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
raise HTTPException(status_code=400, detail="必须提供 file 或 image_url")
# 2. 获取图片对象
# 1. 获取图片对象
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
# ========== 关键修改位置 ==========
# 在送入模型前,强制压缩图片
image = compress_image(image, max_size=1920)
# ===================================
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
# 3. 获取模型
# 2. 获取模型
processor = request.app.state.processor
# 4. 执行推理
# 3. 执行推理
try:
# 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 5. 生成可视化并保存
# 4. 生成可视化并保存
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
# 6. 构建返回 URL
# request.url_for 会自动根据当前域名生成正确的访问链接
file_url = request.url_for("static", path=f"results/{filename}")
return JSONResponse(content={
@@ -148,12 +150,10 @@ async def segment(
if __name__ == "__main__":
import uvicorn
# 使用 Python 函数参数的方式传递配置
uvicorn.run(
"fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名
"fastAPI_main:app",
host="127.0.0.1",
port=55600,
proxy_headers=True, # 对应 --proxy-headers
forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*"
proxy_headers=True,
forwarded_allow_ips="*"
)

159
fastAPI_nocom.py Normal file
View File

@@ -0,0 +1,159 @@
import os
import uuid
import requests
from typing import Optional
from contextlib import asynccontextmanager
import torch
import matplotlib
# 关键:设置非交互式后端,避免服务器环境下报错
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from PIL import Image
# SAM3 相关导入 (请确保你的环境中已正确安装 sam3)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
# ------------------- 配置与路径 -------------------
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- 生命周期管理 -------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源
"""
print("正在加载 SAM3 模型到 GPU...")
# 1. 检测设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
# 2. 加载模型 (全局单例)
model = build_sam3_image_model()
model = model.to(device)
model.eval() # 切换到评估模式
# 3. 初始化 Processor
processor = Sam3Processor(model)
# 4. 存入 app.state 供全局访问
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield # 服务运行中...
# 清理资源 (如果需要)
print("正在清理资源...")
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
# 挂载静态文件目录,用于通过 URL 访问生成的图片
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ------------------- 辅助函数 -------------------
def load_image_from_url(url: str) -> Image.Image:
"""从网络 URL 下载图片"""
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def generate_and_save_result(image: Image.Image, inference_state) -> str:
"""生成可视化结果图并保存,返回文件名"""
# 生成唯一文件名防止冲突
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
# 绘图 (复用你提供的逻辑)
plot_results(image, inference_state)
# 保存
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close() # 务必关闭,防止内存泄漏
return filename
# ------------------- API 接口 -------------------
@app.post("/segment")
async def segment(
request: Request,
prompt: str = Form(...),
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None)
):
"""
接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。
"""
# 1. 校验输入
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# 2. 获取图片对象
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
# 3. 获取模型
processor = request.app.state.processor
# 4. 执行推理
try:
# 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 5. 生成可视化并保存
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
# 6. 构建返回 URL
# request.url_for 会自动根据当前域名生成正确的访问链接
file_url = request.url_for("static", path=f"results/{filename}")
return JSONResponse(content={
"status": "success",
"result_image_url": str(file_url),
"detected_count": len(masks),
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
if __name__ == "__main__":
import uvicorn
# 使用 Python 函数参数的方式传递配置
uvicorn.run(
"fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名
host="127.0.0.1",
port=55600,
proxy_headers=True, # 对应 --proxy-headers
forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*"
)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB