zxd
This commit is contained in:
394
fastAPI_tarot.py
394
fastAPI_tarot.py
@@ -1,68 +1,112 @@
|
||||
"""
|
||||
SAM3 Segmentation API Service (SAM3 图像分割 API 服务)
|
||||
--------------------------------------------------
|
||||
本服务提供基于 SAM3 模型的图像分割能力。
|
||||
包含通用分割、塔罗牌分析和人脸分析等专用接口。
|
||||
|
||||
This service provides image segmentation capabilities using the SAM3 model.
|
||||
It includes specialized endpoints for Tarot card analysis and Face analysis.
|
||||
"""
|
||||
|
||||
# ==========================================
|
||||
# 1. Imports (导入)
|
||||
# ==========================================
|
||||
|
||||
# Standard Library Imports (标准库)
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import requests
|
||||
import numpy as np
|
||||
import cv2
|
||||
from typing import Optional
|
||||
import json
|
||||
import traceback
|
||||
import re
|
||||
from typing import Optional, List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
import dashscope
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
# Third-Party Imports (第三方库)
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import requests
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
|
||||
# FastAPI Imports
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse
|
||||
from PIL import Image
|
||||
|
||||
# Dashscope (Aliyun Qwen) Imports
|
||||
import dashscope
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
# Local Imports (本地模块)
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.visualization_utils import plot_results
|
||||
import human_analysis_service # 引入新服务
|
||||
import human_analysis_service # 引入新服务: 人脸分析
|
||||
|
||||
# ------------------- 配置与路径 -------------------
|
||||
# ==========================================
|
||||
# 2. Configuration & Constants (配置与常量)
|
||||
# ==========================================
|
||||
|
||||
# 路径配置
|
||||
STATIC_DIR = "static"
|
||||
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||||
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||||
|
||||
# ------------------- API Key 核心配置 (已加固) -------------------
|
||||
# API Key 配置
|
||||
VALID_API_KEY = "123quant-speed"
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
# Dashscope 配置 (Qwen-VL)
|
||||
# Dashscope (Qwen-VL) 配置
|
||||
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
||||
QWEN_MODEL = 'qwen-vl-max'
|
||||
|
||||
# 定义 Header 认证
|
||||
# API Tags (用于文档分类)
|
||||
TAG_GENERAL = "General Segmentation (通用分割)"
|
||||
TAG_TAROT = "Tarot Analysis (塔罗牌分析)"
|
||||
TAG_FACE = "Face Analysis (人脸分析)"
|
||||
|
||||
# ==========================================
|
||||
# 3. Security & Middleware (安全与中间件)
|
||||
# ==========================================
|
||||
|
||||
# 定义 Header 认证 Scheme
|
||||
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||||
|
||||
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
||||
"""
|
||||
强制验证 API Key
|
||||
强制验证 API Key (Enforce API Key Verification)
|
||||
|
||||
1. 检查 Header 中是否存在 API Key
|
||||
2. 验证 API Key 是否匹配
|
||||
"""
|
||||
# 1. 检查是否有 Key
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing API Key. Please provide it in the header."
|
||||
)
|
||||
# 2. 检查 Key 是否正确
|
||||
if api_key != VALID_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API Key."
|
||||
)
|
||||
# 3. 验证通过
|
||||
return True
|
||||
|
||||
# ------------------- 生命周期管理 -------------------
|
||||
# ==========================================
|
||||
# 4. Lifespan Management (生命周期管理)
|
||||
# ==========================================
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
FastAPI 生命周期管理器
|
||||
- 启动时: 加载模型到 GPU/CPU
|
||||
- 关闭时: 清理资源
|
||||
"""
|
||||
print("="*40)
|
||||
print("✅ API Key 保护已激活")
|
||||
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
||||
@@ -73,12 +117,15 @@ async def lifespan(app: FastAPI):
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||||
|
||||
# 加载模型
|
||||
model = build_sam3_image_model()
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 初始化处理器
|
||||
processor = Sam3Processor(model)
|
||||
|
||||
# 存储到 app.state 以便全局访问
|
||||
app.state.model = model
|
||||
app.state.processor = processor
|
||||
app.state.device = device
|
||||
@@ -88,52 +135,16 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
print("正在清理资源...")
|
||||
# 这里可以添加释放显存的逻辑,如果需要
|
||||
|
||||
# ------------------- FastAPI 初始化 -------------------
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="SAM3 Segmentation API",
|
||||
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
|
||||
)
|
||||
# ==========================================
|
||||
# 5. Helper Functions (辅助函数)
|
||||
# ==========================================
|
||||
|
||||
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
|
||||
app.openapi_schema = None
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
# 定义安全方案
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"APIKeyHeader": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": API_KEY_HEADER_NAME,
|
||||
}
|
||||
}
|
||||
# 为所有路径应用安全要求
|
||||
for path in openapi_schema["paths"]:
|
||||
for method in openapi_schema["paths"][path]:
|
||||
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
import re
|
||||
|
||||
# ------------------- 辅助函数 -------------------
|
||||
def is_english(text: str) -> bool:
|
||||
"""
|
||||
简单判断是否为英文 prompt。
|
||||
如果有中文字符则认为不是英文。
|
||||
判断文本是否为纯英文
|
||||
- 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False
|
||||
"""
|
||||
for char in text:
|
||||
if '\u4e00' <= char <= '\u9fff':
|
||||
@@ -142,7 +153,8 @@ def is_english(text: str) -> bool:
|
||||
|
||||
def translate_to_sam3_prompt(text: str) -> str:
|
||||
"""
|
||||
使用大模型将非英文提示词翻译为适合 SAM3 的英文提示词
|
||||
使用 Qwen 模型将中文提示词翻译为英文
|
||||
- SAM3 模型对英文 Prompt 支持更好
|
||||
"""
|
||||
print(f"正在翻译提示词: {text}")
|
||||
try:
|
||||
@@ -164,7 +176,7 @@ def translate_to_sam3_prompt(text: str) -> str:
|
||||
return translated_text
|
||||
else:
|
||||
print(f"翻译失败: {response.code} - {response.message}")
|
||||
return text # Fallback to original
|
||||
return text # 失败则回退到原始文本
|
||||
except Exception as e:
|
||||
print(f"翻译异常: {e}")
|
||||
return text
|
||||
@@ -172,6 +184,7 @@ def translate_to_sam3_prompt(text: str) -> str:
|
||||
def order_points(pts):
|
||||
"""
|
||||
对四个坐标点进行排序:左上,右上,右下,左下
|
||||
用于透视变换前的点位整理
|
||||
"""
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
@@ -184,7 +197,8 @@ def order_points(pts):
|
||||
|
||||
def four_point_transform(image, pts):
|
||||
"""
|
||||
根据四个点进行透视变换
|
||||
根据四个点进行透视变换 (Perspective Transform)
|
||||
用于将倾斜的卡片矫正为矩形
|
||||
"""
|
||||
rect = order_points(pts)
|
||||
(tl, tr, br, bl) = rect
|
||||
@@ -210,6 +224,9 @@ def four_point_transform(image, pts):
|
||||
return warped
|
||||
|
||||
def load_image_from_url(url: str) -> Image.Image:
|
||||
"""
|
||||
从 URL 下载图片并转换为 PIL Image 对象
|
||||
"""
|
||||
try:
|
||||
headers = {'User-Agent': 'Mozilla/5.0'}
|
||||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||||
@@ -222,7 +239,17 @@ def load_image_from_url(url: str) -> Image.Image:
|
||||
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]:
|
||||
"""
|
||||
根据 mask 和 box 进行处理并保存独立的对象图片
|
||||
:param cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片)
|
||||
|
||||
参数:
|
||||
- image: 原始图片
|
||||
- masks: 分割掩码列表
|
||||
- boxes: 边界框列表
|
||||
- output_dir: 输出目录
|
||||
- is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑)
|
||||
- cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片)
|
||||
|
||||
返回:
|
||||
- 保存的对象信息列表
|
||||
"""
|
||||
saved_objects = []
|
||||
# Convert image to numpy array (RGB)
|
||||
@@ -256,29 +283,18 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
|
||||
img_rgba = image.copy()
|
||||
|
||||
# 2. 准备 Alpha Mask
|
||||
# mask_uint8 已经是 0-255,可以直接作为 Alpha 通道的基础
|
||||
# 将 Mask 缩放到和图片一样大 (一般是一样的,但以防万一)
|
||||
if mask_uint8.shape != img_rgba.size[::-1]: # size is (w, h), shape is (h, w)
|
||||
# 如果尺寸不一致可能需要 resize,这里假设尺寸一致
|
||||
pass
|
||||
|
||||
mask_img = Image.fromarray(mask_uint8, mode='L')
|
||||
|
||||
# 3. 将 Mask 应用到 Alpha 通道
|
||||
# 创建一个新的空白透明图
|
||||
cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
|
||||
# 将原图粘贴上去,使用 mask 作为 mask
|
||||
cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img)
|
||||
|
||||
# 4. Crop to Box
|
||||
# box is [x1, y1, x2, y2]
|
||||
x1, y1, x2, y2 = map(int, box_np)
|
||||
# 边界检查
|
||||
w, h = cutout_img.size
|
||||
x1 = max(0, x1); y1 = max(0, y1)
|
||||
x2 = min(w, x2); y2 = min(h, y2)
|
||||
|
||||
# 避免无效 crop
|
||||
if x2 > x1 and y2 > y1:
|
||||
final_img = cutout_img.crop((x1, y1, x2, y2))
|
||||
else:
|
||||
@@ -286,7 +302,7 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
|
||||
|
||||
# Save
|
||||
prefix = "cutout"
|
||||
is_rotated = False # 抠图模式下不进行自动旋转
|
||||
is_rotated = False
|
||||
|
||||
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
@@ -299,65 +315,51 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
|
||||
})
|
||||
|
||||
else:
|
||||
# --- 透视矫正模式 (原有逻辑) ---
|
||||
# Find contours
|
||||
# --- 透视矫正模式 (矩形矫正) ---
|
||||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if not contours:
|
||||
continue
|
||||
|
||||
# Get largest contour
|
||||
c = max(contours, key=cv2.contourArea)
|
||||
|
||||
# Approximate contour to polygon
|
||||
peri = cv2.arcLength(c, True)
|
||||
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
|
||||
|
||||
# If we have 4 points, use them. If not, fallback to minAreaRect
|
||||
if len(approx) == 4:
|
||||
pts = approx.reshape(4, 2)
|
||||
else:
|
||||
rect = cv2.minAreaRect(c)
|
||||
pts = cv2.boxPoints(rect)
|
||||
|
||||
# Apply perspective transform
|
||||
# Check if original image has Alpha
|
||||
if img_arr.shape[2] == 4:
|
||||
warped = four_point_transform(img_arr, pts)
|
||||
else:
|
||||
warped = four_point_transform(img_arr, pts)
|
||||
warped = four_point_transform(img_arr, pts)
|
||||
|
||||
# Check orientation (Portrait vs Landscape)
|
||||
h, w = warped.shape[:2]
|
||||
is_rotated = False
|
||||
|
||||
# Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx)
|
||||
# 强制竖屏逻辑 (塔罗牌通常是竖屏)
|
||||
if is_tarot and w > h:
|
||||
# Rotate 90 degrees clockwise
|
||||
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
|
||||
is_rotated = True
|
||||
|
||||
# Convert back to PIL
|
||||
pil_warped = Image.fromarray(warped)
|
||||
|
||||
# Save
|
||||
prefix = "tarot" if is_tarot else "segment"
|
||||
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
pil_warped.save(save_path)
|
||||
|
||||
# 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒)
|
||||
# 这里我们假设长边垂直为正位,如果做了旋转则标记
|
||||
# 真正的正逆位需要OCR或图像识别
|
||||
|
||||
saved_objects.append({
|
||||
"filename": filename,
|
||||
"is_rotated_by_algorithm": is_rotated,
|
||||
"note": "Geometric correction applied. True upright/reversed requires content analysis."
|
||||
"note": "Geometric correction applied."
|
||||
})
|
||||
|
||||
return saved_objects
|
||||
|
||||
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
|
||||
"""
|
||||
生成并保存包含 Mask 叠加效果的完整结果图
|
||||
"""
|
||||
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
plot_results(image, inference_state)
|
||||
@@ -370,21 +372,14 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
||||
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
|
||||
"""
|
||||
try:
|
||||
# 确保路径是绝对路径
|
||||
abs_path = os.path.abspath(image_path)
|
||||
file_url = f"file://{abs_path}"
|
||||
|
||||
# -------------------------------------------------
|
||||
# 优化策略:生成一张旋转180度的对比图
|
||||
# 让 AI 做选择题而不是判断题,大幅提高准确率
|
||||
# -------------------------------------------------
|
||||
try:
|
||||
# 1. 打开原图
|
||||
img = Image.open(abs_path)
|
||||
|
||||
# 2. 生成旋转图 (180度)
|
||||
rotated_img = img.rotate(180)
|
||||
|
||||
# 3. 保存旋转图
|
||||
dir_name = os.path.dirname(abs_path)
|
||||
file_name = os.path.basename(abs_path)
|
||||
@@ -395,8 +390,6 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
||||
rotated_file_url = f"file://{rotated_path}"
|
||||
|
||||
# 4. 构建对比 Prompt
|
||||
# 发送两张图:图1=原图, 图2=旋转图
|
||||
# 询问 AI 哪一张是“正位”
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -422,7 +415,6 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
||||
|
||||
except Exception as e:
|
||||
print(f"对比图生成失败,回退到单图模式: {e}")
|
||||
# 回退到旧的单图模式
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -433,15 +425,12 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
||||
}
|
||||
]
|
||||
|
||||
# 调用模型
|
||||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.output.choices[0].message.content[0]['text']
|
||||
# 尝试解析简单的 JSON
|
||||
import json
|
||||
try:
|
||||
# 清理可能存在的 markdown 标记
|
||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||
result = json.loads(clean_content)
|
||||
return result
|
||||
@@ -458,7 +447,6 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
||||
调用 Qwen-VL 识别塔罗牌牌阵
|
||||
"""
|
||||
try:
|
||||
# 确保路径是绝对路径并加上 file:// 前缀
|
||||
abs_path = os.path.abspath(image_path)
|
||||
file_url = f"file://{abs_path}"
|
||||
|
||||
@@ -476,10 +464,8 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.output.choices[0].message.content[0]['text']
|
||||
# 尝试解析简单的 JSON
|
||||
import json
|
||||
try:
|
||||
# 清理可能存在的 markdown 标记
|
||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||
result = json.loads(clean_content)
|
||||
return result
|
||||
@@ -491,27 +477,82 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
||||
except Exception as e:
|
||||
return {"error": f"牌阵识别失败: {str(e)}"}
|
||||
|
||||
# ------------------- API 接口 (强制依赖验证) -------------------
|
||||
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
||||
# ==========================================
|
||||
# 6. FastAPI App Setup (应用初始化)
|
||||
# ==========================================
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="SAM3 Segmentation API",
|
||||
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
|
||||
openapi_tags=[
|
||||
{"name": TAG_GENERAL, "description": "General purpose image segmentation"},
|
||||
{"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"},
|
||||
{"name": TAG_FACE, "description": "Face detection and analysis"},
|
||||
]
|
||||
)
|
||||
|
||||
# 手动添加 OpenAPI 安全配置
|
||||
app.openapi_schema = None
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
"APIKeyHeader": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": API_KEY_HEADER_NAME,
|
||||
}
|
||||
}
|
||||
for path in openapi_schema["paths"]:
|
||||
for method in openapi_schema["paths"][path]:
|
||||
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
|
||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||
|
||||
# ==========================================
|
||||
# 7. API Endpoints (接口定义)
|
||||
# ==========================================
|
||||
|
||||
# ------------------------------------------
|
||||
# Group 1: General Segmentation (通用分割)
|
||||
# ------------------------------------------
|
||||
|
||||
@app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)])
|
||||
async def segment(
|
||||
request: Request,
|
||||
prompt: str = Form(...),
|
||||
file: Optional[UploadFile] = File(None),
|
||||
image_url: Optional[str] = Form(None),
|
||||
save_segment_images: bool = Form(False),
|
||||
cutout: bool = Form(False)
|
||||
prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"),
|
||||
file: Optional[UploadFile] = File(None, description="Image file to upload"),
|
||||
image_url: Optional[str] = Form(None, description="URL of the image"),
|
||||
save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"),
|
||||
cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"),
|
||||
confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.")
|
||||
):
|
||||
"""
|
||||
**通用图像分割接口**
|
||||
|
||||
- 支持上传图片或提供图片 URL
|
||||
- 支持自动将中文 Prompt 翻译为英文
|
||||
- 支持手动设置置信度 (Confidence Threshold)
|
||||
"""
|
||||
if not file and not image_url:
|
||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||
|
||||
# ------------------- Prompt 翻译/优化 -------------------
|
||||
# 1. Prompt 处理
|
||||
final_prompt = prompt
|
||||
if not is_english(prompt):
|
||||
# 使用 Qwen 进行翻译
|
||||
try:
|
||||
# 构建消息请求翻译
|
||||
# 注意:translate_to_sam3_prompt 是一个阻塞调用,可能会增加耗时
|
||||
# 但对于分割任务来说是可以接受的
|
||||
translated = translate_to_sam3_prompt(prompt)
|
||||
if translated:
|
||||
final_prompt = translated
|
||||
@@ -520,6 +561,7 @@ async def segment(
|
||||
|
||||
print(f"最终使用的 Prompt: {final_prompt}")
|
||||
|
||||
# 2. 图片加载
|
||||
try:
|
||||
if file:
|
||||
image = Image.open(file.file).convert("RGB")
|
||||
@@ -530,13 +572,34 @@ async def segment(
|
||||
|
||||
processor = request.app.state.processor
|
||||
|
||||
# 3. 模型推理
|
||||
try:
|
||||
# 设置图片
|
||||
inference_state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
|
||||
# 处理 Confidence Threshold
|
||||
# 如果用户提供了 confidence,临时修改 processor 的阈值
|
||||
original_confidence = processor.confidence_threshold
|
||||
if confidence is not None:
|
||||
# 简单校验
|
||||
if not (0.0 <= confidence <= 1.0):
|
||||
raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0")
|
||||
processor.confidence_threshold = confidence
|
||||
print(f"Using manual confidence threshold: {confidence}")
|
||||
|
||||
try:
|
||||
# 执行推理
|
||||
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
finally:
|
||||
# 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯)
|
||||
if confidence is not None:
|
||||
processor.confidence_threshold = original_confidence
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||
|
||||
# 4. 结果可视化与保存
|
||||
try:
|
||||
filename = generate_and_save_result(image, inference_state)
|
||||
except Exception as e:
|
||||
@@ -544,7 +607,7 @@ async def segment(
|
||||
|
||||
file_url = request.url_for("static", path=f"results/{filename}")
|
||||
|
||||
# New logic for saving segments
|
||||
# 5. 保存分割子图 (Optional)
|
||||
saved_segments_info = []
|
||||
if save_segment_images:
|
||||
try:
|
||||
@@ -552,7 +615,6 @@ async def segment(
|
||||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 传递 cutout 参数
|
||||
saved_objects = crop_and_save_objects(
|
||||
image,
|
||||
masks,
|
||||
@@ -586,7 +648,11 @@ async def segment(
|
||||
|
||||
return JSONResponse(content=response_content)
|
||||
|
||||
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
|
||||
# ------------------------------------------
|
||||
# Group 2: Tarot Analysis (塔罗牌分析)
|
||||
# ------------------------------------------
|
||||
|
||||
@app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
|
||||
async def segment_tarot(
|
||||
request: Request,
|
||||
file: Optional[UploadFile] = File(None),
|
||||
@@ -594,9 +660,11 @@ async def segment_tarot(
|
||||
expected_count: int = Form(3)
|
||||
):
|
||||
"""
|
||||
塔罗牌分割专用接口
|
||||
**塔罗牌分割接口**
|
||||
|
||||
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
|
||||
2. 如果是,分别抠出这些牌并返回
|
||||
2. 对检测到的卡片进行透视矫正和裁剪
|
||||
3. 返回矫正后的图片 URL
|
||||
"""
|
||||
if not file and not image_url:
|
||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||
@@ -622,7 +690,7 @@ async def segment_tarot(
|
||||
# 核心逻辑:判断数量
|
||||
detected_count = len(masks)
|
||||
|
||||
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
|
||||
# 创建本次请求的独立文件夹
|
||||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -678,7 +746,7 @@ async def segment_tarot(
|
||||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||
})
|
||||
|
||||
@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)])
|
||||
@app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
|
||||
async def recognize_tarot(
|
||||
request: Request,
|
||||
file: Optional[UploadFile] = File(None),
|
||||
@@ -686,7 +754,8 @@ async def recognize_tarot(
|
||||
expected_count: int = Form(3)
|
||||
):
|
||||
"""
|
||||
塔罗牌全流程接口: 分割 + 矫正 + 识别
|
||||
**塔罗牌全流程接口: 分割 + 矫正 + 识别**
|
||||
|
||||
1. 检测是否包含指定数量的塔罗牌 (SAM3)
|
||||
2. 分割并透视矫正
|
||||
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
|
||||
@@ -706,21 +775,18 @@ async def recognize_tarot(
|
||||
|
||||
try:
|
||||
inference_state = processor.set_image(image)
|
||||
# 固定 Prompt 检测塔罗牌
|
||||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||
|
||||
# 核心逻辑:判断数量
|
||||
detected_count = len(masks)
|
||||
|
||||
# 创建本次请求的独立文件夹
|
||||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 保存整体效果图 (无论是成功还是失败,都先保存一张主图)
|
||||
# 保存整体效果图
|
||||
try:
|
||||
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||||
main_file_path = os.path.join(output_dir, main_filename)
|
||||
@@ -730,20 +796,14 @@ async def recognize_tarot(
|
||||
main_file_path = None
|
||||
main_file_url = None
|
||||
|
||||
# Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断)
|
||||
# Step 0: 牌阵识别
|
||||
spread_info = {"spread_name": "Unknown"}
|
||||
if main_file_path:
|
||||
# 使用带有mask绘制的主图或者原始图?
|
||||
# 使用原始图可能更好,不受mask遮挡干扰,但是main_filename是带mask的。
|
||||
# 我们这里暂时用原始图保存一份临时文件给Qwen看
|
||||
# 使用原始图的一份拷贝给 Qwen 识别牌阵
|
||||
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
||||
image.save(temp_raw_path)
|
||||
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
||||
|
||||
# 如果识别结果明确说是“不是正规牌阵”,是否要继续?
|
||||
# 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’”
|
||||
# 我们将其放在返回结果中,由客户端决定是否展示警告
|
||||
|
||||
if detected_count != expected_count:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
@@ -768,8 +828,7 @@ async def recognize_tarot(
|
||||
fname = obj["filename"]
|
||||
file_path = os.path.join(output_dir, fname)
|
||||
|
||||
# 调用 Qwen-VL 识别
|
||||
# 注意:这里会串行调用,速度可能较慢。
|
||||
# 调用 Qwen-VL 识别 (串行)
|
||||
recognition_res = recognize_card_with_qwen(file_path)
|
||||
|
||||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||||
@@ -790,15 +849,20 @@ async def recognize_tarot(
|
||||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||
})
|
||||
|
||||
@app.post("/segment_face", dependencies=[Depends(verify_api_key)])
|
||||
# ------------------------------------------
|
||||
# Group 3: Face Analysis (人脸分析)
|
||||
# ------------------------------------------
|
||||
|
||||
@app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)])
|
||||
async def segment_face(
|
||||
request: Request,
|
||||
file: Optional[UploadFile] = File(None),
|
||||
image_url: Optional[str] = Form(None),
|
||||
prompt: str = Form("face and hair") # 默认提示词包含头发
|
||||
prompt: str = Form("face and hair", description="Prompt for face detection")
|
||||
):
|
||||
"""
|
||||
人脸/头部检测与属性分析接口 (新功能)
|
||||
**人脸/头部检测与属性分析接口**
|
||||
|
||||
1. 调用 SAM3 分割出头部区域 (含头发)
|
||||
2. 裁剪并保存
|
||||
3. 调用 Qwen-VL 识别性别和年龄
|
||||
@@ -806,7 +870,7 @@ async def segment_face(
|
||||
if not file and not image_url:
|
||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||
|
||||
# ------------------- Prompt 翻译/优化 -------------------
|
||||
# Prompt 翻译/优化
|
||||
final_prompt = prompt
|
||||
if not is_english(prompt):
|
||||
try:
|
||||
@@ -818,7 +882,7 @@ async def segment_face(
|
||||
|
||||
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
|
||||
|
||||
# 1. 加载图片
|
||||
# 加载图片
|
||||
try:
|
||||
if file:
|
||||
image = Image.open(file.file).convert("RGB")
|
||||
@@ -829,10 +893,8 @@ async def segment_face(
|
||||
|
||||
processor = request.app.state.processor
|
||||
|
||||
# 2. 调用独立服务进行处理
|
||||
# 调用独立服务进行处理
|
||||
try:
|
||||
# 传入 processor 和 image
|
||||
# 注意:Result Image Dir 我们直接复用 RESULT_IMAGE_DIR
|
||||
result = human_analysis_service.process_face_segmentation_and_analysis(
|
||||
processor=processor,
|
||||
image=image,
|
||||
@@ -840,34 +902,34 @@ async def segment_face(
|
||||
output_base_dir=RESULT_IMAGE_DIR
|
||||
)
|
||||
except Exception as e:
|
||||
# 打印详细错误堆栈以便调试
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||||
|
||||
# 3. 补全 URL (因为 service 层不知道 request context)
|
||||
# 补全 URL
|
||||
if result["status"] == "success":
|
||||
# 处理全图可视化的 URL
|
||||
if result.get("full_visualization"):
|
||||
full_vis_rel_path = result["full_visualization"]
|
||||
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
|
||||
|
||||
for item in result["results"]:
|
||||
# item["relative_path"] 是相对路径,如 results/xxx/xxx.jpg
|
||||
# 我们需要将其转换为完整 URL
|
||||
relative_path = item.pop("relative_path") # 移除相对路径字段,只返回 URL
|
||||
relative_path = item.pop("relative_path")
|
||||
item["url"] = str(request.url_for("static", path=relative_path))
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
# ==========================================
|
||||
# 8. Main Entry Point (启动入口)
|
||||
# ==========================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数
|
||||
# 启动服务器
|
||||
uvicorn.run(
|
||||
"fastAPI_tarot:app",
|
||||
host="127.0.0.1",
|
||||
port=55600,
|
||||
proxy_headers=True,
|
||||
forwarded_allow_ips="*",
|
||||
reload=False # 生产环境建议关闭 reload,确保代码完全重载
|
||||
reload=False
|
||||
)
|
||||
BIN
static/results/seg_153993616fb74472a9287c68e304e8eb.jpg
Normal file
BIN
static/results/seg_153993616fb74472a9287c68e304e8eb.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 132 KiB |
BIN
static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg
Normal file
BIN
static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
BIN
static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg
Normal file
BIN
static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 101 KiB |
BIN
static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg
Normal file
BIN
static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 82 KiB |
BIN
static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg
Normal file
BIN
static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
Reference in New Issue
Block a user