APIkey
This commit is contained in:
122
fastAPI_nocom.py
122
fastAPI_nocom.py
@@ -6,16 +6,15 @@ from contextlib import asynccontextmanager
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import matplotlib
|
import matplotlib
|
||||||
# 关键:设置非交互式后端,避免服务器环境下报错
|
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request, Depends, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
# SAM3 相关导入 (请确保你的环境中已正确安装 sam3)
|
|
||||||
from sam3.model_builder import build_sam3_image_model
|
from sam3.model_builder import build_sam3_image_model
|
||||||
from sam3.model.sam3_image_processor import Sam3Processor
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
from sam3.visualization_utils import plot_results
|
from sam3.visualization_utils import plot_results
|
||||||
@@ -25,48 +24,101 @@ STATIC_DIR = "static"
|
|||||||
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||||||
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# ------------------- API Key 核心配置 (已加固) -------------------
|
||||||
|
VALID_API_KEY = "123quant-speed"
|
||||||
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
# 定义 Header 认证
|
||||||
|
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: Optional[str] = Depends(api_key_header)):
|
||||||
|
"""
|
||||||
|
强制验证 API Key
|
||||||
|
"""
|
||||||
|
# 1. 检查是否有 Key
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing API Key. Please provide it in the header."
|
||||||
|
)
|
||||||
|
# 2. 检查 Key 是否正确
|
||||||
|
if api_key != VALID_API_KEY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API Key."
|
||||||
|
)
|
||||||
|
# 3. 验证通过
|
||||||
|
return True
|
||||||
|
|
||||||
# ------------------- 生命周期管理 -------------------
|
# ------------------- 生命周期管理 -------------------
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""
|
print("="*40)
|
||||||
FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源
|
print("✅ API Key 保护已激活")
|
||||||
"""
|
print(f"✅ 有效 Key: {VALID_API_KEY}")
|
||||||
print("正在加载 SAM3 模型到 GPU...")
|
print("="*40)
|
||||||
|
|
||||||
# 1. 检测设备
|
print("正在加载 SAM3 模型到 GPU...")
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||||||
|
|
||||||
# 2. 加载模型 (全局单例)
|
|
||||||
model = build_sam3_image_model()
|
model = build_sam3_image_model()
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
model.eval() # 切换到评估模式
|
model.eval()
|
||||||
|
|
||||||
# 3. 初始化 Processor
|
|
||||||
processor = Sam3Processor(model)
|
processor = Sam3Processor(model)
|
||||||
|
|
||||||
# 4. 存入 app.state 供全局访问
|
|
||||||
app.state.model = model
|
app.state.model = model
|
||||||
app.state.processor = processor
|
app.state.processor = processor
|
||||||
app.state.device = device
|
app.state.device = device
|
||||||
|
|
||||||
print(f"模型加载完成,设备: {device}")
|
print(f"模型加载完成,设备: {device}")
|
||||||
|
|
||||||
yield # 服务运行中...
|
yield
|
||||||
|
|
||||||
# 清理资源 (如果需要)
|
|
||||||
print("正在清理资源...")
|
print("正在清理资源...")
|
||||||
|
|
||||||
# ------------------- FastAPI 初始化 -------------------
|
# ------------------- FastAPI 初始化 -------------------
|
||||||
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
|
app = FastAPI(
|
||||||
|
lifespan=lifespan,
|
||||||
|
title="SAM3 Segmentation API",
|
||||||
|
description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效
|
||||||
|
app.openapi_schema = None
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
# 定义安全方案
|
||||||
|
openapi_schema["components"]["securitySchemes"] = {
|
||||||
|
"APIKeyHeader": {
|
||||||
|
"type": "apiKey",
|
||||||
|
"in": "header",
|
||||||
|
"name": API_KEY_HEADER_NAME,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# 为所有路径应用安全要求
|
||||||
|
for path in openapi_schema["paths"]:
|
||||||
|
for method in openapi_schema["paths"][path]:
|
||||||
|
openapi_schema["paths"][path][method]["security"] = [{"APIKeyHeader": []}]
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
# 挂载静态文件目录,用于通过 URL 访问生成的图片
|
|
||||||
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
# ------------------- 辅助函数 -------------------
|
# ------------------- 辅助函数 -------------------
|
||||||
def load_image_from_url(url: str) -> Image.Image:
|
def load_image_from_url(url: str) -> Image.Image:
|
||||||
"""从网络 URL 下载图片"""
|
|
||||||
try:
|
try:
|
||||||
headers = {'User-Agent': 'Mozilla/5.0'}
|
headers = {'User-Agent': 'Mozilla/5.0'}
|
||||||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||||||
@@ -77,37 +129,24 @@ def load_image_from_url(url: str) -> Image.Image:
|
|||||||
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
||||||
|
|
||||||
def generate_and_save_result(image: Image.Image, inference_state) -> str:
|
def generate_and_save_result(image: Image.Image, inference_state) -> str:
|
||||||
"""生成可视化结果图并保存,返回文件名"""
|
|
||||||
# 生成唯一文件名防止冲突
|
|
||||||
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||||
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
|
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
|
||||||
|
|
||||||
# 绘图 (复用你提供的逻辑)
|
|
||||||
plot_results(image, inference_state)
|
plot_results(image, inference_state)
|
||||||
|
|
||||||
# 保存
|
|
||||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
plt.close() # 务必关闭,防止内存泄漏
|
plt.close()
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
# ------------------- API 接口 -------------------
|
# ------------------- API 接口 (强制依赖验证) -------------------
|
||||||
@app.post("/segment")
|
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
||||||
async def segment(
|
async def segment(
|
||||||
request: Request,
|
request: Request,
|
||||||
prompt: str = Form(...),
|
prompt: str = Form(...),
|
||||||
file: Optional[UploadFile] = File(None),
|
file: Optional[UploadFile] = File(None),
|
||||||
image_url: Optional[str] = Form(None)
|
image_url: Optional[str] = Form(None)
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. 校验输入
|
|
||||||
if not file and not image_url:
|
if not file and not image_url:
|
||||||
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
# 2. 获取图片对象
|
|
||||||
try:
|
try:
|
||||||
if file:
|
if file:
|
||||||
image = Image.open(file.file).convert("RGB")
|
image = Image.open(file.file).convert("RGB")
|
||||||
@@ -116,27 +155,20 @@ async def segment(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||||||
|
|
||||||
# 3. 获取模型
|
|
||||||
processor = request.app.state.processor
|
processor = request.app.state.processor
|
||||||
|
|
||||||
# 4. 执行推理
|
|
||||||
try:
|
try:
|
||||||
# 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移
|
|
||||||
inference_state = processor.set_image(image)
|
inference_state = processor.set_image(image)
|
||||||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||||||
|
|
||||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
# 5. 生成可视化并保存
|
|
||||||
try:
|
try:
|
||||||
filename = generate_and_save_result(image, inference_state)
|
filename = generate_and_save_result(image, inference_state)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
||||||
|
|
||||||
# 6. 构建返回 URL
|
|
||||||
# request.url_for 会自动根据当前域名生成正确的访问链接
|
|
||||||
file_url = request.url_for("static", path=f"results/{filename}")
|
file_url = request.url_for("static", path=f"results/{filename}")
|
||||||
|
|
||||||
return JSONResponse(content={
|
return JSONResponse(content={
|
||||||
@@ -148,12 +180,12 @@ async def segment(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
# 注意:如果你的文件名不是 fastAPI_nocom.py,请修改下面第一个参数
|
||||||
# 使用 Python 函数参数的方式传递配置
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名
|
"fastAPI_nocom:app",
|
||||||
host="127.0.0.1",
|
host="127.0.0.1",
|
||||||
port=55600,
|
port=55600,
|
||||||
proxy_headers=True, # 对应 --proxy-headers
|
proxy_headers=True,
|
||||||
forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*"
|
forwarded_allow_ips="*",
|
||||||
|
reload=False # 生产环境建议关闭 reload,确保代码完全重载
|
||||||
)
|
)
|
||||||
BIN
static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg
Normal file
BIN
static/results/seg_7e8d2ca9238e4b5dbf1eb81f3342d456.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
BIN
static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg
Normal file
BIN
static/results/seg_c43586b1fc1e4517b9eb6e505024ee0c.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 85 KiB |
BIN
static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg
Normal file
BIN
static/results/seg_e81c526975cb4c838b1c8e04a0b8ba22.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
BIN
static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg
Normal file
BIN
static/results/seg_ed9900e5bd014662a64c56868b8cd74a.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
Reference in New Issue
Block a user