This commit is contained in:
2026-02-15 22:59:12 +08:00
parent 9e6e9f98b6
commit 99968f25ae
7 changed files with 230 additions and 168 deletions

View File

@@ -1,68 +1,112 @@
"""
SAM3 Segmentation API Service (SAM3 图像分割 API 服务)
--------------------------------------------------
本服务提供基于 SAM3 模型的图像分割能力。
包含通用分割、塔罗牌分析和人脸分析等专用接口。
This service provides image segmentation capabilities using the SAM3 model.
It includes specialized endpoints for Tarot card analysis and Face analysis.
"""
# ==========================================
# 1. Imports (导入)
# ==========================================
# Standard Library Imports (标准库)
import os
import uuid
import time
import requests
import numpy as np
import cv2
from typing import Optional
import json
import traceback
import re
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
import dashscope
from dashscope import MultiModalConversation
# Third-Party Imports (第三方库)
import cv2
import torch
import numpy as np
import requests
import matplotlib
matplotlib.use('Agg')
matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错
import matplotlib.pyplot as plt
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
# FastAPI Imports
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from PIL import Image
# Dashscope (Aliyun Qwen) Imports
import dashscope
from dashscope import MultiModalConversation
# Local Imports (本地模块)
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_results
import human_analysis_service # 引入新服务
import human_analysis_service # 引入新服务: 人脸分析
# ------------------- 配置与路径 -------------------
# ==========================================
# 2. Configuration & Constants (配置与常量)
# ==========================================
# 路径配置
STATIC_DIR = "static"
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
# ------------------- API Key 核心配置 (已加固) -------------------
# API Key 配置
VALID_API_KEY = "123quant-speed"
API_KEY_HEADER_NAME = "X-API-Key"
# Dashscope 配置 (Qwen-VL)
# Dashscope (Qwen-VL) 配置
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
QWEN_MODEL = 'qwen-vl-max'
# 定义 Header 认证
# API Tags (用于文档分类)
TAG_GENERAL = "General Segmentation (通用分割)"
TAG_TAROT = "Tarot Analysis (塔罗牌分析)"
TAG_FACE = "Face Analysis (人脸分析)"
# ==========================================
# 3. Security & Middleware (安全与中间件)
# ==========================================
# 定义 Header 认证 Scheme
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
"""
强制验证 API Key
强制验证 API Key (Enforce API Key Verification)
1. 检查 Header 中是否存在 API Key
2. 验证 API Key 是否匹配
"""
# 1. 检查是否有 Key
if not api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing API Key. Please provide it in the header."
)
# 2. 检查 Key 是否正确
if api_key != VALID_API_KEY:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API Key."
)
# 3. 验证通过
return True
# ------------------- 生命周期管理 -------------------
# ==========================================
# 4. Lifespan Management (生命周期管理)
# ==========================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
FastAPI 生命周期管理器
- 启动时: 加载模型到 GPU/CPU
- 关闭时: 清理资源
"""
print("="*40)
print("✅ API Key 保护已激活")
print(f"✅ 有效 Key: {VALID_API_KEY}")
@@ -73,12 +117,15 @@ async def lifespan(app: FastAPI):
if not torch.cuda.is_available():
print("警告: 未检测到 GPU将使用 CPU速度会较慢。")
# 加载模型
model = build_sam3_image_model()
model = model.to(device)
model.eval()
# 初始化处理器
processor = Sam3Processor(model)
# 存储到 app.state 以便全局访问
app.state.model = model
app.state.processor = processor
app.state.device = device
@@ -88,52 +135,16 @@ async def lifespan(app: FastAPI):
yield
print("正在清理资源...")
# 这里可以添加释放显存的逻辑,如果需要
# ------------------- FastAPI 初始化 -------------------
app = FastAPI(
lifespan=lifespan,
title="SAM3 Segmentation API",
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
)
# ==========================================
# 5. Helper Functions (辅助函数)
# ==========================================
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
app.openapi_schema = None
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
# 定义安全方案
openapi_schema["components"]["securitySchemes"] = {
"APIKeyHeader": {
"type": "apiKey",
"in": "header",
"name": API_KEY_HEADER_NAME,
}
}
# 为所有路径应用安全要求
for path in openapi_schema["paths"]:
for method in openapi_schema["paths"][path]:
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
import re
# ------------------- 辅助函数 -------------------
def is_english(text: str) -> bool:
"""
简单判断是否为英文 prompt。
如果有中文字符则认为不是英文。
判断文本是否为英文
- 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False
"""
for char in text:
if '\u4e00' <= char <= '\u9fff':
@@ -142,7 +153,8 @@ def is_english(text: str) -> bool:
def translate_to_sam3_prompt(text: str) -> str:
"""
使用模型将非英文提示词翻译为适合 SAM3 的英文提示词
使用 Qwen 模型将文提示词翻译为英文
- SAM3 模型对英文 Prompt 支持更好
"""
print(f"正在翻译提示词: {text}")
try:
@@ -164,7 +176,7 @@ def translate_to_sam3_prompt(text: str) -> str:
return translated_text
else:
print(f"翻译失败: {response.code} - {response.message}")
return text # Fallback to original
return text # 失败则回退到原始文本
except Exception as e:
print(f"翻译异常: {e}")
return text
@@ -172,6 +184,7 @@ def translate_to_sam3_prompt(text: str) -> str:
def order_points(pts):
"""
对四个坐标点进行排序:左上,右上,右下,左下
用于透视变换前的点位整理
"""
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
@@ -184,7 +197,8 @@ def order_points(pts):
def four_point_transform(image, pts):
"""
根据四个点进行透视变换
根据四个点进行透视变换 (Perspective Transform)
用于将倾斜的卡片矫正为矩形
"""
rect = order_points(pts)
(tl, tr, br, bl) = rect
@@ -210,6 +224,9 @@ def four_point_transform(image, pts):
return warped
def load_image_from_url(url: str) -> Image.Image:
"""
从 URL 下载图片并转换为 PIL Image 对象
"""
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
@@ -222,7 +239,17 @@ def load_image_from_url(url: str) -> Image.Image:
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]:
"""
根据 mask 和 box 进行处理并保存独立的对象图片
:param cutout: 如果为 True则进行轮廓抠图透明背景否则进行透视矫正主要用于卡片
参数:
- image: 原始图片
- masks: 分割掩码列表
- boxes: 边界框列表
- output_dir: 输出目录
- is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑)
- cutout: 如果为 True则进行轮廓抠图透明背景否则进行透视矫正主要用于卡片
返回:
- 保存的对象信息列表
"""
saved_objects = []
# Convert image to numpy array (RGB)
@@ -256,29 +283,18 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
img_rgba = image.copy()
# 2. 准备 Alpha Mask
# mask_uint8 已经是 0-255可以直接作为 Alpha 通道的基础
# 将 Mask 缩放到和图片一样大 (一般是一样的,但以防万一)
if mask_uint8.shape != img_rgba.size[::-1]: # size is (w, h), shape is (h, w)
# 如果尺寸不一致可能需要 resize这里假设尺寸一致
pass
mask_img = Image.fromarray(mask_uint8, mode='L')
# 3. 将 Mask 应用到 Alpha 通道
# 创建一个新的空白透明图
cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
# 将原图粘贴上去,使用 mask 作为 mask
cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img)
# 4. Crop to Box
# box is [x1, y1, x2, y2]
x1, y1, x2, y2 = map(int, box_np)
# 边界检查
w, h = cutout_img.size
x1 = max(0, x1); y1 = max(0, y1)
x2 = min(w, x2); y2 = min(h, y2)
# 避免无效 crop
if x2 > x1 and y2 > y1:
final_img = cutout_img.crop((x1, y1, x2, y2))
else:
@@ -286,7 +302,7 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
# Save
prefix = "cutout"
is_rotated = False # 抠图模式下不进行自动旋转
is_rotated = False
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
@@ -299,65 +315,51 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
})
else:
# --- 透视矫正模式 (原有逻辑) ---
# Find contours
# --- 透视矫正模式 (矩形矫正) ---
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
continue
# Get largest contour
c = max(contours, key=cv2.contourArea)
# Approximate contour to polygon
peri = cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
# If we have 4 points, use them. If not, fallback to minAreaRect
if len(approx) == 4:
pts = approx.reshape(4, 2)
else:
rect = cv2.minAreaRect(c)
pts = cv2.boxPoints(rect)
# Apply perspective transform
# Check if original image has Alpha
if img_arr.shape[2] == 4:
warped = four_point_transform(img_arr, pts)
else:
warped = four_point_transform(img_arr, pts)
warped = four_point_transform(img_arr, pts)
# Check orientation (Portrait vs Landscape)
h, w = warped.shape[:2]
is_rotated = False
# Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx)
# 强制竖屏逻辑 (塔罗牌通常是竖屏)
if is_tarot and w > h:
# Rotate 90 degrees clockwise
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
is_rotated = True
# Convert back to PIL
pil_warped = Image.fromarray(warped)
# Save
prefix = "tarot" if is_tarot else "segment"
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
save_path = os.path.join(output_dir, filename)
pil_warped.save(save_path)
# 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒)
# 这里我们假设长边垂直为正位,如果做了旋转则标记
# 真正的正逆位需要OCR或图像识别
saved_objects.append({
"filename": filename,
"is_rotated_by_algorithm": is_rotated,
"note": "Geometric correction applied. True upright/reversed requires content analysis."
"note": "Geometric correction applied."
})
return saved_objects
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
"""
生成并保存包含 Mask 叠加效果的完整结果图
"""
filename = f"seg_{uuid.uuid4().hex}.jpg"
save_path = os.path.join(output_dir, filename)
plot_results(image, inference_state)
@@ -370,21 +372,14 @@ def recognize_card_with_qwen(image_path: str) -> dict:
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
"""
try:
# 确保路径是绝对路径
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
# -------------------------------------------------
# 优化策略生成一张旋转180度的对比图
# 让 AI 做选择题而不是判断题,大幅提高准确率
# -------------------------------------------------
try:
# 1. 打开原图
img = Image.open(abs_path)
# 2. 生成旋转图 (180度)
rotated_img = img.rotate(180)
# 3. 保存旋转图
dir_name = os.path.dirname(abs_path)
file_name = os.path.basename(abs_path)
@@ -395,8 +390,6 @@ def recognize_card_with_qwen(image_path: str) -> dict:
rotated_file_url = f"file://{rotated_path}"
# 4. 构建对比 Prompt
# 发送两张图图1=原图, 图2=旋转图
# 询问 AI 哪一张是“正位”
messages = [
{
"role": "user",
@@ -422,7 +415,6 @@ def recognize_card_with_qwen(image_path: str) -> dict:
except Exception as e:
print(f"对比图生成失败,回退到单图模式: {e}")
# 回退到旧的单图模式
messages = [
{
"role": "user",
@@ -433,15 +425,12 @@ def recognize_card_with_qwen(image_path: str) -> dict:
}
]
# 调用模型
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
# 尝试解析简单的 JSON
import json
try:
# 清理可能存在的 markdown 标记
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result
@@ -458,7 +447,6 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
调用 Qwen-VL 识别塔罗牌牌阵
"""
try:
# 确保路径是绝对路径并加上 file:// 前缀
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
@@ -476,10 +464,8 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
# 尝试解析简单的 JSON
import json
try:
# 清理可能存在的 markdown 标记
clean_content = content.replace("```json", "").replace("```", "").strip()
result = json.loads(clean_content)
return result
@@ -491,27 +477,82 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
except Exception as e:
return {"error": f"牌阵识别失败: {str(e)}"}
# ------------------- API 接口 (强制依赖验证) -------------------
@app.post("/segment", dependencies=[Depends(verify_api_key)])
# ==========================================
# 6. FastAPI App Setup (应用初始化)
# ==========================================
app = FastAPI(
lifespan=lifespan,
title="SAM3 Segmentation API",
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
openapi_tags=[
{"name": TAG_GENERAL, "description": "General purpose image segmentation"},
{"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"},
{"name": TAG_FACE, "description": "Face detection and analysis"},
]
)
# 手动添加 OpenAPI 安全配置
app.openapi_schema = None
def custom_openapi():
if app.openapi_schema:
return app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=app.title,
version=app.version,
description=app.description,
routes=app.routes,
)
openapi_schema["components"]["securitySchemes"] = {
"APIKeyHeader": {
"type": "apiKey",
"in": "header",
"name": API_KEY_HEADER_NAME,
}
}
for path in openapi_schema["paths"]:
for method in openapi_schema["paths"][path]:
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
# ==========================================
# 7. API Endpoints (接口定义)
# ==========================================
# ------------------------------------------
# Group 1: General Segmentation (通用分割)
# ------------------------------------------
@app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)])
async def segment(
request: Request,
prompt: str = Form(...),
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
save_segment_images: bool = Form(False),
cutout: bool = Form(False)
prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"),
file: Optional[UploadFile] = File(None, description="Image file to upload"),
image_url: Optional[str] = Form(None, description="URL of the image"),
save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"),
cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"),
confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.")
):
"""
**通用图像分割接口**
- 支持上传图片或提供图片 URL
- 支持自动将中文 Prompt 翻译为英文
- 支持手动设置置信度 (Confidence Threshold)
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# ------------------- Prompt 翻译/优化 -------------------
# 1. Prompt 处理
final_prompt = prompt
if not is_english(prompt):
# 使用 Qwen 进行翻译
try:
# 构建消息请求翻译
# 注意translate_to_sam3_prompt 是一个阻塞调用,可能会增加耗时
# 但对于分割任务来说是可以接受的
translated = translate_to_sam3_prompt(prompt)
if translated:
final_prompt = translated
@@ -520,6 +561,7 @@ async def segment(
print(f"最终使用的 Prompt: {final_prompt}")
# 2. 图片加载
try:
if file:
image = Image.open(file.file).convert("RGB")
@@ -530,13 +572,34 @@ async def segment(
processor = request.app.state.processor
# 3. 模型推理
try:
# 设置图片
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
# 处理 Confidence Threshold
# 如果用户提供了 confidence临时修改 processor 的阈值
original_confidence = processor.confidence_threshold
if confidence is not None:
# 简单校验
if not (0.0 <= confidence <= 1.0):
raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0")
processor.confidence_threshold = confidence
print(f"Using manual confidence threshold: {confidence}")
try:
# 执行推理
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
finally:
# 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯)
if confidence is not None:
processor.confidence_threshold = original_confidence
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 4. 结果可视化与保存
try:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
@@ -544,7 +607,7 @@ async def segment(
file_url = request.url_for("static", path=f"results/{filename}")
# New logic for saving segments
# 5. 保存分割子图 (Optional)
saved_segments_info = []
if save_segment_images:
try:
@@ -552,7 +615,6 @@ async def segment(
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
# 传递 cutout 参数
saved_objects = crop_and_save_objects(
image,
masks,
@@ -586,7 +648,11 @@ async def segment(
return JSONResponse(content=response_content)
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
# ------------------------------------------
# Group 2: Tarot Analysis (塔罗牌分析)
# ------------------------------------------
@app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
async def segment_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
@@ -594,9 +660,11 @@ async def segment_tarot(
expected_count: int = Form(3)
):
"""
塔罗牌分割专用接口
**塔罗牌分割接口**
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
2. 如果是,分别抠出这些牌并返回
2. 对检测到的卡片进行透视矫正和裁剪
3. 返回矫正后的图片 URL
"""
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
@@ -622,7 +690,7 @@ async def segment_tarot(
# 核心逻辑:判断数量
detected_count = len(masks)
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
# 创建本次请求的独立文件夹
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
@@ -678,7 +746,7 @@ async def segment_tarot(
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)])
@app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
async def recognize_tarot(
request: Request,
file: Optional[UploadFile] = File(None),
@@ -686,7 +754,8 @@ async def recognize_tarot(
expected_count: int = Form(3)
):
"""
塔罗牌全流程接口: 分割 + 矫正 + 识别
**塔罗牌全流程接口: 分割 + 矫正 + 识别**
1. 检测是否包含指定数量的塔罗牌 (SAM3)
2. 分割并透视矫正
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
@@ -706,21 +775,18 @@ async def recognize_tarot(
try:
inference_state = processor.set_image(image)
# 固定 Prompt 检测塔罗牌
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 核心逻辑:判断数量
detected_count = len(masks)
# 创建本次请求的独立文件夹
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
os.makedirs(output_dir, exist_ok=True)
# 保存整体效果图 (无论是成功还是失败,都先保存一张主图)
# 保存整体效果图
try:
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
main_file_path = os.path.join(output_dir, main_filename)
@@ -730,20 +796,14 @@ async def recognize_tarot(
main_file_path = None
main_file_url = None
# Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断)
# Step 0: 牌阵识别
spread_info = {"spread_name": "Unknown"}
if main_file_path:
# 使用带有mask绘制的主图或者原始图
# 使用原始图可能更好不受mask遮挡干扰但是main_filename是带mask的。
# 我们这里暂时用原始图保存一份临时文件给Qwen看
# 使用原始图的一份拷贝给 Qwen 识别牌阵
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
image.save(temp_raw_path)
spread_info = recognize_spread_with_qwen(temp_raw_path)
# 如果识别结果明确说是“不是正规牌阵”,是否要继续?
# 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’”
# 我们将其放在返回结果中,由客户端决定是否展示警告
if detected_count != expected_count:
return JSONResponse(
status_code=400,
@@ -768,8 +828,7 @@ async def recognize_tarot(
fname = obj["filename"]
file_path = os.path.join(output_dir, fname)
# 调用 Qwen-VL 识别
# 注意:这里会串行调用,速度可能较慢。
# 调用 Qwen-VL 识别 (串行)
recognition_res = recognize_card_with_qwen(file_path)
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
@@ -790,15 +849,20 @@ async def recognize_tarot(
"scores": scores.tolist() if torch.is_tensor(scores) else scores
})
@app.post("/segment_face", dependencies=[Depends(verify_api_key)])
# ------------------------------------------
# Group 3: Face Analysis (人脸分析)
# ------------------------------------------
@app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)])
async def segment_face(
request: Request,
file: Optional[UploadFile] = File(None),
image_url: Optional[str] = Form(None),
prompt: str = Form("face and hair") # 默认提示词包含头发
prompt: str = Form("face and hair", description="Prompt for face detection")
):
"""
人脸/头部检测与属性分析接口 (新功能)
**人脸/头部检测与属性分析接口**
1. 调用 SAM3 分割出头部区域 (含头发)
2. 裁剪并保存
3. 调用 Qwen-VL 识别性别和年龄
@@ -806,7 +870,7 @@ async def segment_face(
if not file and not image_url:
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
# ------------------- Prompt 翻译/优化 -------------------
# Prompt 翻译/优化
final_prompt = prompt
if not is_english(prompt):
try:
@@ -818,7 +882,7 @@ async def segment_face(
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
# 1. 加载图片
# 加载图片
try:
if file:
image = Image.open(file.file).convert("RGB")
@@ -829,10 +893,8 @@ async def segment_face(
processor = request.app.state.processor
# 2. 调用独立服务进行处理
# 调用独立服务进行处理
try:
# 传入 processor 和 image
# 注意Result Image Dir 我们直接复用 RESULT_IMAGE_DIR
result = human_analysis_service.process_face_segmentation_and_analysis(
processor=processor,
image=image,
@@ -840,34 +902,34 @@ async def segment_face(
output_base_dir=RESULT_IMAGE_DIR
)
except Exception as e:
# 打印详细错误堆栈以便调试
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
# 3. 补全 URL (因为 service 层不知道 request context)
# 补全 URL
if result["status"] == "success":
# 处理全图可视化的 URL
if result.get("full_visualization"):
full_vis_rel_path = result["full_visualization"]
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
for item in result["results"]:
# item["relative_path"] 是相对路径,如 results/xxx/xxx.jpg
# 我们需要将其转换为完整 URL
relative_path = item.pop("relative_path") # 移除相对路径字段,只返回 URL
relative_path = item.pop("relative_path")
item["url"] = str(request.url_for("static", path=relative_path))
return JSONResponse(content=result)
# ==========================================
# 8. Main Entry Point (启动入口)
# ==========================================
if __name__ == "__main__":
import uvicorn
# 注意:如果你的文件名不是 fastAPI_tarot.py请修改下面第一个参数
# 启动服务器
uvicorn.run(
"fastAPI_tarot:app",
host="127.0.0.1",
port=55600,
proxy_headers=True,
forwarded_allow_ips="*",
reload=False # 生产环境建议关闭 reload确保代码完全重载
reload=False
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB