kimi magic

This commit is contained in:
2026-03-04 15:21:07 +08:00
parent 6ed33f3185
commit 91685d5bf7
3 changed files with 171 additions and 142 deletions

View File

@@ -13,13 +13,17 @@ from pydantic import BaseModel, Field
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from lang_agent.config.db_config_manager import DBConfigManager from lang_agent.config.db_config_manager import DBConfigManager
from lang_agent.front_api.build_server_utils import GRAPH_BUILD_FNCS, update_pipeline_registry from lang_agent.front_api.build_server_utils import (
GRAPH_BUILD_FNCS,
update_pipeline_registry,
)
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__))) _PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__)))
_MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json") _MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
_MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n" _MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
_PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json") _PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
class GraphConfigUpsertRequest(BaseModel): class GraphConfigUpsertRequest(BaseModel):
graph_id: str graph_id: str
pipeline_id: str pipeline_id: str
@@ -28,6 +32,7 @@ class GraphConfigUpsertRequest(BaseModel):
prompt_dict: Dict[str, str] = Field(default_factory=dict) prompt_dict: Dict[str, str] = Field(default_factory=dict)
api_key: Optional[str] = Field(default=None) api_key: Optional[str] = Field(default=None)
class GraphConfigUpsertResponse(BaseModel): class GraphConfigUpsertResponse(BaseModel):
graph_id: str graph_id: str
pipeline_id: str pipeline_id: str
@@ -36,6 +41,7 @@ class GraphConfigUpsertResponse(BaseModel):
prompt_keys: List[str] prompt_keys: List[str]
api_key: str api_key: str
class GraphConfigReadResponse(BaseModel): class GraphConfigReadResponse(BaseModel):
graph_id: Optional[str] = Field(default=None) graph_id: Optional[str] = Field(default=None)
pipeline_id: str pipeline_id: str
@@ -44,6 +50,7 @@ class GraphConfigReadResponse(BaseModel):
prompt_dict: Dict[str, str] prompt_dict: Dict[str, str]
api_key: str = Field(default="") api_key: str = Field(default="")
class GraphConfigListItem(BaseModel): class GraphConfigListItem(BaseModel):
graph_id: Optional[str] = Field(default=None) graph_id: Optional[str] = Field(default=None)
pipeline_id: str pipeline_id: str
@@ -56,10 +63,12 @@ class GraphConfigListItem(BaseModel):
created_at: Optional[str] = Field(default=None) created_at: Optional[str] = Field(default=None)
updated_at: Optional[str] = Field(default=None) updated_at: Optional[str] = Field(default=None)
class GraphConfigListResponse(BaseModel): class GraphConfigListResponse(BaseModel):
items: List[GraphConfigListItem] items: List[GraphConfigListItem]
count: int count: int
class PipelineCreateRequest(BaseModel): class PipelineCreateRequest(BaseModel):
graph_id: str = Field( graph_id: str = Field(
description="Graph key from GRAPH_BUILD_FNCS, e.g. routing or react" description="Graph key from GRAPH_BUILD_FNCS, e.g. routing or react"
@@ -71,9 +80,8 @@ class PipelineCreateRequest(BaseModel):
api_key: str api_key: str
entry_point: str = Field(default="fastapi_server/server_dashscope.py") entry_point: str = Field(default="fastapi_server/server_dashscope.py")
llm_name: str = Field(default="qwen-plus") llm_name: str = Field(default="qwen-plus")
route_id: Optional[str] = Field(default=None)
enabled: bool = Field(default=True) enabled: bool = Field(default=True)
prompt_pipeline_id: Optional[str] = Field(default=None)
class PipelineCreateResponse(BaseModel): class PipelineCreateResponse(BaseModel):
run_id: str run_id: str
@@ -87,12 +95,12 @@ class PipelineCreateResponse(BaseModel):
auth_header_name: str auth_header_name: str
auth_key_once: str auth_key_once: str
auth_key_masked: str auth_key_masked: str
route_id: str
enabled: bool enabled: bool
config_file: str config_file: str
reload_required: bool reload_required: bool
registry_path: str registry_path: str
class PipelineRunInfo(BaseModel): class PipelineRunInfo(BaseModel):
run_id: str run_id: str
pid: int pid: int
@@ -104,29 +112,33 @@ class PipelineRunInfo(BaseModel):
auth_type: str auth_type: str
auth_header_name: str auth_header_name: str
auth_key_masked: str auth_key_masked: str
route_id: str
enabled: bool enabled: bool
config_file: Optional[str] = Field(default=None) config_file: Optional[str] = Field(default=None)
class PipelineListResponse(BaseModel): class PipelineListResponse(BaseModel):
items: List[PipelineRunInfo] items: List[PipelineRunInfo]
count: int count: int
class PipelineStopResponse(BaseModel): class PipelineStopResponse(BaseModel):
run_id: str run_id: str
status: str status: str
route_id: str pipeline_id: str
enabled: bool enabled: bool
reload_required: bool reload_required: bool
class McpConfigReadResponse(BaseModel): class McpConfigReadResponse(BaseModel):
path: str path: str
raw_content: str raw_content: str
tool_keys: List[str] tool_keys: List[str]
class McpConfigUpdateRequest(BaseModel): class McpConfigUpdateRequest(BaseModel):
raw_content: str raw_content: str
class McpConfigUpdateResponse(BaseModel): class McpConfigUpdateResponse(BaseModel):
status: str status: str
path: str path: str
@@ -196,12 +208,12 @@ def _read_pipeline_registry() -> Dict[str, Any]:
if not osp.exists(_PIPELINE_REGISTRY_PATH): if not osp.exists(_PIPELINE_REGISTRY_PATH):
os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True) os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True)
with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
json.dump({"routes": {}, "api_keys": {}}, f, indent=2) json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2)
with open(_PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f: with open(_PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f:
registry = json.load(f) registry = json.load(f)
routes = registry.get("routes") pipelines = registry.get("pipelines")
if not isinstance(routes, dict): if not isinstance(pipelines, dict):
raise ValueError("`routes` in pipeline registry must be an object") raise ValueError("`pipelines` in pipeline registry must be an object")
return registry return registry
@@ -237,8 +249,11 @@ async def upsert_graph_config(body: GraphConfigUpsertRequest):
api_key=(body.api_key or "").strip(), api_key=(body.api_key or "").strip(),
) )
@app.get("/v1/graph-configs", response_model=GraphConfigListResponse) @app.get("/v1/graph-configs", response_model=GraphConfigListResponse)
async def list_graph_configs(pipeline_id: Optional[str] = None, graph_id: Optional[str] = None): async def list_graph_configs(
pipeline_id: Optional[str] = None, graph_id: Optional[str] = None
):
try: try:
rows = _db.list_prompt_sets(pipeline_id=pipeline_id, graph_id=graph_id) rows = _db.list_prompt_sets(pipeline_id=pipeline_id, graph_id=graph_id)
except Exception as e: except Exception as e:
@@ -247,10 +262,15 @@ async def list_graph_configs(pipeline_id: Optional[str] = None, graph_id: Option
items = [GraphConfigListItem(**row) for row in rows] items = [GraphConfigListItem(**row) for row in rows]
return GraphConfigListResponse(items=items, count=len(items)) return GraphConfigListResponse(items=items, count=len(items))
@app.get("/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse)
@app.get(
"/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse
)
async def get_default_graph_config(pipeline_id: str): async def get_default_graph_config(pipeline_id: str):
try: try:
prompt_dict, tool_keys = _db.get_config(pipeline_id=pipeline_id, prompt_set_id=None) prompt_dict, tool_keys = _db.get_config(
pipeline_id=pipeline_id, prompt_set_id=None
)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception as e:
@@ -279,11 +299,16 @@ async def get_default_graph_config(pipeline_id: str):
api_key=(active.get("api_key") or ""), api_key=(active.get("api_key") or ""),
) )
@app.get("/v1/graphs/{graph_id}/default-config", response_model=GraphConfigReadResponse) @app.get("/v1/graphs/{graph_id}/default-config", response_model=GraphConfigReadResponse)
async def get_graph_default_config_by_graph(graph_id: str): async def get_graph_default_config_by_graph(graph_id: str):
return await get_default_graph_config(pipeline_id=graph_id) return await get_default_graph_config(pipeline_id=graph_id)
@app.get("/v1/graph-configs/{pipeline_id}/{prompt_set_id}", response_model=GraphConfigReadResponse)
@app.get(
"/v1/graph-configs/{pipeline_id}/{prompt_set_id}",
response_model=GraphConfigReadResponse,
)
async def get_graph_config(pipeline_id: str, prompt_set_id: str): async def get_graph_config(pipeline_id: str, prompt_set_id: str):
try: try:
meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id) meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
@@ -333,6 +358,7 @@ async def delete_graph_config(pipeline_id: str, prompt_set_id: str):
async def available_graphs(): async def available_graphs():
return {"available_graphs": sorted(GRAPH_BUILD_FNCS.keys())} return {"available_graphs": sorted(GRAPH_BUILD_FNCS.keys())}
@app.get("/v1/tool-configs/mcp", response_model=McpConfigReadResponse) @app.get("/v1/tool-configs/mcp", response_model=McpConfigReadResponse)
async def get_mcp_tool_config(): async def get_mcp_tool_config():
try: try:
@@ -367,6 +393,7 @@ async def update_mcp_tool_config(body: McpConfigUpdateRequest):
tool_keys=tool_keys, tool_keys=tool_keys,
) )
@app.get("/v1/pipelines", response_model=PipelineListResponse) @app.get("/v1/pipelines", response_model=PipelineListResponse)
async def list_running_pipelines(): async def list_running_pipelines():
try: try:
@@ -377,24 +404,23 @@ async def list_running_pipelines():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
items: List[PipelineRunInfo] = [] items: List[PipelineRunInfo] = []
routes = registry.get("routes", {}) pipelines = registry.get("pipelines", {})
for route_id, spec in sorted(routes.items()): for pipeline_id, spec in sorted(pipelines.items()):
if not isinstance(spec, dict): if not isinstance(spec, dict):
continue continue
enabled = bool(spec.get("enabled", True)) enabled = bool(spec.get("enabled", True))
items.append( items.append(
PipelineRunInfo( PipelineRunInfo(
run_id=route_id, run_id=pipeline_id,
pid=-1, pid=-1,
graph_id=str(spec.get("graph_id") or route_id), graph_id=str(spec.get("graph_id") or pipeline_id),
pipeline_id=str(spec.get("prompt_pipeline_id") or route_id), pipeline_id=pipeline_id,
prompt_set_id="default", prompt_set_id="default",
url=_DASHSCOPE_URL, url=_DASHSCOPE_URL,
port=-1, port=-1,
auth_type="bearer", auth_type="bearer",
auth_header_name="Authorization", auth_header_name="Authorization",
auth_key_masked="", auth_key_masked="",
route_id=route_id,
enabled=enabled, enabled=enabled,
config_file=spec.get("config_file"), config_file=spec.get("config_file"),
) )
@@ -411,50 +437,37 @@ async def create_pipeline(body: PipelineCreateRequest):
detail=f"Unknown graph_id '{body.graph_id}'. Valid options: {sorted(GRAPH_BUILD_FNCS.keys())}", detail=f"Unknown graph_id '{body.graph_id}'. Valid options: {sorted(GRAPH_BUILD_FNCS.keys())}",
) )
route_id = (body.route_id or body.pipeline_id).strip() pipeline_id = body.pipeline_id.strip()
if not route_id: if not pipeline_id:
raise HTTPException(status_code=400, detail="route_id or pipeline_id is required") raise HTTPException(status_code=400, detail="pipeline_id is required")
prompt_pipeline_id = (body.prompt_pipeline_id or body.pipeline_id).strip() config_file = f"configs/pipelines/{pipeline_id}.yml"
if not prompt_pipeline_id:
raise HTTPException(status_code=400, detail="prompt_pipeline_id or pipeline_id is required")
config_file = f"configs/pipelines/{route_id}.yml"
config_abs_dir = osp.join(_PROJECT_ROOT, "configs", "pipelines") config_abs_dir = osp.join(_PROJECT_ROOT, "configs", "pipelines")
try: try:
build_fn( build_fn(
pipeline_id=prompt_pipeline_id, pipeline_id=pipeline_id,
prompt_set=body.prompt_set_id, prompt_set=body.prompt_set_id,
tool_keys=body.tool_keys, tool_keys=body.tool_keys,
api_key=body.api_key, api_key=body.api_key,
llm_name=body.llm_name, llm_name=body.llm_name,
pipeline_config_dir=config_abs_dir, pipeline_config_dir=config_abs_dir,
) )
generated_config_file = f"configs/pipelines/{prompt_pipeline_id}.yml"
if prompt_pipeline_id != route_id:
# Keep runtime route_id and config_file aligned for lazy loading by route.
src = osp.join(config_abs_dir, f"{prompt_pipeline_id}.yml")
dst = osp.join(config_abs_dir, f"{route_id}.yml")
if osp.exists(src):
with open(src, "r", encoding="utf-8") as rf, open(dst, "w", encoding="utf-8") as wf:
wf.write(rf.read())
generated_config_file = config_file
update_pipeline_registry( update_pipeline_registry(
pipeline_id=route_id, pipeline_id=pipeline_id,
prompt_set=prompt_pipeline_id,
graph_id=body.graph_id, graph_id=body.graph_id,
config_file=generated_config_file, config_file=config_file,
llm_name=body.llm_name, llm_name=body.llm_name,
enabled=body.enabled, enabled=body.enabled,
registry_f=_PIPELINE_REGISTRY_PATH, registry_f=_PIPELINE_REGISTRY_PATH,
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to register route: {e}") raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}")
return PipelineCreateResponse( return PipelineCreateResponse(
run_id=route_id, run_id=pipeline_id,
pid=-1, pid=-1,
graph_id=body.graph_id, graph_id=body.graph_id,
pipeline_id=prompt_pipeline_id, pipeline_id=pipeline_id,
prompt_set_id=body.prompt_set_id, prompt_set_id=body.prompt_set_id,
url=_DASHSCOPE_URL, url=_DASHSCOPE_URL,
port=-1, port=-1,
@@ -462,21 +475,23 @@ async def create_pipeline(body: PipelineCreateRequest):
auth_header_name="Authorization", auth_header_name="Authorization",
auth_key_once="", auth_key_once="",
auth_key_masked="", auth_key_masked="",
route_id=route_id,
enabled=body.enabled, enabled=body.enabled,
config_file=config_file, config_file=config_file,
reload_required=True, reload_required=True,
registry_path=_PIPELINE_REGISTRY_PATH, registry_path=_PIPELINE_REGISTRY_PATH,
) )
@app.delete("/v1/pipelines/{route_id}", response_model=PipelineStopResponse)
async def stop_pipeline(route_id: str): @app.delete("/v1/pipelines/{pipeline_id}", response_model=PipelineStopResponse)
async def stop_pipeline(pipeline_id: str):
try: try:
registry = _read_pipeline_registry() registry = _read_pipeline_registry()
routes = registry.get("routes", {}) pipelines = registry.get("pipelines", {})
spec = routes.get(route_id) spec = pipelines.get(pipeline_id)
if not isinstance(spec, dict): if not isinstance(spec, dict):
raise HTTPException(status_code=404, detail=f"route_id '{route_id}' not found") raise HTTPException(
status_code=404, detail=f"pipeline_id '{pipeline_id}' not found"
)
spec["enabled"] = False spec["enabled"] = False
_write_pipeline_registry(registry) _write_pipeline_registry(registry)
except HTTPException: except HTTPException:
@@ -487,9 +502,9 @@ async def stop_pipeline(route_id: str):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return PipelineStopResponse( return PipelineStopResponse(
run_id=route_id, run_id=pipeline_id,
status="disabled", status="disabled",
route_id=route_id, pipeline_id=pipeline_id,
enabled=False, enabled=False,
reload_required=True, reload_required=True,
) )

View File

@@ -28,12 +28,16 @@ API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
REGISTRY_FILE = os.environ.get( REGISTRY_FILE = os.environ.get(
"FAST_PIPELINE_REGISTRY_FILE", "FAST_PIPELINE_REGISTRY_FILE",
osp.join(osp.dirname(osp.dirname(osp.abspath(__file__))), "configs", "pipeline_registry.json"), osp.join(
osp.dirname(osp.dirname(osp.abspath(__file__))),
"configs",
"pipeline_registry.json",
),
) )
PIPELINE_MANAGER = ServerPipelineManager( PIPELINE_MANAGER = ServerPipelineManager(
default_route_id=os.environ.get("FAST_DEFAULT_ROUTE_ID", os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default")), default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
default_config=pipeline_config, default_config=pipeline_config,
) )
PIPELINE_MANAGER.load_registry(REGISTRY_FILE) PIPELINE_MANAGER.load_registry(REGISTRY_FILE)
@@ -62,8 +66,10 @@ class DSApplicationCallRequest(BaseModel):
thread_id: Optional[str] = Field(default="3") thread_id: Optional[str] = Field(default="3")
app = FastAPI(title="DashScope-Compatible Application API", app = FastAPI(
description="DashScope Application.call compatible endpoint backed by pipeline.chat") title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -74,7 +80,9 @@ app.add_middleware(
) )
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen-flash"): def sse_chunks_from_stream(
chunk_generator, response_id: str, model: str = "qwen-flash"
):
""" """
Stream chunks from pipeline and format as SSE. Stream chunks from pipeline and format as SSE.
Accumulates text and sends incremental updates. Accumulates text and sends incremental updates.
@@ -115,7 +123,9 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen
yield f"data: {json.dumps(final)}\n\n" yield f"data: {json.dumps(final)}\n\n"
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"): async def sse_chunks_from_astream(
chunk_generator, response_id: str, model: str = "qwen-flash"
):
""" """
Async version: Stream chunks from pipeline and format as SSE. Async version: Stream chunks from pipeline and format as SSE.
Accumulates text and sends incremental updates. Accumulates text and sends incremental updates.
@@ -202,22 +212,30 @@ async def _process_dashscope_request(
thread_id = body_input.get("session_id") or req_session_id or "3" thread_id = body_input.get("session_id") or req_session_id or "3"
user_msg = _extract_user_message(messages) user_msg = _extract_user_message(messages)
route_id = PIPELINE_MANAGER.resolve_route_id(body=body, app_id=req_app_id, api_key=api_key) pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id(
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(route_id) body=body, app_id=req_app_id, api_key=api_key
)
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id)
# Namespace thread ids to prevent memory collisions across pipelines. # Namespace thread ids to prevent memory collisions across pipelines.
thread_id = f"{route_id}:{thread_id}" thread_id = f"{pipeline_id}:{thread_id}"
response_id = f"appcmpl-{os.urandom(12).hex()}" response_id = f"appcmpl-{os.urandom(12).hex()}"
if stream: if stream:
chunk_generator = await selected_pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) chunk_generator = await selected_pipeline.achat(
inp=user_msg, as_stream=True, thread_id=thread_id
)
return StreamingResponse( return StreamingResponse(
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=selected_model), sse_chunks_from_astream(
chunk_generator, response_id=response_id, model=selected_model
),
media_type="text/event-stream", media_type="text/event-stream",
) )
result_text = await selected_pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) result_text = await selected_pipeline.achat(
inp=user_msg, as_stream=False, thread_id=thread_id
)
if not isinstance(result_text, str): if not isinstance(result_text, str):
result_text = str(result_text) result_text = str(result_text)
@@ -232,9 +250,7 @@ async def _process_dashscope_request(
"created": int(time.time()), "created": int(time.time()),
"model": selected_model, "model": selected_model,
}, },
"route_id": route_id, "pipeline_id": pipeline_id,
# Backward compatibility: keep pipeline_id in response as the route id selector.
"pipeline_id": route_id,
"is_end": True, "is_end": True,
} }
return JSONResponse(content=data) return JSONResponse(content=data)
@@ -292,10 +308,13 @@ async def application_completion(
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "DashScope Application-compatible API", "endpoints": [ return {
"message": "DashScope Application-compatible API",
"endpoints": [
"/v1/apps/{app_id}/sessions/{session_id}/responses", "/v1/apps/{app_id}/sessions/{session_id}/responses",
"/health", "/health",
]} ],
}
@app.get("/health") @app.get("/health")
@@ -310,5 +329,3 @@ if __name__ == "__main__":
port=pipeline_config.port, port=pipeline_config.port,
reload=True, reload=True,
) )

