diff --git a/fastAPI_nocom.py b/fastAPI_nocom.py deleted file mode 100644 index 515105f..0000000 --- a/fastAPI_nocom.py +++ /dev/null @@ -1,191 +0,0 @@ -import os -import uuid -import requests -from typing import Optional -from contextlib import asynccontextmanager - -import torch -import matplotlib -matplotlib.use('Agg') -import matplotlib.pyplot as plt - -from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status -from fastapi.security import APIKeyHeader -from fastapi.staticfiles import StaticFiles -from fastapi.responses import JSONResponse -from PIL import Image - -from sam3.model_builder import build_sam3_image_model -from sam3.model.sam3_image_processor import Sam3Processor -from sam3.visualization_utils import plot_results - -# ------------------- 配置与路径 ------------------- -STATIC_DIR = "static" -RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") -os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) - -# ------------------- API Key 核心配置 (已加固) ------------------- -VALID_API_KEY = "123quant-speed" -API_KEY_HEADER_NAME = "X-API-Key" - -# 定义 Header 认证 -api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) - -async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): - """ - 强制验证 API Key - """ - # 1. 检查是否有 Key - if not api_key: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Missing API Key. Please provide it in the header." - ) - # 2. 检查 Key 是否正确 - if api_key != VALID_API_KEY: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Invalid API Key." - ) - # 3. 验证通过 - return True - -# ------------------- 生命周期管理 ------------------- -@asynccontextmanager -async def lifespan(app: FastAPI): - print("="*40) - print("✅ API Key 保护已激活") - print(f"✅ 有效 Key: {VALID_API_KEY}") - print("="*40) - - print("正在加载 SAM3 模型到 GPU...") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if not torch.cuda.is_available(): - print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") - - model = build_sam3_image_model() - model = model.to(device) - model.eval() - - processor = Sam3Processor(model) - - app.state.model = model - app.state.processor = processor - app.state.device = device - - print(f"模型加载完成,设备: {device}") - - yield - - print("正在清理资源...") - -# ------------------- FastAPI 初始化 ------------------- -app = FastAPI( - lifespan=lifespan, - title="SAM3 Segmentation API", - description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`", -) - -# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效 -app.openapi_schema = None -def custom_openapi(): - if app.openapi_schema: - return app.openapi_schema - from fastapi.openapi.utils import get_openapi - openapi_schema = get_openapi( - title=app.title, - version=app.version, - description=app.description, - routes=app.routes, - ) - # 定义安全方案 - openapi_schema["components"]["securitySchemes"] = { - "APIKeyHeader": { - "type": "apiKey", - "in": "header", - "name": API_KEY_HEADER_NAME, - } - } - # 为所有路径应用安全要求 - for path in openapi_schema["paths"]: - for method in openapi_schema["paths"][path]: - openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] - app.openapi_schema = openapi_schema - return app.openapi_schema - -app.openapi = custom_openapi - -app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") - -# ------------------- 辅助函数 ------------------- -def load_image_from_url(url: str) -> Image.Image: - try: - headers = {'User-Agent': 'Mozilla/5.0'} - response = requests.get(url, headers=headers, stream=True, timeout=10) - response.raise_for_status() - image = Image.open(response.raw).convert("RGB") - return image - except Exception as e: - raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") - -def generate_and_save_result(image: Image.Image, inference_state) -> str: - filename = f"seg_{uuid.uuid4().hex}.jpg" - save_path = os.path.join(RESULT_IMAGE_DIR, filename) - plot_results(image, inference_state) - plt.savefig(save_path, dpi=150, bbox_inches='tight') - plt.close() - return filename - -# ------------------- API 接口 (强制依赖验证) ------------------- -@app.post("/segment", dependencies=[Depends(verify_api_key)]) -async def segment( - request: Request, - prompt: str = Form(...), - file: Optional[UploadFile] = File(None), - image_url: Optional[str] = Form(None) -): - if not file and not image_url: - raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") - - try: - if file: - image = Image.open(file.file).convert("RGB") - elif image_url: - image = load_image_from_url(image_url) - except Exception as e: - raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") - - processor = request.app.state.processor - - try: - inference_state = processor.set_image(image) - output = processor.set_text_prompt(state=inference_state, prompt=prompt) - masks, boxes, scores = output["masks"], output["boxes"], output["scores"] - except Exception as e: - raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") - - try: - filename = generate_and_save_result(image, inference_state) - except Exception as e: - raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") - - file_url = request.url_for("static", path=f"results/{filename}") - - return JSONResponse(content={ - "status": "success", - "result_image_url": str(file_url), - "detected_count": len(masks), - "scores": scores.tolist() if torch.is_tensor(scores) else scores - }) - -if __name__ == "__main__": - import uvicorn - # 注意:如果你的文件名不是 fastAPI_nocom.py,请修改下面第一个参数 - uvicorn.run( - "fastAPI_nocom:app", - host="127.0.0.1", - port=55600, - proxy_headers=True, - forwarded_allow_ips="*", - reload=False # 生产环境建议关闭 reload,确保代码完全重载 - ) \ No newline at end of file diff --git a/fastAPI_tarot.py b/fastAPI_tarot.py index da7d319..3d3f52b 100644 --- a/fastAPI_tarot.py +++ b/fastAPI_tarot.py @@ -73,6 +73,13 @@ dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58' QWEN_MODEL = 'qwen-vl-max' # Default model AVAILABLE_QWEN_MODELS = ["qwen-vl-max", "qwen-vl-plus"] +# 清理配置 (Cleanup Config) +CLEANUP_CONFIG = { + "enabled": os.getenv("AUTO_CLEANUP_ENABLED", "True").lower() == "true", + "lifetime": int(os.getenv("FILE_LIFETIME_SECONDS", "3600")), + "interval": int(os.getenv("CLEANUP_INTERVAL_SECONDS", "600")) +} + # API Tags (用于文档分类) TAG_GENERAL = "General Segmentation (通用分割)" TAG_TAROT = "Tarot Analysis (塔罗牌分析)" @@ -108,14 +115,24 @@ async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): # 4. Lifespan Management (生命周期管理) # ========================================== -async def cleanup_old_files(directory: str, lifetime_seconds: int, interval_seconds: int): +async def cleanup_old_files(directory: str): """ 后台任务:定期清理过期的图片文件 + - 动态读取 CLEANUP_CONFIG 配置 """ - print(f"🧹 自动清理任务已启动 | 目录: {directory} | 生命周期: {lifetime_seconds}s | 检查间隔: {interval_seconds}s") + print(f"🧹 自动清理任务已启动 | 目录: {directory}") while True: try: - await asyncio.sleep(interval_seconds) + # 动态读取配置 + interval = CLEANUP_CONFIG["interval"] + lifetime = CLEANUP_CONFIG["lifetime"] + enabled = CLEANUP_CONFIG["enabled"] + + await asyncio.sleep(interval) + + if not enabled: + continue + current_time = time.time() count = 0 # 遍历所有文件(包括子目录) @@ -125,7 +142,7 @@ async def cleanup_old_files(directory: str, lifetime_seconds: int, interval_seco # 获取文件修改时间 try: file_mtime = os.path.getmtime(file_path) - if current_time - file_mtime > lifetime_seconds: + if current_time - file_mtime > lifetime: os.remove(file_path) count += 1 except OSError: @@ -142,7 +159,7 @@ async def cleanup_old_files(directory: str, lifetime_seconds: int, interval_seco pass if count > 0: - print(f"🧹 已清理 {count} 个过期文件") + print(f"🧹 已清理 {count} 个过期文件 (Lifetime: {lifetime}s)") except asyncio.CancelledError: print("🛑 清理任务已停止") @@ -185,16 +202,12 @@ async def lifespan(app: FastAPI): # --- 启动后台清理任务 --- cleanup_task_handle = None - # 优先读取环境变量,否则使用默认值 - if os.getenv("AUTO_CLEANUP_ENABLED", "False").lower() == "true": - try: - lifetime = int(os.getenv("FILE_LIFETIME_SECONDS", "3600")) - interval = int(os.getenv("CLEANUP_INTERVAL_SECONDS", "600")) - cleanup_task_handle = asyncio.create_task( - cleanup_old_files(RESULT_IMAGE_DIR, lifetime, interval) - ) - except Exception as e: - print(f"启动清理任务失败: {e}") + try: + cleanup_task_handle = asyncio.create_task( + cleanup_old_files(RESULT_IMAGE_DIR) + ) + except Exception as e: + print(f"启动清理任务失败: {e}") # ----------------------- yield @@ -1180,9 +1193,8 @@ async def trigger_cleanup(): Manually trigger cleanup """ try: - # Re-use logic from cleanup_old_files but force it for all files > 0 seconds if we want deep clean? - # Or just use the standard lifetime. Let's use standard lifetime but run it now. - lifetime = int(os.getenv("FILE_LIFETIME_SECONDS", "3600")) + # Use global config + lifetime = CLEANUP_CONFIG["lifetime"] count = 0 current_time = time.time() @@ -1207,7 +1219,7 @@ async def trigger_cleanup(): except: pass - return {"status": "success", "message": f"Cleaned {count} files"} + return {"status": "success", "message": f"Cleaned {count} files (Lifetime: {lifetime}s)"} except Exception as e: return {"status": "error", "message": str(e)} @@ -1222,13 +1234,28 @@ async def get_config(request: Request): return { "device": device, - "cleanup_enabled": os.getenv("AUTO_CLEANUP_ENABLED"), - "file_lifetime": os.getenv("FILE_LIFETIME_SECONDS"), - "cleanup_interval": os.getenv("CLEANUP_INTERVAL_SECONDS"), + "cleanup_config": CLEANUP_CONFIG, "current_qwen_model": QWEN_MODEL, "available_qwen_models": AVAILABLE_QWEN_MODELS } +@app.post("/admin/api/config/cleanup", dependencies=[Depends(verify_admin)]) +async def update_cleanup_config( + enabled: bool = Form(...), + lifetime: int = Form(...), + interval: int = Form(...) +): + """ + Update cleanup configuration + """ + global CLEANUP_CONFIG + CLEANUP_CONFIG["enabled"] = enabled + CLEANUP_CONFIG["lifetime"] = lifetime + CLEANUP_CONFIG["interval"] = interval + + print(f"Updated Cleanup Config: {CLEANUP_CONFIG}") + return {"status": "success", "message": "Cleanup configuration updated", "config": CLEANUP_CONFIG} + @app.post("/admin/api/config/model", dependencies=[Depends(verify_admin)]) async def set_model(model: str = Form(...)): """ diff --git a/static/admin.html b/static/admin.html index cfcca57..aec4f5a 100644 --- a/static/admin.html +++ b/static/admin.html @@ -171,24 +171,38 @@

