605 lines
22 KiB
Python
605 lines
22 KiB
Python
import os
|
||
import uuid
|
||
import time
|
||
import requests
|
||
import numpy as np
|
||
import cv2
|
||
from typing import Optional
|
||
from contextlib import asynccontextmanager
|
||
import dashscope
|
||
from dashscope import MultiModalConversation
|
||
|
||
import torch
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
import matplotlib.pyplot as plt
|
||
|
||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
|
||
from fastapi.security import APIKeyHeader
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import JSONResponse
|
||
from PIL import Image
|
||
|
||
from sam3.model_builder import build_sam3_image_model
|
||
from sam3.model.sam3_image_processor import Sam3Processor
|
||
from sam3.visualization_utils import plot_results
|
||
|
||
# ------------------- 配置与路径 -------------------
|
||
STATIC_DIR = "static"
|
||
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||
|
||
# ------------------- API Key 核心配置 (已加固) -------------------
|
||
VALID_API_KEY = "123quant-speed"
|
||
API_KEY_HEADER_NAME = "X-API-Key"
|
||
|
||
# Dashscope 配置 (Qwen-VL)
|
||
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
||
QWEN_MODEL = 'qwen-vl-max'
|
||
|
||
# 定义 Header 认证
|
||
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||
|
||
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
||
"""
|
||
强制验证 API Key
|
||
"""
|
||
# 1. 检查是否有 Key
|
||
if not api_key:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Missing API Key. Please provide it in the header."
|
||
)
|
||
# 2. 检查 Key 是否正确
|
||
if api_key != VALID_API_KEY:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="Invalid API Key."
|
||
)
|
||
# 3. 验证通过
|
||
return True
|
||
|
||
# ------------------- 生命周期管理 -------------------
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
print("="*40)
|
||
print("✅ API Key 保护已激活")
|
||
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
||
print("="*40)
|
||
|
||
print("正在加载 SAM3 模型到 GPU...")
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
if not torch.cuda.is_available():
|
||
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||
|
||
model = build_sam3_image_model()
|
||
model = model.to(device)
|
||
model.eval()
|
||
|
||
processor = Sam3Processor(model)
|
||
|
||
app.state.model = model
|
||
app.state.processor = processor
|
||
app.state.device = device
|
||
|
||
print(f"模型加载完成,设备: {device}")
|
||
|
||
yield
|
||
|
||
print("正在清理资源...")
|
||
|
||
# ------------------- FastAPI 初始化 -------------------
|
||
app = FastAPI(
|
||
lifespan=lifespan,
|
||
title="SAM3 Segmentation API",
|
||
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`",
|
||
)
|
||
|
||
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
|
||
app.openapi_schema = None
|
||
def custom_openapi():
|
||
if app.openapi_schema:
|
||
return app.openapi_schema
|
||
from fastapi.openapi.utils import get_openapi
|
||
openapi_schema = get_openapi(
|
||
title=app.title,
|
||
version=app.version,
|
||
description=app.description,
|
||
routes=app.routes,
|
||
)
|
||
# 定义安全方案
|
||
openapi_schema["components"]["securitySchemes"] = {
|
||
"APIKeyHeader": {
|
||
"type": "apiKey",
|
||
"in": "header",
|
||
"name": API_KEY_HEADER_NAME,
|
||
}
|
||
}
|
||
# 为所有路径应用安全要求
|
||
for path in openapi_schema["paths"]:
|
||
for method in openapi_schema["paths"][path]:
|
||
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||
app.openapi_schema = openapi_schema
|
||
return app.openapi_schema
|
||
|
||
app.openapi = custom_openapi
|
||
|
||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||
|
||
# ------------------- 辅助函数 -------------------
|
||
def order_points(pts):
|
||
"""
|
||
对四个坐标点进行排序:左上,右上,右下,左下
|
||
"""
|
||
rect = np.zeros((4, 2), dtype="float32")
|
||
s = pts.sum(axis=1)
|
||
rect[0] = pts[np.argmin(s)]
|
||
rect[2] = pts[np.argmax(s)]
|
||
diff = np.diff(pts, axis=1)
|
||
rect[1] = pts[np.argmin(diff)]
|
||
rect[3] = pts[np.argmax(diff)]
|
||
return rect
|
||
|
||
def four_point_transform(image, pts):
|
||
"""
|
||
根据四个点进行透视变换
|
||
"""
|
||
rect = order_points(pts)
|
||
(tl, tr, br, bl) = rect
|
||
|
||
# 计算新图像的宽度
|
||
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
||
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
||
maxWidth = max(int(widthA), int(widthB))
|
||
|
||
# 计算新图像的高度
|
||
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
||
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
||
maxHeight = max(int(heightA), int(heightB))
|
||
|
||
dst = np.array([
|
||
[0, 0],
|
||
[maxWidth - 1, 0],
|
||
[maxWidth - 1, maxHeight - 1],
|
||
[0, maxHeight - 1]], dtype="float32")
|
||
|
||
M = cv2.getPerspectiveTransform(rect, dst)
|
||
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
||
return warped
|
||
|
||
def load_image_from_url(url: str) -> Image.Image:
|
||
try:
|
||
headers = {'User-Agent': 'Mozilla/5.0'}
|
||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||
response.raise_for_status()
|
||
image = Image.open(response.raw).convert("RGB")
|
||
return image
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
||
|
||
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR) -> list[dict]:
|
||
"""
|
||
根据 mask 和 box 进行透视矫正并裁剪出独立的对象图片 (保留透明背景)
|
||
返回包含文件名和元数据的列表
|
||
"""
|
||
saved_objects = []
|
||
# Convert image to numpy array (RGB)
|
||
img_arr = np.array(image)
|
||
|
||
for i, (mask, box) in enumerate(zip(masks, boxes)):
|
||
# Handle tensor/numpy conversions
|
||
if isinstance(mask, torch.Tensor):
|
||
mask_np = mask.cpu().numpy().squeeze()
|
||
else:
|
||
mask_np = mask.squeeze()
|
||
|
||
# Ensure mask is uint8 binary for OpenCV
|
||
if mask_np.dtype == bool:
|
||
mask_uint8 = (mask_np * 255).astype(np.uint8)
|
||
else:
|
||
mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255
|
||
|
||
# Find contours
|
||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
if not contours:
|
||
continue
|
||
|
||
# Get largest contour
|
||
c = max(contours, key=cv2.contourArea)
|
||
|
||
# Approximate contour to polygon
|
||
peri = cv2.arcLength(c, True)
|
||
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
|
||
|
||
# If we have 4 points, use them. If not, fallback to minAreaRect
|
||
if len(approx) == 4:
|
||
pts = approx.reshape(4, 2)
|
||
else:
|
||
rect = cv2.minAreaRect(c)
|
||
pts = cv2.boxPoints(rect)
|
||
|
||
# Apply perspective transform
|
||
# 注意:这里我们只变换RGB部分,Alpha通道需要额外处理或者直接应用同样的变换
|
||
# 为了简单,我们直接对原图(假设不带Alpha)进行变换
|
||
# 如果需要保留背景透明,需要先将原图转为RGBA,再做变换
|
||
|
||
# Check if original image has Alpha
|
||
if img_arr.shape[2] == 4:
|
||
warped = four_point_transform(img_arr, pts)
|
||
else:
|
||
# Add alpha channel from mask?
|
||
# 透视变换后的矩形本身就是去掉了背景的,所以不需要额外的Mask Alpha
|
||
# 但是为了保持一致性,我们可以给变换后的图加一个全不透明的Alpha,或者保留RGB
|
||
warped = four_point_transform(img_arr, pts)
|
||
|
||
# Check orientation (Portrait vs Landscape)
|
||
h, w = warped.shape[:2]
|
||
is_rotated = False
|
||
|
||
# Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx)
|
||
if w > h:
|
||
# Rotate 90 degrees clockwise
|
||
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
|
||
is_rotated = True
|
||
|
||
# Convert back to PIL
|
||
pil_warped = Image.fromarray(warped)
|
||
|
||
# Save
|
||
filename = f"tarot_{uuid.uuid4().hex}_{i}.png"
|
||
save_path = os.path.join(output_dir, filename)
|
||
pil_warped.save(save_path)
|
||
|
||
# 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒)
|
||
# 这里我们假设长边垂直为正位,如果做了旋转则标记
|
||
# 真正的正逆位需要OCR或图像识别
|
||
|
||
saved_objects.append({
|
||
"filename": filename,
|
||
"is_rotated_by_algorithm": is_rotated,
|
||
"note": "Geometric correction applied. True upright/reversed requires content analysis."
|
||
})
|
||
|
||
return saved_objects
|
||
|
||
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
|
||
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||
save_path = os.path.join(output_dir, filename)
|
||
plot_results(image, inference_state)
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
return filename
|
||
|
||
def recognize_card_with_qwen(image_path: str) -> dict:
|
||
"""
|
||
调用 Qwen-VL 识别塔罗牌
|
||
"""
|
||
try:
|
||
# 确保路径是绝对路径并加上 file:// 前缀
|
||
abs_path = os.path.abspath(image_path)
|
||
file_url = f"file://{abs_path}"
|
||
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": file_url},
|
||
{"text": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"}
|
||
]
|
||
}
|
||
]
|
||
|
||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0].message.content[0]['text']
|
||
# 尝试解析简单的 JSON
|
||
import json
|
||
try:
|
||
# 清理可能存在的 markdown 标记
|
||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||
result = json.loads(clean_content)
|
||
return result
|
||
except:
|
||
return {"raw_response": content}
|
||
else:
|
||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||
|
||
except Exception as e:
|
||
return {"error": f"识别失败: {str(e)}"}
|
||
|
||
def recognize_spread_with_qwen(image_path: str) -> dict:
|
||
"""
|
||
调用 Qwen-VL 识别塔罗牌牌阵
|
||
"""
|
||
try:
|
||
# 确保路径是绝对路径并加上 file:// 前缀
|
||
abs_path = os.path.abspath(image_path)
|
||
file_url = f"file://{abs_path}"
|
||
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": file_url},
|
||
{"text": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵(例如:圣三角、凯尔特十字、三张牌等)。如果看不出明显的正规牌阵,请返回“不是正规牌阵”。请以JSON格式返回,包含 'spread_name' 和 'description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。"}
|
||
]
|
||
}
|
||
]
|
||
|
||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0].message.content[0]['text']
|
||
# 尝试解析简单的 JSON
|
||
import json
|
||
try:
|
||
# 清理可能存在的 markdown 标记
|
||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||
result = json.loads(clean_content)
|
||
return result
|
||
except:
|
||
return {"raw_response": content, "spread_name": "Unknown"}
|
||
else:
|
||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||
|
||
except Exception as e:
|
||
return {"error": f"牌阵识别失败: {str(e)}"}
|
||
|
||
# ------------------- API 接口 (强制依赖验证) -------------------
|
||
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
||
async def segment(
|
||
request: Request,
|
||
prompt: str = Form(...),
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None)
|
||
):
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
try:
|
||
inference_state = processor.set_image(image)
|
||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
try:
|
||
filename = generate_and_save_result(image, inference_state)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
||
|
||
file_url = request.url_for("static", path=f"results/{filename}")
|
||
|
||
return JSONResponse(content={
|
||
"status": "success",
|
||
"result_image_url": str(file_url),
|
||
"detected_count": len(masks),
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
})
|
||
|
||
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
|
||
async def segment_tarot(
|
||
request: Request,
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None),
|
||
expected_count: int = Form(3)
|
||
):
|
||
"""
|
||
塔罗牌分割专用接口
|
||
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
|
||
2. 如果是,分别抠出这些牌并返回
|
||
"""
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
try:
|
||
inference_state = processor.set_image(image)
|
||
# 固定 Prompt 检测塔罗牌
|
||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
# 核心逻辑:判断数量
|
||
detected_count = len(masks)
|
||
|
||
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
if detected_count != expected_count:
|
||
# 保存一张图用于调试/反馈
|
||
try:
|
||
filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||
file_url = request.url_for("static", path=f"results/{request_id}/{filename}")
|
||
except:
|
||
file_url = None
|
||
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={
|
||
"status": "failed",
|
||
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
|
||
"detected_count": detected_count,
|
||
"debug_image_url": str(file_url) if file_url else None
|
||
}
|
||
)
|
||
|
||
# 数量正确,执行抠图
|
||
try:
|
||
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||
|
||
# 生成 URL 列表和元数据
|
||
tarot_cards = []
|
||
for obj in saved_objects:
|
||
fname = obj["filename"]
|
||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||
tarot_cards.append({
|
||
"url": file_url,
|
||
"is_rotated": obj["is_rotated_by_algorithm"],
|
||
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
|
||
"note": obj["note"]
|
||
})
|
||
|
||
# 生成整体效果图
|
||
try:
|
||
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
|
||
except:
|
||
main_file_url = None
|
||
|
||
return JSONResponse(content={
|
||
"status": "success",
|
||
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)",
|
||
"tarot_cards": tarot_cards,
|
||
"full_visualization": main_file_url,
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
})
|
||
|
||
@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)])
|
||
async def recognize_tarot(
|
||
request: Request,
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None),
|
||
expected_count: int = Form(3)
|
||
):
|
||
"""
|
||
塔罗牌全流程接口: 分割 + 矫正 + 识别
|
||
1. 检测是否包含指定数量的塔罗牌 (SAM3)
|
||
2. 分割并透视矫正
|
||
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
|
||
"""
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
try:
|
||
inference_state = processor.set_image(image)
|
||
# 固定 Prompt 检测塔罗牌
|
||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
# 核心逻辑:判断数量
|
||
detected_count = len(masks)
|
||
|
||
# 创建本次请求的独立文件夹
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 保存整体效果图 (无论是成功还是失败,都先保存一张主图)
|
||
try:
|
||
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||
main_file_path = os.path.join(output_dir, main_filename)
|
||
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
|
||
except:
|
||
main_filename = None
|
||
main_file_path = None
|
||
main_file_url = None
|
||
|
||
# Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断)
|
||
spread_info = {"spread_name": "Unknown"}
|
||
if main_file_path:
|
||
# 使用带有mask绘制的主图或者原始图?
|
||
# 使用原始图可能更好,不受mask遮挡干扰,但是main_filename是带mask的。
|
||
# 我们这里暂时用原始图保存一份临时文件给Qwen看
|
||
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
||
image.save(temp_raw_path)
|
||
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
||
|
||
# 如果识别结果明确说是“不是正规牌阵”,是否要继续?
|
||
# 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’”
|
||
# 我们将其放在返回结果中,由客户端决定是否展示警告
|
||
|
||
if detected_count != expected_count:
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={
|
||
"status": "failed",
|
||
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
|
||
"detected_count": detected_count,
|
||
"spread_info": spread_info,
|
||
"debug_image_url": str(main_file_url) if main_file_url else None
|
||
}
|
||
)
|
||
|
||
# 数量正确,执行抠图 + 矫正
|
||
try:
|
||
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||
|
||
# 遍历每张卡片进行识别
|
||
tarot_cards = []
|
||
for obj in saved_objects:
|
||
fname = obj["filename"]
|
||
file_path = os.path.join(output_dir, fname)
|
||
|
||
# 调用 Qwen-VL 识别
|
||
# 注意:这里会串行调用,速度可能较慢。
|
||
recognition_res = recognize_card_with_qwen(file_path)
|
||
|
||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||
tarot_cards.append({
|
||
"url": file_url,
|
||
"is_rotated": obj["is_rotated_by_algorithm"],
|
||
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
|
||
"recognition": recognition_res,
|
||
"note": obj["note"]
|
||
})
|
||
|
||
return JSONResponse(content={
|
||
"status": "success",
|
||
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)",
|
||
"spread_info": spread_info,
|
||
"tarot_cards": tarot_cards,
|
||
"full_visualization": main_file_url,
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
})
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
# 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数
|
||
uvicorn.run(
|
||
"fastAPI_tarot:app",
|
||
host="127.0.0.1",
|
||
port=55600,
|
||
proxy_headers=True,
|
||
forwarded_allow_ips="*",
|
||
reload=False # 生产环境建议关闭 reload,确保代码完全重载
|
||
) |