291 lines
8.8 KiB
Python
291 lines
8.8 KiB
Python
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Path
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from fastapi.security import APIKeyHeader
|
|
from pydantic import BaseModel, Field
|
|
from typing import Optional, Literal
|
|
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
import uvicorn
|
|
from loguru import logger
|
|
import tyro
|
|
|
|
# Ensure we can import from project root
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
|
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
|
|
|
# Initialize Pipeline once (matches existing server_* pattern)
|
|
pipeline_config = tyro.cli(PipelineConfig)
|
|
logger.info(f"starting agent with pipeline: \n{pipeline_config}")
|
|
pipeline: Pipeline = pipeline_config.setup()
|
|
|
|
# API Key Authentication
|
|
|
|
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
|
|
"""Verify the API key from Authorization header (Bearer token format)."""
|
|
if not api_key:
|
|
# Tests expect 401 (not FastAPI's default 403) when auth header is missing.
|
|
raise HTTPException(status_code=401, detail="Missing API key")
|
|
key = api_key[7:] if api_key.startswith("Bearer ") else api_key
|
|
if VALID_API_KEYS and key not in VALID_API_KEYS:
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
return key
|
|
|
|
|
|
def _now_iso() -> str:
|
|
# Avoid extra deps; good enough for API metadata.
|
|
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
|
|
|
|
|
def _new_conversation_id() -> str:
|
|
return f"c_{os.urandom(12).hex()}"
|
|
|
|
|
|
def _normalize_thread_id(conversation_id: str) -> str:
|
|
"""
|
|
Pipeline.achat supports a "{thread_id}_{device_id}" format.
|
|
Memory is keyed by the base thread_id (before the device_id suffix).
|
|
"""
|
|
# Conversation IDs we mint are "c_{hex}" (2 segments). Some clients append a device_id:
|
|
# e.g. "c_test123_device456" -> base thread "c_test123".
|
|
parts = conversation_id.split("_")
|
|
if len(parts) >= 3:
|
|
return conversation_id.rsplit("_", 1)[0]
|
|
return conversation_id
|
|
|
|
|
|
def _try_clear_single_thread_memory(thread_id: str) -> bool:
|
|
"""
|
|
Best-effort per-thread memory deletion.
|
|
Returns True if we believe we cleared something, else False.
|
|
"""
|
|
g = getattr(pipeline, "graph", None)
|
|
mem = getattr(g, "memory", None)
|
|
if isinstance(mem, MemorySaver):
|
|
try:
|
|
mem.delete_thread(thread_id)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete memory thread {thread_id}: {e}")
|
|
return False
|
|
return False
|
|
|
|
|
|
class ConversationCreateResponse(BaseModel):
|
|
id: str
|
|
created_at: str
|
|
|
|
|
|
class MessageCreateRequest(BaseModel):
|
|
# Keep this permissive so invalid roles get a 400 from endpoint logic (not 422 from validation).
|
|
role: str = Field(default="user")
|
|
content: str
|
|
stream: bool = Field(default=False)
|
|
|
|
|
|
class MessageResponse(BaseModel):
|
|
role: Literal["assistant"] = Field(default="assistant")
|
|
content: str
|
|
|
|
|
|
class ConversationMessageResponse(BaseModel):
|
|
conversation_id: str
|
|
message: MessageResponse
|
|
|
|
|
|
class ChatRequest(BaseModel):
|
|
input: str
|
|
conversation_id: Optional[str] = Field(default=None)
|
|
stream: bool = Field(default=False)
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
conversation_id: str
|
|
output: str
|
|
|
|
|
|
app = FastAPI(
|
|
title="REST Agent API",
|
|
description="Resource-oriented REST API backed by Pipeline.achat (no RAG/eval/tools exposure).",
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
async def rest_sse_from_astream(
|
|
chunk_generator, response_id: str, conversation_id: str
|
|
):
|
|
"""
|
|
Stream chunks as SSE events.
|
|
|
|
Format:
|
|
- data: {"type":"delta","id":...,"conversation_id":...,"delta":"..."}
|
|
- data: {"type":"done","id":...,"conversation_id":...}
|
|
- data: [DONE]
|
|
"""
|
|
async for chunk in chunk_generator:
|
|
if chunk:
|
|
data = {
|
|
"type": "delta",
|
|
"id": response_id,
|
|
"conversation_id": conversation_id,
|
|
"delta": chunk,
|
|
}
|
|
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
done = {"type": "done", "id": response_id, "conversation_id": conversation_id}
|
|
yield f"data: {json.dumps(done, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {
|
|
"message": "REST Agent API",
|
|
"endpoints": [
|
|
"/v1/conversations (POST)",
|
|
"/v1/chat (POST)",
|
|
"/v1/conversations/{conversation_id}/messages (POST)",
|
|
"/v1/conversations/{conversation_id}/memory (DELETE)",
|
|
"/v1/memory (DELETE)",
|
|
"/health (GET)",
|
|
],
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
|
|
@app.post("/v1/conversations", response_model=ConversationCreateResponse)
|
|
async def create_conversation(_: str = Depends(verify_api_key)):
|
|
return ConversationCreateResponse(id=_new_conversation_id(), created_at=_now_iso())
|
|
|
|
|
|
@app.post("/v1/chat")
|
|
async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
|
conversation_id = body.conversation_id or _new_conversation_id()
|
|
response_id = f"restcmpl-{os.urandom(12).hex()}"
|
|
|
|
if body.stream:
|
|
chunk_generator = await pipeline.achat(
|
|
inp=body.input, as_stream=True, thread_id=conversation_id
|
|
)
|
|
return StreamingResponse(
|
|
rest_sse_from_astream(
|
|
chunk_generator,
|
|
response_id=response_id,
|
|
conversation_id=conversation_id,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
result_text = await pipeline.achat(
|
|
inp=body.input, as_stream=False, thread_id=conversation_id
|
|
)
|
|
if not isinstance(result_text, str):
|
|
result_text = str(result_text)
|
|
return JSONResponse(
|
|
content=ChatResponse(
|
|
conversation_id=conversation_id, output=result_text
|
|
).model_dump()
|
|
)
|
|
|
|
|
|
@app.post("/v1/conversations/{conversation_id}/messages")
|
|
async def create_message(
|
|
body: MessageCreateRequest,
|
|
conversation_id: str = Path(...),
|
|
_: str = Depends(verify_api_key),
|
|
):
|
|
if body.role != "user":
|
|
raise HTTPException(status_code=400, detail="Only role='user' is supported")
|
|
|
|
response_id = f"restmsg-{os.urandom(12).hex()}"
|
|
|
|
if body.stream:
|
|
chunk_generator = await pipeline.achat(
|
|
inp=body.content, as_stream=True, thread_id=conversation_id
|
|
)
|
|
return StreamingResponse(
|
|
rest_sse_from_astream(
|
|
chunk_generator,
|
|
response_id=response_id,
|
|
conversation_id=conversation_id,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
result_text = await pipeline.achat(
|
|
inp=body.content, as_stream=False, thread_id=conversation_id
|
|
)
|
|
if not isinstance(result_text, str):
|
|
result_text = str(result_text)
|
|
out = ConversationMessageResponse(
|
|
conversation_id=conversation_id, message=MessageResponse(content=result_text)
|
|
)
|
|
return JSONResponse(content=out.model_dump())
|
|
|
|
|
|
@app.delete("/v1/memory")
|
|
async def delete_all_memory(_: str = Depends(verify_api_key)):
|
|
"""Delete all conversation memory/history."""
|
|
try:
|
|
await pipeline.aclear_memory()
|
|
return JSONResponse(content={"status": "success", "scope": "all"})
|
|
except Exception as e:
|
|
logger.error(f"Memory deletion error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/conversations/{conversation_id}/memory")
|
|
async def delete_conversation_memory(
|
|
conversation_id: str = Path(...),
|
|
_: str = Depends(verify_api_key),
|
|
):
|
|
"""
|
|
Best-effort per-conversation memory deletion.
|
|
|
|
Note: Pipeline exposes only global clear; per-thread delete is done by directly
|
|
deleting the thread in the underlying MemorySaver if present.
|
|
"""
|
|
thread_id = _normalize_thread_id(conversation_id)
|
|
cleared = _try_clear_single_thread_memory(thread_id)
|
|
if cleared:
|
|
return JSONResponse(
|
|
content={
|
|
"status": "success",
|
|
"scope": "conversation",
|
|
"conversation_id": conversation_id,
|
|
}
|
|
)
|
|
return JSONResponse(
|
|
content={
|
|
"status": "unsupported",
|
|
"message": "Per-conversation memory clearing not supported by current graph; use DELETE /v1/memory instead.",
|
|
"conversation_id": conversation_id,
|
|
},
|
|
status_code=501,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"server_rest:app",
|
|
host="0.0.0.0",
|
|
port=8589,
|
|
reload=True,
|
|
)
|