diff --git a/fastAPI_nocom.py b/fastAPI_nocom.py index 7d7d3ac..515105f 100644 --- a/fastAPI_nocom.py +++ b/fastAPI_nocom.py @@ -6,16 +6,15 @@ from contextlib import asynccontextmanager import torch import matplotlib -# 关键:设置非交互式后端,避免服务器环境下报错 matplotlib.use('Agg') import matplotlib.pyplot as plt -from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status +from fastapi.security import APIKeyHeader from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from PIL import Image -# SAM3 相关导入 (请确保你的环境中已正确安装 sam3) from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results @@ -25,48 +24,101 @@ STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) +# ------------------- API Key 核心配置 (已加固) ------------------- +VALID_API_KEY = "123quant-speed" +API_KEY_HEADER_NAME = "X-API-Key" + +# 定义 Header 认证 +api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) + +async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): + """ + 强制验证 API Key + """ + # 1. 检查是否有 Key + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing API Key. Please provide it in the header." + ) + # 2. 检查 Key 是否正确 + if api_key != VALID_API_KEY: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API Key." + ) + # 3. 验证通过 + return True + # ------------------- 生命周期管理 ------------------- @asynccontextmanager async def lifespan(app: FastAPI): - """ - FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源 - """ - print("正在加载 SAM3 模型到 GPU...") + print("="*40) + print("✅ API Key 保护已激活") + print(f"✅ 有效 Key: {VALID_API_KEY}") + print("="*40) - # 1. 检测设备 + print("正在加载 SAM3 模型到 GPU...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") - # 2. 加载模型 (全局单例) model = build_sam3_image_model() model = model.to(device) - model.eval() # 切换到评估模式 + model.eval() - # 3. 初始化 Processor processor = Sam3Processor(model) - # 4. 存入 app.state 供全局访问 app.state.model = model app.state.processor = processor app.state.device = device print(f"模型加载完成,设备: {device}") - yield # 服务运行中... + yield - # 清理资源 (如果需要) print("正在清理资源...") # ------------------- FastAPI 初始化 ------------------- -app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API") +app = FastAPI( + lifespan=lifespan, + title="SAM3 Segmentation API", + description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`", +) + +# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效 +app.openapi_schema = None +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + from fastapi.openapi.utils import get_openapi + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + # 定义安全方案 + openapi_schema["components"]["securitySchemes"] = { + "APIKeyHeader": { + "type": "apiKey", + "in": "header", + "name": API_KEY_HEADER_NAME, + } + } + # 为所有路径应用安全要求 + for path in openapi_schema["paths"]: + for method in openapi_schema["paths"][path]: + openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] + app.openapi_schema = openapi_schema + return app.openapi_schema + +app.openapi = custom_openapi -# 挂载静态文件目录,用于通过 URL 访问生成的图片 app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # ------------------- 辅助函数 ------------------- def load_image_from_url(url: str) -> Image.Image: - """从网络 URL 下载图片""" try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) @@ -77,37 +129,24 @@ def load_image_from_url(url: str) -> Image.Image: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") def generate_and_save_result(image: Image.Image, inference_state) -> str: - """生成可视化结果图并保存,返回文件名""" - # 生成唯一文件名防止冲突 filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(RESULT_IMAGE_DIR, filename) - - # 绘图 (复用你提供的逻辑) plot_results(image, inference_state) - - # 保存 plt.savefig(save_path, dpi=150, bbox_inches='tight') - plt.close() # 务必关闭,防止内存泄漏 - + plt.close() return filename -# ------------------- API 接口 ------------------- -@app.post("/segment") +# ------------------- API 接口 (强制依赖验证) ------------------- +@app.post("/segment", dependencies=[Depends(verify_api_key)]) async def segment( request: Request, prompt: str = Form(...), file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None) ): - """ - 接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。 - """ - - # 1. 校验输入 if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") - # 2. 获取图片对象 try: if file: image = Image.open(file.file).convert("RGB") @@ -116,27 +155,20 @@ async def segment( except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") - # 3. 获取模型 processor = request.app.state.processor - # 4. 执行推理 try: - # 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移 inference_state = processor.set_image(image) output = processor.set_text_prompt(state=inference_state, prompt=prompt) - masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") - # 5. 生成可视化并保存 try: filename = generate_and_save_result(image, inference_state) except Exception as e: raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") - # 6. 构建返回 URL - # request.url_for 会自动根据当前域名生成正确的访问链接 file_url = request.url_for("static", path=f"results/{filename}") return JSONResponse(content={ @@ -148,12 +180,12 @@ async def segment( if __name__ == "__main__": import uvicorn - - # 使用 Python 函数参数的方式传递配置 + # 注意:如果你的文件名不是 fastAPI_nocom.py,请修改下面第一个参数 uvicorn.run( - "fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名 + "fastAPI_nocom:app", host="127.0.0.1", port=55600, - proxy_headers=True, # 对应 --proxy-headers - forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*" - ) + proxy_headers=True, + forwarded_allow_ips="*", + reload=False # 生产环境建议关闭 reload,确保代码完全重载 + ) \ No newline at end of file diff --git a/static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg b/static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg new file mode 100644 index 0000000..ae8e17e Binary files /dev/null and b/static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg differ diff --git a/static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg b/static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg new file mode 100644 index 0000000..7ce7af3 Binary files /dev/null and b/static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg differ diff --git a/static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg b/static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg new file mode 100644 index 0000000..ae8e17e Binary files /dev/null and b/static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg differ diff --git a/static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg b/static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg new file mode 100644 index 0000000..ae8e17e Binary files /dev/null and b/static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg differ