1495 lines
56 KiB
Python
1495 lines
56 KiB
Python
"""
|
||
SAM3 Segmentation API Service (SAM3 图像分割 API 服务)
|
||
--------------------------------------------------
|
||
本服务提供基于 SAM3 模型的图像分割能力。
|
||
包含通用分割、塔罗牌分析和人脸分析等专用接口。
|
||
|
||
This service provides image segmentation capabilities using the SAM3 model.
|
||
It includes specialized endpoints for Tarot card analysis and Face analysis.
|
||
"""
|
||
|
||
# ==========================================
|
||
# 1. Imports (导入)
|
||
# ==========================================
|
||
|
||
# Standard Library Imports (标准库)
|
||
import os
|
||
import uuid
|
||
import time
|
||
import json
|
||
import traceback
|
||
import re
|
||
import asyncio
|
||
import shutil
|
||
import subprocess
|
||
from datetime import datetime
|
||
from typing import Optional, List, Dict, Any
|
||
from contextlib import asynccontextmanager
|
||
|
||
# Third-Party Imports (第三方库)
|
||
import cv2
|
||
import torch
|
||
import numpy as np
|
||
import requests
|
||
import matplotlib
|
||
matplotlib.use('Agg') # 使用非交互式后端,防止在服务器上报错
|
||
import matplotlib.pyplot as plt
|
||
from PIL import Image
|
||
|
||
# FastAPI Imports
|
||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter, Cookie
|
||
from fastapi.security import APIKeyHeader
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import JSONResponse, HTMLResponse, Response
|
||
|
||
# Dashscope (Aliyun Qwen) Imports
|
||
import dashscope
|
||
from dashscope import MultiModalConversation
|
||
|
||
# Local Imports (本地模块)
|
||
from sam3.model_builder import build_sam3_image_model
|
||
from sam3.model.sam3_image_processor import Sam3Processor
|
||
from sam3.visualization_utils import plot_results
|
||
import human_analysis_service # 引入新服务: 人脸分析
|
||
|
||
# ==========================================
|
||
# 2. Configuration & Constants (配置与常量)
|
||
# ==========================================
|
||
|
||
# 路径配置
|
||
STATIC_DIR = "static"
|
||
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||
|
||
# API Key 配置
|
||
VALID_API_KEY = "123quant-speed"
|
||
API_KEY_HEADER_NAME = "X-API-Key"
|
||
|
||
# Admin Config
|
||
ADMIN_PASSWORD = "admin_secure_password" # 可以根据需求修改
|
||
HISTORY_FILE = "history.json"
|
||
|
||
# Dashscope (Qwen-VL) 配置
|
||
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
||
QWEN_MODEL = 'qwen-vl-max' # Default model
|
||
AVAILABLE_QWEN_MODELS = ["qwen-vl-max", "qwen-vl-plus","qwen3.5-plus"]
|
||
|
||
# 清理配置 (Cleanup Config)
|
||
CLEANUP_CONFIG = {
|
||
"enabled": os.getenv("AUTO_CLEANUP_ENABLED", "True").lower() == "true",
|
||
"lifetime": int(os.getenv("FILE_LIFETIME_SECONDS", "3600")),
|
||
"interval": int(os.getenv("CLEANUP_INTERVAL_SECONDS", "600"))
|
||
}
|
||
|
||
# 提示词配置 (Prompt Config)
|
||
PROMPTS = {
|
||
"translate": "请将以下描述翻译成简洁、精准的英文,用于图像分割模型(SAM)的提示词。直接返回英文,不要包含任何解释或其他文字。\n\n输入: {text}",
|
||
"tarot_card_dual": """这是一张塔罗牌的两个方向:
|
||
图1:原始方向
|
||
图2:旋转180度后的方向
|
||
|
||
请仔细对比两张图片的牌面内容(文字方向、人物站立方向、图案逻辑):
|
||
1. 识别这张牌的名字(中文)。
|
||
2. 判断哪一张图片展示了正确的“正位”(Upright)状态。
|
||
- 如果图1是正位,说明原图就是正位。
|
||
- 如果图2是正位,说明原图是逆位。
|
||
|
||
请以JSON格式返回,包含 'name' 和 'position' 两个字段。
|
||
例如:{'name': '愚者', 'position': '正位'} 或 {'name': '倒吊人', 'position': '逆位'}。
|
||
不要包含Markdown代码块标记。""",
|
||
"tarot_card_single": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。",
|
||
"tarot_spread": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵(例如:圣三角、凯尔特十字、三张牌等)。如果看不出明显的正规牌阵,请返回“不是正规牌阵”。请以JSON格式返回,包含 'spread_name' 和 'description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。",
|
||
"face_analysis": """请仔细观察这张图片中的人物头部/面部特写:
|
||
1. 识别性别 (Gender):男性/女性
|
||
2. 预估年龄 (Age):请给出一个合理的年龄范围,例如 "25-30岁"
|
||
3. 简要描述:发型、发色、是否有眼镜等显著特征。
|
||
|
||
请以 JSON 格式返回,包含 'gender', 'age', 'description' 字段。
|
||
不要包含 Markdown 标记。"""
|
||
}
|
||
|
||
# API Tags (用于文档分类)
|
||
TAG_GENERAL = "General Segmentation (通用分割)"
|
||
TAG_TAROT = "Tarot Analysis (塔罗牌分析)"
|
||
TAG_FACE = "Face Analysis (人脸分析)"
|
||
|
||
# ==========================================
|
||
# 3. Security & Middleware (安全与中间件)
|
||
# ==========================================
|
||
|
||
# 定义 Header 认证 Scheme
|
||
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||
|
||
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
||
"""
|
||
强制验证 API Key (Enforce API Key Verification)
|
||
|
||
1. 检查 Header 中是否存在 API Key
|
||
2. 验证 API Key 是否匹配
|
||
"""
|
||
if not api_key:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Missing API Key. Please provide it in the header."
|
||
)
|
||
if api_key != VALID_API_KEY:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="Invalid API Key."
|
||
)
|
||
return True
|
||
|
||
# ==========================================
|
||
# 4. Lifespan Management (生命周期管理)
|
||
# ==========================================
|
||
|
||
async def cleanup_old_files(directory: str):
|
||
"""
|
||
后台任务:定期清理过期的图片文件
|
||
- 动态读取 CLEANUP_CONFIG 配置
|
||
"""
|
||
print(f"🧹 自动清理任务已启动 | 目录: {directory}")
|
||
while True:
|
||
try:
|
||
# 动态读取配置
|
||
interval = CLEANUP_CONFIG["interval"]
|
||
lifetime = CLEANUP_CONFIG["lifetime"]
|
||
enabled = CLEANUP_CONFIG["enabled"]
|
||
|
||
await asyncio.sleep(interval)
|
||
|
||
if not enabled:
|
||
continue
|
||
|
||
current_time = time.time()
|
||
count = 0
|
||
# 遍历所有文件(包括子目录)
|
||
for root, dirs, files in os.walk(directory):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
# 获取文件修改时间
|
||
try:
|
||
file_mtime = os.path.getmtime(file_path)
|
||
if current_time - file_mtime > lifetime:
|
||
os.remove(file_path)
|
||
count += 1
|
||
except OSError:
|
||
pass # 文件可能已被删除
|
||
|
||
# 尝试清理空目录 (可选,仅清理二级目录)
|
||
for root, dirs, files in os.walk(directory, topdown=False):
|
||
for dir in dirs:
|
||
dir_path = os.path.join(root, dir)
|
||
try:
|
||
if not os.listdir(dir_path): # 如果目录为空
|
||
os.rmdir(dir_path)
|
||
except OSError:
|
||
pass
|
||
|
||
if count > 0:
|
||
print(f"🧹 已清理 {count} 个过期文件 (Lifetime: {lifetime}s)")
|
||
|
||
except asyncio.CancelledError:
|
||
print("🛑 清理任务已停止")
|
||
break
|
||
except Exception as e:
|
||
print(f"⚠️ 清理任务出错: {e}")
|
||
await asyncio.sleep(60) # 出错后等待一分钟再试
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
FastAPI 生命周期管理器
|
||
- 启动时: 加载模型到 GPU/CPU
|
||
- 关闭时: 清理资源
|
||
"""
|
||
print("="*40)
|
||
print("✅ API Key 保护已激活")
|
||
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
||
print("="*40)
|
||
|
||
print("正在加载 SAM3 模型到 GPU...")
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
if not torch.cuda.is_available():
|
||
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||
|
||
# 加载模型
|
||
model = build_sam3_image_model()
|
||
model = model.to(device)
|
||
model.eval()
|
||
|
||
# 初始化处理器
|
||
processor = Sam3Processor(model)
|
||
|
||
# 存储到 app.state 以便全局访问
|
||
app.state.model = model
|
||
app.state.processor = processor
|
||
app.state.device = device
|
||
|
||
print(f"模型加载完成,设备: {device}")
|
||
|
||
# --- 启动后台清理任务 ---
|
||
cleanup_task_handle = None
|
||
try:
|
||
cleanup_task_handle = asyncio.create_task(
|
||
cleanup_old_files(RESULT_IMAGE_DIR)
|
||
)
|
||
except Exception as e:
|
||
print(f"启动清理任务失败: {e}")
|
||
# -----------------------
|
||
|
||
yield
|
||
|
||
print("正在清理资源...")
|
||
|
||
# --- 停止后台任务 ---
|
||
if cleanup_task_handle:
|
||
cleanup_task_handle.cancel()
|
||
try:
|
||
await cleanup_task_handle
|
||
except asyncio.CancelledError:
|
||
pass
|
||
# ------------------
|
||
|
||
# 这里可以添加释放显存的逻辑,如果需要
|
||
|
||
# ==========================================
|
||
# 5. Helper Functions (辅助函数)
|
||
# ==========================================
|
||
|
||
def is_english(text: str) -> bool:
|
||
"""
|
||
判断文本是否为纯英文
|
||
- 如果包含中文字符范围 (\u4e00-\u9fff),则返回 False
|
||
"""
|
||
for char in text:
|
||
if '\u4e00' <= char <= '\u9fff':
|
||
return False
|
||
return True
|
||
|
||
def append_to_history(req_type: str, prompt: str, status: str, result_path: str = None, details: str = "", final_prompt: str = None, duration: float = 0.0):
|
||
"""
|
||
记录请求历史到 history.json
|
||
"""
|
||
record = {
|
||
"timestamp": time.time(),
|
||
"type": req_type,
|
||
"prompt": prompt,
|
||
"final_prompt": final_prompt,
|
||
"status": status,
|
||
"result_path": result_path,
|
||
"details": details,
|
||
"duration": duration
|
||
}
|
||
try:
|
||
with open(HISTORY_FILE, "a", encoding="utf-8") as f:
|
||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||
except Exception as e:
|
||
print(f"Failed to write history: {e}")
|
||
|
||
|
||
def translate_to_sam3_prompt(text: str) -> str:
|
||
"""
|
||
使用 Qwen 模型将中文提示词翻译为英文
|
||
- SAM3 模型对英文 Prompt 支持更好
|
||
"""
|
||
print(f"正在翻译提示词: {text}")
|
||
try:
|
||
prompt_template = PROMPTS["translate"]
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"text": prompt_template.format(text=text)}
|
||
]
|
||
}
|
||
]
|
||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||
|
||
if response.status_code == 200:
|
||
translated_text = response.output.choices[0].message.content[0]['text'].strip()
|
||
# 去除可能的 markdown 标记或引号
|
||
translated_text = translated_text.replace('"', '').replace("'", "").strip()
|
||
print(f"翻译结果: {translated_text}")
|
||
return translated_text
|
||
else:
|
||
print(f"翻译失败: {response.code} - {response.message}")
|
||
return text # 失败则回退到原始文本
|
||
except Exception as e:
|
||
print(f"翻译异常: {e}")
|
||
return text
|
||
|
||
def order_points(pts):
|
||
"""
|
||
对四个坐标点进行排序:左上,右上,右下,左下
|
||
用于透视变换前的点位整理
|
||
"""
|
||
rect = np.zeros((4, 2), dtype="float32")
|
||
s = pts.sum(axis=1)
|
||
rect[0] = pts[np.argmin(s)]
|
||
rect[2] = pts[np.argmax(s)]
|
||
diff = np.diff(pts, axis=1)
|
||
rect[1] = pts[np.argmin(diff)]
|
||
rect[3] = pts[np.argmax(diff)]
|
||
return rect
|
||
|
||
def four_point_transform(image, pts):
|
||
"""
|
||
根据四个点进行透视变换 (Perspective Transform)
|
||
用于将倾斜的卡片矫正为矩形
|
||
"""
|
||
rect = order_points(pts)
|
||
(tl, tr, br, bl) = rect
|
||
|
||
# 计算新图像的宽度
|
||
widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
||
widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
||
maxWidth = max(int(widthA), int(widthB))
|
||
|
||
# 计算新图像的高度
|
||
heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
||
heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
||
maxHeight = max(int(heightA), int(heightB))
|
||
|
||
dst = np.array([
|
||
[0, 0],
|
||
[maxWidth - 1, 0],
|
||
[maxWidth - 1, maxHeight - 1],
|
||
[0, maxHeight - 1]], dtype="float32")
|
||
|
||
M = cv2.getPerspectiveTransform(rect, dst)
|
||
warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
|
||
return warped
|
||
|
||
def load_image_from_url(url: str) -> Image.Image:
|
||
"""
|
||
从 URL 下载图片并转换为 PIL Image 对象
|
||
"""
|
||
try:
|
||
headers = {'User-Agent': 'Mozilla/5.0'}
|
||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||
response.raise_for_status()
|
||
image = Image.open(response.raw).convert("RGB")
|
||
return image
|
||
except Exception as e:
|
||
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
||
|
||
def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False, perspective_correction: bool = False) -> list[dict]:
|
||
"""
|
||
根据 mask 和 box 进行处理并保存独立的对象图片
|
||
|
||
参数:
|
||
- image: 原始图片
|
||
- masks: 分割掩码列表
|
||
- boxes: 边界框列表
|
||
- output_dir: 输出目录
|
||
- is_tarot: 是否为塔罗牌模式 (会影响文件名前缀和旋转逻辑)
|
||
- cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片)
|
||
- perspective_correction: 是否进行梯度透视矫正
|
||
|
||
返回:
|
||
- 保存的对象信息列表
|
||
"""
|
||
saved_objects = []
|
||
# Convert image to numpy array (RGB)
|
||
img_arr = np.array(image)
|
||
|
||
for i, (mask, box) in enumerate(zip(masks, boxes)):
|
||
# Handle tensor/numpy conversions
|
||
if isinstance(mask, torch.Tensor):
|
||
mask_np = mask.cpu().numpy().squeeze()
|
||
else:
|
||
mask_np = mask.squeeze()
|
||
|
||
# Handle box conversion
|
||
if isinstance(box, torch.Tensor):
|
||
box_np = box.cpu().numpy()
|
||
else:
|
||
box_np = np.array(box)
|
||
|
||
# Ensure mask is uint8 binary for OpenCV/Pillow
|
||
if mask_np.dtype == bool:
|
||
mask_uint8 = (mask_np * 255).astype(np.uint8)
|
||
else:
|
||
mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255
|
||
|
||
# --- 准备基础图像 ---
|
||
if cutout:
|
||
# 1. 准备 RGBA 原图
|
||
if image.mode != "RGBA":
|
||
img_rgba = image.convert("RGBA")
|
||
else:
|
||
img_rgba = image.copy()
|
||
|
||
# 2. 准备 Alpha Mask
|
||
mask_img = Image.fromarray(mask_uint8, mode='L')
|
||
|
||
# 3. 将 Mask 应用到 Alpha 通道
|
||
base_img_pil = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0))
|
||
base_img_pil.paste(image.convert("RGB"), (0, 0), mask=mask_img)
|
||
|
||
# Convert to numpy for potential warping
|
||
base_img_arr = np.array(base_img_pil)
|
||
else:
|
||
base_img_pil = image.convert("RGB")
|
||
base_img_arr = img_arr # RGB numpy array
|
||
|
||
# --- 透视矫正 vs 简单裁剪 ---
|
||
final_img_pil = None
|
||
is_rotated = False
|
||
note = ""
|
||
|
||
if perspective_correction:
|
||
# --- 透视矫正模式 (矩形矫正) ---
|
||
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
||
pts = None
|
||
if contours:
|
||
c = max(contours, key=cv2.contourArea)
|
||
peri = cv2.arcLength(c, True)
|
||
approx = cv2.approxPolyDP(c, 0.04 * peri, True)
|
||
|
||
if len(approx) == 4:
|
||
pts = approx.reshape(4, 2)
|
||
else:
|
||
rect = cv2.minAreaRect(c)
|
||
pts = cv2.boxPoints(rect)
|
||
|
||
if pts is not None:
|
||
warped = four_point_transform(base_img_arr, pts)
|
||
note = "Geometric correction applied."
|
||
else:
|
||
# Fallback to simple crop if no contours found
|
||
x1, y1, x2, y2 = map(int, box_np)
|
||
# Ensure bounds
|
||
h, w = base_img_arr.shape[:2]
|
||
x1 = max(0, x1); y1 = max(0, y1)
|
||
x2 = min(w, x2); y2 = min(h, y2)
|
||
warped = base_img_arr[y1:y2, x1:x2]
|
||
note = "Correction failed, fallback to crop."
|
||
|
||
# Check orientation (Portrait vs Landscape) - Only for Tarot usually
|
||
h, w = warped.shape[:2]
|
||
|
||
# 强制竖屏逻辑 (塔罗牌通常是竖屏)
|
||
if is_tarot and w > h:
|
||
warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE)
|
||
is_rotated = True
|
||
|
||
final_img_pil = Image.fromarray(warped)
|
||
|
||
else:
|
||
# --- 简单裁剪模式 (Simple Crop) ---
|
||
x1, y1, x2, y2 = map(int, box_np)
|
||
w, h = base_img_pil.size
|
||
x1 = max(0, x1); y1 = max(0, y1)
|
||
x2 = min(w, x2); y2 = min(h, y2)
|
||
|
||
if x2 > x1 and y2 > y1:
|
||
final_img_pil = base_img_pil.crop((x1, y1, x2, y2))
|
||
else:
|
||
final_img_pil = base_img_pil # Fallback
|
||
|
||
note = "Simple crop applied."
|
||
|
||
# --- 保存图片 ---
|
||
prefix = "cutout" if cutout else ("tarot" if is_tarot else "segment")
|
||
filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png"
|
||
save_path = os.path.join(output_dir, filename)
|
||
final_img_pil.save(save_path)
|
||
|
||
saved_objects.append({
|
||
"filename": filename,
|
||
"is_rotated_by_algorithm": is_rotated,
|
||
"note": note
|
||
})
|
||
|
||
return saved_objects
|
||
|
||
def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str:
|
||
"""
|
||
生成并保存包含 Mask 叠加效果的完整结果图
|
||
"""
|
||
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||
save_path = os.path.join(output_dir, filename)
|
||
plot_results(image, inference_state)
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
return filename
|
||
|
||
def recognize_card_with_qwen(image_path: str) -> dict:
|
||
"""
|
||
调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略)
|
||
"""
|
||
try:
|
||
abs_path = os.path.abspath(image_path)
|
||
file_url = f"file://{abs_path}"
|
||
|
||
try:
|
||
# 1. 打开原图
|
||
img = Image.open(abs_path)
|
||
# 2. 生成旋转图 (180度)
|
||
rotated_img = img.rotate(180)
|
||
# 3. 保存旋转图
|
||
dir_name = os.path.dirname(abs_path)
|
||
file_name = os.path.basename(abs_path)
|
||
rotated_name = f"rotated_{file_name}"
|
||
rotated_path = os.path.join(dir_name, rotated_name)
|
||
rotated_img.save(rotated_path)
|
||
|
||
rotated_file_url = f"file://{rotated_path}"
|
||
|
||
# 4. 构建对比 Prompt
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": file_url}, # 图1 (原图)
|
||
{"image": rotated_file_url}, # 图2 (旋转180度)
|
||
{"text": PROMPTS["tarot_card_dual"]}
|
||
]
|
||
}
|
||
]
|
||
|
||
except Exception as e:
|
||
print(f"对比图生成失败,回退到单图模式: {e}")
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": file_url},
|
||
{"text": PROMPTS["tarot_card_single"]}
|
||
]
|
||
}
|
||
]
|
||
|
||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0].message.content[0]['text']
|
||
import json
|
||
try:
|
||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||
result = json.loads(clean_content)
|
||
return result
|
||
except:
|
||
return {"raw_response": content}
|
||
else:
|
||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||
|
||
except Exception as e:
|
||
return {"error": f"识别失败: {str(e)}"}
|
||
|
||
def recognize_spread_with_qwen(image_path: str) -> dict:
|
||
"""
|
||
调用 Qwen-VL 识别塔罗牌牌阵
|
||
"""
|
||
try:
|
||
abs_path = os.path.abspath(image_path)
|
||
file_url = f"file://{abs_path}"
|
||
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": file_url},
|
||
{"text": PROMPTS["tarot_spread"]}
|
||
]
|
||
}
|
||
]
|
||
|
||
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0].message.content[0]['text']
|
||
import json
|
||
try:
|
||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||
result = json.loads(clean_content)
|
||
return result
|
||
except:
|
||
return {"raw_response": content, "spread_name": "Unknown"}
|
||
else:
|
||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||
|
||
except Exception as e:
|
||
return {"error": f"牌阵识别失败: {str(e)}"}
|
||
|
||
# ==========================================
|
||
# 6. FastAPI App Setup (应用初始化)
|
||
# ==========================================
|
||
|
||
app = FastAPI(
|
||
lifespan=lifespan,
|
||
title="SAM3 Segmentation API",
|
||
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`",
|
||
openapi_tags=[
|
||
{"name": TAG_GENERAL, "description": "General purpose image segmentation"},
|
||
{"name": TAG_TAROT, "description": "Specialized endpoints for Tarot card recognition"},
|
||
{"name": TAG_FACE, "description": "Face detection and analysis"},
|
||
]
|
||
)
|
||
|
||
# 手动添加 OpenAPI 安全配置
|
||
app.openapi_schema = None
|
||
def custom_openapi():
|
||
if app.openapi_schema:
|
||
return app.openapi_schema
|
||
from fastapi.openapi.utils import get_openapi
|
||
openapi_schema = get_openapi(
|
||
title=app.title,
|
||
version=app.version,
|
||
description=app.description,
|
||
routes=app.routes,
|
||
)
|
||
openapi_schema["components"]["securitySchemes"] = {
|
||
"APIKeyHeader": {
|
||
"type": "apiKey",
|
||
"in": "header",
|
||
"name": API_KEY_HEADER_NAME,
|
||
}
|
||
}
|
||
for path in openapi_schema["paths"]:
|
||
for method in openapi_schema["paths"][path]:
|
||
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||
app.openapi_schema = openapi_schema
|
||
return app.openapi_schema
|
||
|
||
app.openapi = custom_openapi
|
||
|
||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||
|
||
# ==========================================
|
||
# 7. API Endpoints (接口定义)
|
||
# ==========================================
|
||
|
||
# ------------------------------------------
|
||
# Group 1: General Segmentation (通用分割)
|
||
# ------------------------------------------
|
||
|
||
@app.post("/segment", tags=[TAG_GENERAL], dependencies=[Depends(verify_api_key)])
|
||
async def segment(
|
||
request: Request,
|
||
prompt: str = Form(..., description="Text prompt for segmentation (e.g., 'cat', 'person')"),
|
||
file: Optional[UploadFile] = File(None, description="Image file to upload"),
|
||
image_url: Optional[str] = Form(None, description="URL of the image"),
|
||
save_segment_images: bool = Form(False, description="Whether to save and return individual segmented objects"),
|
||
cutout: bool = Form(False, description="If True, returns transparent background PNGs; otherwise returns original crops"),
|
||
perspective_correction: bool = Form(False, description="If True, applies perspective correction (warping) to the segmented object."),
|
||
highlight: bool = Form(False, description="If True, darkens the background to highlight the subject (周边变黑放大)."),
|
||
confidence: float = Form(0.7, description="Confidence threshold (0.0-1.0). Default is 0.7.")
|
||
|
||
):
|
||
"""
|
||
**通用图像分割接口**
|
||
|
||
- 支持上传图片或提供图片 URL
|
||
- 支持自动将中文 Prompt 翻译为英文
|
||
- 支持周边变黑放大效果 (Highlight Mode)
|
||
- 支持手动设置置信度 (Confidence Threshold)
|
||
- 支持透视矫正 (Perspective Correction)
|
||
|
||
"""
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
start_time = time.time()
|
||
# 1. Prompt 处理
|
||
final_prompt = prompt
|
||
if not is_english(prompt):
|
||
try:
|
||
translated = translate_to_sam3_prompt(prompt)
|
||
if translated:
|
||
final_prompt = translated
|
||
except Exception as e:
|
||
print(f"Prompt翻译失败,使用原始Prompt: {e}")
|
||
|
||
print(f"最终使用的 Prompt: {final_prompt}")
|
||
|
||
# 2. 图片加载
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("general", prompt, "failed", details=f"Image Load Error: {str(e)}", final_prompt=final_prompt, duration=duration)
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
# 3. 模型推理
|
||
try:
|
||
# 设置图片
|
||
inference_state = processor.set_image(image)
|
||
|
||
# 处理 Confidence Threshold
|
||
# 如果用户提供了 confidence,临时修改 processor 的阈值
|
||
original_confidence = processor.confidence_threshold
|
||
if confidence is not None:
|
||
# 简单校验
|
||
if not (0.0 <= confidence <= 1.0):
|
||
raise HTTPException(status_code=400, detail="Confidence must be between 0.0 and 1.0")
|
||
processor.confidence_threshold = confidence
|
||
print(f"Using manual confidence threshold: {confidence}")
|
||
|
||
try:
|
||
# 执行推理
|
||
output = processor.set_text_prompt(state=inference_state, prompt=final_prompt)
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
finally:
|
||
# 恢复默认阈值,防止影响其他请求 (虽然在单进程模型下可能不重要,但保持状态一致性是个好习惯)
|
||
if confidence is not None:
|
||
processor.confidence_threshold = original_confidence
|
||
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("general", prompt, "failed", details=f"Inference Error: {str(e)}", final_prompt=final_prompt, duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
# 4. 结果可视化与保存
|
||
try:
|
||
if highlight:
|
||
filename = f"seg_highlight_{uuid.uuid4().hex}.jpg"
|
||
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
|
||
# 使用 human_analysis_service 中的可视化函数 (周边变黑)
|
||
human_analysis_service.create_highlighted_visualization(image, masks, save_path)
|
||
else:
|
||
filename = generate_and_save_result(image, inference_state)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("general", prompt, "failed", details=f"Save Error: {str(e)}", final_prompt=final_prompt, duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
||
|
||
file_url = request.url_for("static", path=f"results/{filename}")
|
||
|
||
# 5. 保存分割子图 (Optional)
|
||
saved_segments_info = []
|
||
if save_segment_images:
|
||
try:
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
saved_objects = crop_and_save_objects(
|
||
image,
|
||
masks,
|
||
boxes,
|
||
output_dir=output_dir,
|
||
is_tarot=False,
|
||
cutout=cutout,
|
||
perspective_correction=perspective_correction
|
||
)
|
||
|
||
for obj in saved_objects:
|
||
fname = obj["filename"]
|
||
seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||
saved_segments_info.append({
|
||
"url": seg_url,
|
||
"filename": fname,
|
||
"note": obj.get("note", "")
|
||
})
|
||
except Exception as e:
|
||
print(f"Error saving segments: {e}")
|
||
# Don't fail the whole request just for this part, but log it? Or fail? Usually fail.
|
||
duration = time.time() - start_time
|
||
append_to_history("general", prompt, "partial_success", result_path=f"results/{filename}", details="Segments save failed", final_prompt=final_prompt, duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}")
|
||
|
||
response_content = {
|
||
"status": "success",
|
||
"result_image_url": str(file_url),
|
||
"detected_count": len(masks),
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
}
|
||
|
||
if save_segment_images:
|
||
response_content["segmented_images"] = saved_segments_info
|
||
|
||
duration = time.time() - start_time
|
||
append_to_history("general", prompt, "success", result_path=f"results/{filename}", details=f"Detected: {len(masks)}", final_prompt=final_prompt, duration=duration)
|
||
return JSONResponse(content=response_content)
|
||
|
||
# ------------------------------------------
|
||
# Group 2: Tarot Analysis (塔罗牌分析)
|
||
# ------------------------------------------
|
||
|
||
@app.post("/segment_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
|
||
async def segment_tarot(
|
||
request: Request,
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None),
|
||
expected_count: int = Form(3)
|
||
):
|
||
"""
|
||
**塔罗牌分割接口**
|
||
|
||
1. 检测是否包含指定数量的塔罗牌 (默认为 3)
|
||
2. 对检测到的卡片进行透视矫正和裁剪
|
||
3. 返回矫正后的图片 URL
|
||
"""
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
start_time = time.time()
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Image Load Error: {str(e)}", duration=duration)
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
try:
|
||
inference_state = processor.set_image(image)
|
||
# 固定 Prompt 检测塔罗牌
|
||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Inference Error: {str(e)}", duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
# 核心逻辑:判断数量
|
||
detected_count = len(masks)
|
||
|
||
# 创建本次请求的独立文件夹
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
if detected_count != expected_count:
|
||
# 保存一张图用于调试/反馈
|
||
try:
|
||
filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||
file_url = request.url_for("static", path=f"results/{request_id}/{filename}")
|
||
except:
|
||
file_url = None
|
||
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{filename}" if file_url else None, details=f"Detected {detected_count} cards, expected {expected_count}", duration=duration)
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={
|
||
"status": "failed",
|
||
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
|
||
"detected_count": detected_count,
|
||
"debug_image_url": str(file_url) if file_url else None
|
||
}
|
||
)
|
||
|
||
# 数量正确,执行抠图
|
||
try:
|
||
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir, is_tarot=True, perspective_correction=True)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}", duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||
|
||
# 生成 URL 列表和元数据
|
||
tarot_cards = []
|
||
for obj in saved_objects:
|
||
fname = obj["filename"]
|
||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||
tarot_cards.append({
|
||
"url": file_url,
|
||
"is_rotated": obj["is_rotated_by_algorithm"],
|
||
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
|
||
"note": obj["note"]
|
||
})
|
||
|
||
# 生成整体效果图
|
||
try:
|
||
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
|
||
except:
|
||
main_file_url = None
|
||
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot", f"expected: {expected_count}", "success", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Successfully segmented {expected_count} cards", duration=duration)
|
||
return JSONResponse(content={
|
||
"status": "success",
|
||
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)",
|
||
"tarot_cards": tarot_cards,
|
||
"full_visualization": main_file_url,
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
})
|
||
|
||
@app.post("/recognize_tarot", tags=[TAG_TAROT], dependencies=[Depends(verify_api_key)])
|
||
async def recognize_tarot(
|
||
request: Request,
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None),
|
||
expected_count: int = Form(3)
|
||
):
|
||
"""
|
||
**塔罗牌全流程接口: 分割 + 矫正 + 识别**
|
||
|
||
1. 检测是否包含指定数量的塔罗牌 (SAM3)
|
||
2. 分割并透视矫正
|
||
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
|
||
"""
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
start_time = time.time()
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Image Load Error: {str(e)}", duration=duration)
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
try:
|
||
inference_state = processor.set_image(image)
|
||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Inference Error: {str(e)}", duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||
|
||
detected_count = len(masks)
|
||
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 保存整体效果图
|
||
try:
|
||
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||
main_file_path = os.path.join(output_dir, main_filename)
|
||
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
|
||
except:
|
||
main_filename = None
|
||
main_file_path = None
|
||
main_file_url = None
|
||
|
||
# Step 0: 牌阵识别
|
||
spread_info = {"spread_name": "Unknown"}
|
||
if main_file_path:
|
||
# 使用原始图的一份拷贝给 Qwen 识别牌阵
|
||
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
||
image.save(temp_raw_path)
|
||
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
||
|
||
if detected_count != expected_count:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Detected {detected_count}, expected {expected_count}", duration=duration)
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={
|
||
"status": "failed",
|
||
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
|
||
"detected_count": detected_count,
|
||
"spread_info": spread_info,
|
||
"debug_image_url": str(main_file_url) if main_file_url else None
|
||
}
|
||
)
|
||
|
||
# 数量正确,执行抠图 + 矫正
|
||
try:
|
||
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir, is_tarot=True, perspective_correction=True)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}", duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||
|
||
# 遍历每张卡片进行识别
|
||
tarot_cards = []
|
||
for obj in saved_objects:
|
||
fname = obj["filename"]
|
||
file_path = os.path.join(output_dir, fname)
|
||
|
||
# 调用 Qwen-VL 识别 (串行)
|
||
recognition_res = recognize_card_with_qwen(file_path)
|
||
|
||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||
tarot_cards.append({
|
||
"url": file_url,
|
||
"is_rotated": obj["is_rotated_by_algorithm"],
|
||
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
|
||
"recognition": recognition_res,
|
||
"note": obj["note"]
|
||
})
|
||
|
||
duration = time.time() - start_time
|
||
append_to_history("tarot-recognize", f"expected: {expected_count}", "success", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Spread: {spread_info.get('spread_name', 'Unknown')}", duration=duration)
|
||
return JSONResponse(content={
|
||
"status": "success",
|
||
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)",
|
||
"spread_info": spread_info,
|
||
"tarot_cards": tarot_cards,
|
||
"full_visualization": main_file_url,
|
||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||
})
|
||
|
||
# ------------------------------------------
|
||
# Group 3: Face Analysis (人脸分析)
|
||
# ------------------------------------------
|
||
|
||
@app.post("/segment_face", tags=[TAG_FACE], dependencies=[Depends(verify_api_key)])
|
||
async def segment_face(
|
||
request: Request,
|
||
file: Optional[UploadFile] = File(None),
|
||
image_url: Optional[str] = Form(None),
|
||
prompt: str = Form("face and hair", description="Prompt for face detection")
|
||
):
|
||
"""
|
||
**人脸/头部检测与属性分析接口**
|
||
|
||
1. 调用 SAM3 分割出头部区域 (含头发)
|
||
2. 裁剪并保存
|
||
3. 调用 Qwen-VL 识别性别和年龄
|
||
"""
|
||
if not file and not image_url:
|
||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||
|
||
start_time = time.time()
|
||
# Prompt 翻译/优化
|
||
final_prompt = prompt
|
||
if not is_english(prompt):
|
||
try:
|
||
translated = translate_to_sam3_prompt(prompt)
|
||
if translated:
|
||
final_prompt = translated
|
||
except Exception as e:
|
||
print(f"Prompt翻译失败,使用原始Prompt: {e}")
|
||
|
||
print(f"Face Segment 最终使用的 Prompt: {final_prompt}")
|
||
|
||
# 加载图片
|
||
try:
|
||
if file:
|
||
image = Image.open(file.file).convert("RGB")
|
||
elif image_url:
|
||
image = load_image_from_url(image_url)
|
||
except Exception as e:
|
||
duration = time.time() - start_time
|
||
append_to_history("face", prompt, "failed", details=f"Image Load Error: {str(e)}", final_prompt=final_prompt, duration=duration)
|
||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||
|
||
processor = request.app.state.processor
|
||
|
||
# 调用独立服务进行处理
|
||
try:
|
||
result = human_analysis_service.process_face_segmentation_and_analysis(
|
||
processor=processor,
|
||
image=image,
|
||
prompt=final_prompt,
|
||
output_base_dir=RESULT_IMAGE_DIR,
|
||
qwen_model=QWEN_MODEL,
|
||
analysis_prompt=PROMPTS["face_analysis"]
|
||
)
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
duration = time.time() - start_time
|
||
append_to_history("face", prompt, "failed", details=f"Process Error: {str(e)}", final_prompt=final_prompt, duration=duration)
|
||
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
|
||
|
||
# 补全 URL
|
||
if result["status"] == "success":
|
||
if result.get("full_visualization"):
|
||
full_vis_rel_path = result["full_visualization"]
|
||
result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path))
|
||
|
||
for item in result["results"]:
|
||
relative_path = item.pop("relative_path")
|
||
item["url"] = str(request.url_for("static", path=relative_path))
|
||
|
||
duration = time.time() - start_time
|
||
append_to_history("face", prompt, result["status"], details=f"Results: {len(result.get('results', []))}", final_prompt=final_prompt, duration=duration)
|
||
return JSONResponse(content=result)
|
||
|
||
# ==========================================
|
||
# 9. Admin Management APIs (管理后台接口)
|
||
# ==========================================
|
||
|
||
@app.get("/admin", include_in_schema=False)
|
||
async def admin_page():
|
||
"""
|
||
Serve the admin HTML page
|
||
"""
|
||
# 检查 static/admin.html 是否存在,不存在则返回错误或简易页面
|
||
admin_html_path = os.path.join(STATIC_DIR, "admin.html")
|
||
if os.path.exists(admin_html_path):
|
||
with open(admin_html_path, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
return HTMLResponse(content=content)
|
||
else:
|
||
return HTMLResponse(content="<h1>Admin page not found</h1>", status_code=404)
|
||
|
||
@app.post("/admin/login", include_in_schema=False)
|
||
async def admin_login(password: str = Form(...)):
|
||
"""
|
||
Simple Admin Login
|
||
"""
|
||
if password == ADMIN_PASSWORD:
|
||
content = {"status": "success", "message": "Logged in"}
|
||
response = JSONResponse(content=content)
|
||
# Set a simple cookie
|
||
response.set_cookie(key="admin_token", value="logged_in", httponly=True)
|
||
return response
|
||
else:
|
||
return JSONResponse(status_code=401, content={"status": "error", "message": "Invalid password"})
|
||
|
||
async def verify_admin(request: Request):
|
||
# Check cookie or header
|
||
token = request.cookies.get("admin_token")
|
||
# Also allow local dev without strict check if needed, but here we enforce password
|
||
if token != "logged_in":
|
||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||
|
||
@app.get("/admin/api/history", dependencies=[Depends(verify_admin)])
|
||
async def get_history():
|
||
"""
|
||
Get request history
|
||
"""
|
||
if not os.path.exists(HISTORY_FILE):
|
||
return []
|
||
|
||
records = []
|
||
try:
|
||
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if line.strip():
|
||
try:
|
||
records.append(json.loads(line))
|
||
except:
|
||
pass
|
||
# Limit to last 100 records
|
||
return records[-100:]
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
|
||
@app.get("/admin/api/files", dependencies=[Depends(verify_admin)])
|
||
async def list_files(path: str = ""):
|
||
"""
|
||
List files in static/results
|
||
"""
|
||
# Security check: prevent directory traversal
|
||
if ".." in path or path.startswith("/"):
|
||
raise HTTPException(status_code=400, detail="Invalid path")
|
||
|
||
target_dir = os.path.join(RESULT_IMAGE_DIR, path)
|
||
if not os.path.exists(target_dir):
|
||
return []
|
||
|
||
items = []
|
||
try:
|
||
for entry in os.scandir(target_dir):
|
||
is_dir = entry.is_dir()
|
||
item = {
|
||
"name": entry.name,
|
||
"is_dir": is_dir,
|
||
"path": os.path.join(path, entry.name),
|
||
"mtime": entry.stat().st_mtime
|
||
}
|
||
if is_dir:
|
||
try:
|
||
item["count"] = len(os.listdir(entry.path))
|
||
except:
|
||
item["count"] = 0
|
||
else:
|
||
item["size"] = entry.stat().st_size
|
||
# Construct URL
|
||
# Assuming static mount is /static
|
||
# path is relative to results/
|
||
# so url is /static/results/path/name
|
||
rel_path = os.path.join("results", path, entry.name)
|
||
# Ensure forward slashes for URL
|
||
if os.sep != "/":
|
||
rel_path = rel_path.replace(os.sep, "/")
|
||
item["url"] = f"/static/{rel_path}"
|
||
|
||
items.append(item)
|
||
|
||
# Sort: Directories first, then by time (newest first)
|
||
items.sort(key=lambda x: (not x["is_dir"], -x["mtime"]))
|
||
return items
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.delete("/admin/api/files/{file_path:path}", dependencies=[Depends(verify_admin)])
|
||
async def delete_file(file_path: str):
|
||
"""
|
||
Delete file or directory
|
||
"""
|
||
# Security check
|
||
if ".." in file_path:
|
||
raise HTTPException(status_code=400, detail="Invalid path")
|
||
|
||
# file_path is relative to static/results/ (or passed as full relative path from API)
|
||
# The API is called with relative path from current view
|
||
|
||
target_path = os.path.join(RESULT_IMAGE_DIR, file_path)
|
||
|
||
if not os.path.exists(target_path):
|
||
raise HTTPException(status_code=404, detail="Not found")
|
||
|
||
try:
|
||
if os.path.isdir(target_path):
|
||
shutil.rmtree(target_path)
|
||
else:
|
||
os.remove(target_path)
|
||
return {"status": "success", "message": f"Deleted {file_path}"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.post("/admin/api/cleanup", dependencies=[Depends(verify_admin)])
|
||
async def trigger_cleanup():
|
||
"""
|
||
Manually trigger cleanup
|
||
"""
|
||
try:
|
||
# Use global config
|
||
lifetime = CLEANUP_CONFIG["lifetime"]
|
||
|
||
count = 0
|
||
current_time = time.time()
|
||
for root, dirs, files in os.walk(RESULT_IMAGE_DIR):
|
||
for file in files:
|
||
file_path = os.path.join(root, file)
|
||
try:
|
||
file_mtime = os.path.getmtime(file_path)
|
||
if current_time - file_mtime > lifetime:
|
||
os.remove(file_path)
|
||
count += 1
|
||
except:
|
||
pass
|
||
|
||
# Cleanup empty dirs
|
||
for root, dirs, files in os.walk(RESULT_IMAGE_DIR, topdown=False):
|
||
for dir in dirs:
|
||
dir_path = os.path.join(root, dir)
|
||
try:
|
||
if not os.listdir(dir_path):
|
||
os.rmdir(dir_path)
|
||
except:
|
||
pass
|
||
|
||
return {"status": "success", "message": f"Cleaned {count} files (Lifetime: {lifetime}s)"}
|
||
except Exception as e:
|
||
return {"status": "error", "message": str(e)}
|
||
|
||
@app.get("/admin/api/config", dependencies=[Depends(verify_admin)])
|
||
async def get_config(request: Request):
|
||
"""
|
||
Get system config info
|
||
"""
|
||
device_str = "Unknown"
|
||
gpu_status = {}
|
||
|
||
if hasattr(request.app.state, "device"):
|
||
device_str = str(request.app.state.device)
|
||
|
||
# 获取 GPU 详细信息
|
||
if torch.cuda.is_available():
|
||
try:
|
||
device_id = torch.cuda.current_device()
|
||
props = torch.cuda.get_device_properties(device_id)
|
||
|
||
total_mem = props.total_memory
|
||
reserved_mem = torch.cuda.memory_reserved(device_id)
|
||
allocated_mem = torch.cuda.memory_allocated(device_id)
|
||
|
||
gpu_status = {
|
||
"available": True,
|
||
"name": props.name,
|
||
"total_memory": f"{total_mem / 1024**3:.2f} GB",
|
||
"reserved_memory": f"{reserved_mem / 1024**3:.2f} GB",
|
||
"allocated_memory": f"{allocated_mem / 1024**3:.2f} GB",
|
||
"memory_usage_percent": round((reserved_mem / total_mem) * 100, 1)
|
||
}
|
||
except Exception as e:
|
||
gpu_status = {"available": True, "error": str(e)}
|
||
else:
|
||
gpu_status = {"available": False, "reason": "No CUDA device detected"}
|
||
|
||
return {
|
||
"device": device_str,
|
||
"gpu_status": gpu_status,
|
||
"cleanup_config": CLEANUP_CONFIG,
|
||
"current_qwen_model": QWEN_MODEL,
|
||
"available_qwen_models": AVAILABLE_QWEN_MODELS
|
||
}
|
||
|
||
@app.post("/admin/api/config/cleanup", dependencies=[Depends(verify_admin)])
|
||
async def update_cleanup_config(
|
||
enabled: bool = Form(...),
|
||
lifetime: int = Form(...),
|
||
interval: int = Form(...)
|
||
):
|
||
"""
|
||
Update cleanup configuration
|
||
"""
|
||
global CLEANUP_CONFIG
|
||
CLEANUP_CONFIG["enabled"] = enabled
|
||
CLEANUP_CONFIG["lifetime"] = lifetime
|
||
CLEANUP_CONFIG["interval"] = interval
|
||
|
||
print(f"Updated Cleanup Config: {CLEANUP_CONFIG}")
|
||
return {"status": "success", "message": "Cleanup configuration updated", "config": CLEANUP_CONFIG}
|
||
|
||
@app.post("/admin/api/config/model", dependencies=[Depends(verify_admin)])
|
||
async def set_model(model: str = Form(...)):
|
||
"""
|
||
Set the Qwen model
|
||
"""
|
||
global QWEN_MODEL
|
||
if model not in AVAILABLE_QWEN_MODELS:
|
||
raise HTTPException(status_code=400, detail="Invalid model")
|
||
|
||
QWEN_MODEL = model
|
||
return {"status": "success", "message": f"Model switched to {model}", "current_model": QWEN_MODEL}
|
||
|
||
@app.get("/admin/api/prompts", dependencies=[Depends(verify_admin)])
|
||
async def get_prompts():
|
||
"""
|
||
Get all prompts
|
||
"""
|
||
return PROMPTS
|
||
|
||
@app.post("/admin/api/prompts", dependencies=[Depends(verify_admin)])
|
||
async def update_prompts(
|
||
key: str = Form(...),
|
||
content: str = Form(...)
|
||
):
|
||
"""
|
||
Update a specific prompt
|
||
"""
|
||
if key not in PROMPTS:
|
||
raise HTTPException(status_code=400, detail="Invalid prompt key")
|
||
|
||
PROMPTS[key] = content
|
||
return {"status": "success", "message": f"Prompt '{key}' updated"}
|
||
|
||
# ------------------------------------------
|
||
# GPU Status Helper & API
|
||
# ------------------------------------------
|
||
|
||
def get_gpu_status_smi():
|
||
"""
|
||
Get detailed GPU status using nvidia-smi
|
||
Returns: dict with utilization, memory, temp, power, etc.
|
||
"""
|
||
cuda_version = "Unknown"
|
||
try:
|
||
import torch
|
||
if torch.version.cuda:
|
||
cuda_version = torch.version.cuda
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
# Check if nvidia-smi is available
|
||
# Fields: utilization.gpu, utilization.memory, temperature.gpu, power.draw, power.limit, memory.total, memory.used, memory.free, name, driver_version
|
||
result = subprocess.run(
|
||
['nvidia-smi', '--query-gpu=utilization.gpu,utilization.memory,temperature.gpu,power.draw,power.limit,memory.total,memory.used,memory.free,name,driver_version', '--format=csv,noheader,nounits'],
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
encoding='utf-8'
|
||
)
|
||
|
||
if result.returncode != 0:
|
||
raise Exception("nvidia-smi failed")
|
||
|
||
# Parse the first line (assuming single GPU for now, or take the first one)
|
||
line = result.stdout.strip().split('\n')[0]
|
||
vals = [x.strip() for x in line.split(',')]
|
||
|
||
return {
|
||
"available": True,
|
||
"gpu_util": float(vals[0]), # %
|
||
"mem_util": float(vals[1]), # % (controller utilization)
|
||
"temperature": float(vals[2]), # C
|
||
"power_draw": float(vals[3]), # W
|
||
"power_limit": float(vals[4]), # W
|
||
"mem_total": float(vals[5]), # MB
|
||
"mem_used": float(vals[6]), # MB
|
||
"mem_free": float(vals[7]), # MB
|
||
"name": vals[8],
|
||
"driver_version": vals[9],
|
||
"cuda_version": cuda_version,
|
||
"source": "nvidia-smi",
|
||
"timestamp": time.time()
|
||
}
|
||
except Exception as e:
|
||
# Fallback to torch if available
|
||
if torch.cuda.is_available():
|
||
try:
|
||
device_id = torch.cuda.current_device()
|
||
props = torch.cuda.get_device_properties(device_id)
|
||
mem_reserved = torch.cuda.memory_reserved(device_id) / 1024**2 # MB
|
||
mem_total = props.total_memory / 1024**2 # MB
|
||
|
||
return {
|
||
"available": True,
|
||
"gpu_util": 0, # Torch can't get this easily
|
||
"mem_util": (mem_reserved / mem_total) * 100,
|
||
"temperature": 0,
|
||
"power_draw": 0,
|
||
"power_limit": 0,
|
||
"mem_total": mem_total,
|
||
"mem_used": mem_reserved,
|
||
"mem_free": mem_total - mem_reserved,
|
||
"name": props.name,
|
||
"driver_version": "Unknown",
|
||
"cuda_version": cuda_version,
|
||
"source": "torch",
|
||
"timestamp": time.time()
|
||
}
|
||
except:
|
||
pass
|
||
|
||
return {"available": False, "error": str(e)}
|
||
|
||
@app.get("/admin/api/gpu/status", dependencies=[Depends(verify_admin)])
|
||
async def get_gpu_status_api():
|
||
"""
|
||
Get real-time GPU status
|
||
"""
|
||
return get_gpu_status_smi()
|
||
|
||
# ==========================================
|
||
# 10. Main Entry Point (启动入口)
|
||
# ==========================================
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
# ==========================================
|
||
# 自动清理图片配置 (Auto Cleanup Config)
|
||
# ==========================================
|
||
# 设置是否开启自动清理 (True/False)
|
||
os.environ["AUTO_CLEANUP_ENABLED"] = "True"
|
||
|
||
# 设置图片生命周期 (秒),超过此时间的图片将被删除
|
||
# 例如: 3600 = 1小时, 86400 = 1天
|
||
os.environ["FILE_LIFETIME_SECONDS"] = "3600"
|
||
|
||
# 设置检查间隔 (秒),每隔多久检查一次
|
||
os.environ["CLEANUP_INTERVAL_SECONDS"] = "600"
|
||
# ==========================================
|
||
|
||
# 启动服务器
|
||
uvicorn.run(
|
||
"fastAPI_tarot:app",
|
||
host="127.0.0.1",
|
||
port=55600,
|
||
proxy_headers=True,
|
||
forwarded_allow_ips="*",
|
||
reload=False
|
||
)
|