diff --git a/fastapi_server/server_openai.py b/fastapi_server/server_openai.py index 31e96b8..d82a457 100644 --- a/fastapi_server/server_openai.py +++ b/fastapi_server/server_openai.py @@ -1,6 +1,7 @@ -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException, Request, Depends, Security from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field from typing import List, Optional, Union, Literal import os @@ -20,6 +21,18 @@ from lang_agent.pipeline import Pipeline, PipelineConfig pipeline_config = tyro.cli(PipelineConfig) pipeline: Pipeline = pipeline_config.setup() +# API Key Authentication +API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True) +VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) + + +async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): + """Verify the API key from Authorization header (Bearer token format).""" + key = api_key[7:] if api_key.startswith("Bearer ") else api_key + if VALID_API_KEYS and key not in VALID_API_KEYS: + raise HTTPException(status_code=401, detail="Invalid API key") + return key + class OpenAIMessage(BaseModel): role: str @@ -133,7 +146,7 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, @app.post("/v1/chat/completions") -async def chat_completions(request: Request): +async def chat_completions(request: Request, _: str = Depends(verify_api_key)): try: body = await request.json() @@ -222,7 +235,7 @@ async def health(): @app.delete("/v1/memory") -async def delete_memory(): +async def delete_memory(_: str = Depends(verify_api_key)): """Delete all conversation memory/history.""" try: await pipeline.aclear_memory()