tarot
330
fastAPI_tarot.py
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
from sam3.visualization_utils import plot_results
|
||||||
|
|
||||||
|
# ------------------- 配置与路径 -------------------
|
||||||
|
STATIC_DIR = "static"
|
||||||
|
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||||||
|
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# ------------------- API Key 核心配置 (已加固) -------------------
|
||||||
|
VALID_API_KEY = "123quant-speed"
|
||||||
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
# 定义 Header 认证
|
||||||
|
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
||||||
|
"""
|
||||||
|
强制验证 API Key
|
||||||
|
"""
|
||||||
|
# 1. 检查是否有 Key
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing API Key. Please provide it in the header."
|
||||||
|
)
|
||||||
|
# 2. 检查 Key 是否正确
|
||||||
|
if api_key != VALID_API_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API Key."
|
||||||
|
)
|
||||||
|
# 3. 验证通过
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ------------------- 生命周期管理 -------------------
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
print("="*40)
|
||||||
|
print("✅ API Key 保护已激活")
|
||||||
|
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
||||||
|
print("="*40)
|
||||||
|
|
||||||
|
print("正在加载 SAM3 模型到 GPU...")
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||||||
|
|
||||||
|
model = build_sam3_image_model()
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
processor = Sam3Processor(model)
|
||||||
|
|
||||||
|
app.state.model = model
|
||||||
|
app.state.processor = processor
|
||||||
|
app.state.device = device
|
||||||
|
|
||||||
|
print(f"模型加载完成,设备: {device}")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
print("正在清理资源...")
|
||||||
|
|
||||||
|
# ------------------- FastAPI 初始化 -------------------
|
||||||
|
app = FastAPI(
|
||||||
|
lifespan=lifespan,
|
||||||
|
title="SAM3 Segmentation API",
|
||||||
|
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
|
||||||
|
app.openapi_schema = None
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
# 定义安全方案
|
||||||
|
openapi_schema["components"]["securitySchemes"] = {
|
||||||
|
"APIKeyHeader": {
|
||||||
|
"type": "apiKey",
|
||||||
|
"in": "header",
|
||||||
|
"name": API_KEY_HEADER_NAME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# 为所有路径应用安全要求
|
||||||
|
for path in openapi_schema["paths"]:
|
||||||
|
for method in openapi_schema["paths"][path]:
|
||||||
|
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
|
# ------------------- 辅助函数 -------------------
|
||||||
|
def load_image_from_url(url: str) -> Image.Image:
|
||||||
|
try:
|
||||||
|
headers = {'User-Agent': 'Mozilla/5.0'}
|
||||||
|
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
image = Image.open(response.raw).convert("RGB")
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
||||||
|
|
||||||
|
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR) -> list[str]:
|
||||||
|
"""
|
||||||
|
根据 mask 和 box 裁剪出独立的对象图片 (保留透明背景)
|
||||||
|
"""
|
||||||
|
saved_files = []
|
||||||
|
# Convert image to numpy array
|
||||||
|
img_arr = np.array(image) # RGB (H, W, 3)
|
||||||
|
|
||||||
|
for i, (mask, box) in enumerate(zip(masks, boxes)):
|
||||||
|
# Handle tensor/numpy conversions
|
||||||
|
if isinstance(mask, torch.Tensor):
|
||||||
|
mask_np = mask.cpu().numpy().squeeze()
|
||||||
|
else:
|
||||||
|
mask_np = mask.squeeze()
|
||||||
|
|
||||||
|
if isinstance(box, torch.Tensor):
|
||||||
|
box_np = box.cpu().numpy()
|
||||||
|
else:
|
||||||
|
box_np = box
|
||||||
|
|
||||||
|
# Get coordinates
|
||||||
|
x1, y1, x2, y2 = map(int, box_np)
|
||||||
|
|
||||||
|
# Ensure coordinates are within bounds
|
||||||
|
x1 = max(0, x1)
|
||||||
|
y1 = max(0, y1)
|
||||||
|
x2 = min(image.width, x2)
|
||||||
|
y2 = min(image.height, y2)
|
||||||
|
|
||||||
|
# Check valid crop
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create Alpha channel from mask (0 or 255)
|
||||||
|
# mask_np is boolean or float 0..1. If boolean, *255 -> 0/255.
|
||||||
|
alpha = (mask_np * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# Combine RGB and Alpha
|
||||||
|
rgba = np.dstack((img_arr, alpha))
|
||||||
|
|
||||||
|
# Convert back to PIL for cropping
|
||||||
|
pil_rgba = Image.fromarray(rgba)
|
||||||
|
|
||||||
|
# Crop to bounding box
|
||||||
|
cropped = pil_rgba.crop((x1, y1, x2, y2))
|
||||||
|
|
||||||
|
# Save
|
||||||
|
filename = f"tarot_{uuid.uuid4().hex}_{i}.png" # Use png for transparency
|
||||||
|
save_path = os.path.join(output_dir, filename)
|
||||||
|
cropped.save(save_path)
|
||||||
|
saved_files.append(filename)
|
||||||
|
|
||||||
|
return saved_files
|
||||||
|
|
||||||
|
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
|
||||||
|
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||||
|
save_path = os.path.join(output_dir, filename)
|
||||||
|
plot_results(image, inference_state)
|
||||||
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
return filename
|
||||||
|
|
||||||
|
# ------------------- API 接口 (强制依赖验证) -------------------
|
||||||
|
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
||||||
|
async def segment(
|
||||||
|
request: Request,
|
||||||
|
prompt: str = Form(...),
|
||||||
|
file: Optional[UploadFile] = File(None),
|
||||||
|
image_url: Optional[str] = Form(None)
|
||||||
|
):
|
||||||
|
if not file and not image_url:
|
||||||
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file:
|
||||||
|
image = Image.open(file.file).convert("RGB")
|
||||||
|
elif image_url:
|
||||||
|
image = load_image_from_url(image_url)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||||||
|
|
||||||
|
processor = request.app.state.processor
|
||||||
|
|
||||||
|
try:
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
filename = generate_and_save_result(image, inference_state)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
||||||
|
|
||||||
|
file_url = request.url_for("static", path=f"results/{filename}")
|
||||||
|
|
||||||
|
return JSONResponse(content={
|
||||||
|
"status": "success",
|
||||||
|
"result_image_url": str(file_url),
|
||||||
|
"detected_count": len(masks),
|
||||||
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
|
})
|
||||||
|
|
||||||
|
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
|
||||||
|
async def segment_tarot(
|
||||||
|
request: Request,
|
||||||
|
file: Optional[UploadFile] = File(None),
|
||||||
|
image_url: Optional[str] = Form(None),
|
||||||
|
expected_count: int = Form(3)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
塔罗牌分割专用接口
|
||||||
|
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
|
||||||
|
2. 如果是,分别抠出这些牌并返回
|
||||||
|
"""
|
||||||
|
if not file and not image_url:
|
||||||
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file:
|
||||||
|
image = Image.open(file.file).convert("RGB")
|
||||||
|
elif image_url:
|
||||||
|
image = load_image_from_url(image_url)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||||||
|
|
||||||
|
processor = request.app.state.processor
|
||||||
|
|
||||||
|
try:
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
# 固定 Prompt 检测塔罗牌
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
|
# 核心逻辑:判断数量
|
||||||
|
detected_count = len(masks)
|
||||||
|
|
||||||
|
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
|
||||||
|
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
|
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if detected_count != expected_count:
|
||||||
|
# 保存一张图用于调试/反馈
|
||||||
|
try:
|
||||||
|
filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||||||
|
file_url = request.url_for("static", path=f"results/{request_id}/{filename}")
|
||||||
|
except:
|
||||||
|
file_url = None
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"status": "failed",
|
||||||
|
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
|
||||||
|
"detected_count": detected_count,
|
||||||
|
"debug_image_url": str(file_url) if file_url else None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数量正确,执行抠图
|
||||||
|
try:
|
||||||
|
filenames = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||||||
|
|
||||||
|
# 生成 URL 列表
|
||||||
|
card_urls = [str(request.url_for("static", path=f"results/{request_id}/{fname}")) for fname in filenames]
|
||||||
|
|
||||||
|
# 生成整体效果图
|
||||||
|
try:
|
||||||
|
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||||||
|
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
|
||||||
|
except:
|
||||||
|
main_file_url = None
|
||||||
|
|
||||||
|
return JSONResponse(content={
|
||||||
|
"status": "success",
|
||||||
|
"message": f"成功识别并分割 {expected_count} 张塔罗牌",
|
||||||
|
"tarot_cards": card_urls,
|
||||||
|
"full_visualization": main_file_url,
|
||||||
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
|
})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
# 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数
|
||||||
|
uvicorn.run(
|
||||||
|
"fastAPI_tarot:app",
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=55600,
|
||||||
|
proxy_headers=True,
|
||||||
|
forwarded_allow_ips="*",
|
||||||
|
reload=False # 生产环境建议关闭 reload,确保代码完全重载
|
||||||
|
)
|
||||||
|
Before Width: | Height: | Size: 111 KiB |
|
Before Width: | Height: | Size: 118 KiB |
|
Before Width: | Height: | Size: 83 KiB |
BIN
static/results/seg_9061c56e4b284f60a109e405c20af31b.jpg
Normal file
|
After Width: | Height: | Size: 92 KiB |
|
Before Width: | Height: | Size: 144 KiB |
|
Before Width: | Height: | Size: 85 KiB |
|
Before Width: | Height: | Size: 83 KiB |
|
Before Width: | Height: | Size: 83 KiB |
BIN
static/results/tarot_c0033d32490548b99a6ddbcf721f2d9a_0.png
Normal file
|
After Width: | Height: | Size: 516 KiB |