zxd
This commit is contained in:
398
fastAPI_tarot.py
398
fastAPI_tarot.py
@@ -1,68 +1,112 @@
|
|||||||
|
"""
|
||||||
|
SAM3 Segmentation API Service (SAM3 图像分割 API 服务)
|
||||||
|
--------------------------------------------------
|
||||||
|
本服务提供基于 SAM3 模型的图像分割能力。
|
||||||
|
包含通用分割、塔罗牌分析和人脸分析等专用接口。
|
||||||
|
|
||||||
|
This service provides image segmentation capabilities using the SAM3 model.
|
||||||
|
It includes specialized endpoints for Tarot card analysis and Face analysis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 1. Imports (导入)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# Standard Library Imports (标准库)
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
import requests
|
import json
|
||||||
import numpy as np
|
import traceback
|
||||||
import cv2
|
import re
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict, Any
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import dashscope
|
|
||||||
from dashscope import MultiModalConversation
|
|
||||||
|
|
||||||
|
# Third-Party Imports (第三方库)
|
||||||
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
|
# FastAPI Imports
|
||||||
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
# Dashscope (Aliyun Qwen) Imports
|
||||||
|
import dashscope
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
|
# Local Imports (本地模块)
|
||||||
from sam3.model_builder import build_sam3_image_model
|
from sam3.model_builder import build_sam3_image_model
|
||||||
from sam3.model.sam3_image_processor import Sam3Processor
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
from sam3.visualization_utils import plot_results
|
from sam3.visualization_utils import plot_results
|
||||||
import human_analysis_service # 引入新服务
|
import human_analysis_service # 引入新服务: 人脸分析
|
||||||
|
|
||||||
# ------------------- 配置与路径 -------------------
|
# ==========================================
|
||||||
|
# 2. Configuration & Constants (配置与常量)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# 路径配置
|
||||||
STATIC_DIR = "static"
|
STATIC_DIR = "static"
|
||||||
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||||||
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
# ------------------- API Key 核心配置 (已加固) -------------------
|
# API Key 配置
|
||||||
VALID_API_KEY = "123quant-speed"
|
VALID_API_KEY = "123quant-speed"
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
# Dashscope 配置 (Qwen-VL)
|
# Dashscope (Qwen-VL) 配置
|
||||||
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
||||||
QWEN_MODEL = 'qwen-vl-max'
|
QWEN_MODEL = 'qwen-vl-max'
|
||||||
|
|
||||||
# 定义 Header 认证
|
# API Tags (用于文档分类)
|
||||||
|
TAG_GENERAL = "General Segmentation (通用分割)"
|
||||||
|
TAG_TAROT = "Tarot Analysis (塔罗牌分析)"
|
||||||
|
TAG_FACE = "Face Analysis (人脸分析)"
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 3. Security & Middleware (安全与中间件)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# 定义 Header 认证 Scheme
|
||||||
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
||||||
"""
|
"""
|
||||||
强制验证 API Key
|
强制验证 API Key (Enforce API Key Verification)
|
||||||
|
|
||||||
|
1. 检查 Header 中是否存在 API Key
|
||||||
|
2. 验证 API Key 是否匹配
|
||||||
"""
|
"""
|
||||||
# 1. 检查是否有 Key
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Missing API Key. Please provide it in the header."
|
detail="Missing API Key. Please provide it in the header."
|
||||||
)
|
)
|
||||||
# 2. 检查 Key 是否正确
|
|
||||||
if api_key != VALID_API_KEY:
|
if api_key != VALID_API_KEY:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Invalid API Key."
|
detail="Invalid API Key."
|
||||||
)
|
)
|
||||||
# 3. 验证通过
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# ------------------- 生命周期管理 -------------------
|
# ==========================================
|
||||||
|
# 4. Lifespan Management (生命周期管理)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
FastAPI 生命周期管理器
|
||||||
|
- 启动时: 加载模型到 GPU/CPU
|
||||||
|
- 关闭时: 清理资源
|
||||||
|
"""
|
||||||
print("="*40)
|
print("="*40)
|
||||||
print("✅ API Key 保护已激活")
|
print("✅ API Key 保护已激活")
|
||||||
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
||||||
@@ -73,12 +117,15 @@ async def lifespan(app: FastAPI):
|
|||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
model = build_sam3_image_model()
|
model = build_sam3_image_model()
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# 初始化处理器
|
||||||
processor = Sam3Processor(model)
|
processor = Sam3Processor(model)
|
||||||
|
|
||||||
|
# 存储到 app.state 以便全局访问
|
||||||
app.state.model = model
|
app.state.model = model
|
||||||
app.state.processor = processor
|
app.state.processor = processor
|
||||||
app.state.device = device
|
app.state.device = device
|
||||||
@@ -88,52 +135,16 @@ async def lifespan(app: FastAPI):
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
print("正在清理资源...")
|
print("正在清理资源...")
|
||||||
|
# 这里可以添加释放显存的逻辑,如果需要
|
||||||
|
|
||||||
# ------------------- FastAPI 初始化 -------------------
|
# ==========================================
|
||||||
app = FastAPI(
|
# 5. Helper Functions (辅助函数)
|
||||||
lifespan=lifespan,
|
# ==========================================
|
||||||
title="SAM3 Segmentation API",
|
|
||||||
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
|
|
||||||
app.openapi_schema = None
|
|
||||||
def custom_openapi():
|
|
||||||
if app.openapi_schema:
|
|
||||||
return app.openapi_schema
|
|
||||||
from fastapi.openapi.utils import get_openapi
|
|
||||||
openapi_schema = get_openapi(
|
|
||||||
title=app.title,
|
|
||||||
version=app.version,
|
|
||||||
description=app.description,
|
|
||||||
routes=app.routes,
|
|
||||||
)
|
|
||||||
# 定义安全方案
|
|
||||||
openapi_schema["components"]["securitySchemes"] = {
|
|
||||||
"APIKeyHeader": {
|
|
||||||
"type": "apiKey",
|
|
||||||
"in": "header",
|
|
||||||
"name": API_KEY_HEADER_NAME,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
# 为所有路径应用安全要求
|
|
||||||
for path in openapi_schema["paths"]:
|
|
||||||
for method in openapi_schema["paths"][path]:
|
|
||||||
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
|
||||||
app.openapi_schema = openapi_schema
|
|
||||||
return app.openapi_schema
|
|
||||||
|
|
||||||
app.openapi = custom_openapi
|
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
# ------------------- 辅助函数 -------------------
|
|
||||||
def is_english(text: str) -> bool:
|
def is_english(text: str) -> bool:
|
||||||
"""
|
"""
|
||||||
简单判断是否为英文 prompt。
|
判断文本是否为纯英文
|
||||||
如果有中文字符则认为不是英文。
|
- 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False
|
||||||
"""
|
"""
|
||||||
for char in text:
|
for char in text:
|
||||||
if '\u4e00' <= char <= '\u9fff':
|
if '\u4e00' <= char <= '\u9fff':
|
||||||
@@ -142,7 +153,8 @@ def is_english(text: str) -> bool:
|
|||||||
|
|
||||||
def translate_to_sam3_prompt(text: str) -> str:
|
def translate_to_sam3_prompt(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
使用大模型将非英文提示词翻译为适合 SAM3 的英文提示词
|
使用 Qwen 模型将中文提示词翻译为英文
|
||||||
|
- SAM3 模型对英文 Prompt 支持更好
|
||||||
"""
|
"""
|
||||||
print(f"正在翻译提示词: {text}")
|
print(f"正在翻译提示词: {text}")
|
||||||
try:
|
try:
|
||||||
@@ -164,7 +176,7 @@ def translate_to_sam3_prompt(text: str) -> str:
|
|||||||
return translated_text
|
return translated_text
|
||||||
else:
|
else:
|
||||||
print(f"翻译失败: {response.code} - {response.message}")
|
print(f"翻译失败: {response.code} - {response.message}")
|
||||||
return text # Fallback to original
|
return text # 失败则回退到原始文本
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"翻译异常: {e}")
|
print(f"翻译异常: {e}")
|
||||||
return text
|
return text
|
||||||
@@ -172,6 +184,7 @@ def translate_to_sam3_prompt(text: str) -> str:
|
|||||||
def order_points(pts):
|
def order_points(pts):
|
||||||
"""
|
"""
|
||||||
对四个坐标点进行排序:左上,右上,右下,左下
|
对四个坐标点进行排序:左上,右上,右下,左下
|
||||||
|
用于透视变换前的点位整理
|
||||||
"""
|
"""
|
||||||
rect = np.zeros((4, 2), dtype="float32")
|
rect = np.zeros((4, 2), dtype="float32")
|
||||||
s = pts.sum(axis=1)
|
s = pts.sum(axis=1)
|
||||||
@@ -184,7 +197,8 @@ def order_points(pts):
|
|||||||
|
|
||||||
def four_point_transform(image, pts):
|
def four_point_transform(image, pts):
|
||||||
"""
|
"""
|
||||||
根据四个点进行透视变换
|
根据四个点进行透视变换 (Perspective Transform)
|
||||||
|
用于将倾斜的卡片矫正为矩形
|
||||||
"""
|
"""
|
||||||
rect = order_points(pts)
|
rect = order_points(pts)
|
||||||
(tl, tr, br, bl) = rect
|
(tl, tr, br, bl) = rect
|
||||||
@@ -210,6 +224,9 @@ def four_point_transform(image, pts):
|
|||||||
return warped
|
return warped
|
||||||
|
|
||||||
def load_image_from_url(url: str) -> Image.Image:
|
def load_image_from_url(url: str) -> Image.Image:
|
||||||
|
"""
|
||||||
|
从 URL 下载图片并转换为 PIL Image 对象
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
headers = {'User-Agent': 'Mozilla/5.0'}
|
headers = {'User-Agent': 'Mozilla/5.0'}
|
||||||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||||||
@@ -222,7 +239,17 @@ def load_image_from_url(url: str) -> Image.Image:
|
|||||||
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]:
|
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
根据 mask 和 box 进行处理并保存独立的对象图片
|
根据 mask 和 box 进行处理并保存独立的对象图片
|
||||||
:param cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片)
|
|
||||||
|
参数:
|
||||||
|
- image: 原始图片
|
||||||
|
- masks: 分割掩码列表
|
||||||
|
- boxes: 边界框列表
|
||||||
|
- output_dir: 输出目录
|
||||||
|
- is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑)
|
||||||
|
- cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 保存的对象信息列表
|
||||||
"""
|
"""
|
||||||
saved_objects = []
|
saved_objects = []
|
||||||
# Convert image to numpy array (RGB)
|
# Convert image to numpy array (RGB)
|
||||||
@@ -256,29 +283,18 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
|
|||||||
img_rgba = image.copy()
|
img_rgba = image.copy()
|
||||||
|
|
||||||
# 2. 准备 Alpha Mask
|
# 2. 准备 Alpha Mask
|
||||||
# mask_uint8 已经是 0-255,可以直接作为 Alpha 通道的基础
|
|
||||||
# 将 Mask 缩放到和图片一样大 (一般是一样的,但以防万一)
|
|
||||||
if mask_uint8.shape != img_rgba.size[::-1]: # size is (w, h), shape is (h, w)
|
|
||||||
# 如果尺寸不一致可能需要 resize,这里假设尺寸一致
|
|
||||||
pass
|
|
||||||
|
|
||||||
mask_img = Image.fromarray(mask_uint8, mode='L')
|
mask_img = Image.fromarray(mask_uint8, mode='L')
|
||||||
|
|
||||||
# 3. 将 Mask 应用到 Alpha 通道
|
# 3. 将 Mask 应用到 Alpha 通道
|
||||||
# 创建一个新的空白透明图
|
|
||||||
cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
|
cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
|
||||||
# 将原图粘贴上去,使用 mask 作为 mask
|
|
||||||
cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img)
|
cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img)
|
||||||
|
|
||||||
# 4. Crop to Box
|
# 4. Crop to Box
|
||||||
# box is [x1, y1, x2, y2]
|
|
||||||
x1, y1, x2, y2 = map(int, box_np)
|
x1, y1, x2, y2 = map(int, box_np)
|
||||||
# 边界检查
|
|
||||||
w, h = cutout_img.size
|
w, h = cutout_img.size
|
||||||
x1 = max(0, x1); y1 = max(0, y1)
|
x1 = max(0, x1); y1 = max(0, y1)
|
||||||
x2 = min(w, x2); y2 = min(h, y2)
|
x2 = min(w, x2); y2 = min(h, y2)
|
||||||
|
|
||||||
# 避免无效 crop
|
|
||||||
if x2 > x1 and y2 > y1:
|
if x2 > x1 and y2 > y1:
|
||||||
final_img = cutout_img.crop((x1, y1, x2, y2))
|
final_img = cutout_img.crop((x1, y1, x2, y2))
|
||||||
else:
|
else:
|
||||||
@@ -286,7 +302,7 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
|
|||||||
|
|
||||||
# Save
|
# Save
|
||||||
prefix = "cutout"
|
prefix = "cutout"
|
||||||
is_rotated = False # 抠图模式下不进行自动旋转
|
is_rotated = False
|
||||||
|
|
||||||
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
||||||
save_path = os.path.join(output_dir, filename)
|
save_path = os.path.join(output_dir, filename)
|
||||||
@@ -299,65 +315,51 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE
|
|||||||
})
|
})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# --- 透视矫正模式 (原有逻辑) ---
|
# --- 透视矫正模式 (矩形矫正) ---
|
||||||
# Find contours
|
|
||||||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
if not contours:
|
if not contours:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get largest contour
|
|
||||||
c = max(contours, key=cv2.contourArea)
|
c = max(contours, key=cv2.contourArea)
|
||||||
|
|
||||||
# Approximate contour to polygon
|
|
||||||
peri = cv2.arcLength(c, True)
|
peri = cv2.arcLength(c, True)
|
||||||
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
|
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
|
||||||
|
|
||||||
# If we have 4 points, use them. If not, fallback to minAreaRect
|
|
||||||
if len(approx) == 4:
|
if len(approx) == 4:
|
||||||
pts = approx.reshape(4, 2)
|
pts = approx.reshape(4, 2)
|
||||||
else:
|
else:
|
||||||
rect = cv2.minAreaRect(c)
|
rect = cv2.minAreaRect(c)
|
||||||
pts = cv2.boxPoints(rect)
|
pts = cv2.boxPoints(rect)
|
||||||
|
|
||||||
# Apply perspective transform
|
warped = four_point_transform(img_arr, pts)
|
||||||
# Check if original image has Alpha
|
|
||||||
if img_arr.shape[2] == 4:
|
|
||||||
warped = four_point_transform(img_arr, pts)
|
|
||||||
else:
|
|
||||||
warped = four_point_transform(img_arr, pts)
|
|
||||||
|
|
||||||
# Check orientation (Portrait vs Landscape)
|
# Check orientation (Portrait vs Landscape)
|
||||||
h, w = warped.shape[:2]
|
h, w = warped.shape[:2]
|
||||||
is_rotated = False
|
is_rotated = False
|
||||||
|
|
||||||
# Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx)
|
# 强制竖屏逻辑 (塔罗牌通常是竖屏)
|
||||||
if is_tarot and w > h:
|
if is_tarot and w > h:
|
||||||
# Rotate 90 degrees clockwise
|
|
||||||
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
|
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
|
||||||
is_rotated = True
|
is_rotated = True
|
||||||
|
|
||||||
# Convert back to PIL
|
|
||||||
pil_warped = Image.fromarray(warped)
|
pil_warped = Image.fromarray(warped)
|
||||||
|
|
||||||
# Save
|
|
||||||
prefix = "tarot" if is_tarot else "segment"
|
prefix = "tarot" if is_tarot else "segment"
|
||||||
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
||||||
save_path = os.path.join(output_dir, filename)
|
save_path = os.path.join(output_dir, filename)
|
||||||
pil_warped.save(save_path)
|
pil_warped.save(save_path)
|
||||||
|
|
||||||
# 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒)
|
|
||||||
# 这里我们假设长边垂直为正位,如果做了旋转则标记
|
|
||||||
# 真正的正逆位需要OCR或图像识别
|
|
||||||
|
|
||||||
saved_objects.append({
|
saved_objects.append({
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"is_rotated_by_algorithm": is_rotated,
|
"is_rotated_by_algorithm": is_rotated,
|
||||||
"note": "Geometric correction applied. True upright/reversed requires content analysis."
|
"note": "Geometric correction applied."
|
||||||
})
|
})
|
||||||
|
|
||||||
return saved_objects
|
return saved_objects
|
||||||
|
|
||||||
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
|
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
|
||||||
|
"""
|
||||||
|
生成并保存包含 Mask 叠加效果的完整结果图
|
||||||
|
"""
|
||||||
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||||
save_path = os.path.join(output_dir, filename)
|
save_path = os.path.join(output_dir, filename)
|
||||||
plot_results(image, inference_state)
|
plot_results(image, inference_state)
|
||||||
@@ -370,21 +372,14 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
|||||||
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
|
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 确保路径是绝对路径
|
|
||||||
abs_path = os.path.abspath(image_path)
|
abs_path = os.path.abspath(image_path)
|
||||||
file_url = f"file://{abs_path}"
|
file_url = f"file://{abs_path}"
|
||||||
|
|
||||||
# -------------------------------------------------
|
|
||||||
# 优化策略:生成一张旋转180度的对比图
|
|
||||||
# 让 AI 做选择题而不是判断题,大幅提高准确率
|
|
||||||
# -------------------------------------------------
|
|
||||||
try:
|
try:
|
||||||
# 1. 打开原图
|
# 1. 打开原图
|
||||||
img = Image.open(abs_path)
|
img = Image.open(abs_path)
|
||||||
|
|
||||||
# 2. 生成旋转图 (180度)
|
# 2. 生成旋转图 (180度)
|
||||||
rotated_img = img.rotate(180)
|
rotated_img = img.rotate(180)
|
||||||
|
|
||||||
# 3. 保存旋转图
|
# 3. 保存旋转图
|
||||||
dir_name = os.path.dirname(abs_path)
|
dir_name = os.path.dirname(abs_path)
|
||||||
file_name = os.path.basename(abs_path)
|
file_name = os.path.basename(abs_path)
|
||||||
@@ -395,8 +390,6 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
|||||||
rotated_file_url = f"file://{rotated_path}"
|
rotated_file_url = f"file://{rotated_path}"
|
||||||
|
|
||||||
# 4. 构建对比 Prompt
|
# 4. 构建对比 Prompt
|
||||||
# 发送两张图:图1=原图, 图2=旋转图
|
|
||||||
# 询问 AI 哪一张是“正位”
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -422,7 +415,6 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"对比图生成失败,回退到单图模式: {e}")
|
print(f"对比图生成失败,回退到单图模式: {e}")
|
||||||
# 回退到旧的单图模式
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -433,15 +425,12 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
# 调用模型
|
|
||||||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
content = response.output.choices[0].message.content[0]['text']
|
content = response.output.choices[0].message.content[0]['text']
|
||||||
# 尝试解析简单的 JSON
|
|
||||||
import json
|
import json
|
||||||
try:
|
try:
|
||||||
# 清理可能存在的 markdown 标记
|
|
||||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||||
result = json.loads(clean_content)
|
result = json.loads(clean_content)
|
||||||
return result
|
return result
|
||||||
@@ -458,7 +447,6 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
|||||||
调用 Qwen-VL 识别塔罗牌牌阵
|
调用 Qwen-VL 识别塔罗牌牌阵
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 确保路径是绝对路径并加上 file:// 前缀
|
|
||||||
abs_path = os.path.abspath(image_path)
|
abs_path = os.path.abspath(image_path)
|
||||||
file_url = f"file://{abs_path}"
|
file_url = f"file://{abs_path}"
|
||||||
|
|
||||||
@@ -476,10 +464,8 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
|||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
content = response.output.choices[0].message.content[0]['text']
|
content = response.output.choices[0].message.content[0]['text']
|
||||||
# 尝试解析简单的 JSON
|
|
||||||
import json
|
import json
|
||||||
try:
|
try:
|
||||||
# 清理可能存在的 markdown 标记
|
|
||||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||||
result = json.loads(clean_content)
|
result = json.loads(clean_content)
|
||||||
return result
|
return result
|
||||||
@@ -491,27 +477,82 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"error": f"牌阵识别失败: {str(e)}"}
|
return {"error": f"牌阵识别失败: {str(e)}"}
|
||||||
|
|
||||||
# ------------------- API 接口 (强制依赖验证) -------------------
|
# ==========================================
|
||||||
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
# 6. FastAPI App Setup (应用初始化)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
lifespan=lifespan,
|
||||||
|
title="SAM3 Segmentation API",
|
||||||
|
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
|
||||||
|
openapi_tags=[
|
||||||
|
{"name": TAG_GENERAL, "description": "General purpose image segmentation"},
|
||||||
|
{"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"},
|
||||||
|
{"name": TAG_FACE, "description": "Face detection and analysis"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 手动添加 OpenAPI 安全配置
|
||||||
|
app.openapi_schema = None
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
openapi_schema["components"]["securitySchemes"] = {
|
||||||
|
"APIKeyHeader": {
|
||||||
|
"type": "apiKey",
|
||||||
|
"in": "header",
|
||||||
|
"name": API_KEY_HEADER_NAME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for path in openapi_schema["paths"]:
|
||||||
|
for method in openapi_schema["paths"][path]:
|
||||||
|
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 7. API Endpoints (接口定义)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# ------------------------------------------
|
||||||
|
# Group 1: General Segmentation (通用分割)
|
||||||
|
# ------------------------------------------
|
||||||
|
|
||||||
|
@app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)])
|
||||||
async def segment(
|
async def segment(
|
||||||
request: Request,
|
request: Request,
|
||||||
prompt: str = Form(...),
|
prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"),
|
||||||
file: Optional[UploadFile] = File(None),
|
file: Optional[UploadFile] = File(None, description="Image file to upload"),
|
||||||
image_url: Optional[str] = Form(None),
|
image_url: Optional[str] = Form(None, description="URL of the image"),
|
||||||
save_segment_images: bool = Form(False),
|
save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"),
|
||||||
cutout: bool = Form(False)
|
cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"),
|
||||||
|
confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.")
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
**通用图像分割接口**
|
||||||
|
|
||||||
|
- 支持上传图片或提供图片 URL
|
||||||
|
- 支持自动将中文 Prompt 翻译为英文
|
||||||
|
- 支持手动设置置信度 (Confidence Threshold)
|
||||||
|
"""
|
||||||
if not file and not image_url:
|
if not file and not image_url:
|
||||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
# ------------------- Prompt 翻译/优化 -------------------
|
# 1. Prompt 处理
|
||||||
final_prompt = prompt
|
final_prompt = prompt
|
||||||
if not is_english(prompt):
|
if not is_english(prompt):
|
||||||
# 使用 Qwen 进行翻译
|
|
||||||
try:
|
try:
|
||||||
# 构建消息请求翻译
|
|
||||||
# 注意:translate_to_sam3_prompt 是一个阻塞调用,可能会增加耗时
|
|
||||||
# 但对于分割任务来说是可以接受的
|
|
||||||
translated = translate_to_sam3_prompt(prompt)
|
translated = translate_to_sam3_prompt(prompt)
|
||||||
if translated:
|
if translated:
|
||||||
final_prompt = translated
|
final_prompt = translated
|
||||||
@@ -520,6 +561,7 @@ async def segment(
|
|||||||
|
|
||||||
print(f"最终使用的 Prompt: {final_prompt}")
|
print(f"最终使用的 Prompt: {final_prompt}")
|
||||||
|
|
||||||
|
# 2. 图片加载
|
||||||
try:
|
try:
|
||||||
if file:
|
if file:
|
||||||
image = Image.open(file.file).convert("RGB")
|
image = Image.open(file.file).convert("RGB")
|
||||||
@@ -529,14 +571,35 @@ async def segment(
|
|||||||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||||||
|
|
||||||
processor = request.app.state.processor
|
processor = request.app.state.processor
|
||||||
|
|
||||||
|
# 3. 模型推理
|
||||||
try:
|
try:
|
||||||
|
# 设置图片
|
||||||
inference_state = processor.set_image(image)
|
inference_state = processor.set_image(image)
|
||||||
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
|
|
||||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
# 处理 Confidence Threshold
|
||||||
|
# 如果用户提供了 confidence,临时修改 processor 的阈值
|
||||||
|
original_confidence = processor.confidence_threshold
|
||||||
|
if confidence is not None:
|
||||||
|
# 简单校验
|
||||||
|
if not (0.0 <= confidence <= 1.0):
|
||||||
|
raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0")
|
||||||
|
processor.confidence_threshold = confidence
|
||||||
|
print(f"Using manual confidence threshold: {confidence}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行推理
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
finally:
|
||||||
|
# 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯)
|
||||||
|
if confidence is not None:
|
||||||
|
processor.confidence_threshold = original_confidence
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
|
# 4. 结果可视化与保存
|
||||||
try:
|
try:
|
||||||
filename = generate_and_save_result(image, inference_state)
|
filename = generate_and_save_result(image, inference_state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -544,7 +607,7 @@ async def segment(
|
|||||||
|
|
||||||
file_url = request.url_for("static", path=f"results/{filename}")
|
file_url = request.url_for("static", path=f"results/{filename}")
|
||||||
|
|
||||||
# New logic for saving segments
|
# 5. 保存分割子图 (Optional)
|
||||||
saved_segments_info = []
|
saved_segments_info = []
|
||||||
if save_segment_images:
|
if save_segment_images:
|
||||||
try:
|
try:
|
||||||
@@ -552,7 +615,6 @@ async def segment(
|
|||||||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# 传递 cutout 参数
|
|
||||||
saved_objects = crop_and_save_objects(
|
saved_objects = crop_and_save_objects(
|
||||||
image,
|
image,
|
||||||
masks,
|
masks,
|
||||||
@@ -586,7 +648,11 @@ async def segment(
|
|||||||
|
|
||||||
return JSONResponse(content=response_content)
|
return JSONResponse(content=response_content)
|
||||||
|
|
||||||
@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)])
|
# ------------------------------------------
|
||||||
|
# Group 2: Tarot Analysis (塔罗牌分析)
|
||||||
|
# ------------------------------------------
|
||||||
|
|
||||||
|
@app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
|
||||||
async def segment_tarot(
|
async def segment_tarot(
|
||||||
request: Request,
|
request: Request,
|
||||||
file: Optional[UploadFile] = File(None),
|
file: Optional[UploadFile] = File(None),
|
||||||
@@ -594,9 +660,11 @@ async def segment_tarot(
|
|||||||
expected_count: int = Form(3)
|
expected_count: int = Form(3)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
塔罗牌分割专用接口
|
**塔罗牌分割接口**
|
||||||
|
|
||||||
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
|
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
|
||||||
2. 如果是,分别抠出这些牌并返回
|
2. 对检测到的卡片进行透视矫正和裁剪
|
||||||
|
3. 返回矫正后的图片 URL
|
||||||
"""
|
"""
|
||||||
if not file and not image_url:
|
if not file and not image_url:
|
||||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
@@ -622,7 +690,7 @@ async def segment_tarot(
|
|||||||
# 核心逻辑:判断数量
|
# 核心逻辑:判断数量
|
||||||
detected_count = len(masks)
|
detected_count = len(masks)
|
||||||
|
|
||||||
# 创建本次请求的独立文件夹 (时间戳_UUID前8位)
|
# 创建本次请求的独立文件夹
|
||||||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
@@ -678,7 +746,7 @@ async def segment_tarot(
|
|||||||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
})
|
})
|
||||||
|
|
||||||
@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)])
|
@app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
|
||||||
async def recognize_tarot(
|
async def recognize_tarot(
|
||||||
request: Request,
|
request: Request,
|
||||||
file: Optional[UploadFile] = File(None),
|
file: Optional[UploadFile] = File(None),
|
||||||
@@ -686,7 +754,8 @@ async def recognize_tarot(
|
|||||||
expected_count: int = Form(3)
|
expected_count: int = Form(3)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
塔罗牌全流程接口: 分割 + 矫正 + 识别
|
**塔罗牌全流程接口: 分割 + 矫正 + 识别**
|
||||||
|
|
||||||
1. 检测是否包含指定数量的塔罗牌 (SAM3)
|
1. 检测是否包含指定数量的塔罗牌 (SAM3)
|
||||||
2. 分割并透视矫正
|
2. 分割并透视矫正
|
||||||
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
|
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
|
||||||
@@ -706,21 +775,18 @@ async def recognize_tarot(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
inference_state = processor.set_image(image)
|
inference_state = processor.set_image(image)
|
||||||
# 固定 Prompt 检测塔罗牌
|
|
||||||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
# 核心逻辑:判断数量
|
|
||||||
detected_count = len(masks)
|
detected_count = len(masks)
|
||||||
|
|
||||||
# 创建本次请求的独立文件夹
|
|
||||||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# 保存整体效果图 (无论是成功还是失败,都先保存一张主图)
|
# 保存整体效果图
|
||||||
try:
|
try:
|
||||||
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||||||
main_file_path = os.path.join(output_dir, main_filename)
|
main_file_path = os.path.join(output_dir, main_filename)
|
||||||
@@ -730,20 +796,14 @@ async def recognize_tarot(
|
|||||||
main_file_path = None
|
main_file_path = None
|
||||||
main_file_url = None
|
main_file_url = None
|
||||||
|
|
||||||
# Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断)
|
# Step 0: 牌阵识别
|
||||||
spread_info = {"spread_name": "Unknown"}
|
spread_info = {"spread_name": "Unknown"}
|
||||||
if main_file_path:
|
if main_file_path:
|
||||||
# 使用带有mask绘制的主图或者原始图?
|
# 使用原始图的一份拷贝给 Qwen 识别牌阵
|
||||||
# 使用原始图可能更好,不受mask遮挡干扰,但是main_filename是带mask的。
|
|
||||||
# 我们这里暂时用原始图保存一份临时文件给Qwen看
|
|
||||||
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
||||||
image.save(temp_raw_path)
|
image.save(temp_raw_path)
|
||||||
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
||||||
|
|
||||||
# 如果识别结果明确说是“不是正规牌阵”,是否要继续?
|
|
||||||
# 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’”
|
|
||||||
# 我们将其放在返回结果中,由客户端决定是否展示警告
|
|
||||||
|
|
||||||
if detected_count != expected_count:
|
if detected_count != expected_count:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@@ -768,8 +828,7 @@ async def recognize_tarot(
|
|||||||
fname = obj["filename"]
|
fname = obj["filename"]
|
||||||
file_path = os.path.join(output_dir, fname)
|
file_path = os.path.join(output_dir, fname)
|
||||||
|
|
||||||
# 调用 Qwen-VL 识别
|
# 调用 Qwen-VL 识别 (串行)
|
||||||
# 注意:这里会串行调用,速度可能较慢。
|
|
||||||
recognition_res = recognize_card_with_qwen(file_path)
|
recognition_res = recognize_card_with_qwen(file_path)
|
||||||
|
|
||||||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||||||
@@ -790,15 +849,20 @@ async def recognize_tarot(
|
|||||||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
})
|
})
|
||||||
|
|
||||||
@app.post("/segment_face", dependencies=[Depends(verify_api_key)])
|
# ------------------------------------------
|
||||||
|
# Group 3: Face Analysis (人脸分析)
|
||||||
|
# ------------------------------------------
|
||||||
|
|
||||||
|
@app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)])
|
||||||
async def segment_face(
|
async def segment_face(
|
||||||
request: Request,
|
request: Request,
|
||||||
file: Optional[UploadFile] = File(None),
|
file: Optional[UploadFile] = File(None),
|
||||||
image_url: Optional[str] = Form(None),
|
image_url: Optional[str] = Form(None),
|
||||||
prompt: str = Form("face and hair") # 默认提示词包含头发
|
prompt: str = Form("face and hair", description="Prompt for face detection")
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
人脸/头部检测与属性分析接口 (新功能)
|
**人脸/头部检测与属性分析接口**
|
||||||
|
|
||||||
1. 调用 SAM3 分割出头部区域 (含头发)
|
1. 调用 SAM3 分割出头部区域 (含头发)
|
||||||
2. 裁剪并保存
|
2. 裁剪并保存
|
||||||
3. 调用 Qwen-VL 识别性别和年龄
|
3. 调用 Qwen-VL 识别性别和年龄
|
||||||
@@ -806,7 +870,7 @@ async def segment_face(
|
|||||||
if not file and not image_url:
|
if not file and not image_url:
|
||||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
# ------------------- Prompt 翻译/优化 -------------------
|
# Prompt 翻译/优化
|
||||||
final_prompt = prompt
|
final_prompt = prompt
|
||||||
if not is_english(prompt):
|
if not is_english(prompt):
|
||||||
try:
|
try:
|
||||||
@@ -818,7 +882,7 @@ async def segment_face(
|
|||||||
|
|
||||||
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
|
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
|
||||||
|
|
||||||
# 1. 加载图片
|
# 加载图片
|
||||||
try:
|
try:
|
||||||
if file:
|
if file:
|
||||||
image = Image.open(file.file).convert("RGB")
|
image = Image.open(file.file).convert("RGB")
|
||||||
@@ -829,10 +893,8 @@ async def segment_face(
|
|||||||
|
|
||||||
processor = request.app.state.processor
|
processor = request.app.state.processor
|
||||||
|
|
||||||
# 2. 调用独立服务进行处理
|
# 调用独立服务进行处理
|
||||||
try:
|
try:
|
||||||
# 传入 processor 和 image
|
|
||||||
# 注意:Result Image Dir 我们直接复用 RESULT_IMAGE_DIR
|
|
||||||
result = human_analysis_service.process_face_segmentation_and_analysis(
|
result = human_analysis_service.process_face_segmentation_and_analysis(
|
||||||
processor=processor,
|
processor=processor,
|
||||||
image=image,
|
image=image,
|
||||||
@@ -840,34 +902,34 @@ async def segment_face(
|
|||||||
output_base_dir=RESULT_IMAGE_DIR
|
output_base_dir=RESULT_IMAGE_DIR
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 打印详细错误堆栈以便调试
|
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||||||
|
|
||||||
# 3. 补全 URL (因为 service 层不知道 request context)
|
# 补全 URL
|
||||||
if result["status"] == "success":
|
if result["status"] == "success":
|
||||||
# 处理全图可视化的 URL
|
|
||||||
if result.get("full_visualization"):
|
if result.get("full_visualization"):
|
||||||
full_vis_rel_path = result["full_visualization"]
|
full_vis_rel_path = result["full_visualization"]
|
||||||
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
|
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
|
||||||
|
|
||||||
for item in result["results"]:
|
for item in result["results"]:
|
||||||
# item["relative_path"] 是相对路径,如 results/xxx/xxx.jpg
|
relative_path = item.pop("relative_path")
|
||||||
# 我们需要将其转换为完整 URL
|
|
||||||
relative_path = item.pop("relative_path") # 移除相对路径字段,只返回 URL
|
|
||||||
item["url"] = str(request.url_for("static", path=relative_path))
|
item["url"] = str(request.url_for("static", path=relative_path))
|
||||||
|
|
||||||
return JSONResponse(content=result)
|
return JSONResponse(content=result)
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 8. Main Entry Point (启动入口)
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
# 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数
|
# 启动服务器
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"fastAPI_tarot:app",
|
"fastAPI_tarot:app",
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port=55600,
|
port=55600,
|
||||||
proxy_headers=True,
|
proxy_headers=True,
|
||||||
forwarded_allow_ips="*",
|
forwarded_allow_ips="*",
|
||||||
reload=False # 生产环境建议关闭 reload,确保代码完全重载
|
reload=False
|
||||||
)
|
)
|
||||||
|
|||||||
BIN
static/results/seg_153993616fb74472a9287c68e304e8eb.jpg
Normal file
BIN
static/results/seg_153993616fb74472a9287c68e304e8eb.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 132 KiB |
BIN
static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg
Normal file
BIN
static/results/seg_7a044baf66b1480a8c343c2b501db944.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
BIN
static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg
Normal file
BIN
static/results/seg_863aa7fac0cd4c5f946b7dd37f3b4f36.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 101 KiB |
BIN
static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg
Normal file
BIN
static/results/seg_c47b6e3658bb452a8778c5d8f2d8a6df.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 82 KiB |
BIN
static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg
Normal file
BIN
static/results/seg_f6cd00ded6b54337af288a844f8890aa.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 78 KiB |
Reference in New Issue
Block a user