""" SAM3 Segmentation API Service (SAM3 图像分割 API 服务) -------------------------------------------------- 本服务提供基于 SAM3 模型的图像分割能力。 包含通用分割、塔罗牌分析和人脸分析等专用接口。 This service provides image segmentation capabilities using the SAM3 model. It includes specialized endpoints for Tarot card analysis and Face analysis. """ # ========================================== # 1. Imports (导入) # ========================================== # Standard Library Imports (标准库) import os import uuid import time import json import traceback import re import asyncio from typing import Optional, List, Dict, Any from contextlib import asynccontextmanager # Third-Party Imports (第三方库) import cv2 import torch import numpy as np import requests import matplotlib matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错 import matplotlib.pyplot as plt from PIL import Image # FastAPI Imports from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter from fastapi.security import APIKeyHeader from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse # Dashscope (Aliyun Qwen) Imports import dashscope from dashscope import MultiModalConversation # Local Imports (本地模块) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results import human_analysis_service # 引入新服务: 人脸分析 # ========================================== # 2. Configuration & Constants (配置与常量) # ========================================== # 路径配置 STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) # API Key 配置 VALID_API_KEY = "123quant-speed" API_KEY_HEADER_NAME = "X-API-Key" # Dashscope (Qwen-VL) 配置 dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58' QWEN_MODEL = 'qwen-vl-max' # API Tags (用于文档分类) TAG_GENERAL = "General Segmentation (通用分割)" TAG_TAROT = "Tarot Analysis (塔罗牌分析)" TAG_FACE = "Face Analysis (人脸分析)" # ========================================== # 3. Security & Middleware (安全与中间件) # ========================================== # 定义 Header 认证 Scheme api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): """ 强制验证 API Key (Enforce API Key Verification) 1. 检查 Header 中是否存在 API Key 2. 验证 API Key 是否匹配 """ if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API Key. Please provide it in the header." ) if api_key != VALID_API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key." ) return True # ========================================== # 4. Lifespan Management (生命周期管理) # ========================================== async def cleanup_old_files(directory: str, lifetime_seconds: int, interval_seconds: int): """ 后台任务:定期清理过期的图片文件 """ print(f"🧹 自动清理任务已启动 | 目录: {directory} | 生命周期: {lifetime_seconds}s | 检查间隔: {interval_seconds}s") while True: try: await asyncio.sleep(interval_seconds) current_time = time.time() count = 0 # 遍历所有文件(包括子目录) for root, dirs, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) # 获取文件修改时间 try: file_mtime = os.path.getmtime(file_path) if current_time - file_mtime > lifetime_seconds: os.remove(file_path) count += 1 except OSError: pass # 文件可能已被删除 # 尝试清理空目录 (可选,仅清理二级目录) for root, dirs, files in os.walk(directory, topdown=False): for dir in dirs: dir_path = os.path.join(root, dir) try: if not os.listdir(dir_path): # 如果目录为空 os.rmdir(dir_path) except OSError: pass if count > 0: print(f"🧹 已清理 {count} 个过期文件") except asyncio.CancelledError: print("🛑 清理任务已停止") break except Exception as e: print(f"⚠️ 清理任务出错: {e}") await asyncio.sleep(60) # 出错后等待一分钟再试 @asynccontextmanager async def lifespan(app: FastAPI): """ FastAPI 生命周期管理器 - 启动时: 加载模型到 GPU/CPU - 关闭时: 清理资源 """ print("="*40) print("✅ API Key 保护已激活") print(f"✅ 有效 Key: {VALID_API_KEY}") print("="*40) print("正在加载 SAM3 模型到 GPU...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") # 加载模型 model = build_sam3_image_model() model = model.to(device) model.eval() # 初始化处理器 processor = Sam3Processor(model) # 存储到 app.state 以便全局访问 app.state.model = model app.state.processor = processor app.state.device = device print(f"模型加载完成,设备: {device}") # --- 启动后台清理任务 --- cleanup_task_handle = None # 优先读取环境变量,否则使用默认值 if os.getenv("AUTO_CLEANUP_ENABLED", "False").lower() == "true": try: lifetime = int(os.getenv("FILE_LIFETIME_SECONDS", "3600")) interval = int(os.getenv("CLEANUP_INTERVAL_SECONDS", "600")) cleanup_task_handle = asyncio.create_task( cleanup_old_files(RESULT_IMAGE_DIR, lifetime, interval) ) except Exception as e: print(f"启动清理任务失败: {e}") # ----------------------- yield print("正在清理资源...") # --- 停止后台任务 --- if cleanup_task_handle: cleanup_task_handle.cancel() try: await cleanup_task_handle except asyncio.CancelledError: pass # ------------------ # 这里可以添加释放显存的逻辑,如果需要 # ========================================== # 5. Helper Functions (辅助函数) # ========================================== def is_english(text: str) -> bool: """ 判断文本是否为纯英文 - 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False """ for char in text: if '\u4e00' <= char <= '\u9fff': return False return True def translate_to_sam3_prompt(text: str) -> str: """ 使用 Qwen 模型将中文提示词翻译为英文 - SAM3 模型对英文 Prompt 支持更好 """ print(f"正在翻译提示词: {text}") try: messages = [ { "role": "user", "content": [ {"text": f"请将以下描述翻译成简洁、精准的英文,用于图像分割模型(SAM)的提示词。直接返回英文,不要包含任何解释或其他文字。\n\n输入: {text}"} ] } ] response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: translated_text = response.output.choices[0].message.content[0]['text'].strip() # 去除可能的 markdown 标记或引号 translated_text = translated_text.replace('"', '').replace("'", "").strip() print(f"翻译结果: {translated_text}") return translated_text else: print(f"翻译失败: {response.code} - {response.message}") return text # 失败则回退到原始文本 except Exception as e: print(f"翻译异常: {e}") return text def order_points(pts): """ 对四个坐标点进行排序:左上,右上,右下,左下 用于透视变换前的点位整理 """ rect = np.zeros((4, 2), dtype="float32") s = pts.sum(axis=1) rect[0] = pts[np.argmin(s)] rect[2] = pts[np.argmax(s)] diff = np.diff(pts, axis=1) rect[1] = pts[np.argmin(diff)] rect[3] = pts[np.argmax(diff)] return rect def four_point_transform(image, pts): """ 根据四个点进行透视变换 (Perspective Transform) 用于将倾斜的卡片矫正为矩形 """ rect = order_points(pts) (tl, tr, br, bl) = rect # 计算新图像的宽度 widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) maxWidth = max(int(widthA), int(widthB)) # 计算新图像的高度 heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) maxHeight = max(int(heightA), int(heightB)) dst = np.array([ [0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32") M = cv2.getPerspectiveTransform(rect, dst) warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) return warped def load_image_from_url(url: str) -> Image.Image: """ 从 URL 下载图片并转换为 PIL Image 对象 """ try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) response.raise_for_status() image = Image.open(response.raw).convert("RGB") return image except Exception as e: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]: """ 根据 mask 和 box 进行处理并保存独立的对象图片 参数: - image: 原始图片 - masks: 分割掩码列表 - boxes: 边界框列表 - output_dir: 输出目录 - is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑) - cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片) 返回: - 保存的对象信息列表 """ saved_objects = [] # Convert image to numpy array (RGB) img_arr = np.array(image) for i, (mask, box) in enumerate(zip(masks, boxes)): # Handle tensor/numpy conversions if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy().squeeze() else: mask_np = mask.squeeze() # Handle box conversion if isinstance(box, torch.Tensor): box_np = box.cpu().numpy() else: box_np = np.array(box) # Ensure mask is uint8 binary for OpenCV/Pillow if mask_np.dtype == bool: mask_uint8 = (mask_np * 255).astype(np.uint8) else: mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255 if cutout: # --- 轮廓抠图模式 (透明背景) --- # 1. 准备 RGBA 原图 if image.mode != "RGBA": img_rgba = image.convert("RGBA") else: img_rgba = image.copy() # 2. 准备 Alpha Mask mask_img = Image.fromarray(mask_uint8, mode='L') # 3. 将 Mask 应用到 Alpha 通道 cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0)) cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img) # 4. Crop to Box x1, y1, x2, y2 = map(int, box_np) w, h = cutout_img.size x1 = max(0, x1); y1 = max(0, y1) x2 = min(w, x2); y2 = min(h, y2) if x2 > x1 and y2 > y1: final_img = cutout_img.crop((x1, y1, x2, y2)) else: final_img = cutout_img # Fallback # Save prefix = "cutout" is_rotated = False filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" save_path = os.path.join(output_dir, filename) final_img.save(save_path) saved_objects.append({ "filename": filename, "is_rotated_by_algorithm": is_rotated, "note": "Mask cutout applied. Background removed." }) else: # --- 透视矫正模式 (矩形矫正) --- contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: continue c = max(contours, key=cv2.contourArea) peri = cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, 0.04 * peri, True) if len(approx) == 4: pts = approx.reshape(4, 2) else: rect = cv2.minAreaRect(c) pts = cv2.boxPoints(rect) warped = four_point_transform(img_arr, pts) # Check orientation (Portrait vs Landscape) h, w = warped.shape[:2] is_rotated = False # 强制竖屏逻辑 (塔罗牌通常是竖屏) if is_tarot and w > h: warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) is_rotated = True pil_warped = Image.fromarray(warped) prefix = "tarot" if is_tarot else "segment" filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" save_path = os.path.join(output_dir, filename) pil_warped.save(save_path) saved_objects.append({ "filename": filename, "is_rotated_by_algorithm": is_rotated, "note": "Geometric correction applied." }) return saved_objects def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str: """ 生成并保存包含 Mask 叠加效果的完整结果图 """ filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(output_dir, filename) plot_results(image, inference_state) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() return filename def recognize_card_with_qwen(image_path: str) -> dict: """ 调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略) """ try: abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" try: # 1. 打开原图 img = Image.open(abs_path) # 2. 生成旋转图 (180度) rotated_img = img.rotate(180) # 3. 保存旋转图 dir_name = os.path.dirname(abs_path) file_name = os.path.basename(abs_path) rotated_name = f"rotated_{file_name}" rotated_path = os.path.join(dir_name, rotated_name) rotated_img.save(rotated_path) rotated_file_url = f"file://{rotated_path}" # 4. 构建对比 Prompt messages = [ { "role": "user", "content": [ {"image": file_url}, # 图1 (原图) {"image": rotated_file_url}, # 图2 (旋转180度) {"text": """这是一张塔罗牌的两个方向: 图1:原始方向 图2:旋转180度后的方向 请仔细对比两张图片的牌面内容(文字方向、人物站立方向、图案逻辑): 1. 识别这张牌的名字(中文)。 2. 判断哪一张图片展示了正确的“正位”(Upright)状态。 - 如果图1是正位,说明原图就是正位。 - 如果图2是正位,说明原图是逆位。 请以JSON格式返回,包含 'name' 和 'position' 两个字段。 例如:{'name': '愚者', 'position': '正位'} 或 {'name': '倒吊人', 'position': '逆位'}。 不要包含Markdown代码块标记。"""} ] } ] except Exception as e: print(f"对比图生成失败,回退到单图模式: {e}") messages = [ { "role": "user", "content": [ {"image": file_url}, {"text": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"} ] } ] response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: content = response.output.choices[0].message.content[0]['text'] import json try: clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result except: return {"raw_response": content} else: return {"error": f"API Error: {response.code} - {response.message}"} except Exception as e: return {"error": f"识别失败: {str(e)}"} def recognize_spread_with_qwen(image_path: str) -> dict: """ 调用 Qwen-VL 识别塔罗牌牌阵 """ try: abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" messages = [ { "role": "user", "content": [ {"image": file_url}, {"text": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵(例如:圣三角、凯尔特十字、三张牌等)。如果看不出明显的正规牌阵,请返回“不是正规牌阵”。请以JSON格式返回,包含 'spread_name' 和 'description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。"} ] } ] response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: content = response.output.choices[0].message.content[0]['text'] import json try: clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result except: return {"raw_response": content, "spread_name": "Unknown"} else: return {"error": f"API Error: {response.code} - {response.message}"} except Exception as e: return {"error": f"牌阵识别失败: {str(e)}"} # ========================================== # 6. FastAPI App Setup (应用初始化) # ========================================== app = FastAPI( lifespan=lifespan, title="SAM3 Segmentation API", description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`", openapi_tags=[ {"name": TAG_GENERAL, "description": "General purpose image segmentation"}, {"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"}, {"name": TAG_FACE, "description": "Face detection and analysis"}, ] ) # 手动添加 OpenAPI 安全配置 app.openapi_schema = None def custom_openapi(): if app.openapi_schema: return app.openapi_schema from fastapi.openapi.utils import get_openapi openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, ) openapi_schema["components"]["securitySchemes"] = { "APIKeyHeader": { "type": "apiKey", "in": "header", "name": API_KEY_HEADER_NAME, } } for path in openapi_schema["paths"]: for method in openapi_schema["paths"][path]: openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # ========================================== # 7. API Endpoints (接口定义) # ========================================== # ------------------------------------------ # Group 1: General Segmentation (通用分割) # ------------------------------------------ @app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)]) async def segment( request: Request, prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"), file: Optional[UploadFile] = File(None, description="Image file to upload"), image_url: Optional[str] = Form(None, description="URL of the image"), save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"), cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"), highlight: bool = Form(False, description="If True, darkens the background to highlight the subject (周边变黑放大)."), confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.") ): """ **通用图像分割接口** - 支持上传图片或提供图片 URL - 支持自动将中文 Prompt 翻译为英文 - 支持周边变黑放大效果 (Highlight Mode) - 支持手动设置置信度 (Confidence Threshold) """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") # 1. Prompt 处理 final_prompt = prompt if not is_english(prompt): try: translated = translate_to_sam3_prompt(prompt) if translated: final_prompt = translated except Exception as e: print(f"Prompt翻译失败,使用原始Prompt: {e}") print(f"最终使用的 Prompt: {final_prompt}") # 2. 图片加载 try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor # 3. 模型推理 try: # 设置图片 inference_state = processor.set_image(image) # 处理 Confidence Threshold # 如果用户提供了 confidence,临时修改 processor 的阈值 original_confidence = processor.confidence_threshold if confidence is not None: # 简单校验 if not (0.0 <= confidence <= 1.0): raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0") processor.confidence_threshold = confidence print(f"Using manual confidence threshold: {confidence}") try: # 执行推理 output = processor.set_text_prompt(state=inference_state, prompt=final_prompt) masks, boxes, scores = output["masks"], output["boxes"], output["scores"] finally: # 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯) if confidence is not None: processor.confidence_threshold = original_confidence except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 4. 结果可视化与保存 try: if highlight: filename = f"seg_highlight_{uuid.uuid4().hex}.jpg" save_path = os.path.join(RESULT_IMAGE_DIR, filename) # 使用 human_analysis_service 中的可视化函数 (周边变黑) human_analysis_service.create_highlighted_visualization(image, masks, save_path) else: filename = generate_and_save_result(image, inference_state) except Exception as e: raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") file_url = request.url_for("static", path=f"results/{filename}") # 5. 保存分割子图 (Optional) saved_segments_info = [] if save_segment_images: try: request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) saved_objects = crop_and_save_objects( image, masks, boxes, output_dir=output_dir, is_tarot=False, cutout=cutout ) for obj in saved_objects: fname = obj["filename"] seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) saved_segments_info.append({ "url": seg_url, "filename": fname, "note": obj.get("note", "") }) except Exception as e: print(f"Error saving segments: {e}") raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}") response_content = { "status": "success", "result_image_url": str(file_url), "detected_count": len(masks), "scores": scores.tolist() if torch.is_tensor(scores) else scores } if save_segment_images: response_content["segmented_images"] = saved_segments_info return JSONResponse(content=response_content) # ------------------------------------------ # Group 2: Tarot Analysis (塔罗牌分析) # ------------------------------------------ @app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)]) async def segment_tarot( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), expected_count: int = Form(3) ): """ **塔罗牌分割接口** 1. 检测是否包含指定数量的塔罗牌 (默认为 3) 2. 对检测到的卡片进行透视矫正和裁剪 3. 返回矫正后的图片 URL """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor try: inference_state = processor.set_image(image) # 固定 Prompt 检测塔罗牌 output = processor.set_text_prompt(state=inference_state, prompt="tarot card") masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 核心逻辑:判断数量 detected_count = len(masks) # 创建本次请求的独立文件夹 request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) if detected_count != expected_count: # 保存一张图用于调试/反馈 try: filename = generate_and_save_result(image, inference_state, output_dir=output_dir) file_url = request.url_for("static", path=f"results/{request_id}/{filename}") except: file_url = None return JSONResponse( status_code=400, content={ "status": "failed", "message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。", "detected_count": detected_count, "debug_image_url": str(file_url) if file_url else None } ) # 数量正确,执行抠图 try: saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}") # 生成 URL 列表和元数据 tarot_cards = [] for obj in saved_objects: fname = obj["filename"] file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) tarot_cards.append({ "url": file_url, "is_rotated": obj["is_rotated_by_algorithm"], "orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait", "note": obj["note"] }) # 生成整体效果图 try: main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}")) except: main_file_url = None return JSONResponse(content={ "status": "success", "message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)", "tarot_cards": tarot_cards, "full_visualization": main_file_url, "scores": scores.tolist() if torch.is_tensor(scores) else scores }) @app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)]) async def recognize_tarot( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), expected_count: int = Form(3) ): """ **塔罗牌全流程接口: 分割 + 矫正 + 识别** 1. 检测是否包含指定数量的塔罗牌 (SAM3) 2. 分割并透视矫正 3. 调用 Qwen-VL 识别每张牌的名称和正逆位 """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor try: inference_state = processor.set_image(image) output = processor.set_text_prompt(state=inference_state, prompt="tarot card") masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") detected_count = len(masks) request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) # 保存整体效果图 try: main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) main_file_path = os.path.join(output_dir, main_filename) main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}")) except: main_filename = None main_file_path = None main_file_url = None # Step 0: 牌阵识别 spread_info = {"spread_name": "Unknown"} if main_file_path: # 使用原始图的一份拷贝给 Qwen 识别牌阵 temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg") image.save(temp_raw_path) spread_info = recognize_spread_with_qwen(temp_raw_path) if detected_count != expected_count: return JSONResponse( status_code=400, content={ "status": "failed", "message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。", "detected_count": detected_count, "spread_info": spread_info, "debug_image_url": str(main_file_url) if main_file_url else None } ) # 数量正确,执行抠图 + 矫正 try: saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}") # 遍历每张卡片进行识别 tarot_cards = [] for obj in saved_objects: fname = obj["filename"] file_path = os.path.join(output_dir, fname) # 调用 Qwen-VL 识别 (串行) recognition_res = recognize_card_with_qwen(file_path) file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) tarot_cards.append({ "url": file_url, "is_rotated": obj["is_rotated_by_algorithm"], "orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait", "recognition": recognition_res, "note": obj["note"] }) return JSONResponse(content={ "status": "success", "message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)", "spread_info": spread_info, "tarot_cards": tarot_cards, "full_visualization": main_file_url, "scores": scores.tolist() if torch.is_tensor(scores) else scores }) # ------------------------------------------ # Group 3: Face Analysis (人脸分析) # ------------------------------------------ @app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)]) async def segment_face( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), prompt: str = Form("face and hair", description="Prompt for face detection") ): """ **人脸/头部检测与属性分析接口** 1. 调用 SAM3 分割出头部区域 (含头发) 2. 裁剪并保存 3. 调用 Qwen-VL 识别性别和年龄 """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") # Prompt 翻译/优化 final_prompt = prompt if not is_english(prompt): try: translated = translate_to_sam3_prompt(prompt) if translated: final_prompt = translated except Exception as e: print(f"Prompt翻译失败,使用原始Prompt: {e}") print(f"Face Segment 最终使用的 Prompt: {final_prompt}") # 加载图片 try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor # 调用独立服务进行处理 try: result = human_analysis_service.process_face_segmentation_and_analysis( processor=processor, image=image, prompt=final_prompt, output_base_dir=RESULT_IMAGE_DIR ) except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") # 补全 URL if result["status"] == "success": if result.get("full_visualization"): full_vis_rel_path = result["full_visualization"] result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path)) for item in result["results"]: relative_path = item.pop("relative_path") item["url"] = str(request.url_for("static", path=relative_path)) return JSONResponse(content=result) # ========================================== # 8. Main Entry Point (启动入口) # ========================================== if __name__ == "__main__": import uvicorn # ========================================== # 自动清理图片配置 (Auto Cleanup Config) # ========================================== # 设置是否开启自动清理 (True/False) os.environ["AUTO_CLEANUP_ENABLED"] = "True" # 设置图片生命周期 (秒),超过此时间的图片将被删除 # 例如: 3600 = 1小时, 86400 = 1天 os.environ["FILE_LIFETIME_SECONDS"] = "3600" # 设置检查间隔 (秒),每隔多久检查一次 os.environ["CLEANUP_INTERVAL_SECONDS"] = "600" # ========================================== # 启动服务器 uvicorn.run( "fastAPI_tarot:app", host="127.0.0.1", port=55600, proxy_headers=True, forwarded_allow_ips="*", reload=False )