moved files
This commit is contained in:
0
lang_agent/fastapi_server/__init__.py
Normal file
0
lang_agent/fastapi_server/__init__.py
Normal file
30
lang_agent/fastapi_server/combined.py
Normal file
30
lang_agent/fastapi_server/combined.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from fastapi_server.front_apis import app as front_app
|
||||
from fastapi_server.server_dashscope import create_dashscope_router
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Combined Front + DashScope APIs",
|
||||
description=(
|
||||
"Single-process app exposing front_apis control endpoints and "
|
||||
"DashScope-compatible chat endpoints."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Keep existing /v1/... admin APIs unchanged.
|
||||
app.include_router(front_app.router)
|
||||
|
||||
# Add DashScope endpoints at their existing URLs. We intentionally skip
|
||||
# DashScope's root/health routes to avoid clashing with front_apis.
|
||||
app.include_router(create_dashscope_router(include_meta_routes=False))
|
||||
|
||||
265
lang_agent/fastapi_server/fake_stream_server_dashscopy.py
Normal file
265
lang_agent/fastapi_server/fake_stream_server_dashscopy.py
Normal file
@@ -0,0 +1,265 @@
|
||||
from fastapi import FastAPI, HTTPException, Path, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import tyro
|
||||
|
||||
# Ensure we can import from project root
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
|
||||
# Initialize Pipeline once
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
logger.info(f"starting server with:\n{pipeline_config}")
|
||||
pipeline:Pipeline = pipeline_config.setup()
|
||||
|
||||
|
||||
class DSMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class DSApplicationCallRequest(BaseModel):
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
app_id: Optional[str] = Field(default=None)
|
||||
session_id: Optional[str] = Field(default=None)
|
||||
messages: List[DSMessage]
|
||||
stream: bool = Field(default=True)
|
||||
# Optional overrides for pipeline behavior
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
app = FastAPI(title="DashScope-Compatible Application API",
|
||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def sse_chunks_from_text(full_text: str, response_id: str, model: str = "qwen-flash", chunk_size: int = 1000):
|
||||
created_time = int(time.time())
|
||||
|
||||
for i in range(0, len(full_text), chunk_size):
|
||||
chunk = full_text[i:i + chunk_size]
|
||||
if chunk:
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
# Send empty during stream; many SDKs only expose output_text on final
|
||||
"text": "",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": False,
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
final = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
"text": full_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": True,
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
async def application_responses(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
session_id: str = Path(...),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
# Prefer path params
|
||||
req_app_id = app_id or body.get("app_id")
|
||||
req_session_id = session_id or body['input'].get("session_id")
|
||||
|
||||
# Normalize messages
|
||||
messages = body.get("messages")
|
||||
if messages is None and isinstance(body.get("input"), dict):
|
||||
messages = body.get("input", {}).get("messages")
|
||||
if messages is None and isinstance(body.get("input"), dict):
|
||||
prompt = body.get("input", {}).get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="messages is required")
|
||||
|
||||
# Determine stream flag
|
||||
stream = body.get("stream")
|
||||
if stream is None:
|
||||
stream = body.get("parameters", {}).get("stream", True)
|
||||
|
||||
thread_id = body['input'].get("session_id")
|
||||
|
||||
# Extract latest user message
|
||||
user_msg = None
|
||||
for m in reversed(messages):
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
content = m.get("content") if isinstance(m, dict) else None
|
||||
if role == "user" and content:
|
||||
user_msg = content
|
||||
break
|
||||
if user_msg is None:
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
# Invoke pipeline (non-stream) then stream-chunk it to the client
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_text(result_text, response_id=response_id, model=pipeline_config.llm_name, chunk_size=10),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming response structure
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"app_id": req_app_id,
|
||||
"session_id": req_session_id,
|
||||
"output": {
|
||||
"text": result_text,
|
||||
"created": int(time.time()),
|
||||
"model": pipeline_config.llm_name,
|
||||
},
|
||||
"is_end": True,
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and without session in path
|
||||
@app.post("/apps/{app_id}/completion")
|
||||
@app.post("/v1/apps/{app_id}/completion")
|
||||
@app.post("/api/apps/{app_id}/completion")
|
||||
@app.post("/api/v1/apps/{app_id}/completion")
|
||||
async def application_completion(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
req_session_id = body['input'].get("session_id")
|
||||
|
||||
# Normalize messages
|
||||
messages = body.get("messages")
|
||||
if messages is None and isinstance(body.get("input"), dict):
|
||||
messages = body.get("input", {}).get("messages")
|
||||
if messages is None and isinstance(body.get("input"), dict):
|
||||
prompt = body.get("input", {}).get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="messages is required")
|
||||
|
||||
stream = body.get("stream")
|
||||
if stream is None:
|
||||
stream = body.get("parameters", {}).get("stream", True)
|
||||
|
||||
thread_id = body['input'].get("session_id")
|
||||
|
||||
user_msg = None
|
||||
for m in reversed(messages):
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
content = m.get("content") if isinstance(m, dict) else None
|
||||
if role == "user" and content:
|
||||
user_msg = content
|
||||
break
|
||||
if user_msg is None:
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_text(result_text, response_id=response_id, model=pipeline_config.llm_name, chunk_size=1000),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"app_id": app_id,
|
||||
"session_id": req_session_id,
|
||||
"output": {
|
||||
"text": result_text,
|
||||
"created": int(time.time()),
|
||||
"model": pipeline_config.llm_name,
|
||||
},
|
||||
"is_end": True,
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible completion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "DashScope Application-compatible API", "endpoints": [
|
||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||
"/health",
|
||||
]}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"fake_stream_server_dashscopy:app",
|
||||
host="0.0.0.0",
|
||||
port=8588,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
|
||||
702
lang_agent/fastapi_server/front_apis.py
Normal file
702
lang_agent/fastapi_server/front_apis.py
Normal file
@@ -0,0 +1,702 @@
|
||||
from typing import Dict, List, Optional, Any
|
||||
import commentjson
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Ensure we can import from project root.
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.config.db_config_manager import DBConfigManager
|
||||
from lang_agent.config.constants import (
|
||||
_PROJECT_ROOT,
|
||||
MCP_CONFIG_PATH,
|
||||
MCP_CONFIG_DEFAULT_CONTENT,
|
||||
PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
from lang_agent.front_api.build_server_utils import (
|
||||
GRAPH_BUILD_FNCS,
|
||||
update_pipeline_registry,
|
||||
)
|
||||
|
||||
|
||||
class GraphConfigUpsertRequest(BaseModel):
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
prompt_set_id: Optional[str] = Field(default=None)
|
||||
tool_keys: List[str] = Field(default_factory=list)
|
||||
prompt_dict: Dict[str, str] = Field(default_factory=dict)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class GraphConfigUpsertResponse(BaseModel):
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
tool_keys: List[str]
|
||||
prompt_keys: List[str]
|
||||
api_key: str
|
||||
|
||||
|
||||
class GraphConfigReadResponse(BaseModel):
|
||||
graph_id: Optional[str] = Field(default=None)
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
tool_keys: List[str]
|
||||
prompt_dict: Dict[str, str]
|
||||
api_key: str = Field(default="")
|
||||
|
||||
|
||||
class GraphConfigListItem(BaseModel):
|
||||
graph_id: Optional[str] = Field(default=None)
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
name: str
|
||||
description: str
|
||||
is_active: bool
|
||||
tool_keys: List[str]
|
||||
api_key: str = Field(default="")
|
||||
created_at: Optional[str] = Field(default=None)
|
||||
updated_at: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class GraphConfigListResponse(BaseModel):
|
||||
items: List[GraphConfigListItem]
|
||||
count: int
|
||||
|
||||
|
||||
class PipelineCreateRequest(BaseModel):
|
||||
graph_id: str = Field(
|
||||
description="Graph key from GRAPH_BUILD_FNCS, e.g. routing or react"
|
||||
)
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
tool_keys: List[str] = Field(default_factory=list)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
llm_name: str = Field(default="qwen-plus")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class PipelineSpec(BaseModel):
|
||||
pipeline_id: str
|
||||
graph_id: str
|
||||
enabled: bool
|
||||
config_file: str
|
||||
llm_name: str
|
||||
overrides: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PipelineCreateResponse(BaseModel):
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
graph_id: str
|
||||
config_file: str
|
||||
llm_name: str
|
||||
enabled: bool
|
||||
reload_required: bool
|
||||
registry_path: str
|
||||
|
||||
|
||||
class PipelineListResponse(BaseModel):
|
||||
items: List[PipelineSpec]
|
||||
count: int
|
||||
|
||||
|
||||
class PipelineStopResponse(BaseModel):
|
||||
pipeline_id: str
|
||||
status: str
|
||||
enabled: bool
|
||||
reload_required: bool
|
||||
|
||||
|
||||
class ApiKeyPolicyItem(BaseModel):
|
||||
api_key: str
|
||||
default_pipeline_id: Optional[str] = Field(default=None)
|
||||
allowed_pipeline_ids: List[str] = Field(default_factory=list)
|
||||
app_id: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class ApiKeyPolicyListResponse(BaseModel):
|
||||
items: List[ApiKeyPolicyItem]
|
||||
count: int
|
||||
|
||||
|
||||
class ApiKeyPolicyUpsertRequest(BaseModel):
|
||||
default_pipeline_id: Optional[str] = Field(default=None)
|
||||
allowed_pipeline_ids: List[str] = Field(default_factory=list)
|
||||
app_id: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class ApiKeyPolicyDeleteResponse(BaseModel):
|
||||
api_key: str
|
||||
status: str
|
||||
reload_required: bool
|
||||
|
||||
|
||||
class McpConfigReadResponse(BaseModel):
|
||||
path: str
|
||||
raw_content: str
|
||||
tool_keys: List[str]
|
||||
|
||||
|
||||
class McpConfigUpdateRequest(BaseModel):
|
||||
raw_content: str
|
||||
|
||||
|
||||
class McpConfigUpdateResponse(BaseModel):
|
||||
status: str
|
||||
path: str
|
||||
tool_keys: List[str]
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Front APIs",
|
||||
description="Manage graph configs and launch graph pipelines.",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
_db = DBConfigManager()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "Front APIs",
|
||||
"endpoints": [
|
||||
"/v1/graph-configs (POST)",
|
||||
"/v1/graph-configs (GET)",
|
||||
"/v1/graph-configs/default/{pipeline_id} (GET)",
|
||||
"/v1/graphs/{graph_id}/default-config (GET)",
|
||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id} (GET)",
|
||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id} (DELETE)",
|
||||
"/v1/pipelines/graphs (GET)",
|
||||
"/v1/pipelines (POST) - build config + upsert pipeline registry entry",
|
||||
"/v1/pipelines (GET) - list registry pipeline specs",
|
||||
"/v1/pipelines/{pipeline_id} (DELETE) - disable pipeline in registry",
|
||||
"/v1/pipelines/api-keys (GET) - list API key routing policies",
|
||||
"/v1/pipelines/api-keys/{api_key} (PUT) - upsert API key routing policy",
|
||||
"/v1/pipelines/api-keys/{api_key} (DELETE) - delete API key routing policy",
|
||||
"/v1/tool-configs/mcp (GET)",
|
||||
"/v1/tool-configs/mcp (PUT)",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _parse_mcp_tool_keys(raw_content: str) -> List[str]:
|
||||
parsed = commentjson.loads(raw_content or "{}")
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("mcp_config must be a JSON object at top level")
|
||||
return sorted(str(key) for key in parsed.keys())
|
||||
|
||||
|
||||
def _read_mcp_config_raw() -> str:
|
||||
if not osp.exists(MCP_CONFIG_PATH):
|
||||
os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(MCP_CONFIG_DEFAULT_CONTENT)
|
||||
with open(MCP_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _read_pipeline_registry() -> Dict[str, Any]:
|
||||
if not osp.exists(PIPELINE_REGISTRY_PATH):
|
||||
os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2)
|
||||
with open(PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f:
|
||||
registry = json.load(f)
|
||||
pipelines = registry.get("pipelines")
|
||||
if not isinstance(pipelines, dict):
|
||||
raise ValueError("`pipelines` in pipeline registry must be an object")
|
||||
api_keys = registry.get("api_keys")
|
||||
if api_keys is None:
|
||||
registry["api_keys"] = {}
|
||||
elif not isinstance(api_keys, dict):
|
||||
raise ValueError("`api_keys` in pipeline registry must be an object")
|
||||
return registry
|
||||
|
||||
|
||||
def _write_pipeline_registry(registry: Dict[str, Any]) -> None:
|
||||
os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(registry, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def _normalize_pipeline_spec(pipeline_id: str, spec: Dict[str, Any]) -> PipelineSpec:
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError(f"pipeline spec for '{pipeline_id}' must be an object")
|
||||
overrides = spec.get("overrides", {})
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
if not isinstance(overrides, dict):
|
||||
raise ValueError(f"`overrides` for pipeline '{pipeline_id}' must be an object")
|
||||
llm_name = str(overrides.get("llm_name") or "unknown")
|
||||
return PipelineSpec(
|
||||
pipeline_id=pipeline_id,
|
||||
graph_id=str(spec.get("graph_id") or pipeline_id),
|
||||
enabled=bool(spec.get("enabled", True)),
|
||||
config_file=str(spec.get("config_file") or ""),
|
||||
llm_name=llm_name,
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_api_key_policy(api_key: str, policy: Dict[str, Any]) -> ApiKeyPolicyItem:
|
||||
if not isinstance(policy, dict):
|
||||
raise ValueError(f"api key policy for '{api_key}' must be an object")
|
||||
allowed = policy.get("allowed_pipeline_ids") or []
|
||||
if not isinstance(allowed, list):
|
||||
raise ValueError(
|
||||
f"`allowed_pipeline_ids` for api key '{api_key}' must be a list"
|
||||
)
|
||||
cleaned_allowed = []
|
||||
seen = set()
|
||||
for pid in allowed:
|
||||
pipeline_id = str(pid).strip()
|
||||
if not pipeline_id or pipeline_id in seen:
|
||||
continue
|
||||
seen.add(pipeline_id)
|
||||
cleaned_allowed.append(pipeline_id)
|
||||
default_pipeline_id = policy.get("default_pipeline_id")
|
||||
if default_pipeline_id is not None:
|
||||
default_pipeline_id = str(default_pipeline_id).strip() or None
|
||||
app_id = policy.get("app_id")
|
||||
if app_id is not None:
|
||||
app_id = str(app_id).strip() or None
|
||||
return ApiKeyPolicyItem(
|
||||
api_key=api_key,
|
||||
default_pipeline_id=default_pipeline_id,
|
||||
allowed_pipeline_ids=cleaned_allowed,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/graph-configs", response_model=GraphConfigUpsertResponse)
|
||||
async def upsert_graph_config(body: GraphConfigUpsertRequest):
|
||||
try:
|
||||
resolved_prompt_set_id = _db.set_config(
|
||||
graph_id=body.graph_id,
|
||||
pipeline_id=body.pipeline_id,
|
||||
prompt_set_id=body.prompt_set_id,
|
||||
tool_list=body.tool_keys,
|
||||
prompt_dict=body.prompt_dict,
|
||||
api_key=body.api_key,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return GraphConfigUpsertResponse(
|
||||
graph_id=body.graph_id,
|
||||
pipeline_id=body.pipeline_id,
|
||||
prompt_set_id=resolved_prompt_set_id,
|
||||
tool_keys=body.tool_keys,
|
||||
prompt_keys=list(body.prompt_dict.keys()),
|
||||
api_key=(body.api_key or "").strip(),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/graph-configs", response_model=GraphConfigListResponse)
|
||||
async def list_graph_configs(
|
||||
pipeline_id: Optional[str] = None, graph_id: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id, graph_id=graph_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
items = [GraphConfigListItem(**row) for row in rows]
|
||||
return GraphConfigListResponse(items=items, count=len(items))
|
||||
|
||||
|
||||
@app.get(
|
||||
"/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse
|
||||
)
|
||||
async def get_default_graph_config(pipeline_id: str):
|
||||
try:
|
||||
prompt_dict, tool_keys = _db.get_config(
|
||||
pipeline_id=pipeline_id, prompt_set_id=None
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if not prompt_dict and not tool_keys:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No active prompt set found for pipeline '{pipeline_id}'",
|
||||
)
|
||||
|
||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id)
|
||||
active = next((row for row in rows if row["is_active"]), None)
|
||||
if active is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No active prompt set found for pipeline '{pipeline_id}'",
|
||||
)
|
||||
|
||||
return GraphConfigReadResponse(
|
||||
graph_id=active.get("graph_id"),
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=active["prompt_set_id"],
|
||||
tool_keys=tool_keys,
|
||||
prompt_dict=prompt_dict,
|
||||
api_key=(active.get("api_key") or ""),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/graphs/{graph_id}/default-config", response_model=GraphConfigReadResponse)
|
||||
async def get_graph_default_config_by_graph(graph_id: str):
|
||||
return await get_default_graph_config(pipeline_id=graph_id)
|
||||
|
||||
|
||||
@app.get(
|
||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id}",
|
||||
response_model=GraphConfigReadResponse,
|
||||
)
|
||||
async def get_graph_config(pipeline_id: str, prompt_set_id: str):
|
||||
try:
|
||||
meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
||||
if meta is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"prompt_set_id '{prompt_set_id}' not found for pipeline '{pipeline_id}'",
|
||||
)
|
||||
prompt_dict, tool_keys = _db.get_config(
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return GraphConfigReadResponse(
|
||||
graph_id=meta.get("graph_id"),
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
tool_keys=tool_keys,
|
||||
prompt_dict=prompt_dict,
|
||||
api_key=(meta.get("api_key") or ""),
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/v1/graph-configs/{pipeline_id}/{prompt_set_id}")
|
||||
async def delete_graph_config(pipeline_id: str, prompt_set_id: str):
|
||||
try:
|
||||
_db.remove_config(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"pipeline_id": pipeline_id,
|
||||
"prompt_set_id": prompt_set_id,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/pipelines/graphs")
|
||||
async def available_graphs():
|
||||
return {"available_graphs": sorted(GRAPH_BUILD_FNCS.keys())}
|
||||
|
||||
|
||||
@app.get("/v1/tool-configs/mcp", response_model=McpConfigReadResponse)
|
||||
async def get_mcp_tool_config():
|
||||
try:
|
||||
raw_content = _read_mcp_config_raw()
|
||||
tool_keys = _parse_mcp_tool_keys(raw_content)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return McpConfigReadResponse(
|
||||
path=MCP_CONFIG_PATH,
|
||||
raw_content=raw_content,
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
|
||||
|
||||
@app.put("/v1/tool-configs/mcp", response_model=McpConfigUpdateResponse)
|
||||
async def update_mcp_tool_config(body: McpConfigUpdateRequest):
|
||||
try:
|
||||
tool_keys = _parse_mcp_tool_keys(body.raw_content)
|
||||
os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
# Keep user formatting/comments as entered while ensuring trailing newline.
|
||||
f.write(body.raw_content.rstrip() + "\n")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return McpConfigUpdateResponse(
|
||||
status="updated",
|
||||
path=MCP_CONFIG_PATH,
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/pipelines", response_model=PipelineListResponse)
|
||||
async def list_running_pipelines():
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
items: List[PipelineSpec] = []
|
||||
pipelines = registry.get("pipelines", {})
|
||||
for pipeline_id, spec in sorted(pipelines.items()):
|
||||
items.append(_normalize_pipeline_spec(pipeline_id, spec))
|
||||
return PipelineListResponse(items=items, count=len(items))
|
||||
|
||||
|
||||
@app.post("/v1/pipelines", response_model=PipelineCreateResponse)
|
||||
async def create_pipeline(body: PipelineCreateRequest):
|
||||
build_fn = GRAPH_BUILD_FNCS.get(body.graph_id)
|
||||
if build_fn is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown graph_id '{body.graph_id}'. Valid options: {sorted(GRAPH_BUILD_FNCS.keys())}",
|
||||
)
|
||||
|
||||
pipeline_id = body.pipeline_id.strip()
|
||||
if not pipeline_id:
|
||||
raise HTTPException(status_code=400, detail="pipeline_id is required")
|
||||
prompt_set_id = body.prompt_set_id.strip()
|
||||
if not prompt_set_id:
|
||||
raise HTTPException(status_code=400, detail="prompt_set_id is required")
|
||||
|
||||
resolved_api_key = (body.api_key or "").strip()
|
||||
if not resolved_api_key:
|
||||
meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
||||
if meta is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"prompt_set_id '{prompt_set_id}' not found for pipeline '{pipeline_id}', "
|
||||
"and request api_key is empty"
|
||||
),
|
||||
)
|
||||
resolved_api_key = str(meta.get("api_key") or "").strip()
|
||||
if not resolved_api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"api_key is required either in request body or in prompt set metadata"
|
||||
),
|
||||
)
|
||||
|
||||
config_file = f"configs/pipelines/{pipeline_id}.yml"
|
||||
config_abs_dir = osp.join(_PROJECT_ROOT, "configs", "pipelines")
|
||||
try:
|
||||
build_fn(
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set=prompt_set_id,
|
||||
tool_keys=body.tool_keys,
|
||||
api_key=resolved_api_key,
|
||||
llm_name=body.llm_name,
|
||||
pipeline_config_dir=config_abs_dir,
|
||||
)
|
||||
|
||||
update_pipeline_registry(
|
||||
pipeline_id=pipeline_id,
|
||||
graph_id=body.graph_id,
|
||||
config_file=config_file,
|
||||
llm_name=body.llm_name,
|
||||
enabled=body.enabled,
|
||||
registry_f=PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}")
|
||||
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
pipeline_spec = registry.get("pipelines", {}).get(pipeline_id)
|
||||
if pipeline_spec is None:
|
||||
raise ValueError(
|
||||
f"pipeline '{pipeline_id}' missing from registry after update"
|
||||
)
|
||||
normalized = _normalize_pipeline_spec(pipeline_id, pipeline_spec)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to read pipeline registry after update: {e}",
|
||||
)
|
||||
|
||||
return PipelineCreateResponse(
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
graph_id=normalized.graph_id,
|
||||
config_file=normalized.config_file,
|
||||
llm_name=normalized.llm_name,
|
||||
enabled=normalized.enabled,
|
||||
reload_required=False,
|
||||
registry_path=PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/v1/pipelines/{pipeline_id}", response_model=PipelineStopResponse)
|
||||
async def stop_pipeline(pipeline_id: str):
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
pipelines = registry.get("pipelines", {})
|
||||
spec = pipelines.get(pipeline_id)
|
||||
if not isinstance(spec, dict):
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"pipeline_id '{pipeline_id}' not found"
|
||||
)
|
||||
spec["enabled"] = False
|
||||
_write_pipeline_registry(registry)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return PipelineStopResponse(
|
||||
pipeline_id=pipeline_id,
|
||||
status="disabled",
|
||||
enabled=False,
|
||||
reload_required=False,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/pipelines/api-keys", response_model=ApiKeyPolicyListResponse)
|
||||
async def list_pipeline_api_keys():
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
api_keys = registry.get("api_keys", {})
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
items: List[ApiKeyPolicyItem] = []
|
||||
for api_key, policy in sorted(api_keys.items()):
|
||||
items.append(_normalize_api_key_policy(str(api_key), policy))
|
||||
return ApiKeyPolicyListResponse(items=items, count=len(items))
|
||||
|
||||
|
||||
@app.put(
|
||||
"/v1/pipelines/api-keys/{api_key}",
|
||||
response_model=ApiKeyPolicyItem,
|
||||
)
|
||||
async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertRequest):
|
||||
normalized_key = api_key.strip()
|
||||
if not normalized_key:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="api_key path parameter is required"
|
||||
)
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
pipelines = registry.get("pipelines", {})
|
||||
if not isinstance(pipelines, dict):
|
||||
raise ValueError("`pipelines` in pipeline registry must be an object")
|
||||
known_pipeline_ids = set(pipelines.keys())
|
||||
|
||||
allowed = []
|
||||
seen = set()
|
||||
for pipeline_id in body.allowed_pipeline_ids:
|
||||
cleaned = str(pipeline_id).strip()
|
||||
if not cleaned or cleaned in seen:
|
||||
continue
|
||||
if cleaned not in known_pipeline_ids:
|
||||
raise ValueError(
|
||||
f"unknown pipeline_id '{cleaned}' in allowed_pipeline_ids"
|
||||
)
|
||||
seen.add(cleaned)
|
||||
allowed.append(cleaned)
|
||||
|
||||
default_pipeline_id = body.default_pipeline_id
|
||||
if default_pipeline_id is not None:
|
||||
default_pipeline_id = default_pipeline_id.strip() or None
|
||||
if default_pipeline_id and default_pipeline_id not in known_pipeline_ids:
|
||||
raise ValueError(f"unknown default_pipeline_id '{default_pipeline_id}'")
|
||||
|
||||
app_id = body.app_id.strip() if body.app_id else None
|
||||
policy: Dict[str, Any] = {}
|
||||
if default_pipeline_id:
|
||||
policy["default_pipeline_id"] = default_pipeline_id
|
||||
if allowed:
|
||||
policy["allowed_pipeline_ids"] = allowed
|
||||
if app_id:
|
||||
policy["app_id"] = app_id
|
||||
|
||||
registry.setdefault("api_keys", {})[normalized_key] = policy
|
||||
_write_pipeline_registry(registry)
|
||||
return _normalize_api_key_policy(normalized_key, policy)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete(
|
||||
"/v1/pipelines/api-keys/{api_key}",
|
||||
response_model=ApiKeyPolicyDeleteResponse,
|
||||
)
|
||||
async def delete_pipeline_api_key_policy(api_key: str):
|
||||
normalized_key = api_key.strip()
|
||||
if not normalized_key:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="api_key path parameter is required"
|
||||
)
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
api_keys = registry.get("api_keys", {})
|
||||
if normalized_key not in api_keys:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"api_key '{normalized_key}' not found"
|
||||
)
|
||||
del api_keys[normalized_key]
|
||||
_write_pipeline_registry(registry)
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return ApiKeyPolicyDeleteResponse(
|
||||
api_key=normalized_key,
|
||||
status="deleted",
|
||||
reload_required=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"front_apis:app",
|
||||
host="0.0.0.0",
|
||||
port=8500,
|
||||
reload=True,
|
||||
)
|
||||
365
lang_agent/fastapi_server/server_dashscope.py
Normal file
365
lang_agent/fastapi_server/server_dashscope.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Optional
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import tyro
|
||||
|
||||
# Ensure we can import from project root
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import PipelineConfig
|
||||
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
|
||||
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
def _build_default_pipeline_config() -> PipelineConfig:
|
||||
"""
|
||||
Build import-time defaults without parsing CLI args.
|
||||
|
||||
This keeps module import safe for reuse by combined apps and tests.
|
||||
"""
|
||||
pipeline_config = PipelineConfig()
|
||||
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
|
||||
return pipeline_config
|
||||
|
||||
|
||||
def _build_pipeline_manager(base_config: PipelineConfig) -> ServerPipelineManager:
|
||||
pipeline_manager = ServerPipelineManager(
|
||||
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
|
||||
default_config=base_config,
|
||||
)
|
||||
pipeline_manager.load_registry(PIPELINE_REGISTRY_PATH)
|
||||
return pipeline_manager
|
||||
|
||||
|
||||
pipeline_config = _build_default_pipeline_config()
|
||||
PIPELINE_MANAGER = _build_pipeline_manager(pipeline_config)
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
key = api_key[7:] if api_key.startswith("Bearer ") else api_key
|
||||
if VALID_API_KEYS and key not in VALID_API_KEYS:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return key
|
||||
|
||||
|
||||
class DSMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class DSApplicationCallRequest(BaseModel):
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
app_id: Optional[str] = Field(default=None)
|
||||
session_id: Optional[str] = Field(default=None)
|
||||
messages: List[DSMessage]
|
||||
stream: bool = Field(default=True)
|
||||
# Optional overrides for pipeline behavior
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
def sse_chunks_from_stream(
|
||||
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||
):
|
||||
"""
|
||||
Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
DashScope SDK expects accumulated text in each chunk (not deltas).
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
accumulated_text = ""
|
||||
|
||||
for chunk in chunk_generator:
|
||||
if chunk:
|
||||
accumulated_text += chunk
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
# DashScope SDK expects accumulated text, not empty or delta
|
||||
"text": accumulated_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": False,
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Final message with complete text
|
||||
final = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
"text": accumulated_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": True,
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(
|
||||
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||
):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
DashScope SDK expects accumulated text in each chunk (not deltas).
|
||||
"""
|
||||
created_time = int(time.time())
|
||||
accumulated_text = ""
|
||||
|
||||
async for chunk in chunk_generator:
|
||||
if chunk:
|
||||
accumulated_text += chunk
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
"text": accumulated_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": False,
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Final message with complete text
|
||||
final = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"output": {
|
||||
"text": accumulated_text,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
},
|
||||
"is_end": True,
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
def _normalize_messages(body: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
messages = body.get("messages")
|
||||
body_input = body.get("input", {})
|
||||
if messages is None and isinstance(body_input, dict):
|
||||
messages = body_input.get("messages")
|
||||
if messages is None and isinstance(body_input, dict):
|
||||
prompt = body_input.get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="messages is required")
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_user_message(messages: List[Dict[str, Any]]) -> str:
|
||||
user_msg = None
|
||||
for m in reversed(messages):
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
content = m.get("content") if isinstance(m, dict) else None
|
||||
if role == "user" and content:
|
||||
user_msg = content
|
||||
break
|
||||
if user_msg is None:
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
return user_msg
|
||||
|
||||
|
||||
async def _process_dashscope_request(
|
||||
body: Dict[str, Any],
|
||||
app_id: Optional[str],
|
||||
session_id: Optional[str],
|
||||
api_key: str,
|
||||
pipeline_manager: ServerPipelineManager,
|
||||
):
|
||||
try:
|
||||
pipeline_manager.refresh_registry_if_needed()
|
||||
except Exception as e:
|
||||
logger.error(f"failed to refresh pipeline registry: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}")
|
||||
|
||||
req_app_id = app_id or body.get("app_id")
|
||||
body_input = body.get("input", {}) if isinstance(body.get("input"), dict) else {}
|
||||
req_session_id = session_id or body_input.get("session_id")
|
||||
messages = _normalize_messages(body)
|
||||
|
||||
stream = body.get("stream")
|
||||
if stream is None:
|
||||
stream = body.get("parameters", {}).get("stream", True)
|
||||
|
||||
thread_id = body_input.get("session_id") or req_session_id or "3"
|
||||
user_msg = _extract_user_message(messages)
|
||||
|
||||
pipeline_id = pipeline_manager.resolve_pipeline_id(
|
||||
body=body, app_id=req_app_id, api_key=api_key
|
||||
)
|
||||
selected_pipeline, selected_model = pipeline_manager.get_pipeline(pipeline_id)
|
||||
|
||||
# Namespace thread ids to prevent memory collisions across pipelines.
|
||||
thread_id = f"{pipeline_id}:{thread_id}"
|
||||
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
chunk_generator = await selected_pipeline.achat(
|
||||
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_astream(
|
||||
chunk_generator, response_id=response_id, model=selected_model
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result_text = await selected_pipeline.achat(
|
||||
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
data = {
|
||||
"request_id": response_id,
|
||||
"code": 200,
|
||||
"message": "OK",
|
||||
"app_id": req_app_id,
|
||||
"session_id": req_session_id,
|
||||
"output": {
|
||||
"text": result_text,
|
||||
"created": int(time.time()),
|
||||
"model": selected_model,
|
||||
},
|
||||
"pipeline_id": pipeline_id,
|
||||
"is_end": True,
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
|
||||
def create_dashscope_router(
|
||||
pipeline_manager: Optional[ServerPipelineManager] = None,
|
||||
include_meta_routes: bool = True,
|
||||
) -> APIRouter:
|
||||
manager = pipeline_manager or PIPELINE_MANAGER
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
@router.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
async def application_responses(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
session_id: str = Path(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
return await _process_dashscope_request(
|
||||
body=body,
|
||||
app_id=app_id,
|
||||
session_id=session_id,
|
||||
api_key=api_key,
|
||||
pipeline_manager=manager,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and
|
||||
# without session in path.
|
||||
@router.post("/apps/{app_id}/completion")
|
||||
@router.post("/v1/apps/{app_id}/completion")
|
||||
@router.post("/api/apps/{app_id}/completion")
|
||||
@router.post("/api/v1/apps/{app_id}/completion")
|
||||
async def application_completion(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
return await _process_dashscope_request(
|
||||
body=body,
|
||||
app_id=app_id,
|
||||
session_id=None,
|
||||
api_key=api_key,
|
||||
pipeline_manager=manager,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible completion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if include_meta_routes:
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "DashScope Application-compatible API",
|
||||
"endpoints": [
|
||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||
"/health",
|
||||
],
|
||||
}
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_dashscope_app(
|
||||
pipeline_manager: Optional[ServerPipelineManager] = None,
|
||||
) -> FastAPI:
|
||||
dashscope_app = FastAPI(
|
||||
title="DashScope-Compatible Application API",
|
||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
|
||||
)
|
||||
dashscope_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
dashscope_app.include_router(
|
||||
create_dashscope_router(
|
||||
pipeline_manager=pipeline_manager,
|
||||
include_meta_routes=True,
|
||||
)
|
||||
)
|
||||
return dashscope_app
|
||||
|
||||
|
||||
dashscope_router = create_dashscope_router(include_meta_routes=False)
|
||||
app = create_dashscope_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# CLI parsing is intentionally only in script mode to keep module import safe.
|
||||
cli_pipeline_config = tyro.cli(PipelineConfig)
|
||||
logger.info(f"starting agent with CLI pipeline config: \n{cli_pipeline_config}")
|
||||
cli_pipeline_manager = _build_pipeline_manager(cli_pipeline_config)
|
||||
uvicorn.run(
|
||||
create_dashscope_app(pipeline_manager=cli_pipeline_manager),
|
||||
host=cli_pipeline_config.host,
|
||||
port=cli_pipeline_config.port,
|
||||
reload=False,
|
||||
)
|
||||
229
lang_agent/fastapi_server/server_openai.py
Normal file
229
lang_agent/fastapi_server/server_openai.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Union, Literal
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import tyro
|
||||
|
||||
# Ensure we can import from project root
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize Pipeline once
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
key = api_key[7:] if api_key.startswith("Bearer ") else api_key
|
||||
if VALID_API_KEYS and key not in VALID_API_KEYS:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return key
|
||||
|
||||
|
||||
class OpenAIMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAIChatCompletionRequest(BaseModel):
|
||||
model: str = Field(default="gpt-3.5-turbo")
|
||||
messages: List[OpenAIMessage]
|
||||
stream: bool = Field(default=False)
|
||||
temperature: Optional[float] = Field(default=1.0)
|
||||
max_tokens: Optional[int] = Field(default=None)
|
||||
# Optional overrides for pipeline behavior
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="OpenAI-Compatible Chat API",
|
||||
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def sse_chunks_from_stream(
|
||||
chunk_generator, response_id: str, model: str, created_time: int
|
||||
):
|
||||
"""
|
||||
Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
for chunk in chunk_generator:
|
||||
if chunk:
|
||||
data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Final message
|
||||
final = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(
|
||||
chunk_generator, response_id: str, model: str, created_time: int
|
||||
):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
async for chunk in chunk_generator:
|
||||
if chunk:
|
||||
data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Final message
|
||||
final = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
messages = body.get("messages")
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="messages is required")
|
||||
|
||||
stream = body.get("stream", False)
|
||||
model = body.get("model", "gpt-3.5-turbo")
|
||||
thread_id = body.get("thread_id", 3)
|
||||
|
||||
# Extract latest user message
|
||||
user_msg = None
|
||||
for m in reversed(messages):
|
||||
role = m.get("role") if isinstance(m, dict) else None
|
||||
content = m.get("content") if isinstance(m, dict) else None
|
||||
if role == "user" and content:
|
||||
user_msg = content
|
||||
break
|
||||
|
||||
if user_msg is None:
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
response_id = f"chatcmpl-{os.urandom(12).hex()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
if stream:
|
||||
# Use async streaming from pipeline
|
||||
chunk_generator = await pipeline.achat(
|
||||
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_astream(
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
model=model,
|
||||
created_time=created_time,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: get full result using async
|
||||
result_text = await pipeline.achat(
|
||||
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": result_text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI-compatible endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "OpenAI-compatible Chat API",
|
||||
"endpoints": ["/v1/chat/completions", "/v1/memory (DELETE)", "/health"],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.delete("/v1/memory")
|
||||
async def delete_memory(_: str = Depends(verify_api_key)):
|
||||
"""Delete all conversation memory/history."""
|
||||
try:
|
||||
await pipeline.aclear_memory()
|
||||
return JSONResponse(content={"status": "success", "message": "Memory cleared"})
|
||||
except Exception as e:
|
||||
logger.error(f"Memory deletion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"server_openai:app",
|
||||
host="0.0.0.0",
|
||||
port=8588,
|
||||
reload=True,
|
||||
)
|
||||
290
lang_agent/fastapi_server/server_rest.py
Normal file
290
lang_agent/fastapi_server/server_rest.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from fastapi import FastAPI, HTTPException, Request, Depends, Security, Path
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Literal
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import tyro
|
||||
|
||||
# Ensure we can import from project root
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize Pipeline once (matches existing server_* pattern)
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
logger.info(f"starting agent with pipeline: \n{pipeline_config}")
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
|
||||
# API Key Authentication
|
||||
|
||||
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
if not api_key:
|
||||
# Tests expect 401 (not FastAPI's default 403) when auth header is missing.
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
key = api_key[7:] if api_key.startswith("Bearer ") else api_key
|
||||
if VALID_API_KEYS and key not in VALID_API_KEYS:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
return key
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
# Avoid extra deps; good enough for API metadata.
|
||||
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
|
||||
def _new_conversation_id() -> str:
|
||||
return f"c_{os.urandom(12).hex()}"
|
||||
|
||||
|
||||
def _normalize_thread_id(conversation_id: str) -> str:
|
||||
"""
|
||||
Pipeline.achat supports a "{thread_id}_{device_id}" format.
|
||||
Memory is keyed by the base thread_id (before the device_id suffix).
|
||||
"""
|
||||
# Conversation IDs we mint are "c_{hex}" (2 segments). Some clients append a device_id:
|
||||
# e.g. "c_test123_device456" -> base thread "c_test123".
|
||||
parts = conversation_id.split("_")
|
||||
if len(parts) >= 3:
|
||||
return conversation_id.rsplit("_", 1)[0]
|
||||
return conversation_id
|
||||
|
||||
|
||||
def _try_clear_single_thread_memory(thread_id: str) -> bool:
|
||||
"""
|
||||
Best-effort per-thread memory deletion.
|
||||
Returns True if we believe we cleared something, else False.
|
||||
"""
|
||||
g = getattr(pipeline, "graph", None)
|
||||
mem = getattr(g, "memory", None)
|
||||
if isinstance(mem, MemorySaver):
|
||||
try:
|
||||
mem.delete_thread(thread_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete memory thread {thread_id}: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
class ConversationCreateResponse(BaseModel):
|
||||
id: str
|
||||
created_at: str
|
||||
|
||||
|
||||
class MessageCreateRequest(BaseModel):
|
||||
# Keep this permissive so invalid roles get a 400 from endpoint logic (not 422 from validation).
|
||||
role: str = Field(default="user")
|
||||
content: str
|
||||
stream: bool = Field(default=False)
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
role: Literal["assistant"] = Field(default="assistant")
|
||||
content: str
|
||||
|
||||
|
||||
class ConversationMessageResponse(BaseModel):
|
||||
conversation_id: str
|
||||
message: MessageResponse
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
input: str
|
||||
conversation_id: Optional[str] = Field(default=None)
|
||||
stream: bool = Field(default=False)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
conversation_id: str
|
||||
output: str
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="REST Agent API",
|
||||
description="Resource-oriented REST API backed by Pipeline.achat (no RAG/eval/tools exposure).",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
async def rest_sse_from_astream(
|
||||
chunk_generator, response_id: str, conversation_id: str
|
||||
):
|
||||
"""
|
||||
Stream chunks as SSE events.
|
||||
|
||||
Format:
|
||||
- data: {"type":"delta","id":...,"conversation_id":...,"delta":"..."}
|
||||
- data: {"type":"done","id":...,"conversation_id":...}
|
||||
- data: [DONE]
|
||||
"""
|
||||
async for chunk in chunk_generator:
|
||||
if chunk:
|
||||
data = {
|
||||
"type": "delta",
|
||||
"id": response_id,
|
||||
"conversation_id": conversation_id,
|
||||
"delta": chunk,
|
||||
}
|
||||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
done = {"type": "done", "id": response_id, "conversation_id": conversation_id}
|
||||
yield f"data: {json.dumps(done, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "REST Agent API",
|
||||
"endpoints": [
|
||||
"/v1/conversations (POST)",
|
||||
"/v1/chat (POST)",
|
||||
"/v1/conversations/{conversation_id}/messages (POST)",
|
||||
"/v1/conversations/{conversation_id}/memory (DELETE)",
|
||||
"/v1/memory (DELETE)",
|
||||
"/health (GET)",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.post("/v1/conversations", response_model=ConversationCreateResponse)
|
||||
async def create_conversation(_: str = Depends(verify_api_key)):
|
||||
return ConversationCreateResponse(id=_new_conversation_id(), created_at=_now_iso())
|
||||
|
||||
|
||||
@app.post("/v1/chat")
|
||||
async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
||||
conversation_id = body.conversation_id or _new_conversation_id()
|
||||
response_id = f"restcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if body.stream:
|
||||
chunk_generator = await pipeline.achat(
|
||||
inp=body.input, as_stream=True, thread_id=conversation_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
rest_sse_from_astream(
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
conversation_id=conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result_text = await pipeline.achat(
|
||||
inp=body.input, as_stream=False, thread_id=conversation_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
return JSONResponse(
|
||||
content=ChatResponse(
|
||||
conversation_id=conversation_id, output=result_text
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/conversations/{conversation_id}/messages")
|
||||
async def create_message(
|
||||
body: MessageCreateRequest,
|
||||
conversation_id: str = Path(...),
|
||||
_: str = Depends(verify_api_key),
|
||||
):
|
||||
if body.role != "user":
|
||||
raise HTTPException(status_code=400, detail="Only role='user' is supported")
|
||||
|
||||
response_id = f"restmsg-{os.urandom(12).hex()}"
|
||||
|
||||
if body.stream:
|
||||
chunk_generator = await pipeline.achat(
|
||||
inp=body.content, as_stream=True, thread_id=conversation_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
rest_sse_from_astream(
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
conversation_id=conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result_text = await pipeline.achat(
|
||||
inp=body.content, as_stream=False, thread_id=conversation_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
out = ConversationMessageResponse(
|
||||
conversation_id=conversation_id, message=MessageResponse(content=result_text)
|
||||
)
|
||||
return JSONResponse(content=out.model_dump())
|
||||
|
||||
|
||||
@app.delete("/v1/memory")
|
||||
async def delete_all_memory(_: str = Depends(verify_api_key)):
|
||||
"""Delete all conversation memory/history."""
|
||||
try:
|
||||
await pipeline.aclear_memory()
|
||||
return JSONResponse(content={"status": "success", "scope": "all"})
|
||||
except Exception as e:
|
||||
logger.error(f"Memory deletion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/v1/conversations/{conversation_id}/memory")
|
||||
async def delete_conversation_memory(
|
||||
conversation_id: str = Path(...),
|
||||
_: str = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Best-effort per-conversation memory deletion.
|
||||
|
||||
Note: Pipeline exposes only global clear; per-thread delete is done by directly
|
||||
deleting the thread in the underlying MemorySaver if present.
|
||||
"""
|
||||
thread_id = _normalize_thread_id(conversation_id)
|
||||
cleared = _try_clear_single_thread_memory(thread_id)
|
||||
if cleared:
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "success",
|
||||
"scope": "conversation",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
)
|
||||
return JSONResponse(
|
||||
content={
|
||||
"status": "unsupported",
|
||||
"message": "Per-conversation memory clearing not supported by current graph; use DELETE /v1/memory instead.",
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
status_code=501,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"server_rest:app",
|
||||
host="0.0.0.0",
|
||||
port=8589,
|
||||
reload=True,
|
||||
)
|
||||
208
lang_agent/fastapi_server/server_viewer.py
Normal file
208
lang_agent/fastapi_server/server_viewer.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Optional
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import time
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure we can import from project root
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from lang_agent.components.conv_store import ConversationStore
|
||||
|
||||
app = FastAPI(
|
||||
title="Conversation Viewer",
|
||||
description="Web UI to view conversations from the database",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize conversation store
|
||||
try:
|
||||
conv_store = ConversationStore()
|
||||
except ValueError as e:
|
||||
print(f"Warning: {e}. Make sure CONN_STR environment variable is set.")
|
||||
conv_store = None
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message_type: str
|
||||
content: str
|
||||
sequence_number: int
|
||||
created_at: str
|
||||
|
||||
|
||||
class ConversationListItem(BaseModel):
|
||||
conversation_id: str
|
||||
message_count: int
|
||||
last_updated: Optional[str] = None
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
"""Serve the main HTML page"""
|
||||
html_path = Path(__file__).parent.parent / "static" / "viewer.html"
|
||||
if html_path.exists():
|
||||
return HTMLResponse(content=html_path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
return HTMLResponse(content="<h1>Viewer HTML not found. Please create static/viewer.html</h1>")
|
||||
|
||||
|
||||
@app.get("/api/conversations", response_model=List[ConversationListItem])
|
||||
async def list_conversations():
|
||||
"""Get list of all conversations"""
|
||||
if conv_store is None:
|
||||
raise HTTPException(status_code=500, detail="Database connection not configured")
|
||||
|
||||
import psycopg
|
||||
conn_str = os.environ.get("CONN_STR")
|
||||
if not conn_str:
|
||||
raise HTTPException(status_code=500, detail="CONN_STR not set")
|
||||
|
||||
with psycopg.connect(conn_str) as conn:
|
||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||
# Get all unique conversation IDs with message counts and last updated time
|
||||
cur.execute("""
|
||||
SELECT
|
||||
conversation_id,
|
||||
COUNT(*) as message_count,
|
||||
MAX(created_at) as last_updated
|
||||
FROM messages
|
||||
GROUP BY conversation_id
|
||||
ORDER BY last_updated DESC
|
||||
""")
|
||||
results = cur.fetchall()
|
||||
|
||||
return [
|
||||
ConversationListItem(
|
||||
conversation_id=row["conversation_id"],
|
||||
message_count=row["message_count"],
|
||||
last_updated=row["last_updated"].isoformat() if row["last_updated"] else None
|
||||
)
|
||||
for row in results
|
||||
]
|
||||
|
||||
|
||||
@app.get("/api/conversations/{conversation_id}/messages", response_model=List[MessageResponse])
|
||||
async def get_conversation_messages(conversation_id: str):
|
||||
"""Get all messages for a specific conversation"""
|
||||
if conv_store is None:
|
||||
raise HTTPException(status_code=500, detail="Database connection not configured")
|
||||
|
||||
messages = conv_store.get_conversation(conversation_id)
|
||||
|
||||
return [
|
||||
MessageResponse(
|
||||
message_type=msg["message_type"],
|
||||
content=msg["content"],
|
||||
sequence_number=msg["sequence_number"],
|
||||
created_at=msg["created_at"].isoformat() if msg["created_at"] else ""
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy", "db_connected": conv_store is not None}
|
||||
|
||||
|
||||
@app.get("/api/events")
|
||||
async def stream_events():
|
||||
"""Server-Sent Events endpoint for live updates"""
|
||||
if conv_store is None:
|
||||
raise HTTPException(status_code=500, detail="Database connection not configured")
|
||||
|
||||
import psycopg
|
||||
conn_str = os.environ.get("CONN_STR")
|
||||
if not conn_str:
|
||||
raise HTTPException(status_code=500, detail="CONN_STR not set")
|
||||
|
||||
async def event_generator():
|
||||
last_check = {}
|
||||
check_interval = 2.0 # Check every 2 seconds
|
||||
|
||||
while True:
|
||||
try:
|
||||
with psycopg.connect(conn_str) as conn:
|
||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||
# Get current state of all conversations
|
||||
cur.execute("""
|
||||
SELECT
|
||||
conversation_id,
|
||||
COUNT(*) as message_count,
|
||||
MAX(created_at) as last_updated
|
||||
FROM messages
|
||||
GROUP BY conversation_id
|
||||
ORDER BY last_updated DESC
|
||||
""")
|
||||
results = cur.fetchall()
|
||||
|
||||
current_state = {}
|
||||
for row in results:
|
||||
conv_id = row["conversation_id"]
|
||||
last_updated = row["last_updated"]
|
||||
message_count = row["message_count"]
|
||||
|
||||
current_state[conv_id] = {
|
||||
"message_count": message_count,
|
||||
"last_updated": last_updated.isoformat() if last_updated else None,
|
||||
"timestamp": last_updated.timestamp() if last_updated else 0
|
||||
}
|
||||
|
||||
# Check if this conversation is new or updated
|
||||
if conv_id not in last_check:
|
||||
# New conversation
|
||||
yield f"data: {json.dumps({'type': 'conversation_new', 'conversation': {'conversation_id': conv_id, 'message_count': message_count, 'last_updated': current_state[conv_id]['last_updated']}})}\n\n"
|
||||
elif last_check[conv_id]["timestamp"] < current_state[conv_id]["timestamp"]:
|
||||
# Updated conversation (new messages)
|
||||
yield f"data: {json.dumps({'type': 'conversation_updated', 'conversation': {'conversation_id': conv_id, 'message_count': message_count, 'last_updated': current_state[conv_id]['last_updated']}})}\n\n"
|
||||
|
||||
# Check for deleted conversations
|
||||
for conv_id in last_check:
|
||||
if conv_id not in current_state:
|
||||
yield f"data: {json.dumps({'type': 'conversation_deleted', 'conversation_id': conv_id})}\n\n"
|
||||
|
||||
last_check = current_state
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)})}\n\n"
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable buffering for nginx
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"server_viewer:app",
|
||||
host="0.0.0.0",
|
||||
port=8590,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user