View File

@@ -11,12 +11,12 @@ from lang_agent.config.core_config import load_tyro_conf
class ServerPipelineManager: class ServerPipelineManager:
"""Lazily load and cache multiple pipelines keyed by a client-facing route id.""" """Lazily load and cache multiple pipelines keyed by a client-facing pipeline id."""
def __init__(self, default_route_id: str, default_config: PipelineConfig): def __init__(self, default_pipeline_id: str, default_config: PipelineConfig):
self.default_route_id = default_route_id self.default_pipeline_id = default_pipeline_id
self.default_config = default_config self.default_config = default_config
self._route_specs: Dict[str, Dict[str, Any]] = {} self._pipeline_specs: Dict[str, Dict[str, Any]] = {}
self._api_key_policy: Dict[str, Dict[str, Any]] = {} self._api_key_policy: Dict[str, Dict[str, Any]] = {}
self._pipelines: Dict[str, Pipeline] = {} self._pipelines: Dict[str, Pipeline] = {}
self._pipeline_llm: Dict[str, str] = {} self._pipeline_llm: Dict[str, str] = {}
@@ -38,32 +38,31 @@ class ServerPipelineManager:
with open(abs_path, "r", encoding="utf-8") as f: with open(abs_path, "r", encoding="utf-8") as f:
registry: dict = json.load(f) registry: dict = json.load(f)
routes = registry.get("routes") pipelines = registry.get("pipelines")
if routes is None: if pipelines is None:
# Backward compatibility with initial schema. raise ValueError("`pipelines` in pipeline registry must be an object.")
routes = registry.get("pipelines", {})
if not isinstance(routes, dict):
raise ValueError("`routes` in pipeline registry must be an object.")
self._route_specs = {} self._pipeline_specs = {}
for route_id, spec in routes.items(): for pipeline_id, spec in pipelines.items():
if not isinstance(spec, dict): if not isinstance(spec, dict):
raise ValueError(f"route spec for `{route_id}` must be an object.") raise ValueError(
self._route_specs[route_id] = { f"pipeline spec for `{pipeline_id}` must be an object."
)
self._pipeline_specs[pipeline_id] = {
"enabled": bool(spec.get("enabled", True)), "enabled": bool(spec.get("enabled", True)),
"config_file": spec.get("config_file"), "config_file": spec.get("config_file"),
"overrides": spec.get("overrides", {}), "overrides": spec.get("overrides", {}),
# Explicitly separates routing id from prompt config pipeline_id.
"prompt_pipeline_id": spec.get("prompt_pipeline_id"),
} }
if not self._route_specs: if not self._pipeline_specs:
raise ValueError("pipeline registry must define at least one route.") raise ValueError("pipeline registry must define at least one pipeline.")
api_key_policy = registry.get("api_keys", {}) api_key_policy = registry.get("api_keys", {})
if api_key_policy and not isinstance(api_key_policy, dict): if api_key_policy and not isinstance(api_key_policy, dict):
raise ValueError("`api_keys` in pipeline registry must be an object.") raise ValueError("`api_keys` in pipeline registry must be an object.")
self._api_key_policy = api_key_policy self._api_key_policy = api_key_policy
logger.info(f"loaded pipeline registry: {abs_path}, routes={list(self._route_specs.keys())}") logger.info(
f"loaded pipeline registry: {abs_path}, pipelines={list(self._pipeline_specs.keys())}"
)
def _resolve_config_path(self, config_file: str) -> str: def _resolve_config_path(self, config_file: str) -> str:
path = FsPath(config_file) path = FsPath(config_file)
@@ -74,48 +73,47 @@ class ServerPipelineManager:
root = FsPath(__file__).resolve().parents[2] root = FsPath(__file__).resolve().parents[2]
return str((root / path).resolve()) return str((root / path).resolve())
def _build_pipeline(self, route_id: str) -> Tuple[Pipeline, str]: def _build_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
spec = self._route_specs.get(route_id) spec = self._pipeline_specs.get(pipeline_id)
if spec is None: if spec is None:
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}") raise HTTPException(
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
)
if not spec.get("enabled", True): if not spec.get("enabled", True):
raise HTTPException(status_code=403, detail=f"Route disabled: {route_id}") raise HTTPException(
status_code=403, detail=f"Pipeline disabled: {pipeline_id}"
)
config_file = spec.get("config_file") config_file = spec.get("config_file")
overrides = spec.get("overrides", {}) overrides = spec.get("overrides", {})
if config_file: if config_file:
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file)) loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
# Some legacy yaml configs deserialize to plain dicts instead of
# InstantiateConfig dataclasses. Fall back to default config in that case.
if hasattr(loaded_cfg, "setup"): if hasattr(loaded_cfg, "setup"):
cfg = loaded_cfg cfg = loaded_cfg
else: else:
logger.warning( logger.warning(
f"config_file for route `{route_id}` did not deserialize to config object; " f"config_file for pipeline `{pipeline_id}` did not deserialize to config object; "
"falling back to default config and applying route-level overrides." "falling back to default config and applying pipeline-level overrides."
) )
cfg = copy.deepcopy(self.default_config) cfg = copy.deepcopy(self.default_config)
else: else:
# Build from default config + shallow overrides so new pipelines can be
# added via registry without additional yaml files.
cfg = copy.deepcopy(self.default_config) cfg = copy.deepcopy(self.default_config)
if not isinstance(overrides, dict): if not isinstance(overrides, dict):
raise ValueError(f"route `overrides` for `{route_id}` must be an object.") raise ValueError(
f"pipeline `overrides` for `{pipeline_id}` must be an object."
)
for key, value in overrides.items(): for key, value in overrides.items():
if not hasattr(cfg, key): if not hasattr(cfg, key):
raise ValueError(f"unknown override field `{key}` for route `{route_id}`") raise ValueError(
f"unknown override field `{key}` for pipeline `{pipeline_id}`"
)
setattr(cfg, key, value) setattr(cfg, key, value)
prompt_pipeline_id = spec.get("prompt_pipeline_id")
if prompt_pipeline_id and (not isinstance(overrides, dict) or "pipeline_id" not in overrides):
if hasattr(cfg, "pipeline_id"):
cfg.pipeline_id = prompt_pipeline_id
p = cfg.setup() p = cfg.setup()
llm_name = getattr(cfg, "llm_name", "unknown-model") llm_name = getattr(cfg, "llm_name", "unknown-model")
return p, llm_name return p, llm_name
def _authorize(self, api_key: str, route_id: str) -> None: def _authorize(self, api_key: str, pipeline_id: str) -> None:
if not self._api_key_policy: if not self._api_key_policy:
return return
@@ -123,47 +121,46 @@ class ServerPipelineManager:
if policy is None: if policy is None:
return return
allowed = policy.get("allowed_route_ids")
if allowed is None:
# Backward compatibility.
allowed = policy.get("allowed_pipeline_ids") allowed = policy.get("allowed_pipeline_ids")
if allowed and route_id not in allowed: if allowed and pipeline_id not in allowed:
raise HTTPException(status_code=403, detail=f"route_id `{route_id}` is not allowed for this API key") raise HTTPException(
status_code=403,
detail=f"pipeline_id `{pipeline_id}` is not allowed for this API key",
)
def resolve_route_id(self, body: Dict[str, Any], app_id: Optional[str], api_key: str) -> str: def resolve_pipeline_id(
self, body: Dict[str, Any], app_id: Optional[str], api_key: str
) -> str:
body_input = body.get("input", {}) body_input = body.get("input", {})
route_id = ( pipeline_id = (
body.get("route_id") body.get("pipeline_id")
or (body_input.get("route_id") if isinstance(body_input, dict) else None)
or body.get("pipeline_key")
or (body_input.get("pipeline_key") if isinstance(body_input, dict) else None)
# Backward compatibility: pipeline_id still accepted as route selector.
or body.get("pipeline_id")
or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None) or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None)
or app_id or app_id
) )
if not route_id: if not pipeline_id:
key_policy = self._api_key_policy.get(api_key, {}) if self._api_key_policy else {} key_policy = (
route_id = key_policy.get("default_route_id") self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
if not route_id: )
# Backward compatibility. pipeline_id = key_policy.get(
route_id = key_policy.get("default_pipeline_id", self.default_route_id) "default_pipeline_id", self.default_pipeline_id
)
if route_id not in self._route_specs: if pipeline_id not in self._pipeline_specs:
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}") raise HTTPException(
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
)
self._authorize(api_key, route_id) self._authorize(api_key, pipeline_id)
return route_id return pipeline_id
def get_pipeline(self, route_id: str) -> Tuple[Pipeline, str]: def get_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
cached = self._pipelines.get(route_id) cached = self._pipelines.get(pipeline_id)
if cached is not None: if cached is not None:
return cached, self._pipeline_llm[route_id] return cached, self._pipeline_llm[pipeline_id]
pipeline_obj, llm_name = self._build_pipeline(route_id) pipeline_obj, llm_name = self._build_pipeline(pipeline_id)
self._pipelines[route_id] = pipeline_obj self._pipelines[pipeline_id] = pipeline_obj
self._pipeline_llm[route_id] = llm_name self._pipeline_llm[pipeline_id] = llm_name
logger.info(f"lazy-loaded route_id={route_id} model={llm_name}") logger.info(f"lazy-loaded pipeline_id={pipeline_id} model={llm_name}")
return pipeline_obj, llm_name return pipeline_obj, llm_name