Files
sam3_local/fastAPI_tarot.py
2026-02-15 16:37:24 +08:00

411 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import time
import requests
import numpy as np
import cv2
from typing import Optional
from contextlib import asynccontextmanager
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
# ------------------- 配置与路径 -------------------
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- API Key 核心配置 (已加固) -------------------
VALID_API_KEY = "123quant-speed"
API_KEY_HEADER_NAME = "X-API-Key"
# 定义 Header 认证
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
"""
强制验证 API Key
"""
# 1. 检查是否有 Key
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API Key. Please provide it in the header."
)
# 2. 检查 Key 是否正确
if api_key != VALID_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key."
)
# 3. 验证通过
return True
# ------------------- 生命周期管理 -------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
print("="*40)
print("✅ API Key 保护已激活")
print(f"✅ 有效 Key: {VALID_API_KEY}")
print("="*40)
print("正在加载 SAM3 模型到 GPU...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
model = build_sam3_image_model()
model = model.to(device)
model.eval()
processor = Sam3Processor(model)
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield
print("正在清理资源...")
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(
lifespan=lifespan,
title="SAM3 Segmentation API",
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`",
)
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
app.openapi_schema = None
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# 定义安全方案
openapi_schema["components"]["securitySchemes"] = {
"APIKeyHeader": {
"type": "apiKey",
"in": "header",
"name": API_KEY_HEADER_NAME,
}
}
# 为所有路径应用安全要求
for path in openapi_schema["paths"]:
for method in openapi_schema["paths"][path]:
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ------------------- 辅助函数 -------------------
def order_points(pts):
"""
对四个坐标点进行排序:左上,右上,右下,左下
"""
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def four_point_transform(image, pts):
"""
根据四个点进行透视变换
"""
rect = order_points(pts)
(tl, tr, br, bl) = rect
# 计算新图像的宽度
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
# 计算新图像的高度
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
def load_image_from_url(url: str) -> Image.Image:
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR) -> list[dict]:
"""
根据 mask 和 box 进行透视矫正并裁剪出独立的对象图片 (保留透明背景)
返回包含文件名和元数据的列表
"""
saved_objects = []
# Convert image to numpy array (RGB)
img_arr = np.array(image)
for i, (mask, box) in enumerate(zip(masks, boxes)):
# Handle tensor/numpy conversions
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy().squeeze()
else:
mask_np = mask.squeeze()
# Ensure mask is uint8 binary for OpenCV
if mask_np.dtype == bool:
mask_uint8 = (mask_np * 255).astype(np.uint8)
else:
mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255
# Find contours
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
continue
# Get largest contour
c = max(contours, key=cv2.contourArea)
# Approximate contour to polygon
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
# If we have 4 points, use them. If not, fallback to minAreaRect
if len(approx) == 4:
pts = approx.reshape(4, 2)
else:
rect = cv2.minAreaRect(c)
pts = cv2.boxPoints(rect)
# Apply perspective transform
# 注意这里我们只变换RGB部分Alpha通道需要额外处理或者直接应用同样的变换
# 为了简单我们直接对原图假设不带Alpha进行变换
# 如果需要保留背景透明需要先将原图转为RGBA再做变换
# Check if original image has Alpha
if img_arr.shape[2] == 4:
warped = four_point_transform(img_arr, pts)
else:
# Add alpha channel from mask?
# 透视变换后的矩形本身就是去掉了背景的所以不需要额外的Mask Alpha
# 但是为了保持一致性我们可以给变换后的图加一个全不透明的Alpha或者保留RGB
warped = four_point_transform(img_arr, pts)
# Check orientation (Portrait vs Landscape)
h, w = warped.shape[:2]
is_rotated = False
# Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx)
if w > h:
# Rotate 90 degrees clockwise
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
is_rotated = True
# Convert back to PIL
pil_warped = Image.fromarray(warped)
# Save
filename = f"tarot_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
pil_warped.save(save_path)
# 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒)
# 这里我们假设长边垂直为正位,如果做了旋转则标记
# 真正的正逆位需要OCR或图像识别
saved_objects.append({
"filename": filename,
"is_rotated_by_algorithm": is_rotated,
"note": "Geometric correction applied. True upright/reversed requires content analysis."
})
return saved_objects
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(output_dir, filename)
plot_results(image, inference_state)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
return filename
# ------------------- API 接口 (强制依赖验证) -------------------
@app.post("/segment", dependencies=[Depends(verify_api_key)])
async def segment(
request: Request,
prompt: str = Form(...),
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None)
):
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
file_url = request.url_for("static", path=f"results/{filename}")
return JSONResponse(content={
"status": "success",
"result_image_url": str(file_url),
"detected_count": len(masks),
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
async def segment_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
expected_count: int = Form(3)
):
"""
塔罗牌分割专用接口
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
2. 如果是,分别抠出这些牌并返回
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
# 固定 Prompt 检测塔罗牌
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 核心逻辑:判断数量
detected_count = len(masks)
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
if detected_count != expected_count:
# 保存一张图用于调试/反馈
try:
filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
file_url = request.url_for("static", path=f"results/{request_id}/{filename}")
except:
file_url = None
return JSONResponse(
status_code=400,
content={
"status": "failed",
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
"detected_count": detected_count,
"debug_image_url": str(file_url) if file_url else None
}
)
# 数量正确,执行抠图
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 生成 URL 列表和元数据
tarot_cards = []
for obj in saved_objects:
fname = obj["filename"]
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
tarot_cards.append({
"url": file_url,
"is_rotated": obj["is_rotated_by_algorithm"],
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
"note": obj["note"]
})
# 生成整体效果图
try:
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
except:
main_file_url = None
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)",
"tarot_cards": tarot_cards,
"full_visualization": main_file_url,
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
if __name__ == "__main__":
import uvicorn
# 注意:如果你的文件名不是 fastAPI_tarot.py请修改下面第一个参数
uvicorn.run(
"fastAPI_tarot:app",
host="127.0.0.1",
port=55600,
proxy_headers=True,
forwarded_allow_ips="*",
reload=False # 生产环境建议关闭 reload确保代码完全重载
)