diff --git a/lang_agent/fastapi_server/front_apis.py b/lang_agent/fastapi_server/front_apis.py index 1b56aeb..2269398 100644 --- a/lang_agent/fastapi_server/front_apis.py +++ b/lang_agent/fastapi_server/front_apis.py @@ -4,6 +4,7 @@ import os import os.path as osp import sys import json +import psycopg from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware @@ -118,6 +119,33 @@ class PipelineStopResponse(BaseModel): reload_required: bool +class ConversationListItem(BaseModel): + conversation_id: str + pipeline_id: str + message_count: int + last_updated: Optional[str] = Field(default=None) + + +class PipelineConversationListResponse(BaseModel): + pipeline_id: str + items: List[ConversationListItem] + count: int + + +class ConversationMessageItem(BaseModel): + message_type: str + content: str + sequence_number: int + created_at: str + + +class PipelineConversationMessagesResponse(BaseModel): + pipeline_id: str + conversation_id: str + items: List[ConversationMessageItem] + count: int + + class ApiKeyPolicyItem(BaseModel): api_key: str default_pipeline_id: Optional[str] = Field(default=None) @@ -200,6 +228,8 @@ async def root(): "/v1/pipelines (POST) - build config + upsert pipeline registry entry", "/v1/pipelines (GET) - list registry pipeline specs", "/v1/pipelines/{pipeline_id} (DELETE) - disable pipeline in registry", + "/v1/pipelines/{pipeline_id}/conversations (GET) - list pipeline conversations", + "/v1/pipelines/{pipeline_id}/conversations/{conversation_id}/messages (GET) - list messages in a conversation", "/v1/pipelines/api-keys (GET) - list API key routing policies", "/v1/pipelines/api-keys/{api_key} (PUT) - upsert API key routing policy", "/v1/pipelines/api-keys/{api_key} (DELETE) - delete API key routing policy", @@ -630,6 +660,123 @@ async def stop_pipeline(pipeline_id: str): ) +@app.get( + "/v1/pipelines/{pipeline_id}/conversations", + response_model=PipelineConversationListResponse, +) +async def list_pipeline_conversations(pipeline_id: str, limit: int = 100): + if limit < 1 or limit > 500: + raise HTTPException(status_code=400, detail="limit must be between 1 and 500") + + conn_str = os.environ.get("CONN_STR") + if not conn_str: + raise HTTPException(status_code=500, detail="CONN_STR not set") + + try: + with psycopg.connect(conn_str) as conn: + with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: + cur.execute( + """ + SELECT + conversation_id, + pipeline_id, + COUNT(*) AS message_count, + MAX(created_at) AS last_updated + FROM messages + WHERE pipeline_id = %s + GROUP BY conversation_id, pipeline_id + ORDER BY last_updated DESC + LIMIT %s + """, + (pipeline_id, limit), + ) + rows = cur.fetchall() + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + items = [ + ConversationListItem( + conversation_id=str(row["conversation_id"]), + pipeline_id=str(row["pipeline_id"]), + message_count=int(row["message_count"]), + last_updated=( + row["last_updated"].isoformat() if row.get("last_updated") else None + ), + ) + for row in rows + ] + return PipelineConversationListResponse( + pipeline_id=pipeline_id, items=items, count=len(items) + ) + + +@app.get( + "/v1/pipelines/{pipeline_id}/conversations/{conversation_id}/messages", + response_model=PipelineConversationMessagesResponse, +) +async def get_pipeline_conversation_messages(pipeline_id: str, conversation_id: str): + conn_str = os.environ.get("CONN_STR") + if not conn_str: + raise HTTPException(status_code=500, detail="CONN_STR not set") + + try: + with psycopg.connect(conn_str) as conn: + with conn.cursor(row_factory=psycopg.rows.dict_row) as cur: + cur.execute( + """ + SELECT 1 + FROM messages + WHERE pipeline_id = %s AND conversation_id = %s + LIMIT 1 + """, + (pipeline_id, conversation_id), + ) + exists = cur.fetchone() + if exists is None: + raise HTTPException( + status_code=404, + detail=( + f"conversation_id '{conversation_id}' not found for " + f"pipeline '{pipeline_id}'" + ), + ) + + cur.execute( + """ + SELECT + message_type, + content, + sequence_number, + created_at + FROM messages + WHERE pipeline_id = %s AND conversation_id = %s + ORDER BY sequence_number ASC + """, + (pipeline_id, conversation_id), + ) + rows = cur.fetchall() + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + items = [ + ConversationMessageItem( + message_type=str(row["message_type"]), + content=str(row["content"]), + sequence_number=int(row["sequence_number"]), + created_at=row["created_at"].isoformat() if row.get("created_at") else "", + ) + for row in rows + ] + return PipelineConversationMessagesResponse( + pipeline_id=pipeline_id, + conversation_id=conversation_id, + items=items, + count=len(items), + ) + + @app.get("/v1/pipelines/api-keys", response_model=ApiKeyPolicyListResponse) async def list_pipeline_api_keys(): try: