unified constants
This commit is contained in:
@@ -13,16 +13,17 @@ from pydantic import BaseModel, Field
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.config.db_config_manager import DBConfigManager
|
||||
from lang_agent.config.constants import (
|
||||
_PROJECT_ROOT,
|
||||
MCP_CONFIG_PATH,
|
||||
MCP_CONFIG_DEFAULT_CONTENT,
|
||||
PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
from lang_agent.front_api.build_server_utils import (
|
||||
GRAPH_BUILD_FNCS,
|
||||
update_pipeline_registry,
|
||||
)
|
||||
|
||||
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__)))
|
||||
_MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
|
||||
_MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
|
||||
_PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
|
||||
|
||||
|
||||
class GraphConfigUpsertRequest(BaseModel):
|
||||
graph_id: str
|
||||
@@ -206,20 +207,20 @@ def _parse_mcp_tool_keys(raw_content: str) -> List[str]:
|
||||
|
||||
|
||||
def _read_mcp_config_raw() -> str:
|
||||
if not osp.exists(_MCP_CONFIG_PATH):
|
||||
os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(_MCP_CONFIG_DEFAULT_CONTENT)
|
||||
with open(_MCP_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
if not osp.exists(MCP_CONFIG_PATH):
|
||||
os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(MCP_CONFIG_DEFAULT_CONTENT)
|
||||
with open(MCP_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _read_pipeline_registry() -> Dict[str, Any]:
|
||||
if not osp.exists(_PIPELINE_REGISTRY_PATH):
|
||||
os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
if not osp.exists(PIPELINE_REGISTRY_PATH):
|
||||
os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2)
|
||||
with open(_PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f:
|
||||
with open(PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f:
|
||||
registry = json.load(f)
|
||||
pipelines = registry.get("pipelines")
|
||||
if not isinstance(pipelines, dict):
|
||||
@@ -233,8 +234,8 @@ def _read_pipeline_registry() -> Dict[str, Any]:
|
||||
|
||||
|
||||
def _write_pipeline_registry(registry: Dict[str, Any]) -> None:
|
||||
os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(registry, f, indent=2)
|
||||
f.write("\n")
|
||||
|
||||
@@ -433,7 +434,7 @@ async def get_mcp_tool_config():
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return McpConfigReadResponse(
|
||||
path=_MCP_CONFIG_PATH,
|
||||
path=MCP_CONFIG_PATH,
|
||||
raw_content=raw_content,
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
@@ -443,8 +444,8 @@ async def get_mcp_tool_config():
|
||||
async def update_mcp_tool_config(body: McpConfigUpdateRequest):
|
||||
try:
|
||||
tool_keys = _parse_mcp_tool_keys(body.raw_content)
|
||||
os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
# Keep user formatting/comments as entered while ensuring trailing newline.
|
||||
f.write(body.raw_content.rstrip() + "\n")
|
||||
except ValueError as e:
|
||||
@@ -453,7 +454,7 @@ async def update_mcp_tool_config(body: McpConfigUpdateRequest):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return McpConfigUpdateResponse(
|
||||
status="updated",
|
||||
path=_MCP_CONFIG_PATH,
|
||||
path=MCP_CONFIG_PATH,
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
|
||||
@@ -528,7 +529,7 @@ async def create_pipeline(body: PipelineCreateRequest):
|
||||
config_file=config_file,
|
||||
llm_name=body.llm_name,
|
||||
enabled=body.enabled,
|
||||
registry_f=_PIPELINE_REGISTRY_PATH,
|
||||
registry_f=PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}")
|
||||
@@ -543,7 +544,8 @@ async def create_pipeline(body: PipelineCreateRequest):
|
||||
normalized = _normalize_pipeline_spec(pipeline_id, pipeline_spec)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to read pipeline registry after update: {e}"
|
||||
status_code=500,
|
||||
detail=f"Failed to read pipeline registry after update: {e}",
|
||||
)
|
||||
|
||||
return PipelineCreateResponse(
|
||||
@@ -554,7 +556,7 @@ async def create_pipeline(body: PipelineCreateRequest):
|
||||
llm_name=normalized.llm_name,
|
||||
enabled=normalized.enabled,
|
||||
reload_required=True,
|
||||
registry_path=_PIPELINE_REGISTRY_PATH,
|
||||
registry_path=PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
|
||||
|
||||
@@ -608,7 +610,9 @@ async def list_pipeline_api_keys():
|
||||
async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertRequest):
|
||||
normalized_key = api_key.strip()
|
||||
if not normalized_key:
|
||||
raise HTTPException(status_code=400, detail="api_key path parameter is required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="api_key path parameter is required"
|
||||
)
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
pipelines = registry.get("pipelines", {})
|
||||
@@ -662,7 +666,9 @@ async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertR
|
||||
async def delete_pipeline_api_key_policy(api_key: str):
|
||||
normalized_key = api_key.strip()
|
||||
if not normalized_key:
|
||||
raise HTTPException(status_code=400, detail="api_key path parameter is required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="api_key path parameter is required"
|
||||
)
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
api_keys = registry.get("api_keys", {})
|
||||
|
||||
@@ -18,29 +18,18 @@ sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import PipelineConfig
|
||||
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
|
||||
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Load base config for route-level overrides (pipelines are lazy-loaded from registry)
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
|
||||
|
||||
# API Key Authentication
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
REGISTRY_FILE = os.environ.get(
|
||||
"FAST_PIPELINE_REGISTRY_FILE",
|
||||
osp.join(
|
||||
osp.dirname(osp.dirname(osp.abspath(__file__))),
|
||||
"configs",
|
||||
"pipeline_registry.json",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
PIPELINE_MANAGER = ServerPipelineManager(
|
||||
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
|
||||
default_config=pipeline_config,
|
||||
)
|
||||
PIPELINE_MANAGER.load_registry(REGISTRY_FILE)
|
||||
PIPELINE_MANAGER.load_registry(PIPELINE_REGISTRY_PATH)
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||
|
||||
@@ -16,15 +16,12 @@ import tyro
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize Pipeline once
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
|
||||
# API Key Authentication
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
@@ -46,12 +43,12 @@ class OpenAIChatCompletionRequest(BaseModel):
|
||||
temperature: Optional[float] = Field(default=1.0)
|
||||
max_tokens: Optional[int] = Field(default=None)
|
||||
# Optional overrides for pipeline behavior
|
||||
thread_id: Optional[str] = Field(default='3')
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="OpenAI-Compatible Chat API",
|
||||
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat"
|
||||
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -63,7 +60,9 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, created_time: int):
|
||||
def sse_chunks_from_stream(
|
||||
chunk_generator, response_id: str, model: str, created_time: int
|
||||
):
|
||||
"""
|
||||
Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
@@ -75,14 +74,8 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -92,19 +85,15 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int):
|
||||
async def sse_chunks_from_astream(
|
||||
chunk_generator, response_id: str, model: str, created_time: int
|
||||
):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
@@ -116,14 +105,8 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -133,13 +116,7 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -149,15 +126,15 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
||||
async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
|
||||
messages = body.get("messages")
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="messages is required")
|
||||
|
||||
|
||||
stream = body.get("stream", False)
|
||||
model = body.get("model", "gpt-3.5-turbo")
|
||||
thread_id = body.get("thread_id", 3)
|
||||
|
||||
|
||||
# Extract latest user message
|
||||
user_msg = None
|
||||
for m in reversed(messages):
|
||||
@@ -166,27 +143,36 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
if role == "user" and content:
|
||||
user_msg = content
|
||||
break
|
||||
|
||||
|
||||
if user_msg is None:
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
|
||||
response_id = f"chatcmpl-{os.urandom(12).hex()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
if stream:
|
||||
# Use async streaming from pipeline
|
||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
chunk_generator = await pipeline.achat(
|
||||
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time),
|
||||
sse_chunks_from_astream(
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
model=model,
|
||||
created_time=created_time,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
# Non-streaming: get full result using async
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
result_text = await pipeline.achat(
|
||||
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
|
||||
data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion",
|
||||
@@ -195,21 +181,14 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": result_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
"message": {"role": "assistant", "content": result_text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -221,11 +200,7 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
async def root():
|
||||
return {
|
||||
"message": "OpenAI-compatible Chat API",
|
||||
"endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/memory (DELETE)",
|
||||
"/health"
|
||||
]
|
||||
"endpoints": ["/v1/chat/completions", "/v1/memory (DELETE)", "/health"],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize Pipeline once (matches existing server_* pattern)
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
@@ -24,9 +25,6 @@ logger.info(f"starting agent with pipeline: \n{pipeline_config}")
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
|
||||
# API Key Authentication
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
|
||||
|
||||
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
@@ -125,7 +123,9 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
async def rest_sse_from_astream(chunk_generator, response_id: str, conversation_id: str):
|
||||
async def rest_sse_from_astream(
|
||||
chunk_generator, response_id: str, conversation_id: str
|
||||
):
|
||||
"""
|
||||
Stream chunks as SSE events.
|
||||
|
||||
@@ -185,7 +185,9 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
||||
)
|
||||
return StreamingResponse(
|
||||
rest_sse_from_astream(
|
||||
chunk_generator, response_id=response_id, conversation_id=conversation_id
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
conversation_id=conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -195,7 +197,11 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
return JSONResponse(content=ChatResponse(conversation_id=conversation_id, output=result_text).model_dump())
|
||||
return JSONResponse(
|
||||
content=ChatResponse(
|
||||
conversation_id=conversation_id, output=result_text
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/conversations/{conversation_id}/messages")
|
||||
@@ -215,7 +221,9 @@ async def create_message(
|
||||
)
|
||||
return StreamingResponse(
|
||||
rest_sse_from_astream(
|
||||
chunk_generator, response_id=response_id, conversation_id=conversation_id
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
conversation_id=conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -257,7 +265,11 @@ async def delete_conversation_memory(
|
||||
cleared = _try_clear_single_thread_memory(thread_id)
|
||||
if cleared:
|
||||
return JSONResponse(
|
||||
content={"status": "success", "scope": "conversation", "conversation_id": conversation_id}
|
||||
content={
|
||||
"status": "success",
|
||||
"scope": "conversation",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
)
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -276,5 +288,3 @@ if __name__ == "__main__":
|
||||
port=8589,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user