自动清理配置

-
-
- 状态 - 运行中 +
+
+ 启用自动清理 +
-
- 文件保留时长 - 3600 秒 (1小时) + +
+
+ 文件保留时长 (秒) + {{ (cleanupConfig.lifetime / 3600).toFixed(1) }} 小时 +
+
-
- 检查间隔 - 600 秒 (10分钟) + +
+
+ 检查间隔 (秒) + {{ (cleanupConfig.interval / 60).toFixed(1) }} 分钟 +
+
-
- + -

将删除所有超过保留时长的文件

@@ -247,6 +261,11 @@ const deviceInfo = ref('Loading...'); const currentModel = ref(''); const availableModels = ref([]); + const cleanupConfig = ref({ + enabled: true, + lifetime: 3600, + interval: 600 + }); // 检查登录状态 const checkLogin = () => { @@ -306,11 +325,28 @@ deviceInfo.value = res.data.device; currentModel.value = res.data.current_qwen_model; availableModels.value = res.data.available_qwen_models; + if (res.data.cleanup_config) { + cleanupConfig.value = res.data.cleanup_config; + } } catch (e) { console.error(e); } }; + const saveCleanupConfig = async () => { + try { + const formData = new FormData(); + formData.append('enabled', cleanupConfig.value.enabled); + formData.append('lifetime', cleanupConfig.value.lifetime); + formData.append('interval', cleanupConfig.value.interval); + + const res = await axios.post('/admin/api/config/cleanup', formData); + alert(res.data.message); + } catch (e) { + alert('保存配置失败: ' + (e.response?.data?.detail || e.message)); + } + }; + const updateModel = async () => { try { const formData = new FormData(); @@ -415,7 +451,8 @@ enterDir, navigateUp, deleteFile, triggerCleanup, viewResult, previewImage, isImage, previewUrl, formatDate, getTypeBadgeClass, cleaning, deviceInfo, - currentModel, availableModels, updateModel + currentModel, availableModels, updateModel, + cleanupConfig, saveCleanupConfig }; } }).mount('#app');