from fastapi import FastAPI, HTTPException, Request, Depends, Security, Path from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field from typing import Optional, Literal import os import sys import time import json import uvicorn from loguru import logger import tyro # Ensure we can import from project root sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from langgraph.checkpoint.memory import MemorySaver from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS # Initialize Pipeline once (matches existing server_* pattern) pipeline_config = tyro.cli(PipelineConfig) logger.info(f"starting agent with pipeline: \n{pipeline_config}") pipeline: Pipeline = pipeline_config.setup() # API Key Authentication async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)): """Verify the API key from Authorization header (Bearer token format).""" if not api_key: # Tests expect 401 (not FastAPI's default 403) when auth header is missing. raise HTTPException(status_code=401, detail="Missing API key") key = api_key[7:] if api_key.startswith("Bearer ") else api_key if VALID_API_KEYS and key not in VALID_API_KEYS: raise HTTPException(status_code=401, detail="Invalid API key") return key def _now_iso() -> str: # Avoid extra deps; good enough for API metadata. return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) def _new_conversation_id() -> str: return f"c_{os.urandom(12).hex()}" def _normalize_thread_id(conversation_id: str) -> str: """ Pipeline.achat supports a "{thread_id}_{device_id}" format. Memory is keyed by the base thread_id (before the device_id suffix). """ # Conversation IDs we mint are "c_{hex}" (2 segments). Some clients append a device_id: # e.g. "c_test123_device456" -> base thread "c_test123". parts = conversation_id.split("_") if len(parts) >= 3: return conversation_id.rsplit("_", 1)[0] return conversation_id def _try_clear_single_thread_memory(thread_id: str) -> bool: """ Best-effort per-thread memory deletion. Returns True if we believe we cleared something, else False. """ g = getattr(pipeline, "graph", None) mem = getattr(g, "memory", None) if isinstance(mem, MemorySaver): try: mem.delete_thread(thread_id) return True except Exception as e: logger.warning(f"Failed to delete memory thread {thread_id}: {e}") return False return False class ConversationCreateResponse(BaseModel): id: str created_at: str class MessageCreateRequest(BaseModel): # Keep this permissive so invalid roles get a 400 from endpoint logic (not 422 from validation). role: str = Field(default="user") content: str stream: bool = Field(default=False) class MessageResponse(BaseModel): role: Literal["assistant"] = Field(default="assistant") content: str class ConversationMessageResponse(BaseModel): conversation_id: str message: MessageResponse class ChatRequest(BaseModel): input: str conversation_id: Optional[str] = Field(default=None) stream: bool = Field(default=False) class ChatResponse(BaseModel): conversation_id: str output: str app = FastAPI( title="REST Agent API", description="Resource-oriented REST API backed by Pipeline.achat (no RAG/eval/tools exposure).", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) async def rest_sse_from_astream( chunk_generator, response_id: str, conversation_id: str ): """ Stream chunks as SSE events. Format: - data: {"type":"delta","id":...,"conversation_id":...,"delta":"..."} - data: {"type":"done","id":...,"conversation_id":...} - data: [DONE] """ async for chunk in chunk_generator: if chunk: data = { "type": "delta", "id": response_id, "conversation_id": conversation_id, "delta": chunk, } yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" done = {"type": "done", "id": response_id, "conversation_id": conversation_id} yield f"data: {json.dumps(done, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" @app.get("/") async def root(): return { "message": "REST Agent API", "endpoints": [ "/v1/conversations (POST)", "/v1/chat (POST)", "/v1/conversations/{conversation_id}/messages (POST)", "/v1/conversations/{conversation_id}/memory (DELETE)", "/v1/memory (DELETE)", "/health (GET)", ], } @app.get("/health") async def health(): return {"status": "healthy"} @app.post("/v1/conversations", response_model=ConversationCreateResponse) async def create_conversation(_: str = Depends(verify_api_key)): return ConversationCreateResponse(id=_new_conversation_id(), created_at=_now_iso()) @app.post("/v1/chat") async def chat(body: ChatRequest, _: str = Depends(verify_api_key)): conversation_id = body.conversation_id or _new_conversation_id() response_id = f"restcmpl-{os.urandom(12).hex()}" if body.stream: chunk_generator = await pipeline.achat( inp=body.input, as_stream=True, thread_id=conversation_id ) return StreamingResponse( rest_sse_from_astream( chunk_generator, response_id=response_id, conversation_id=conversation_id, ), media_type="text/event-stream", ) result_text = await pipeline.achat( inp=body.input, as_stream=False, thread_id=conversation_id ) if not isinstance(result_text, str): result_text = str(result_text) return JSONResponse( content=ChatResponse( conversation_id=conversation_id, output=result_text ).model_dump() ) @app.post("/v1/conversations/{conversation_id}/messages") async def create_message( body: MessageCreateRequest, conversation_id: str = Path(...), _: str = Depends(verify_api_key), ): if body.role != "user": raise HTTPException(status_code=400, detail="Only role='user' is supported") response_id = f"restmsg-{os.urandom(12).hex()}" if body.stream: chunk_generator = await pipeline.achat( inp=body.content, as_stream=True, thread_id=conversation_id ) return StreamingResponse( rest_sse_from_astream( chunk_generator, response_id=response_id, conversation_id=conversation_id, ), media_type="text/event-stream", ) result_text = await pipeline.achat( inp=body.content, as_stream=False, thread_id=conversation_id ) if not isinstance(result_text, str): result_text = str(result_text) out = ConversationMessageResponse( conversation_id=conversation_id, message=MessageResponse(content=result_text) ) return JSONResponse(content=out.model_dump()) @app.delete("/v1/memory") async def delete_all_memory(_: str = Depends(verify_api_key)): """Delete all conversation memory/history.""" try: await pipeline.aclear_memory() return JSONResponse(content={"status": "success", "scope": "all"}) except Exception as e: logger.error(f"Memory deletion error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.delete("/v1/conversations/{conversation_id}/memory") async def delete_conversation_memory( conversation_id: str = Path(...), _: str = Depends(verify_api_key), ): """ Best-effort per-conversation memory deletion. Note: Pipeline exposes only global clear; per-thread delete is done by directly deleting the thread in the underlying MemorySaver if present. """ thread_id = _normalize_thread_id(conversation_id) cleared = _try_clear_single_thread_memory(thread_id) if cleared: return JSONResponse( content={ "status": "success", "scope": "conversation", "conversation_id": conversation_id, } ) return JSONResponse( content={ "status": "unsupported", "message": "Per-conversation memory clearing not supported by current graph; use DELETE /v1/memory instead.", "conversation_id": conversation_id, }, status_code=501, ) if __name__ == "__main__": uvicorn.run( "server_rest:app", host="0.0.0.0", port=8589, reload=True, )