""" SAM3 Segmentation API Service (SAM3 图像分割 API 服务) -------------------------------------------------- 本服务提供基于 SAM3 模型的图像分割能力。 包含通用分割、塔罗牌分析和人脸分析等专用接口。 This service provides image segmentation capabilities using the SAM3 model. It includes specialized endpoints for Tarot card analysis and Face analysis. """ # ========================================== # 1. Imports (导入) # ========================================== # Standard Library Imports (标准库) import os import uuid import time import json import traceback import re import asyncio import shutil import subprocess from datetime import datetime from typing import Optional, List, Dict, Any from contextlib import asynccontextmanager # Third-Party Imports (第三方库) import cv2 import torch import numpy as np import requests import matplotlib matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错 import matplotlib.pyplot as plt from PIL import Image # FastAPI Imports from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter, Cookie from fastapi.security import APIKeyHeader from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse, HTMLResponse, Response # Dashscope (Aliyun Qwen) Imports import dashscope from dashscope import MultiModalConversation # Local Imports (本地模块) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results import human_analysis_service # 引入新服务: 人脸分析 # ========================================== # 2. Configuration & Constants (配置与常量) # ========================================== # 路径配置 STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) # API Key 配置 VALID_API_KEY = "123quant-speed" API_KEY_HEADER_NAME = "X-API-Key" # Admin Config ADMIN_PASSWORD = "admin_secure_password" # 可以根据需求修改 HISTORY_FILE = "history.json" # Dashscope (Qwen-VL) 配置 dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58' QWEN_MODEL = 'qwen-vl-max' # Default model AVAILABLE_QWEN_MODELS = ["qwen-vl-max", "qwen-vl-plus"] # 清理配置 (Cleanup Config) CLEANUP_CONFIG = { "enabled": os.getenv("AUTO_CLEANUP_ENABLED", "True").lower() == "true", "lifetime": int(os.getenv("FILE_LIFETIME_SECONDS", "3600")), "interval": int(os.getenv("CLEANUP_INTERVAL_SECONDS", "600")) } # 提示词配置 (Prompt Config) PROMPTS = { "translate": "请将以下描述翻译成简洁、精准的英文,用于图像分割模型(SAM)的提示词。直接返回英文,不要包含任何解释或其他文字。\n\n输入: {text}", "tarot_card_dual": """这是一张塔罗牌的两个方向: 图1:原始方向 图2:旋转180度后的方向 请仔细对比两张图片的牌面内容(文字方向、人物站立方向、图案逻辑): 1. 识别这张牌的名字(中文)。 2. 判断哪一张图片展示了正确的“正位”(Upright)状态。 - 如果图1是正位,说明原图就是正位。 - 如果图2是正位,说明原图是逆位。 请以JSON格式返回,包含 'name' 和 'position' 两个字段。 例如:{'name': '愚者', 'position': '正位'} 或 {'name': '倒吊人', 'position': '逆位'}。 不要包含Markdown代码块标记。""", "tarot_card_single": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。", "tarot_spread": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵(例如:圣三角、凯尔特十字、三张牌等)。如果看不出明显的正规牌阵,请返回“不是正规牌阵”。请以JSON格式返回,包含 'spread_name' 和 'description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。", "face_analysis": """请仔细观察这张图片中的人物头部/面部特写: 1. 识别性别 (Gender):男性/女性 2. 预估年龄 (Age):请给出一个合理的年龄范围,例如 "25-30岁" 3. 简要描述:发型、发色、是否有眼镜等显著特征。 请以 JSON 格式返回,包含 'gender', 'age', 'description' 字段。 不要包含 Markdown 标记。""" } # API Tags (用于文档分类) TAG_GENERAL = "General Segmentation (通用分割)" TAG_TAROT = "Tarot Analysis (塔罗牌分析)" TAG_FACE = "Face Analysis (人脸分析)" # ========================================== # 3. Security & Middleware (安全与中间件) # ========================================== # 定义 Header 认证 Scheme api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): """ 强制验证 API Key (Enforce API Key Verification) 1. 检查 Header 中是否存在 API Key 2. 验证 API Key 是否匹配 """ if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API Key. Please provide it in the header." ) if api_key != VALID_API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key." ) return True # ========================================== # 4. Lifespan Management (生命周期管理) # ========================================== async def cleanup_old_files(directory: str): """ 后台任务:定期清理过期的图片文件 - 动态读取 CLEANUP_CONFIG 配置 """ print(f"🧹 自动清理任务已启动 | 目录: {directory}") while True: try: # 动态读取配置 interval = CLEANUP_CONFIG["interval"] lifetime = CLEANUP_CONFIG["lifetime"] enabled = CLEANUP_CONFIG["enabled"] await asyncio.sleep(interval) if not enabled: continue current_time = time.time() count = 0 # 遍历所有文件(包括子目录) for root, dirs, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) # 获取文件修改时间 try: file_mtime = os.path.getmtime(file_path) if current_time - file_mtime > lifetime: os.remove(file_path) count += 1 except OSError: pass # 文件可能已被删除 # 尝试清理空目录 (可选,仅清理二级目录) for root, dirs, files in os.walk(directory, topdown=False): for dir in dirs: dir_path = os.path.join(root, dir) try: if not os.listdir(dir_path): # 如果目录为空 os.rmdir(dir_path) except OSError: pass if count > 0: print(f"🧹 已清理 {count} 个过期文件 (Lifetime: {lifetime}s)") except asyncio.CancelledError: print("🛑 清理任务已停止") break except Exception as e: print(f"⚠️ 清理任务出错: {e}") await asyncio.sleep(60) # 出错后等待一分钟再试 @asynccontextmanager async def lifespan(app: FastAPI): """ FastAPI 生命周期管理器 - 启动时: 加载模型到 GPU/CPU - 关闭时: 清理资源 """ print("="*40) print("✅ API Key 保护已激活") print(f"✅ 有效 Key: {VALID_API_KEY}") print("="*40) print("正在加载 SAM3 模型到 GPU...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") # 加载模型 model = build_sam3_image_model() model = model.to(device) model.eval() # 初始化处理器 processor = Sam3Processor(model) # 存储到 app.state 以便全局访问 app.state.model = model app.state.processor = processor app.state.device = device print(f"模型加载完成,设备: {device}") # --- 启动后台清理任务 --- cleanup_task_handle = None try: cleanup_task_handle = asyncio.create_task( cleanup_old_files(RESULT_IMAGE_DIR) ) except Exception as e: print(f"启动清理任务失败: {e}") # ----------------------- yield print("正在清理资源...") # --- 停止后台任务 --- if cleanup_task_handle: cleanup_task_handle.cancel() try: await cleanup_task_handle except asyncio.CancelledError: pass # ------------------ # 这里可以添加释放显存的逻辑,如果需要 # ========================================== # 5. Helper Functions (辅助函数) # ========================================== def is_english(text: str) -> bool: """ 判断文本是否为纯英文 - 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False """ for char in text: if '\u4e00' <= char <= '\u9fff': return False return True def append_to_history(req_type: str, prompt: str, status: str, result_path: str = None, details: str = "", final_prompt: str = None, duration: float = 0.0): """ 记录请求历史到 history.json """ record = { "timestamp": time.time(), "type": req_type, "prompt": prompt, "final_prompt": final_prompt, "status": status, "result_path": result_path, "details": details, "duration": duration } try: with open(HISTORY_FILE, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") except Exception as e: print(f"Failed to write history: {e}") def translate_to_sam3_prompt(text: str) -> str: """ 使用 Qwen 模型将中文提示词翻译为英文 - SAM3 模型对英文 Prompt 支持更好 """ print(f"正在翻译提示词: {text}") try: prompt_template = PROMPTS["translate"] messages = [ { "role": "user", "content": [ {"text": prompt_template.format(text=text)} ] } ] response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: translated_text = response.output.choices[0].message.content[0]['text'].strip() # 去除可能的 markdown 标记或引号 translated_text = translated_text.replace('"', '').replace("'", "").strip() print(f"翻译结果: {translated_text}") return translated_text else: print(f"翻译失败: {response.code} - {response.message}") return text # 失败则回退到原始文本 except Exception as e: print(f"翻译异常: {e}") return text def order_points(pts): """ 对四个坐标点进行排序:左上,右上,右下,左下 用于透视变换前的点位整理 """ rect = np.zeros((4, 2), dtype="float32") s = pts.sum(axis=1) rect[0] = pts[np.argmin(s)] rect[2] = pts[np.argmax(s)] diff = np.diff(pts, axis=1) rect[1] = pts[np.argmin(diff)] rect[3] = pts[np.argmax(diff)] return rect def four_point_transform(image, pts): """ 根据四个点进行透视变换 (Perspective Transform) 用于将倾斜的卡片矫正为矩形 """ rect = order_points(pts) (tl, tr, br, bl) = rect # 计算新图像的宽度 widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) maxWidth = max(int(widthA), int(widthB)) # 计算新图像的高度 heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) maxHeight = max(int(heightA), int(heightB)) dst = np.array([ [0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], dtype="float32") M = cv2.getPerspectiveTransform(rect, dst) warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) return warped def load_image_from_url(url: str) -> Image.Image: """ 从 URL 下载图片并转换为 PIL Image 对象 """ try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) response.raise_for_status() image = Image.open(response.raw).convert("RGB") return image except Exception as e: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False, perspective_correction: bool = False) -> list[dict]: """ 根据 mask 和 box 进行处理并保存独立的对象图片 参数: - image: 原始图片 - masks: 分割掩码列表 - boxes: 边界框列表 - output_dir: 输出目录 - is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑) - cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片) - perspective_correction: 是否进行梯度透视矫正 返回: - 保存的对象信息列表 """ saved_objects = [] # Convert image to numpy array (RGB) img_arr = np.array(image) for i, (mask, box) in enumerate(zip(masks, boxes)): # Handle tensor/numpy conversions if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy().squeeze() else: mask_np = mask.squeeze() # Handle box conversion if isinstance(box, torch.Tensor): box_np = box.cpu().numpy() else: box_np = np.array(box) # Ensure mask is uint8 binary for OpenCV/Pillow if mask_np.dtype == bool: mask_uint8 = (mask_np * 255).astype(np.uint8) else: mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255 # --- 准备基础图像 --- if cutout: # 1. 准备 RGBA 原图 if image.mode != "RGBA": img_rgba = image.convert("RGBA") else: img_rgba = image.copy() # 2. 准备 Alpha Mask mask_img = Image.fromarray(mask_uint8, mode='L') # 3. 将 Mask 应用到 Alpha 通道 base_img_pil = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0)) base_img_pil.paste(image.convert("RGB"), (0, 0), mask=mask_img) # Convert to numpy for potential warping base_img_arr = np.array(base_img_pil) else: base_img_pil = image.convert("RGB") base_img_arr = img_arr # RGB numpy array # --- 透视矫正 vs 简单裁剪 --- final_img_pil = None is_rotated = False note = "" if perspective_correction: # --- 透视矫正模式 (矩形矫正) --- contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) pts = None if contours: c = max(contours, key=cv2.contourArea) peri = cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, 0.04 * peri, True) if len(approx) == 4: pts = approx.reshape(4, 2) else: rect = cv2.minAreaRect(c) pts = cv2.boxPoints(rect) if pts is not None: warped = four_point_transform(base_img_arr, pts) note = "Geometric correction applied." else: # Fallback to simple crop if no contours found x1, y1, x2, y2 = map(int, box_np) # Ensure bounds h, w = base_img_arr.shape[:2] x1 = max(0, x1); y1 = max(0, y1) x2 = min(w, x2); y2 = min(h, y2) warped = base_img_arr[y1:y2, x1:x2] note = "Correction failed, fallback to crop." # Check orientation (Portrait vs Landscape) - Only for Tarot usually h, w = warped.shape[:2] # 强制竖屏逻辑 (塔罗牌通常是竖屏) if is_tarot and w > h: warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) is_rotated = True final_img_pil = Image.fromarray(warped) else: # --- 简单裁剪模式 (Simple Crop) --- x1, y1, x2, y2 = map(int, box_np) w, h = base_img_pil.size x1 = max(0, x1); y1 = max(0, y1) x2 = min(w, x2); y2 = min(h, y2) if x2 > x1 and y2 > y1: final_img_pil = base_img_pil.crop((x1, y1, x2, y2)) else: final_img_pil = base_img_pil # Fallback note = "Simple crop applied." # --- 保存图片 --- prefix = "cutout" if cutout else ("tarot" if is_tarot else "segment") filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" save_path = os.path.join(output_dir, filename) final_img_pil.save(save_path) saved_objects.append({ "filename": filename, "is_rotated_by_algorithm": is_rotated, "note": note }) return saved_objects def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str: """ 生成并保存包含 Mask 叠加效果的完整结果图 """ filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(output_dir, filename) plot_results(image, inference_state) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() return filename def recognize_card_with_qwen(image_path: str) -> dict: """ 调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略) """ try: abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" try: # 1. 打开原图 img = Image.open(abs_path) # 2. 生成旋转图 (180度) rotated_img = img.rotate(180) # 3. 保存旋转图 dir_name = os.path.dirname(abs_path) file_name = os.path.basename(abs_path) rotated_name = f"rotated_{file_name}" rotated_path = os.path.join(dir_name, rotated_name) rotated_img.save(rotated_path) rotated_file_url = f"file://{rotated_path}" # 4. 构建对比 Prompt messages = [ { "role": "user", "content": [ {"image": file_url}, # 图1 (原图) {"image": rotated_file_url}, # 图2 (旋转180度) {"text": PROMPTS["tarot_card_dual"]} ] } ] except Exception as e: print(f"对比图生成失败,回退到单图模式: {e}") messages = [ { "role": "user", "content": [ {"image": file_url}, {"text": PROMPTS["tarot_card_single"]} ] } ] response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: content = response.output.choices[0].message.content[0]['text'] import json try: clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result except: return {"raw_response": content} else: return {"error": f"API Error: {response.code} - {response.message}"} except Exception as e: return {"error": f"识别失败: {str(e)}"} def recognize_spread_with_qwen(image_path: str) -> dict: """ 调用 Qwen-VL 识别塔罗牌牌阵 """ try: abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" messages = [ { "role": "user", "content": [ {"image": file_url}, {"text": PROMPTS["tarot_spread"]} ] } ] response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: content = response.output.choices[0].message.content[0]['text'] import json try: clean_content = content.replace("```json", "").replace("```", "").strip() result = json.loads(clean_content) return result except: return {"raw_response": content, "spread_name": "Unknown"} else: return {"error": f"API Error: {response.code} - {response.message}"} except Exception as e: return {"error": f"牌阵识别失败: {str(e)}"} # ========================================== # 6. FastAPI App Setup (应用初始化) # ========================================== app = FastAPI( lifespan=lifespan, title="SAM3 Segmentation API", description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`", openapi_tags=[ {"name": TAG_GENERAL, "description": "General purpose image segmentation"}, {"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"}, {"name": TAG_FACE, "description": "Face detection and analysis"}, ] ) # 手动添加 OpenAPI 安全配置 app.openapi_schema = None def custom_openapi(): if app.openapi_schema: return app.openapi_schema from fastapi.openapi.utils import get_openapi openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, ) openapi_schema["components"]["securitySchemes"] = { "APIKeyHeader": { "type": "apiKey", "in": "header", "name": API_KEY_HEADER_NAME, } } for path in openapi_schema["paths"]: for method in openapi_schema["paths"][path]: openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # ========================================== # 7. API Endpoints (接口定义) # ========================================== # ------------------------------------------ # Group 1: General Segmentation (通用分割) # ------------------------------------------ @app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)]) async def segment( request: Request, prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"), file: Optional[UploadFile] = File(None, description="Image file to upload"), image_url: Optional[str] = Form(None, description="URL of the image"), save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"), cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"), perspective_correction: bool = Form(False, description="If True, applies perspective correction (warping) to the segmented object."), highlight: bool = Form(False, description="If True, darkens the background to highlight the subject (周边变黑放大)."), confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.") ): """ **通用图像分割接口** - 支持上传图片或提供图片 URL - 支持自动将中文 Prompt 翻译为英文 - 支持周边变黑放大效果 (Highlight Mode) - 支持手动设置置信度 (Confidence Threshold) - 支持透视矫正 (Perspective Correction) """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") start_time = time.time() # 1. Prompt 处理 final_prompt = prompt if not is_english(prompt): try: translated = translate_to_sam3_prompt(prompt) if translated: final_prompt = translated except Exception as e: print(f"Prompt翻译失败,使用原始Prompt: {e}") print(f"最终使用的 Prompt: {final_prompt}") # 2. 图片加载 try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: duration = time.time() - start_time append_to_history("general", prompt, "failed", details=f"Image Load Error: {str(e)}", final_prompt=final_prompt, duration=duration) raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor # 3. 模型推理 try: # 设置图片 inference_state = processor.set_image(image) # 处理 Confidence Threshold # 如果用户提供了 confidence,临时修改 processor 的阈值 original_confidence = processor.confidence_threshold if confidence is not None: # 简单校验 if not (0.0 <= confidence <= 1.0): raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0") processor.confidence_threshold = confidence print(f"Using manual confidence threshold: {confidence}") try: # 执行推理 output = processor.set_text_prompt(state=inference_state, prompt=final_prompt) masks, boxes, scores = output["masks"], output["boxes"], output["scores"] finally: # 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯) if confidence is not None: processor.confidence_threshold = original_confidence except Exception as e: duration = time.time() - start_time append_to_history("general", prompt, "failed", details=f"Inference Error: {str(e)}", final_prompt=final_prompt, duration=duration) raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 4. 结果可视化与保存 try: if highlight: filename = f"seg_highlight_{uuid.uuid4().hex}.jpg" save_path = os.path.join(RESULT_IMAGE_DIR, filename) # 使用 human_analysis_service 中的可视化函数 (周边变黑) human_analysis_service.create_highlighted_visualization(image, masks, save_path) else: filename = generate_and_save_result(image, inference_state) except Exception as e: duration = time.time() - start_time append_to_history("general", prompt, "failed", details=f"Save Error: {str(e)}", final_prompt=final_prompt, duration=duration) raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") file_url = request.url_for("static", path=f"results/{filename}") # 5. 保存分割子图 (Optional) saved_segments_info = [] if save_segment_images: try: request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) saved_objects = crop_and_save_objects( image, masks, boxes, output_dir=output_dir, is_tarot=False, cutout=cutout, perspective_correction=perspective_correction ) for obj in saved_objects: fname = obj["filename"] seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) saved_segments_info.append({ "url": seg_url, "filename": fname, "note": obj.get("note", "") }) except Exception as e: print(f"Error saving segments: {e}") # Don't fail the whole request just for this part, but log it? Or fail? Usually fail. duration = time.time() - start_time append_to_history("general", prompt, "partial_success", result_path=f"results/{filename}", details="Segments save failed", final_prompt=final_prompt, duration=duration) raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}") response_content = { "status": "success", "result_image_url": str(file_url), "detected_count": len(masks), "scores": scores.tolist() if torch.is_tensor(scores) else scores } if save_segment_images: response_content["segmented_images"] = saved_segments_info duration = time.time() - start_time append_to_history("general", prompt, "success", result_path=f"results/{filename}", details=f"Detected: {len(masks)}", final_prompt=final_prompt, duration=duration) return JSONResponse(content=response_content) # ------------------------------------------ # Group 2: Tarot Analysis (塔罗牌分析) # ------------------------------------------ @app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)]) async def segment_tarot( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), expected_count: int = Form(3) ): """ **塔罗牌分割接口** 1. 检测是否包含指定数量的塔罗牌 (默认为 3) 2. 对检测到的卡片进行透视矫正和裁剪 3. 返回矫正后的图片 URL """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") start_time = time.time() try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: duration = time.time() - start_time append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Image Load Error: {str(e)}", duration=duration) raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor try: inference_state = processor.set_image(image) # 固定 Prompt 检测塔罗牌 output = processor.set_text_prompt(state=inference_state, prompt="tarot card") masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: duration = time.time() - start_time append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Inference Error: {str(e)}", duration=duration) raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 核心逻辑:判断数量 detected_count = len(masks) # 创建本次请求的独立文件夹 request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) if detected_count != expected_count: # 保存一张图用于调试/反馈 try: filename = generate_and_save_result(image, inference_state, output_dir=output_dir) file_url = request.url_for("static", path=f"results/{request_id}/{filename}") except: file_url = None duration = time.time() - start_time append_to_history("tarot", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{filename}" if file_url else None, details=f"Detected {detected_count} cards, expected {expected_count}", duration=duration) return JSONResponse( status_code=400, content={ "status": "failed", "message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。", "detected_count": detected_count, "debug_image_url": str(file_url) if file_url else None } ) # 数量正确,执行抠图 try: saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir, is_tarot=True, perspective_correction=True) except Exception as e: duration = time.time() - start_time append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}", duration=duration) raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}") # 生成 URL 列表和元数据 tarot_cards = [] for obj in saved_objects: fname = obj["filename"] file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) tarot_cards.append({ "url": file_url, "is_rotated": obj["is_rotated_by_algorithm"], "orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait", "note": obj["note"] }) # 生成整体效果图 try: main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}")) except: main_file_url = None duration = time.time() - start_time append_to_history("tarot", f"expected: {expected_count}", "success", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Successfully segmented {expected_count} cards", duration=duration) return JSONResponse(content={ "status": "success", "message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)", "tarot_cards": tarot_cards, "full_visualization": main_file_url, "scores": scores.tolist() if torch.is_tensor(scores) else scores }) @app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)]) async def recognize_tarot( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), expected_count: int = Form(3) ): """ **塔罗牌全流程接口: 分割 + 矫正 + 识别** 1. 检测是否包含指定数量的塔罗牌 (SAM3) 2. 分割并透视矫正 3. 调用 Qwen-VL 识别每张牌的名称和正逆位 """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") start_time = time.time() try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: duration = time.time() - start_time append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Image Load Error: {str(e)}", duration=duration) raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor try: inference_state = processor.set_image(image) output = processor.set_text_prompt(state=inference_state, prompt="tarot card") masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: duration = time.time() - start_time append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Inference Error: {str(e)}", duration=duration) raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") detected_count = len(masks) request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) # 保存整体效果图 try: main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) main_file_path = os.path.join(output_dir, main_filename) main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}")) except: main_filename = None main_file_path = None main_file_url = None # Step 0: 牌阵识别 spread_info = {"spread_name": "Unknown"} if main_file_path: # 使用原始图的一份拷贝给 Qwen 识别牌阵 temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg") image.save(temp_raw_path) spread_info = recognize_spread_with_qwen(temp_raw_path) if detected_count != expected_count: duration = time.time() - start_time append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Detected {detected_count}, expected {expected_count}", duration=duration) return JSONResponse( status_code=400, content={ "status": "failed", "message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。", "detected_count": detected_count, "spread_info": spread_info, "debug_image_url": str(main_file_url) if main_file_url else None } ) # 数量正确,执行抠图 + 矫正 try: saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir, is_tarot=True, perspective_correction=True) except Exception as e: duration = time.time() - start_time append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}", duration=duration) raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}") # 遍历每张卡片进行识别 tarot_cards = [] for obj in saved_objects: fname = obj["filename"] file_path = os.path.join(output_dir, fname) # 调用 Qwen-VL 识别 (串行) recognition_res = recognize_card_with_qwen(file_path) file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) tarot_cards.append({ "url": file_url, "is_rotated": obj["is_rotated_by_algorithm"], "orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait", "recognition": recognition_res, "note": obj["note"] }) duration = time.time() - start_time append_to_history("tarot-recognize", f"expected: {expected_count}", "success", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Spread: {spread_info.get('spread_name', 'Unknown')}", duration=duration) return JSONResponse(content={ "status": "success", "message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)", "spread_info": spread_info, "tarot_cards": tarot_cards, "full_visualization": main_file_url, "scores": scores.tolist() if torch.is_tensor(scores) else scores }) # ------------------------------------------ # Group 3: Face Analysis (人脸分析) # ------------------------------------------ @app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)]) async def segment_face( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), prompt: str = Form("face and hair", description="Prompt for face detection") ): """ **人脸/头部检测与属性分析接口** 1. 调用 SAM3 分割出头部区域 (含头发) 2. 裁剪并保存 3. 调用 Qwen-VL 识别性别和年龄 """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") start_time = time.time() # Prompt 翻译/优化 final_prompt = prompt if not is_english(prompt): try: translated = translate_to_sam3_prompt(prompt) if translated: final_prompt = translated except Exception as e: print(f"Prompt翻译失败,使用原始Prompt: {e}") print(f"Face Segment 最终使用的 Prompt: {final_prompt}") # 加载图片 try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: duration = time.time() - start_time append_to_history("face", prompt, "failed", details=f"Image Load Error: {str(e)}", final_prompt=final_prompt, duration=duration) raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor # 调用独立服务进行处理 try: result = human_analysis_service.process_face_segmentation_and_analysis( processor=processor, image=image, prompt=final_prompt, output_base_dir=RESULT_IMAGE_DIR, qwen_model=QWEN_MODEL, analysis_prompt=PROMPTS["face_analysis"] ) except Exception as e: import traceback traceback.print_exc() duration = time.time() - start_time append_to_history("face", prompt, "failed", details=f"Process Error: {str(e)}", final_prompt=final_prompt, duration=duration) raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") # 补全 URL if result["status"] == "success": if result.get("full_visualization"): full_vis_rel_path = result["full_visualization"] result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path)) for item in result["results"]: relative_path = item.pop("relative_path") item["url"] = str(request.url_for("static", path=relative_path)) duration = time.time() - start_time append_to_history("face", prompt, result["status"], details=f"Results: {len(result.get('results', []))}", final_prompt=final_prompt, duration=duration) return JSONResponse(content=result) # ========================================== # 9. Admin Management APIs (管理后台接口) # ========================================== @app.get("/admin", include_in_schema=False) async def admin_page(): """ Serve the admin HTML page """ # 检查 static/admin.html 是否存在,不存在则返回错误或简易页面 admin_html_path = os.path.join(STATIC_DIR, "admin.html") if os.path.exists(admin_html_path): with open(admin_html_path, "r", encoding="utf-8") as f: content = f.read() return HTMLResponse(content=content) else: return HTMLResponse(content="