diff --git a/fastAPI_tarot.py b/fastAPI_tarot.py index 8f6e7c7..6271c4b 100644 --- a/fastAPI_tarot.py +++ b/fastAPI_tarot.py @@ -1,68 +1,112 @@ +""" +SAM3 Segmentation API Service (SAM3 图像分割 API 服务) +-------------------------------------------------- +本服务提供基于 SAM3 模型的图像分割能力。 +包含通用分割、塔罗牌分析和人脸分析等专用接口。 + +This service provides image segmentation capabilities using the SAM3 model. +It includes specialized endpoints for Tarot card analysis and Face analysis. +""" + +# ========================================== +# 1. Imports (导入) +# ========================================== + +# Standard Library Imports (标准库) import os import uuid import time -import requests -import numpy as np -import cv2 -from typing import Optional +import json +import traceback +import re +from typing import Optional, List, Dict, Any from contextlib import asynccontextmanager -import dashscope -from dashscope import MultiModalConversation +# Third-Party Imports (第三方库) +import cv2 import torch +import numpy as np +import requests import matplotlib -matplotlib.use('Agg') +matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错 import matplotlib.pyplot as plt +from PIL import Image -from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status +# FastAPI Imports +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter from fastapi.security import APIKeyHeader from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse -from PIL import Image +# Dashscope (Aliyun Qwen) Imports +import dashscope +from dashscope import MultiModalConversation + +# Local Imports (本地模块) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results -import human_analysis_service # 引入新服务 +import human_analysis_service # 引入新服务: 人脸分析 -# ------------------- 配置与路径 ------------------- +# ========================================== +# 2. Configuration & Constants (配置与常量) +# ========================================== + +# 路径配置 STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) -# ------------------- API Key 核心配置 (已加固) ------------------- +# API Key 配置 VALID_API_KEY = "123quant-speed" API_KEY_HEADER_NAME = "X-API-Key" -# Dashscope 配置 (Qwen-VL) +# Dashscope (Qwen-VL) 配置 dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58' QWEN_MODEL = 'qwen-vl-max' -# 定义 Header 认证 +# API Tags (用于文档分类) +TAG_GENERAL = "General Segmentation (通用分割)" +TAG_TAROT = "Tarot Analysis (塔罗牌分析)" +TAG_FACE = "Face Analysis (人脸分析)" + +# ========================================== +# 3. Security & Middleware (安全与中间件) +# ========================================== + +# 定义 Header 认证 Scheme api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): """ - 强制验证 API Key + 强制验证 API Key (Enforce API Key Verification) + + 1. 检查 Header 中是否存在 API Key + 2. 验证 API Key 是否匹配 """ - # 1. 检查是否有 Key if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API Key. Please provide it in the header." ) - # 2. 检查 Key 是否正确 if api_key != VALID_API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key." ) - # 3. 验证通过 return True -# ------------------- 生命周期管理 ------------------- +# ========================================== +# 4. Lifespan Management (生命周期管理) +# ========================================== + @asynccontextmanager async def lifespan(app: FastAPI): + """ + FastAPI 生命周期管理器 + - 启动时: 加载模型到 GPU/CPU + - 关闭时: 清理资源 + """ print("="*40) print("✅ API Key 保护已激活") print(f"✅ 有效 Key: {VALID_API_KEY}") @@ -73,12 +117,15 @@ async def lifespan(app: FastAPI): if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") + # 加载模型 model = build_sam3_image_model() model = model.to(device) model.eval() + # 初始化处理器 processor = Sam3Processor(model) + # 存储到 app.state 以便全局访问 app.state.model = model app.state.processor = processor app.state.device = device @@ -88,52 +135,16 @@ async def lifespan(app: FastAPI): yield print("正在清理资源...") + # 这里可以添加释放显存的逻辑,如果需要 -# ------------------- FastAPI 初始化 ------------------- -app = FastAPI( - lifespan=lifespan, - title="SAM3 Segmentation API", - description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`", -) +# ========================================== +# 5. Helper Functions (辅助函数) +# ========================================== -# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效 -app.openapi_schema = None -def custom_openapi(): - if app.openapi_schema: - return app.openapi_schema - from fastapi.openapi.utils import get_openapi - openapi_schema = get_openapi( - title=app.title, - version=app.version, - description=app.description, - routes=app.routes, - ) - # 定义安全方案 - openapi_schema["components"]["securitySchemes"] = { - "APIKeyHeader": { - "type": "apiKey", - "in": "header", - "name": API_KEY_HEADER_NAME, - } - } - # 为所有路径应用安全要求 - for path in openapi_schema["paths"]: - for method in openapi_schema["paths"][path]: - openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] - app.openapi_schema = openapi_schema - return app.openapi_schema - -app.openapi = custom_openapi - -app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") - -import re - -# ------------------- 辅助函数 ------------------- def is_english(text: str) -> bool: """ - 简单判断是否为英文 prompt。 - 如果有中文字符则认为不是英文。 + 判断文本是否为纯英文 + - 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False """ for char in text: if '\u4e00' <= char <= '\u9fff': @@ -142,7 +153,8 @@ def is_english(text: str) -> bool: def translate_to_sam3_prompt(text: str) -> str: """ - 使用大模型将非英文提示词翻译为适合 SAM3 的英文提示词 + 使用 Qwen 模型将中文提示词翻译为英文 + - SAM3 模型对英文 Prompt 支持更好 """ print(f"正在翻译提示词: {text}") try: @@ -164,7 +176,7 @@ def translate_to_sam3_prompt(text: str) -> str: return translated_text else: print(f"翻译失败: {response.code} - {response.message}") - return text # Fallback to original + return text # 失败则回退到原始文本 except Exception as e: print(f"翻译异常: {e}") return text @@ -172,6 +184,7 @@ def translate_to_sam3_prompt(text: str) -> str: def order_points(pts): """ 对四个坐标点进行排序:左上,右上,右下,左下 + 用于透视变换前的点位整理 """ rect = np.zeros((4, 2), dtype="float32") s = pts.sum(axis=1) @@ -184,7 +197,8 @@ def order_points(pts): def four_point_transform(image, pts): """ - 根据四个点进行透视变换 + 根据四个点进行透视变换 (Perspective Transform) + 用于将倾斜的卡片矫正为矩形 """ rect = order_points(pts) (tl, tr, br, bl) = rect @@ -210,6 +224,9 @@ def four_point_transform(image, pts): return warped def load_image_from_url(url: str) -> Image.Image: + """ + 从 URL 下载图片并转换为 PIL Image 对象 + """ try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) @@ -222,7 +239,17 @@ def load_image_from_url(url: str) -> Image.Image: def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]: """ 根据 mask 和 box 进行处理并保存独立的对象图片 - :param cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片) + + 参数: + - image: 原始图片 + - masks: 分割掩码列表 + - boxes: 边界框列表 + - output_dir: 输出目录 + - is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑) + - cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片) + + 返回: + - 保存的对象信息列表 """ saved_objects = [] # Convert image to numpy array (RGB) @@ -256,29 +283,18 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE img_rgba = image.copy() # 2. 准备 Alpha Mask - # mask_uint8 已经是 0-255,可以直接作为 Alpha 通道的基础 - # 将 Mask 缩放到和图片一样大 (一般是一样的,但以防万一) - if mask_uint8.shape != img_rgba.size[::-1]: # size is (w, h), shape is (h, w) - # 如果尺寸不一致可能需要 resize,这里假设尺寸一致 - pass - mask_img = Image.fromarray(mask_uint8, mode='L') # 3. 将 Mask 应用到 Alpha 通道 - # 创建一个新的空白透明图 cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0)) - # 将原图粘贴上去,使用 mask 作为 mask cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img) # 4. Crop to Box - # box is [x1, y1, x2, y2] x1, y1, x2, y2 = map(int, box_np) - # 边界检查 w, h = cutout_img.size x1 = max(0, x1); y1 = max(0, y1) x2 = min(w, x2); y2 = min(h, y2) - # 避免无效 crop if x2 > x1 and y2 > y1: final_img = cutout_img.crop((x1, y1, x2, y2)) else: @@ -286,7 +302,7 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE # Save prefix = "cutout" - is_rotated = False # 抠图模式下不进行自动旋转 + is_rotated = False filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" save_path = os.path.join(output_dir, filename) @@ -299,65 +315,51 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE }) else: - # --- 透视矫正模式 (原有逻辑) --- - # Find contours + # --- 透视矫正模式 (矩形矫正) --- contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: continue - # Get largest contour c = max(contours, key=cv2.contourArea) - - # Approximate contour to polygon peri = cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, 0.04 * peri, True) - # If we have 4 points, use them. If not, fallback to minAreaRect if len(approx) == 4: pts = approx.reshape(4, 2) else: rect = cv2.minAreaRect(c) pts = cv2.boxPoints(rect) - # Apply perspective transform - # Check if original image has Alpha - if img_arr.shape[2] == 4: - warped = four_point_transform(img_arr, pts) - else: - warped = four_point_transform(img_arr, pts) + warped = four_point_transform(img_arr, pts) # Check orientation (Portrait vs Landscape) h, w = warped.shape[:2] is_rotated = False - # Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx) + # 强制竖屏逻辑 (塔罗牌通常是竖屏) if is_tarot and w > h: - # Rotate 90 degrees clockwise warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) is_rotated = True - # Convert back to PIL pil_warped = Image.fromarray(warped) - # Save prefix = "tarot" if is_tarot else "segment" filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" save_path = os.path.join(output_dir, filename) pil_warped.save(save_path) - # 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒) - # 这里我们假设长边垂直为正位,如果做了旋转则标记 - # 真正的正逆位需要OCR或图像识别 - saved_objects.append({ "filename": filename, "is_rotated_by_algorithm": is_rotated, - "note": "Geometric correction applied. True upright/reversed requires content analysis." + "note": "Geometric correction applied." }) return saved_objects def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str: + """ + 生成并保存包含 Mask 叠加效果的完整结果图 + """ filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(output_dir, filename) plot_results(image, inference_state) @@ -370,21 +372,14 @@ def recognize_card_with_qwen(image_path: str) -> dict: 调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略) """ try: - # 确保路径是绝对路径 abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" - # ------------------------------------------------- - # 优化策略:生成一张旋转180度的对比图 - # 让 AI 做选择题而不是判断题,大幅提高准确率 - # ------------------------------------------------- try: # 1. 打开原图 img = Image.open(abs_path) - # 2. 生成旋转图 (180度) rotated_img = img.rotate(180) - # 3. 保存旋转图 dir_name = os.path.dirname(abs_path) file_name = os.path.basename(abs_path) @@ -395,8 +390,6 @@ def recognize_card_with_qwen(image_path: str) -> dict: rotated_file_url = f"file://{rotated_path}" # 4. 构建对比 Prompt - # 发送两张图:图1=原图, 图2=旋转图 - # 询问 AI 哪一张是“正位” messages = [ { "role": "user", @@ -422,7 +415,6 @@ def recognize_card_with_qwen(image_path: str) -> dict: except Exception as e: print(f"对比图生成失败,回退到单图模式: {e}") - # 回退到旧的单图模式 messages = [ { "role": "user", @@ -433,15 +425,12 @@ def recognize_card_with_qwen(image_path: str) -> dict: } ] - # 调用模型 response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: content = response.output.choices[0].message.content[0]['text'] - # 尝试解析简单的 JSON import json try: - # 清理可能存在的 markdown 标记 clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result @@ -458,7 +447,6 @@ def recognize_spread_with_qwen(image_path: str) -> dict: 调用 Qwen-VL 识别塔罗牌牌阵 """ try: - # 确保路径是绝对路径并加上 file:// 前缀 abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" @@ -476,10 +464,8 @@ def recognize_spread_with_qwen(image_path: str) -> dict: if response.status_code == 200: content = response.output.choices[0].message.content[0]['text'] - # 尝试解析简单的 JSON import json try: - # 清理可能存在的 markdown 标记 clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result @@ -491,27 +477,82 @@ def recognize_spread_with_qwen(image_path: str) -> dict: except Exception as e: return {"error": f"牌阵识别失败: {str(e)}"} -# ------------------- API 接口 (强制依赖验证) ------------------- -@app.post("/segment", dependencies=[Depends(verify_api_key)]) +# ========================================== +# 6. FastAPI App Setup (应用初始化) +# ========================================== + +app = FastAPI( + lifespan=lifespan, + title="SAM3 Segmentation API", + description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`", + openapi_tags=[ + {"name": TAG_GENERAL, "description": "General purpose image segmentation"}, + {"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"}, + {"name": TAG_FACE, "description": "Face detection and analysis"}, + ] +) + +# 手动添加 OpenAPI 安全配置 +app.openapi_schema = None +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + from fastapi.openapi.utils import get_openapi + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + openapi_schema["components"]["securitySchemes"] = { + "APIKeyHeader": { + "type": "apiKey", + "in": "header", + "name": API_KEY_HEADER_NAME, + } + } + for path in openapi_schema["paths"]: + for method in openapi_schema["paths"][path]: + openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] + app.openapi_schema = openapi_schema + return app.openapi_schema + +app.openapi = custom_openapi + +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + +# ========================================== +# 7. API Endpoints (接口定义) +# ========================================== + +# ------------------------------------------ +# Group 1: General Segmentation (通用分割) +# ------------------------------------------ + +@app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)]) async def segment( request: Request, - prompt: str = Form(...), - file: Optional[UploadFile] = File(None), - image_url: Optional[str] = Form(None), - save_segment_images: bool = Form(False), - cutout: bool = Form(False) + prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"), + file: Optional[UploadFile] = File(None, description="Image file to upload"), + image_url: Optional[str] = Form(None, description="URL of the image"), + save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"), + cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"), + confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.") ): + """ + **通用图像分割接口** + + - 支持上传图片或提供图片 URL + - 支持自动将中文 Prompt 翻译为英文 + - 支持手动设置置信度 (Confidence Threshold) + """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") - # ------------------- Prompt 翻译/优化 ------------------- + # 1. Prompt 处理 final_prompt = prompt if not is_english(prompt): - # 使用 Qwen 进行翻译 try: - # 构建消息请求翻译 - # 注意:translate_to_sam3_prompt 是一个阻塞调用,可能会增加耗时 - # 但对于分割任务来说是可以接受的 translated = translate_to_sam3_prompt(prompt) if translated: final_prompt = translated @@ -520,6 +561,7 @@ async def segment( print(f"最终使用的 Prompt: {final_prompt}") + # 2. 图片加载 try: if file: image = Image.open(file.file).convert("RGB") @@ -529,14 +571,35 @@ async def segment( raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor - + + # 3. 模型推理 try: + # 设置图片 inference_state = processor.set_image(image) - output = processor.set_text_prompt(state=inference_state, prompt=final_prompt) - masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + + # 处理 Confidence Threshold + # 如果用户提供了 confidence,临时修改 processor 的阈值 + original_confidence = processor.confidence_threshold + if confidence is not None: + # 简单校验 + if not (0.0 <= confidence <= 1.0): + raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0") + processor.confidence_threshold = confidence + print(f"Using manual confidence threshold: {confidence}") + + try: + # 执行推理 + output = processor.set_text_prompt(state=inference_state, prompt=final_prompt) + masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + finally: + # 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯) + if confidence is not None: + processor.confidence_threshold = original_confidence + except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") + # 4. 结果可视化与保存 try: filename = generate_and_save_result(image, inference_state) except Exception as e: @@ -544,7 +607,7 @@ async def segment( file_url = request.url_for("static", path=f"results/{filename}") - # New logic for saving segments + # 5. 保存分割子图 (Optional) saved_segments_info = [] if save_segment_images: try: @@ -552,7 +615,6 @@ async def segment( output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) - # 传递 cutout 参数 saved_objects = crop_and_save_objects( image, masks, @@ -586,7 +648,11 @@ async def segment( return JSONResponse(content=response_content) -@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)]) +# ------------------------------------------ +# Group 2: Tarot Analysis (塔罗牌分析) +# ------------------------------------------ + +@app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)]) async def segment_tarot( request: Request, file: Optional[UploadFile] = File(None), @@ -594,9 +660,11 @@ async def segment_tarot( expected_count: int = Form(3) ): """ - 塔罗牌分割专用接口 + **塔罗牌分割接口** + 1. 检测是否包含指定数量的塔罗牌 (默认为 3) - 2. 如果是,分别抠出这些牌并返回 + 2. 对检测到的卡片进行透视矫正和裁剪 + 3. 返回矫正后的图片 URL """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") @@ -622,7 +690,7 @@ async def segment_tarot( # 核心逻辑:判断数量 detected_count = len(masks) - # 创建本次请求的独立文件夹 (时间戳_UUID前8位) + # 创建本次请求的独立文件夹 request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) @@ -678,7 +746,7 @@ async def segment_tarot( "scores": scores.tolist() if torch.is_tensor(scores) else scores }) -@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)]) +@app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)]) async def recognize_tarot( request: Request, file: Optional[UploadFile] = File(None), @@ -686,7 +754,8 @@ async def recognize_tarot( expected_count: int = Form(3) ): """ - 塔罗牌全流程接口: 分割 + 矫正 + 识别 + **塔罗牌全流程接口: 分割 + 矫正 + 识别** + 1. 检测是否包含指定数量的塔罗牌 (SAM3) 2. 分割并透视矫正 3. 调用 Qwen-VL 识别每张牌的名称和正逆位 @@ -706,21 +775,18 @@ async def recognize_tarot( try: inference_state = processor.set_image(image) - # 固定 Prompt 检测塔罗牌 output = processor.set_text_prompt(state=inference_state, prompt="tarot card") masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") - # 核心逻辑:判断数量 detected_count = len(masks) - # 创建本次请求的独立文件夹 request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) - # 保存整体效果图 (无论是成功还是失败,都先保存一张主图) + # 保存整体效果图 try: main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) main_file_path = os.path.join(output_dir, main_filename) @@ -730,20 +796,14 @@ async def recognize_tarot( main_file_path = None main_file_url = None - # Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断) + # Step 0: 牌阵识别 spread_info = {"spread_name": "Unknown"} if main_file_path: - # 使用带有mask绘制的主图或者原始图? - # 使用原始图可能更好,不受mask遮挡干扰,但是main_filename是带mask的。 - # 我们这里暂时用原始图保存一份临时文件给Qwen看 + # 使用原始图的一份拷贝给 Qwen 识别牌阵 temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg") image.save(temp_raw_path) spread_info = recognize_spread_with_qwen(temp_raw_path) - # 如果识别结果明确说是“不是正规牌阵”,是否要继续? - # 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’” - # 我们将其放在返回结果中,由客户端决定是否展示警告 - if detected_count != expected_count: return JSONResponse( status_code=400, @@ -768,8 +828,7 @@ async def recognize_tarot( fname = obj["filename"] file_path = os.path.join(output_dir, fname) - # 调用 Qwen-VL 识别 - # 注意:这里会串行调用,速度可能较慢。 + # 调用 Qwen-VL 识别 (串行) recognition_res = recognize_card_with_qwen(file_path) file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) @@ -790,15 +849,20 @@ async def recognize_tarot( "scores": scores.tolist() if torch.is_tensor(scores) else scores }) -@app.post("/segment_face", dependencies=[Depends(verify_api_key)]) +# ------------------------------------------ +# Group 3: Face Analysis (人脸分析) +# ------------------------------------------ + +@app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)]) async def segment_face( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), - prompt: str = Form("face and hair") # 默认提示词包含头发 + prompt: str = Form("face and hair", description="Prompt for face detection") ): """ - 人脸/头部检测与属性分析接口 (新功能) + **人脸/头部检测与属性分析接口** + 1. 调用 SAM3 分割出头部区域 (含头发) 2. 裁剪并保存 3. 调用 Qwen-VL 识别性别和年龄 @@ -806,7 +870,7 @@ async def segment_face( if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") - # ------------------- Prompt 翻译/优化 ------------------- + # Prompt 翻译/优化 final_prompt = prompt if not is_english(prompt): try: @@ -818,7 +882,7 @@ async def segment_face( print(f"Face Segment 最终使用的 Prompt: {final_prompt}") - # 1. 加载图片 + # 加载图片 try: if file: image = Image.open(file.file).convert("RGB") @@ -829,10 +893,8 @@ async def segment_face( processor = request.app.state.processor - # 2. 调用独立服务进行处理 + # 调用独立服务进行处理 try: - # 传入 processor 和 image - # 注意:Result Image Dir 我们直接复用 RESULT_IMAGE_DIR result = human_analysis_service.process_face_segmentation_and_analysis( processor=processor, image=image, @@ -840,34 +902,34 @@ async def segment_face( output_base_dir=RESULT_IMAGE_DIR ) except Exception as e: - # 打印详细错误堆栈以便调试 import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") - # 3. 补全 URL (因为 service 层不知道 request context) + # 补全 URL if result["status"] == "success": - # 处理全图可视化的 URL if result.get("full_visualization"): full_vis_rel_path = result["full_visualization"] result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path)) for item in result["results"]: - # item["relative_path"] 是相对路径,如 results/xxx/xxx.jpg - # 我们需要将其转换为完整 URL - relative_path = item.pop("relative_path") # 移除相对路径字段,只返回 URL + relative_path = item.pop("relative_path") item["url"] = str(request.url_for("static", path=relative_path)) return JSONResponse(content=result) +# ========================================== +# 8. Main Entry Point (启动入口) +# ========================================== + if __name__ == "__main__": import uvicorn - # 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数 + # 启动服务器 uvicorn.run( "fastAPI_tarot:app", host="127.0.0.1", port=55600, proxy_headers=True, forwarded_allow_ips="*", - reload=False # 生产环境建议关闭 reload,确保代码完全重载 - ) \ No newline at end of file + reload=False + ) diff --git a/static/results/seg_153993616fb74472a9287c68e304e8eb.jpg b/static/results/seg_153993616fb74472a9287c68e304e8eb.jpg new file mode 100644 index 0000000..30fae08 Binary files /dev/null and b/static/results/seg_153993616fb74472a9287c68e304e8eb.jpg differ diff --git a/static/results/seg_37fe4609bc984323a0284b404456f966.jpg b/static/results/seg_37fe4609bc984323a0284b404456f966.jpg deleted file mode 100644 index 9e75b44..0000000 Binary files a/static/results/seg_37fe4609bc984323a0284b404456f966.jpg and /dev/null differ diff --git a/static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg b/static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg new file mode 100644 index 0000000..d58cda7 Binary files /dev/null and b/static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg differ diff --git a/static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg b/static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg new file mode 100644 index 0000000..368798f Binary files /dev/null and b/static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg differ diff --git a/static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg b/static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg new file mode 100644 index 0000000..8d2cf1c Binary files /dev/null and b/static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg differ diff --git a/static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg b/static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg new file mode 100644 index 0000000..d58cda7 Binary files /dev/null and b/static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg differ