import os import uuid import time import requests import numpy as np from typing import Optional from contextlib import asynccontextmanager import torch import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status from fastapi.security import APIKeyHeader from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse from PIL import Image from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results # ------------------- 配置与路径 ------------------- STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) # ------------------- API Key 核心配置 (已加固) ------------------- VALID_API_KEY = "123quant-speed" API_KEY_HEADER_NAME = "X-API-Key" # 定义 Header 认证 api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)): """ 强制验证 API Key """ # 1. 检查是否有 Key if not api_key: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API Key. Please provide it in the header." ) # 2. 检查 Key 是否正确 if api_key != VALID_API_KEY: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid API Key." ) # 3. 验证通过 return True # ------------------- 生命周期管理 ------------------- @asynccontextmanager async def lifespan(app: FastAPI): print("="*40) print("✅ API Key 保护已激活") print(f"✅ 有效 Key: {VALID_API_KEY}") print("="*40) print("正在加载 SAM3 模型到 GPU...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") model = build_sam3_image_model() model = model.to(device) model.eval() processor = Sam3Processor(model) app.state.model = model app.state.processor = processor app.state.device = device print(f"模型加载完成,设备: {device}") yield print("正在清理资源...") # ------------------- FastAPI 初始化 ------------------- app = FastAPI( lifespan=lifespan, title="SAM3 Segmentation API", description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`", ) # 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效 app.openapi_schema = None def custom_openapi(): if app.openapi_schema: return app.openapi_schema from fastapi.openapi.utils import get_openapi openapi_schema = get_openapi( title=app.title, version=app.version, description=app.description, routes=app.routes, ) # 定义安全方案 openapi_schema["components"]["securitySchemes"] = { "APIKeyHeader": { "type": "apiKey", "in": "header", "name": API_KEY_HEADER_NAME, } } # 为所有路径应用安全要求 for path in openapi_schema["paths"]: for method in openapi_schema["paths"][path]: openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}] app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # ------------------- 辅助函数 ------------------- def load_image_from_url(url: str) -> Image.Image: try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) response.raise_for_status() image = Image.open(response.raw).convert("RGB") return image except Exception as e: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR) -> list[str]: """ 根据 mask 和 box 裁剪出独立的对象图片 (保留透明背景) """ saved_files = [] # Convert image to numpy array img_arr = np.array(image) # RGB (H, W, 3) for i, (mask, box) in enumerate(zip(masks, boxes)): # Handle tensor/numpy conversions if isinstance(mask, torch.Tensor): mask_np = mask.cpu().numpy().squeeze() else: mask_np = mask.squeeze() if isinstance(box, torch.Tensor): box_np = box.cpu().numpy() else: box_np = box # Get coordinates x1, y1, x2, y2 = map(int, box_np) # Ensure coordinates are within bounds x1 = max(0, x1) y1 = max(0, y1) x2 = min(image.width, x2) y2 = min(image.height, y2) # Check valid crop if x2 <= x1 or y2 <= y1: continue # Create Alpha channel from mask (0 or 255) # mask_np is boolean or float 0..1. If boolean, *255 -> 0/255. alpha = (mask_np * 255).astype(np.uint8) # Combine RGB and Alpha rgba = np.dstack((img_arr, alpha)) # Convert back to PIL for cropping pil_rgba = Image.fromarray(rgba) # Crop to bounding box cropped = pil_rgba.crop((x1, y1, x2, y2)) # Save filename = f"tarot_{uuid.uuid4().hex}_{i}.png" # Use png for transparency save_path = os.path.join(output_dir, filename) cropped.save(save_path) saved_files.append(filename) return saved_files def generate_and_save_result(image: Image.Image, inference_state, output_dir: str = RESULT_IMAGE_DIR) -> str: filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(output_dir, filename) plot_results(image, inference_state) plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() return filename # ------------------- API 接口 (强制依赖验证) ------------------- @app.post("/segment", dependencies=[Depends(verify_api_key)]) async def segment( request: Request, prompt: str = Form(...), file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None) ): if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor try: inference_state = processor.set_image(image) output = processor.set_text_prompt(state=inference_state, prompt=prompt) masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") try: filename = generate_and_save_result(image, inference_state) except Exception as e: raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") file_url = request.url_for("static", path=f"results/{filename}") return JSONResponse(content={ "status": "success", "result_image_url": str(file_url), "detected_count": len(masks), "scores": scores.tolist() if torch.is_tensor(scores) else scores }) @app.post("/segment_tarot", dependencies=[Depends(verify_api_key)]) async def segment_tarot( request: Request, file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), expected_count: int = Form(3) ): """ 塔罗牌分割专用接口 1. 检测是否包含指定数量的塔罗牌 (默认为 3) 2. 如果是,分别抠出这些牌并返回 """ if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") processor = request.app.state.processor try: inference_state = processor.set_image(image) # 固定 Prompt 检测塔罗牌 output = processor.set_text_prompt(state=inference_state, prompt="tarot card") masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") # 核心逻辑:判断数量 detected_count = len(masks) # 创建本次请求的独立文件夹 (时间戳_UUID前8位) request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) if detected_count != expected_count: # 保存一张图用于调试/反馈 try: filename = generate_and_save_result(image, inference_state, output_dir=output_dir) file_url = request.url_for("static", path=f"results/{request_id}/{filename}") except: file_url = None return JSONResponse( status_code=400, content={ "status": "failed", "message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。", "detected_count": detected_count, "debug_image_url": str(file_url) if file_url else None } ) # 数量正确,执行抠图 try: filenames = crop_and_save_objects(image, masks, boxes, output_dir=output_dir) except Exception as e: raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}") # 生成 URL 列表 card_urls = [str(request.url_for("static", path=f"results/{request_id}/{fname}")) for fname in filenames] # 生成整体效果图 try: main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir) main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}")) except: main_file_url = None return JSONResponse(content={ "status": "success", "message": f"成功识别并分割 {expected_count} 张塔罗牌", "tarot_cards": card_urls, "full_visualization": main_file_url, "scores": scores.tolist() if torch.is_tensor(scores) else scores }) if __name__ == "__main__": import uvicorn # 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数 uvicorn.run( "fastAPI_tarot:app", host="127.0.0.1", port=55600, proxy_headers=True, forwarded_allow_ips="*", reload=False # 生产环境建议关闭 reload,确保代码完全重载 )