Files
sam3_local/fastAPI_tarot.py
2026-02-15 22:44:25 +08:00

873 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import time
import requests
import numpy as np
import cv2
from typing import Optional
from contextlib import asynccontextmanager
import dashscope
from dashscope import MultiModalConversation
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
import human_analysis_service # 引入新服务
# ------------------- 配置与路径 -------------------
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- API Key 核心配置 (已加固) -------------------
VALID_API_KEY = "123quant-speed"
API_KEY_HEADER_NAME = "X-API-Key"
# Dashscope 配置 (Qwen-VL)
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
QWEN_MODEL = 'qwen-vl-max'
# 定义 Header 认证
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
"""
强制验证 API Key
"""
# 1. 检查是否有 Key
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API Key. Please provide it in the header."
)
# 2. 检查 Key 是否正确
if api_key != VALID_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key."
)
# 3. 验证通过
return True
# ------------------- 生命周期管理 -------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
print("="*40)
print("✅ API Key 保护已激活")
print(f"✅ 有效 Key: {VALID_API_KEY}")
print("="*40)
print("正在加载 SAM3 模型到 GPU...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
model = build_sam3_image_model()
model = model.to(device)
model.eval()
processor = Sam3Processor(model)
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield
print("正在清理资源...")
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(
lifespan=lifespan,
title="SAM3 Segmentation API",
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
)
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
app.openapi_schema = None
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# 定义安全方案
openapi_schema["components"]["securitySchemes"] = {
"APIKeyHeader": {
"type": "apiKey",
"in": "header",
"name": API_KEY_HEADER_NAME,
}
}
# 为所有路径应用安全要求
for path in openapi_schema["paths"]:
for method in openapi_schema["paths"][path]:
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
import re
# ------------------- 辅助函数 -------------------
def is_english(text: str) -> bool:
"""
简单判断是否为英文 prompt。
如果有中文字符则认为不是英文。
"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return False
return True
def translate_to_sam3_prompt(text: str) -> str:
"""
使用大模型将非英文提示词翻译为适合 SAM3 的英文提示词
"""
print(f"正在翻译提示词: {text}")
try:
messages = [
{
"role": "user",
"content": [
{"text": f"请将以下描述翻译成简洁、精准的英文,用于图像分割模型(SAM)的提示词。直接返回英文,不要包含任何解释或其他文字。\n\n输入: {text}"}
]
}
]
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
translated_text = response.output.choices[0].message.content[0]['text'].strip()
# 去除可能的 markdown 标记或引号
translated_text = translated_text.replace('"', '').replace("'", "").strip()
print(f"翻译结果: {translated_text}")
return translated_text
else:
print(f"翻译失败: {response.code} - {response.message}")
return text # Fallback to original
except Exception as e:
print(f"翻译异常: {e}")
return text
def order_points(pts):
"""
对四个坐标点进行排序:左上,右上,右下,左下
"""
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def four_point_transform(image, pts):
"""
根据四个点进行透视变换
"""
rect = order_points(pts)
(tl, tr, br, bl) = rect
# 计算新图像的宽度
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
# 计算新图像的高度
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
def load_image_from_url(url: str) -> Image.Image:
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]:
"""
根据 mask 和 box 进行处理并保存独立的对象图片
:param cutout: 如果为 True则进行轮廓抠图透明背景否则进行透视矫正主要用于卡片
"""
saved_objects = []
# Convert image to numpy array (RGB)
img_arr = np.array(image)
for i, (mask, box) in enumerate(zip(masks, boxes)):
# Handle tensor/numpy conversions
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy().squeeze()
else:
mask_np = mask.squeeze()
# Handle box conversion
if isinstance(box, torch.Tensor):
box_np = box.cpu().numpy()
else:
box_np = np.array(box)
# Ensure mask is uint8 binary for OpenCV/Pillow
if mask_np.dtype == bool:
mask_uint8 = (mask_np * 255).astype(np.uint8)
else:
mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255
if cutout:
# --- 轮廓抠图模式 (透明背景) ---
# 1. 准备 RGBA 原图
if image.mode != "RGBA":
img_rgba = image.convert("RGBA")
else:
img_rgba = image.copy()
# 2. 准备 Alpha Mask
# mask_uint8 已经是 0-255可以直接作为 Alpha 通道的基础
# 将 Mask 缩放到和图片一样大 (一般是一样的,但以防万一)
if mask_uint8.shape != img_rgba.size[::-1]: # size is (w, h), shape is (h, w)
# 如果尺寸不一致可能需要 resize这里假设尺寸一致
pass
mask_img = Image.fromarray(mask_uint8, mode='L')
# 3. 将 Mask 应用到 Alpha 通道
# 创建一个新的空白透明图
cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
# 将原图粘贴上去,使用 mask 作为 mask
cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img)
# 4. Crop to Box
# box is [x1, y1, x2, y2]
x1, y1, x2, y2 = map(int, box_np)
# 边界检查
w, h = cutout_img.size
x1 = max(0, x1); y1 = max(0, y1)
x2 = min(w, x2); y2 = min(h, y2)
# 避免无效 crop
if x2 > x1 and y2 > y1:
final_img = cutout_img.crop((x1, y1, x2, y2))
else:
final_img = cutout_img # Fallback
# Save
prefix = "cutout"
is_rotated = False # 抠图模式下不进行自动旋转
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
final_img.save(save_path)
saved_objects.append({
"filename": filename,
"is_rotated_by_algorithm": is_rotated,
"note": "Mask cutout applied. Background removed."
})
else:
# --- 透视矫正模式 (原有逻辑) ---
# Find contours
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
continue
# Get largest contour
c = max(contours, key=cv2.contourArea)
# Approximate contour to polygon
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
# If we have 4 points, use them. If not, fallback to minAreaRect
if len(approx) == 4:
pts = approx.reshape(4, 2)
else:
rect = cv2.minAreaRect(c)
pts = cv2.boxPoints(rect)
# Apply perspective transform
# Check if original image has Alpha
if img_arr.shape[2] == 4:
warped = four_point_transform(img_arr, pts)
else:
warped = four_point_transform(img_arr, pts)
# Check orientation (Portrait vs Landscape)
h, w = warped.shape[:2]
is_rotated = False
# Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx)
if is_tarot and w > h:
# Rotate 90 degrees clockwise
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
is_rotated = True
# Convert back to PIL
pil_warped = Image.fromarray(warped)
# Save
prefix = "tarot" if is_tarot else "segment"
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
pil_warped.save(save_path)
# 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒)
# 这里我们假设长边垂直为正位,如果做了旋转则标记
# 真正的正逆位需要OCR或图像识别
saved_objects.append({
"filename": filename,
"is_rotated_by_algorithm": is_rotated,
"note": "Geometric correction applied. True upright/reversed requires content analysis."
})
return saved_objects
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(output_dir, filename)
plot_results(image, inference_state)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
return filename
def recognize_card_with_qwen(image_path: str) -> dict:
"""
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
"""
try:
# 确保路径是绝对路径
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
# -------------------------------------------------
# 优化策略生成一张旋转180度的对比图
# 让 AI 做选择题而不是判断题,大幅提高准确率
# -------------------------------------------------
try:
# 1. 打开原图
img = Image.open(abs_path)
# 2. 生成旋转图 (180度)
rotated_img = img.rotate(180)
# 3. 保存旋转图
dir_name = os.path.dirname(abs_path)
file_name = os.path.basename(abs_path)
rotated_name = f"rotated_{file_name}"
rotated_path = os.path.join(dir_name, rotated_name)
rotated_img.save(rotated_path)
rotated_file_url = f"file://{rotated_path}"
# 4. 构建对比 Prompt
# 发送两张图图1=原图, 图2=旋转图
# 询问 AI 哪一张是“正位”
messages = [
{
"role": "user",
"content": [
{"image": file_url}, # 图1 (原图)
{"image": rotated_file_url}, # 图2 (旋转180度)
{"text": """这是一张塔罗牌的两个方向:
图1原始方向
图2旋转180度后的方向
请仔细对比两张图片的牌面内容(文字方向、人物站立方向、图案逻辑):
1. 识别这张牌的名字(中文)。
2. 判断哪一张图片展示了正确的“正位”Upright状态。
- 如果图1是正位说明原图就是正位。
- 如果图2是正位说明原图是逆位。
请以JSON格式返回包含 'name''position' 两个字段。
例如:{'name': '愚者', 'position': '正位'} 或 {'name': '倒吊人', 'position': '逆位'}。
不要包含Markdown代码块标记。"""}
]
}
]
except Exception as e:
print(f"对比图生成失败,回退到单图模式: {e}")
# 回退到旧的单图模式
messages = [
{
"role": "user",
"content": [
{"image": file_url},
{"text": "这是一张塔罗牌。请识别它的名字中文并判断它是正位还是逆位。请以JSON格式返回包含 'name''position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"}
]
}
]
# 调用模型
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
# 尝试解析简单的 JSON
import json
try:
# 清理可能存在的 markdown 标记
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result
except:
return {"raw_response": content}
else:
return {"error": f"API Error: {response.code} - {response.message}"}
except Exception as e:
return {"error": f"识别失败: {str(e)}"}
def recognize_spread_with_qwen(image_path: str) -> dict:
"""
调用 Qwen-VL 识别塔罗牌牌阵
"""
try:
# 确保路径是绝对路径并加上 file:// 前缀
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
messages = [
{
"role": "user",
"content": [
{"image": file_url},
{"text": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵例如圣三角、凯尔特十字、三张牌等。如果看不出明显的正规牌阵请返回“不是正规牌阵”。请以JSON格式返回包含 'spread_name''description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。"}
]
}
]
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
# 尝试解析简单的 JSON
import json
try:
# 清理可能存在的 markdown 标记
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result
except:
return {"raw_response": content, "spread_name": "Unknown"}
else:
return {"error": f"API Error: {response.code} - {response.message}"}
except Exception as e:
return {"error": f"牌阵识别失败: {str(e)}"}
# ------------------- API 接口 (强制依赖验证) -------------------
@app.post("/segment", dependencies=[Depends(verify_api_key)])
async def segment(
request: Request,
prompt: str = Form(...),
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
save_segment_images: bool = Form(False),
cutout: bool = Form(False)
):
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# ------------------- Prompt 翻译/优化 -------------------
final_prompt = prompt
if not is_english(prompt):
# 使用 Qwen 进行翻译
try:
# 构建消息请求翻译
# 注意translate_to_sam3_prompt 是一个阻塞调用,可能会增加耗时
# 但对于分割任务来说是可以接受的
translated = translate_to_sam3_prompt(prompt)
if translated:
final_prompt = translated
except Exception as e:
print(f"Prompt翻译失败使用原始Prompt: {e}")
print(f"最终使用的 Prompt: {final_prompt}")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
file_url = request.url_for("static", path=f"results/{filename}")
# New logic for saving segments
saved_segments_info = []
if save_segment_images:
try:
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
# 传递 cutout 参数
saved_objects = crop_and_save_objects(
image,
masks,
boxes,
output_dir=output_dir,
is_tarot=False,
cutout=cutout
)
for obj in saved_objects:
fname = obj["filename"]
seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
saved_segments_info.append({
"url": seg_url,
"filename": fname,
"note": obj.get("note", "")
})
except Exception as e:
print(f"Error saving segments: {e}")
raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}")
response_content = {
"status": "success",
"result_image_url": str(file_url),
"detected_count": len(masks),
"scores": scores.tolist() if torch.is_tensor(scores) else scores
}
if save_segment_images:
response_content["segmented_images"] = saved_segments_info
return JSONResponse(content=response_content)
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
async def segment_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
expected_count: int = Form(3)
):
"""
塔罗牌分割专用接口
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
2. 如果是,分别抠出这些牌并返回
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
# 固定 Prompt 检测塔罗牌
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 核心逻辑:判断数量
detected_count = len(masks)
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
if detected_count != expected_count:
# 保存一张图用于调试/反馈
try:
filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
file_url = request.url_for("static", path=f"results/{request_id}/{filename}")
except:
file_url = None
return JSONResponse(
status_code=400,
content={
"status": "failed",
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
"detected_count": detected_count,
"debug_image_url": str(file_url) if file_url else None
}
)
# 数量正确,执行抠图
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 生成 URL 列表和元数据
tarot_cards = []
for obj in saved_objects:
fname = obj["filename"]
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
tarot_cards.append({
"url": file_url,
"is_rotated": obj["is_rotated_by_algorithm"],
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
"note": obj["note"]
})
# 生成整体效果图
try:
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
except:
main_file_url = None
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)",
"tarot_cards": tarot_cards,
"full_visualization": main_file_url,
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)])
async def recognize_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
expected_count: int = Form(3)
):
"""
塔罗牌全流程接口: 分割 + 矫正 + 识别
1. 检测是否包含指定数量的塔罗牌 (SAM3)
2. 分割并透视矫正
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
# 固定 Prompt 检测塔罗牌
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 核心逻辑:判断数量
detected_count = len(masks)
# 创建本次请求的独立文件夹
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
# 保存整体效果图 (无论是成功还是失败,都先保存一张主图)
try:
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
main_file_path = os.path.join(output_dir, main_filename)
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
except:
main_filename = None
main_file_path = None
main_file_url = None
# Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断)
spread_info = {"spread_name": "Unknown"}
if main_file_path:
# 使用带有mask绘制的主图或者原始图
# 使用原始图可能更好不受mask遮挡干扰但是main_filename是带mask的。
# 我们这里暂时用原始图保存一份临时文件给Qwen看
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
image.save(temp_raw_path)
spread_info = recognize_spread_with_qwen(temp_raw_path)
# 如果识别结果明确说是“不是正规牌阵”,是否要继续?
# 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’”
# 我们将其放在返回结果中,由客户端决定是否展示警告
if detected_count != expected_count:
return JSONResponse(
status_code=400,
content={
"status": "failed",
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
"detected_count": detected_count,
"spread_info": spread_info,
"debug_image_url": str(main_file_url) if main_file_url else None
}
)
# 数量正确,执行抠图 + 矫正
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 遍历每张卡片进行识别
tarot_cards = []
for obj in saved_objects:
fname = obj["filename"]
file_path = os.path.join(output_dir, fname)
# 调用 Qwen-VL 识别
# 注意:这里会串行调用,速度可能较慢。
recognition_res = recognize_card_with_qwen(file_path)
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
tarot_cards.append({
"url": file_url,
"is_rotated": obj["is_rotated_by_algorithm"],
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
"recognition": recognition_res,
"note": obj["note"]
})
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)",
"spread_info": spread_info,
"tarot_cards": tarot_cards,
"full_visualization": main_file_url,
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
@app.post("/segment_face", dependencies=[Depends(verify_api_key)])
async def segment_face(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
prompt: str = Form("face and hair") # 默认提示词包含头发
):
"""
人脸/头部检测与属性分析接口 (新功能)
1. 调用 SAM3 分割出头部区域 (含头发)
2. 裁剪并保存
3. 调用 Qwen-VL 识别性别和年龄
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# ------------------- Prompt 翻译/优化 -------------------
final_prompt = prompt
if not is_english(prompt):
try:
translated = translate_to_sam3_prompt(prompt)
if translated:
final_prompt = translated
except Exception as e:
print(f"Prompt翻译失败使用原始Prompt: {e}")
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
# 1. 加载图片
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
# 2. 调用独立服务进行处理
try:
# 传入 processor 和 image
# 注意Result Image Dir 我们直接复用 RESULT_IMAGE_DIR
result = human_analysis_service.process_face_segmentation_and_analysis(
processor=processor,
image=image,
prompt=final_prompt,
output_base_dir=RESULT_IMAGE_DIR
)
except Exception as e:
# 打印详细错误堆栈以便调试
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
# 3. 补全 URL (因为 service 层不知道 request context)
if result["status"] == "success":
# 处理全图可视化的 URL
if result.get("full_visualization"):
full_vis_rel_path = result["full_visualization"]
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
for item in result["results"]:
# item["relative_path"] 是相对路径,如 results/xxx/xxx.jpg
# 我们需要将其转换为完整 URL
relative_path = item.pop("relative_path") # 移除相对路径字段,只返回 URL
item["url"] = str(request.url_for("static", path=relative_path))
return JSONResponse(content=result)
if __name__ == "__main__":
import uvicorn
# 注意:如果你的文件名不是 fastAPI_tarot.py请修改下面第一个参数
uvicorn.run(
"fastAPI_tarot:app",
host="127.0.0.1",
port=55600,
proxy_headers=True,
forwarded_allow_ips="*",
reload=False # 生产环境建议关闭 reload确保代码完全重载
)