This commit is contained in:
2026-02-17 12:03:18 +08:00
parent b13f6df90e
commit 0f72cf7917
3 changed files with 796 additions and 4 deletions

View File

@@ -20,6 +20,8 @@ import json
import traceback
import re
import asyncio
import shutil
from datetime import datetime
from typing import Optional, List, Dict, Any
from contextlib import asynccontextmanager
@@ -34,10 +36,10 @@ import matplotlib.pyplot as plt
from PIL import Image
# FastAPI Imports
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status, APIRouter, Cookie
from fastapi.security import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, HTMLResponse, Response
# Dashscope (Aliyun Qwen) Imports
import dashscope
@@ -62,6 +64,10 @@ os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
VALID_API_KEY = "123quant-speed"
API_KEY_HEADER_NAME = "X-API-Key"
# Admin Config
ADMIN_PASSWORD = "admin_secure_password" # 可以根据需求修改
HISTORY_FILE = "history.json"
# Dashscope (Qwen-VL) 配置
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
QWEN_MODEL = 'qwen-vl-max'
@@ -219,6 +225,25 @@ def is_english(text: str) -> bool:
return False
return True
def append_to_history(req_type: str, prompt: str, status: str, result_path: str = None, details: str = ""):
"""
记录请求历史到 history.json
"""
record = {
"timestamp": time.time(),
"type": req_type,
"prompt": prompt,
"status": status,
"result_path": result_path,
"details": details
}
try:
with open(HISTORY_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
except Exception as e:
print(f"Failed to write history: {e}")
def translate_to_sam3_prompt(text: str) -> str:
"""
使用 Qwen 模型将中文提示词翻译为英文
@@ -640,6 +665,7 @@ async def segment(
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
append_to_history("general", prompt, "failed", details=f"Image Load Error: {str(e)}")
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
@@ -669,6 +695,7 @@ async def segment(
processor.confidence_threshold = original_confidence
except Exception as e:
append_to_history("general", prompt, "failed", details=f"Inference Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 4. 结果可视化与保存
@@ -681,6 +708,7 @@ async def segment(
else:
filename = generate_and_save_result(image, inference_state)
except Exception as e:
append_to_history("general", prompt, "failed", details=f"Save Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
file_url = request.url_for("static", path=f"results/{filename}")
@@ -712,6 +740,8 @@ async def segment(
})
except Exception as e:
print(f"Error saving segments: {e}")
# Don't fail the whole request just for this part, but log it? Or fail? Usually fail.
append_to_history("general", prompt, "partial_success", result_path=f"results/{filename}", details="Segments save failed")
raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}")
response_content = {
@@ -724,6 +754,7 @@ async def segment(
if save_segment_images:
response_content["segmented_images"] = saved_segments_info
append_to_history("general", prompt, "success", result_path=f"results/{filename}", details=f"Detected: {len(masks)}")
return JSONResponse(content=response_content)
# ------------------------------------------
@@ -753,6 +784,7 @@ async def segment_tarot(
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Image Load Error: {str(e)}")
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
@@ -763,6 +795,7 @@ async def segment_tarot(
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Inference Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
# 核心逻辑:判断数量
@@ -781,6 +814,7 @@ async def segment_tarot(
except:
file_url = None
append_to_history("tarot", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{filename}" if file_url else None, details=f"Detected {detected_count} cards, expected {expected_count}")
return JSONResponse(
status_code=400,
content={
@@ -795,6 +829,7 @@ async def segment_tarot(
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
append_to_history("tarot", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 生成 URL 列表和元数据
@@ -816,6 +851,7 @@ async def segment_tarot(
except:
main_file_url = None
append_to_history("tarot", f"expected: {expected_count}", "success", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Successfully segmented {expected_count} cards")
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (已执行透视矫正)",
@@ -847,6 +883,7 @@ async def recognize_tarot(
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Image Load Error: {str(e)}")
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
@@ -856,6 +893,7 @@ async def recognize_tarot(
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
except Exception as e:
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Inference Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
detected_count = len(masks)
@@ -883,6 +921,7 @@ async def recognize_tarot(
spread_info = recognize_spread_with_qwen(temp_raw_path)
if detected_count != expected_count:
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Detected {detected_count}, expected {expected_count}")
return JSONResponse(
status_code=400,
content={
@@ -898,6 +937,7 @@ async def recognize_tarot(
try:
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
except Exception as e:
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
# 遍历每张卡片进行识别
@@ -918,6 +958,7 @@ async def recognize_tarot(
"note": obj["note"]
})
append_to_history("tarot-recognize", f"expected: {expected_count}", "success", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Spread: {spread_info.get('spread_name', 'Unknown')}")
return JSONResponse(content={
"status": "success",
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)",
@@ -967,6 +1008,7 @@ async def segment_face(
elif image_url:
image = load_image_from_url(image_url)
except Exception as e:
append_to_history("face", prompt, "failed", details=f"Image Load Error: {str(e)}")
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
processor = request.app.state.processor
@@ -982,6 +1024,7 @@ async def segment_face(
except Exception as e:
import traceback
traceback.print_exc()
append_to_history("face", prompt, "failed", details=f"Process Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
# 补全 URL
@@ -993,11 +1036,197 @@ async def segment_face(
for item in result["results"]:
relative_path = item.pop("relative_path")
item["url"] = str(request.url_for("static", path=relative_path))
append_to_history("face", prompt, result["status"], details=f"Results: {len(result.get('results', []))}")
return JSONResponse(content=result)
# ==========================================
# 8. Main Entry Point (启动入口)
# 9. Admin Management APIs (管理后台接口)
# ==========================================
@app.get("/admin", include_in_schema=False)
async def admin_page():
"""
Serve the admin HTML page
"""
# 检查 static/admin.html 是否存在,不存在则返回错误或简易页面
admin_html_path = os.path.join(STATIC_DIR, "admin.html")
if os.path.exists(admin_html_path):
with open(admin_html_path, "r", encoding="utf-8") as f:
content = f.read()
return HTMLResponse(content=content)
else:
return HTMLResponse(content="<h1>Admin page not found</h1>", status_code=404)
@app.post("/admin/login", include_in_schema=False)
async def admin_login(password: str = Form(...)):
"""
Simple Admin Login
"""
if password == ADMIN_PASSWORD:
content = {"status": "success", "message": "Logged in"}
response = JSONResponse(content=content)
# Set a simple cookie
response.set_cookie(key="admin_token", value="logged_in", httponly=True)
return response
else:
return JSONResponse(status_code=401, content={"status": "error", "message": "Invalid password"})
async def verify_admin(request: Request):
# Check cookie or header
token = request.cookies.get("admin_token")
# Also allow local dev without strict check if needed, but here we enforce password
if token != "logged_in":
raise HTTPException(status_code=401, detail="Unauthorized")
@app.get("/admin/api/history", dependencies=[Depends(verify_admin)])
async def get_history():
"""
Get request history
"""
if not os.path.exists(HISTORY_FILE):
return []
records = []
try:
with open(HISTORY_FILE, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
try:
records.append(json.loads(line))
except:
pass
# Limit to last 100 records
return records[-100:]
except Exception as e:
return {"error": str(e)}
@app.get("/admin/api/files", dependencies=[Depends(verify_admin)])
async def list_files(path: str = ""):
"""
List files in static/results
"""
# Security check: prevent directory traversal
if ".." in path or path.startswith("/"):
raise HTTPException(status_code=400, detail="Invalid path")
target_dir = os.path.join(RESULT_IMAGE_DIR, path)
if not os.path.exists(target_dir):
return []
items = []
try:
for entry in os.scandir(target_dir):
is_dir = entry.is_dir()
item = {
"name": entry.name,
"is_dir": is_dir,
"path": os.path.join(path, entry.name),
"mtime": entry.stat().st_mtime
}
if is_dir:
try:
item["count"] = len(os.listdir(entry.path))
except:
item["count"] = 0
else:
item["size"] = entry.stat().st_size
# Construct URL
# Assuming static mount is /static
# path is relative to results/
# so url is /static/results/path/name
rel_path = os.path.join("results", path, entry.name)
item["url"] = f"/static/{rel_path}"
items.append(item)
# Sort: Directories first, then by time (newest first)
items.sort(key=lambda x: (not x["is_dir"], -x["mtime"]))
return items
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/admin/api/files/{file_path:path}", dependencies=[Depends(verify_admin)])
async def delete_file(file_path: str):
"""
Delete file or directory
"""
# Security check
if ".." in file_path:
raise HTTPException(status_code=400, detail="Invalid path")
# file_path is relative to static/results/ (or passed as full relative path from API)
# The API is called with relative path from current view
target_path = os.path.join(RESULT_IMAGE_DIR, file_path)
if not os.path.exists(target_path):
raise HTTPException(status_code=404, detail="Not found")
try:
if os.path.isdir(target_path):
shutil.rmtree(target_path)
else:
os.remove(target_path)
return {"status": "success", "message": f"Deleted {file_path}"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/admin/api/cleanup", dependencies=[Depends(verify_admin)])
async def trigger_cleanup():
"""
Manually trigger cleanup
"""
try:
# Re-use logic from cleanup_old_files but force it for all files > 0 seconds if we want deep clean?
# Or just use the standard lifetime. Let's use standard lifetime but run it now.
lifetime = int(os.getenv("FILE_LIFETIME_SECONDS", "3600"))
count = 0
current_time = time.time()
for root, dirs, files in os.walk(RESULT_IMAGE_DIR):
for file in files:
file_path = os.path.join(root, file)
try:
file_mtime = os.path.getmtime(file_path)
if current_time - file_mtime > lifetime:
os.remove(file_path)
count += 1
except:
pass
# Cleanup empty dirs
for root, dirs, files in os.walk(RESULT_IMAGE_DIR, topdown=False):
for dir in dirs:
dir_path = os.path.join(root, dir)
try:
if not os.listdir(dir_path):
os.rmdir(dir_path)
except:
pass
return {"status": "success", "message": f"Cleaned {count} files"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.get("/admin/api/config", dependencies=[Depends(verify_admin)])
async def get_config(request: Request):
"""
Get system config info
"""
device = "Unknown"
if hasattr(request.app.state, "device"):
device = str(request.app.state.device)
return {
"device": device,
"cleanup_enabled": os.getenv("AUTO_CLEANUP_ENABLED"),
"file_lifetime": os.getenv("FILE_LIFETIME_SECONDS"),
"cleanup_interval": os.getenv("CLEANUP_INTERVAL_SECONDS")
}
# ==========================================
# 10. Main Entry Point (启动入口)
# ==========================================
if __name__ == "__main__":