Files
sam3_local/fastAPI_tarot.py
2026-02-15 22:59:12 +08:00

936 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SAM3 Segmentation API Service (SAM3 图像分割 API 服务)
--------------------------------------------------
本服务提供基于 SAM3 模型的图像分割能力。
包含通用分割、塔罗牌分析和人脸分析等专用接口。
This service provides image segmentation capabilities using the SAM3 model.
It includes specialized endpoints for Tarot card analysis and Face analysis.
"""
# ==========================================
# 1. Imports (导入)
# ==========================================
# Standard Library Imports (标准库)
import os
import uuid
import time
import json
import traceback
import re
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
# Third-Party Imports (第三方库)
import cv2
import torch
import numpy as np
import requests
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错
import matplotlib.pyplot as plt
from PIL import Image
# FastAPI Imports
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
# Dashscope (Aliyun Qwen) Imports
import dashscope
from dashscope import MultiModalConversation
# Local Imports (本地模块)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
import human_analysis_service # 引入新服务: 人脸分析
# ==========================================
# 2. Configuration & Constants (配置与常量)
# ==========================================
# 路径配置
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# API Key 配置
VALID_API_KEY = "123quant-speed"
API_KEY_HEADER_NAME = "X-API-Key"
# Dashscope (Qwen-VL) 配置
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
QWEN_MODEL = 'qwen-vl-max'
# API Tags (用于文档分类)
TAG_GENERAL = "General Segmentation (通用分割)"
TAG_TAROT = "Tarot Analysis (塔罗牌分析)"
TAG_FACE = "Face Analysis (人脸分析)"
# ==========================================
# 3. Security & Middleware (安全与中间件)
# ==========================================
# 定义 Header 认证 Scheme
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
"""
强制验证 API Key (Enforce API Key Verification)
1. 检查 Header 中是否存在 API Key
2. 验证 API Key 是否匹配
"""
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API Key. Please provide it in the header."
)
if api_key != VALID_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key."
)
return True
# ==========================================
# 4. Lifespan Management (生命周期管理)
# ==========================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI 生命周期管理器
- 启动时: 加载模型到 GPU/CPU
- 关闭时: 清理资源
"""
print("="*40)
print("✅ API Key 保护已激活")
print(f"✅ 有效 Key: {VALID_API_KEY}")
print("="*40)
print("正在加载 SAM3 模型到 GPU...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
# 加载模型
model = build_sam3_image_model()
model = model.to(device)
model.eval()
# 初始化处理器
processor = Sam3Processor(model)
# 存储到 app.state 以便全局访问
app.state.model = model
app.state.processor = processor
app.state.device = device
print(f"模型加载完成,设备: {device}")
yield
print("正在清理资源...")
# 这里可以添加释放显存的逻辑,如果需要
# ==========================================
# 5. Helper Functions (辅助函数)
# ==========================================
def is_english(text: str) -> bool:
"""
判断文本是否为纯英文
- 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False
"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
return False
return True
def translate_to_sam3_prompt(text: str) -> str:
"""
使用 Qwen 模型将中文提示词翻译为英文
- SAM3 模型对英文 Prompt 支持更好
"""
print(f"正在翻译提示词: {text}")
try:
messages = [
{
"role": "user",
"content": [
{"text": f"请将以下描述翻译成简洁、精准的英文,用于图像分割模型(SAM)的提示词。直接返回英文,不要包含任何解释或其他文字。\n\n输入: {text}"}
]
}
]
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
translated_text = response.output.choices[0].message.content[0]['text'].strip()
# 去除可能的 markdown 标记或引号
translated_text = translated_text.replace('"', '').replace("'", "").strip()
print(f"翻译结果: {translated_text}")
return translated_text
else:
print(f"翻译失败: {response.code} - {response.message}")
return text # 失败则回退到原始文本
except Exception as e:
print(f"翻译异常: {e}")
return text
def order_points(pts):
"""
对四个坐标点进行排序:左上,右上,右下,左下
用于透视变换前的点位整理
"""
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def four_point_transform(image, pts):
"""
根据四个点进行透视变换 (Perspective Transform)
用于将倾斜的卡片矫正为矩形
"""
rect = order_points(pts)
(tl, tr, br, bl) = rect
# 计算新图像的宽度
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
maxWidth = max(int(widthA), int(widthB))
# 计算新图像的高度
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
maxHeight = max(int(heightA), int(heightB))
dst = np.array([
[0, 0],
[maxWidth - 1, 0],
[maxWidth - 1, maxHeight - 1],
[0, maxHeight - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(rect, dst)
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
return warped
def load_image_from_url(url: str) -> Image.Image:
"""
从 URL 下载图片并转换为 PIL Image 对象
"""
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]:
"""
根据 mask 和 box 进行处理并保存独立的对象图片
参数:
- image: 原始图片
- masks: 分割掩码列表
- boxes: 边界框列表
- output_dir: 输出目录
- is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑)
- cutout: 如果为 True则进行轮廓抠图透明背景否则进行透视矫正主要用于卡片
返回:
- 保存的对象信息列表
"""
saved_objects = []
# Convert image to numpy array (RGB)
img_arr = np.array(image)
for i, (mask, box) in enumerate(zip(masks, boxes)):
# Handle tensor/numpy conversions
if isinstance(mask, torch.Tensor):
mask_np = mask.cpu().numpy().squeeze()
else:
mask_np = mask.squeeze()
# Handle box conversion
if isinstance(box, torch.Tensor):
box_np = box.cpu().numpy()
else:
box_np = np.array(box)
# Ensure mask is uint8 binary for OpenCV/Pillow
if mask_np.dtype == bool:
mask_uint8 = (mask_np * 255).astype(np.uint8)
else:
mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255
if cutout:
# --- 轮廓抠图模式 (透明背景) ---
# 1. 准备 RGBA 原图
if image.mode != "RGBA":
img_rgba = image.convert("RGBA")
else:
img_rgba = image.copy()
# 2. 准备 Alpha Mask
mask_img = Image.fromarray(mask_uint8, mode='L')
# 3. 将 Mask 应用到 Alpha 通道
cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img)
# 4. Crop to Box
x1, y1, x2, y2 = map(int, box_np)
w, h = cutout_img.size
x1 = max(0, x1); y1 = max(0, y1)
x2 = min(w, x2); y2 = min(h, y2)
if x2 > x1 and y2 > y1:
final_img = cutout_img.crop((x1, y1, x2, y2))
else:
final_img = cutout_img # Fallback
# Save
prefix = "cutout"
is_rotated = False
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
final_img.save(save_path)
saved_objects.append({
"filename": filename,
"is_rotated_by_algorithm": is_rotated,
"note": "Mask cutout applied. Background removed."
})
else:
# --- 透视矫正模式 (矩形矫正) ---
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
continue
c = max(contours, key=cv2.contourArea)
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
if len(approx) == 4:
pts = approx.reshape(4, 2)
else:
rect = cv2.minAreaRect(c)
pts = cv2.boxPoints(rect)
warped = four_point_transform(img_arr, pts)
# Check orientation (Portrait vs Landscape)
h, w = warped.shape[:2]
is_rotated = False
# 强制竖屏逻辑 (塔罗牌通常是竖屏)
if is_tarot and w > h:
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
is_rotated = True
pil_warped = Image.fromarray(warped)
prefix = "tarot" if is_tarot else "segment"
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
pil_warped.save(save_path)
saved_objects.append({
"filename": filename,
"is_rotated_by_algorithm": is_rotated,
"note": "Geometric correction applied."
})
return saved_objects
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
"""
生成并保存包含 Mask 叠加效果的完整结果图
"""
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(output_dir, filename)
plot_results(image, inference_state)
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
return filename
def recognize_card_with_qwen(image_path: str) -> dict:
"""
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
"""
try:
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
try:
# 1. 打开原图
img = Image.open(abs_path)
# 2. 生成旋转图 (180度)
rotated_img = img.rotate(180)
# 3. 保存旋转图
dir_name = os.path.dirname(abs_path)
file_name = os.path.basename(abs_path)
rotated_name = f"rotated_{file_name}"
rotated_path = os.path.join(dir_name, rotated_name)
rotated_img.save(rotated_path)
rotated_file_url = f"file://{rotated_path}"
# 4. 构建对比 Prompt
messages = [
{
"role": "user",
"content": [
{"image": file_url}, # 图1 (原图)
{"image": rotated_file_url}, # 图2 (旋转180度)
{"text": """这是一张塔罗牌的两个方向:
图1原始方向
图2旋转180度后的方向
请仔细对比两张图片的牌面内容(文字方向、人物站立方向、图案逻辑):
1. 识别这张牌的名字(中文)。
2. 判断哪一张图片展示了正确的“正位”Upright状态。
- 如果图1是正位说明原图就是正位。
- 如果图2是正位说明原图是逆位。
请以JSON格式返回包含 'name''position' 两个字段。
例如:{'name': '愚者', 'position': '正位'} 或 {'name': '倒吊人', 'position': '逆位'}。
不要包含Markdown代码块标记。"""}
]
}
]
except Exception as e:
print(f"对比图生成失败,回退到单图模式: {e}")
messages = [
{
"role": "user",
"content": [
{"image": file_url},
{"text": "这是一张塔罗牌。请识别它的名字中文并判断它是正位还是逆位。请以JSON格式返回包含 'name''position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"}
]
}
]
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
import json
try:
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result
except:
return {"raw_response": content}
else:
return {"error": f"API Error: {response.code} - {response.message}"}
except Exception as e:
return {"error": f"识别失败: {str(e)}"}
def recognize_spread_with_qwen(image_path: str) -> dict:
"""
调用 Qwen-VL 识别塔罗牌牌阵
"""
try:
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
messages = [
{
"role": "user",
"content": [
{"image": file_url},
{"text": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵例如圣三角、凯尔特十字、三张牌等。如果看不出明显的正规牌阵请返回“不是正规牌阵”。请以JSON格式返回包含 'spread_name''description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。"}
]
}
]
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
import json
try:
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result
except:
return {"raw_response": content, "spread_name": "Unknown"}
else:
return {"error": f"API Error: {response.code} - {response.message}"}
except Exception as e:
return {"error": f"牌阵识别失败: {str(e)}"}
# ==========================================
# 6. FastAPI App Setup (应用初始化)
# ==========================================
app = FastAPI(
lifespan=lifespan,
title="SAM3 Segmentation API",
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
openapi_tags=[
{"name": TAG_GENERAL, "description": "General purpose image segmentation"},
{"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"},
{"name": TAG_FACE, "description": "Face detection and analysis"},
]
)
# 手动添加 OpenAPI 安全配置
app.openapi_schema = None
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
openapi_schema["components"]["securitySchemes"] = {
"APIKeyHeader": {
"type": "apiKey",
"in": "header",
"name": API_KEY_HEADER_NAME,
}
}
for path in openapi_schema["paths"]:
for method in openapi_schema["paths"][path]:
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ==========================================
# 7. API Endpoints (接口定义)
# ==========================================
# ------------------------------------------
# Group 1: General Segmentation (通用分割)
# ------------------------------------------
@app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)])
async def segment(
request: Request,
prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"),
file: Optional[UploadFile] = File(None, description="Image file to upload"),
image_url: Optional[str] = Form(None, description="URL of the image"),
save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"),
cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"),
confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.")
):
"""
**通用图像分割接口**
- 支持上传图片或提供图片 URL
- 支持自动将中文 Prompt 翻译为英文
- 支持手动设置置信度 (Confidence Threshold)
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# 1. Prompt 处理
final_prompt = prompt
if not is_english(prompt):
try:
translated = translate_to_sam3_prompt(prompt)
if translated:
final_prompt = translated
except Exception as e:
print(f"Prompt翻译失败使用原始Prompt: {e}")
print(f"最终使用的 Prompt: {final_prompt}")
# 2. 图片加载
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
# 3. 模型推理
try:
# 设置图片
inference_state = processor.set_image(image)
# 处理 Confidence Threshold
# 如果用户提供了 confidence临时修改 processor 的阈值
original_confidence = processor.confidence_threshold
if confidence is not None:
# 简单校验
if not (0.0 <= confidence <= 1.0):
raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0")
processor.confidence_threshold = confidence
print(f"Using manual confidence threshold: {confidence}")
try:
# 执行推理
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
finally:
# 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯)
if confidence is not None:
processor.confidence_threshold = original_confidence
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 4. 结果可视化与保存
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
file_url = request.url_for("static", path=f"results/{filename}")
# 5. 保存分割子图 (Optional)
saved_segments_info = []
if save_segment_images:
try:
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
saved_objects = crop_and_save_objects(
image,
masks,
boxes,
output_dir=output_dir,
is_tarot=False,
cutout=cutout
)
for obj in saved_objects:
fname = obj["filename"]
seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
saved_segments_info.append({
"url": seg_url,
"filename": fname,
"note": obj.get("note", "")
})
except Exception as e:
print(f"Error saving segments: {e}")
raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}")
response_content = {
"status": "success",
"result_image_url": str(file_url),
"detected_count": len(masks),
"scores": scores.tolist() if torch.is_tensor(scores) else scores
}
if save_segment_images:
response_content["segmented_images"] = saved_segments_info
return JSONResponse(content=response_content)
# ------------------------------------------
# Group 2: Tarot Analysis (塔罗牌分析)
# ------------------------------------------
@app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
async def segment_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
expected_count: int = Form(3)
):
"""
**塔罗牌分割接口**
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
2. 对检测到的卡片进行透视矫正和裁剪
3. 返回矫正后的图片 URL
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
# 固定 Prompt 检测塔罗牌
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 核心逻辑:判断数量
detected_count = len(masks)
# 创建本次请求的独立文件夹
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
if detected_count != expected_count:
# 保存一张图用于调试/反馈
try:
filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
file_url = request.url_for("static", path=f"results/{request_id}/{filename}")
except:
file_url = None
return JSONResponse(
status_code=400,
content={
"status": "failed",
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
"detected_count": detected_count,
"debug_image_url": str(file_url) if file_url else None
}
)
# 数量正确,执行抠图
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 生成 URL 列表和元数据
tarot_cards = []
for obj in saved_objects:
fname = obj["filename"]
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
tarot_cards.append({
"url": file_url,
"is_rotated": obj["is_rotated_by_algorithm"],
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
"note": obj["note"]
})
# 生成整体效果图
try:
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
except:
main_file_url = None
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)",
"tarot_cards": tarot_cards,
"full_visualization": main_file_url,
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
@app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
async def recognize_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
expected_count: int = Form(3)
):
"""
**塔罗牌全流程接口: 分割 + 矫正 + 识别**
1. 检测是否包含指定数量的塔罗牌 (SAM3)
2. 分割并透视矫正
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
try:
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
detected_count = len(masks)
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
# 保存整体效果图
try:
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
main_file_path = os.path.join(output_dir, main_filename)
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
except:
main_filename = None
main_file_path = None
main_file_url = None
# Step 0: 牌阵识别
spread_info = {"spread_name": "Unknown"}
if main_file_path:
# 使用原始图的一份拷贝给 Qwen 识别牌阵
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
image.save(temp_raw_path)
spread_info = recognize_spread_with_qwen(temp_raw_path)
if detected_count != expected_count:
return JSONResponse(
status_code=400,
content={
"status": "failed",
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
"detected_count": detected_count,
"spread_info": spread_info,
"debug_image_url": str(main_file_url) if main_file_url else None
}
)
# 数量正确,执行抠图 + 矫正
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 遍历每张卡片进行识别
tarot_cards = []
for obj in saved_objects:
fname = obj["filename"]
file_path = os.path.join(output_dir, fname)
# 调用 Qwen-VL 识别 (串行)
recognition_res = recognize_card_with_qwen(file_path)
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
tarot_cards.append({
"url": file_url,
"is_rotated": obj["is_rotated_by_algorithm"],
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
"recognition": recognition_res,
"note": obj["note"]
})
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)",
"spread_info": spread_info,
"tarot_cards": tarot_cards,
"full_visualization": main_file_url,
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
# ------------------------------------------
# Group 3: Face Analysis (人脸分析)
# ------------------------------------------
@app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)])
async def segment_face(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
prompt: str = Form("face and hair", description="Prompt for face detection")
):
"""
**人脸/头部检测与属性分析接口**
1. 调用 SAM3 分割出头部区域 (含头发)
2. 裁剪并保存
3. 调用 Qwen-VL 识别性别和年龄
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# Prompt 翻译/优化
final_prompt = prompt
if not is_english(prompt):
try:
translated = translate_to_sam3_prompt(prompt)
if translated:
final_prompt = translated
except Exception as e:
print(f"Prompt翻译失败使用原始Prompt: {e}")
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
# 加载图片
try:
if file:
image = Image.open(file.file).convert("RGB")
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
# 调用独立服务进行处理
try:
result = human_analysis_service.process_face_segmentation_and_analysis(
processor=processor,
image=image,
prompt=final_prompt,
output_base_dir=RESULT_IMAGE_DIR
)
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
# 补全 URL
if result["status"] == "success":
if result.get("full_visualization"):
full_vis_rel_path = result["full_visualization"]
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
for item in result["results"]:
relative_path = item.pop("relative_path")
item["url"] = str(request.url_for("static", path=relative_path))
return JSONResponse(content=result)
# ==========================================
# 8. Main Entry Point (启动入口)
# ==========================================
if __name__ == "__main__":
import uvicorn
# 启动服务器
uvicorn.run(
"fastAPI_tarot:app",
host="127.0.0.1",
port=55600,
proxy_headers=True,
forwarded_allow_ips="*",
reload=False
)