diff --git a/fastAPI_tarot.py b/fastAPI_tarot.py new file mode 100644 index 0000000..b50e15f --- /dev/null +++ b/fastAPI_tarot.py @@ -0,0 +1,330 @@ +import os +import uuid +import time +import requests +import numpy as np +from typing import Optional +from contextlib import asynccontextmanager + +import torch +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status +from fastapi.security import APIKeyHeader +from fastapi.staticfiles import StaticFiles +from fastapi.responses import JSONResponse +from PIL import Image + +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor +from sam3.visualization_utils import plot_results + +# ------------------- 配置与路径 ------------------- +STATIC_DIR = "static" +RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") +os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) + +# ------------------- API Key 核心配置 (已加固) ------------------- +VALID_API_KEY = "123quant-speed" +API_KEY_HEADER_NAME = "X-API-Key" + +# 定义 Header 认证 +api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) + +async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): + """ + 强制验证 API Key + """ + # 1. 检查是否有 Key + if not api_key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing API Key. Please provide it in the header." + ) + # 2. 检查 Key 是否正确 + if api_key != VALID_API_KEY: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API Key." + ) + # 3. 验证通过 + return True + +# ------------------- 生命周期管理 ------------------- +@asynccontextmanager +async def lifespan(app: FastAPI): + print("="*40) + print("✅ API Key 保护已激活") + print(f"✅ 有效 Key: {VALID_API_KEY}") + print("="*40) + + print("正在加载 SAM3 模型到 GPU...") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not torch.cuda.is_available(): + print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") + + model = build_sam3_image_model() + model = model.to(device) + model.eval() + + processor = Sam3Processor(model) + + app.state.model = model + app.state.processor = processor + app.state.device = device + + print(f"模型加载完成,设备: {device}") + + yield + + print("正在清理资源...") + +# ------------------- FastAPI 初始化 ------------------- +app = FastAPI( + lifespan=lifespan, + title="SAM3 Segmentation API", + description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`", +) + +# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效 +app.openapi_schema = None +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + from fastapi.openapi.utils import get_openapi + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + # 定义安全方案 + openapi_schema["components"]["securitySchemes"] = { + "APIKeyHeader": { + "type": "apiKey", + "in": "header", + "name": API_KEY_HEADER_NAME, + } + } + # 为所有路径应用安全要求 + for path in openapi_schema["paths"]: + for method in openapi_schema["paths"][path]: + openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] + app.openapi_schema = openapi_schema + return app.openapi_schema + +app.openapi = custom_openapi + +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + +# ------------------- 辅助函数 ------------------- +def load_image_from_url(url: str) -> Image.Image: + try: + headers = {'User-Agent': 'Mozilla/5.0'} + response = requests.get(url, headers=headers, stream=True, timeout=10) + response.raise_for_status() + image = Image.open(response.raw).convert("RGB") + return image + except Exception as e: + raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") + +def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR) -> list[str]: + """ + 根据 mask 和 box 裁剪出独立的对象图片 (保留透明背景) + """ + saved_files = [] + # Convert image to numpy array + img_arr = np.array(image) # RGB (H, W, 3) + + for i, (mask, box) in enumerate(zip(masks, boxes)): + # Handle tensor/numpy conversions + if isinstance(mask, torch.Tensor): + mask_np = mask.cpu().numpy().squeeze() + else: + mask_np = mask.squeeze() + + if isinstance(box, torch.Tensor): + box_np = box.cpu().numpy() + else: + box_np = box + + # Get coordinates + x1, y1, x2, y2 = map(int, box_np) + + # Ensure coordinates are within bounds + x1 = max(0, x1) + y1 = max(0, y1) + x2 = min(image.width, x2) + y2 = min(image.height, y2) + + # Check valid crop + if x2 <= x1 or y2 <= y1: + continue + + # Create Alpha channel from mask (0 or 255) + # mask_np is boolean or float 0..1. If boolean, *255 -> 0/255. + alpha = (mask_np * 255).astype(np.uint8) + + # Combine RGB and Alpha + rgba = np.dstack((img_arr, alpha)) + + # Convert back to PIL for cropping + pil_rgba = Image.fromarray(rgba) + + # Crop to bounding box + cropped = pil_rgba.crop((x1, y1, x2, y2)) + + # Save + filename = f"tarot_{uuid.uuid4().hex}_{i}.png" # Use png for transparency + save_path = os.path.join(output_dir, filename) + cropped.save(save_path) + saved_files.append(filename) + + return saved_files + +def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str: + filename = f"seg_{uuid.uuid4().hex}.jpg" + save_path = os.path.join(output_dir, filename) + plot_results(image, inference_state) + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() + return filename + +# ------------------- API 接口 (强制依赖验证) ------------------- +@app.post("/segment", dependencies=[Depends(verify_api_key)]) +async def segment( + request: Request, + prompt: str = Form(...), + file: Optional[UploadFile] = File(None), + image_url: Optional[str] = Form(None) +): + if not file and not image_url: + raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + + try: + if file: + image = Image.open(file.file).convert("RGB") + elif image_url: + image = load_image_from_url(image_url) + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") + + processor = request.app.state.processor + + try: + inference_state = processor.set_image(image) + output = processor.set_text_prompt(state=inference_state, prompt=prompt) + masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + except Exception as e: + raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") + + try: + filename = generate_and_save_result(image, inference_state) + except Exception as e: + raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") + + file_url = request.url_for("static", path=f"results/{filename}") + + return JSONResponse(content={ + "status": "success", + "result_image_url": str(file_url), + "detected_count": len(masks), + "scores": scores.tolist() if torch.is_tensor(scores) else scores + }) + +@app.post("/segment_tarot", dependencies=[Depends(verify_api_key)]) +async def segment_tarot( + request: Request, + file: Optional[UploadFile] = File(None), + image_url: Optional[str] = Form(None), + expected_count: int = Form(3) +): + """ + 塔罗牌分割专用接口 + 1. 检测是否包含指定数量的塔罗牌 (默认为 3) + 2. 如果是,分别抠出这些牌并返回 + """ + if not file and not image_url: + raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + + try: + if file: + image = Image.open(file.file).convert("RGB") + elif image_url: + image = load_image_from_url(image_url) + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") + + processor = request.app.state.processor + + try: + inference_state = processor.set_image(image) + # 固定 Prompt 检测塔罗牌 + output = processor.set_text_prompt(state=inference_state, prompt="tarot card") + masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + except Exception as e: + raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") + + # 核心逻辑:判断数量 + detected_count = len(masks) + + # 创建本次请求的独立文件夹 (时间戳_UUID前8位) + request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" + output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) + os.makedirs(output_dir, exist_ok=True) + + if detected_count != expected_count: + # 保存一张图用于调试/反馈 + try: + filename = generate_and_save_result(image, inference_state, output_dir=output_dir) + file_url = request.url_for("static", path=f"results/{request_id}/{filename}") + except: + file_url = None + + return JSONResponse( + status_code=400, + content={ + "status": "failed", + "message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。", + "detected_count": detected_count, + "debug_image_url": str(file_url) if file_url else None + } + ) + + # 数量正确,执行抠图 + try: + filenames = crop_and_save_objects(image, masks, boxes, output_dir=output_dir) + except Exception as e: + raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}") + + # 生成 URL 列表 + card_urls = [str(request.url_for("static", path=f"results/{request_id}/{fname}")) for fname in filenames] + + # 生成整体效果图 + try: + main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) + main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}")) + except: + main_file_url = None + + return JSONResponse(content={ + "status": "success", + "message": f"成功识别并分割 {expected_count} 张塔罗牌", + "tarot_cards": card_urls, + "full_visualization": main_file_url, + "scores": scores.tolist() if torch.is_tensor(scores) else scores + }) + +if __name__ == "__main__": + import uvicorn + # 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数 + uvicorn.run( + "fastAPI_tarot:app", + host="127.0.0.1", + port=55600, + proxy_headers=True, + forwarded_allow_ips="*", + reload=False # 生产环境建议关闭 reload,确保代码完全重载 + ) \ No newline at end of file diff --git a/static/results/seg_36b5940b74da4fd4880fd135f7dccd5a.jpg b/static/results/seg_36b5940b74da4fd4880fd135f7dccd5a.jpg deleted file mode 100644 index ede9e55..0000000 Binary files a/static/results/seg_36b5940b74da4fd4880fd135f7dccd5a.jpg and /dev/null differ diff --git a/static/results/seg_4da8384c943c49099c0cc06f91b5f5e0.jpg b/static/results/seg_4da8384c943c49099c0cc06f91b5f5e0.jpg deleted file mode 100644 index ac64486..0000000 Binary files a/static/results/seg_4da8384c943c49099c0cc06f91b5f5e0.jpg and /dev/null differ diff --git a/static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg b/static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg deleted file mode 100644 index ae8e17e..0000000 Binary files a/static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg and /dev/null differ diff --git a/static/results/seg_9061c56e4b284f60a109e405c20af31b.jpg b/static/results/seg_9061c56e4b284f60a109e405c20af31b.jpg new file mode 100644 index 0000000..ede7a42 Binary files /dev/null and b/static/results/seg_9061c56e4b284f60a109e405c20af31b.jpg differ diff --git a/static/results/seg_b46b74373a1642fd8c0e9c009632774c.jpg b/static/results/seg_b46b74373a1642fd8c0e9c009632774c.jpg deleted file mode 100644 index 29d812e..0000000 Binary files a/static/results/seg_b46b74373a1642fd8c0e9c009632774c.jpg and /dev/null differ diff --git a/static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg b/static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg deleted file mode 100644 index 7ce7af3..0000000 Binary files a/static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg and /dev/null differ diff --git a/static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg b/static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg deleted file mode 100644 index ae8e17e..0000000 Binary files a/static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg and /dev/null differ diff --git a/static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg b/static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg deleted file mode 100644 index ae8e17e..0000000 Binary files a/static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg and /dev/null differ diff --git a/static/results/tarot_c0033d32490548b99a6ddbcf721f2d9a_0.png b/static/results/tarot_c0033d32490548b99a6ddbcf721f2d9a_0.png new file mode 100644 index 0000000..8c01017 Binary files /dev/null and b/static/results/tarot_c0033d32490548b99a6ddbcf721f2d9a_0.png differ