kimi magic
This commit is contained in:
@@ -13,13 +13,17 @@ from pydantic import BaseModel, Field
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.config.db_config_manager import DBConfigManager
|
||||
from lang_agent.front_api.build_server_utils import GRAPH_BUILD_FNCS, update_pipeline_registry
|
||||
from lang_agent.front_api.build_server_utils import (
|
||||
GRAPH_BUILD_FNCS,
|
||||
update_pipeline_registry,
|
||||
)
|
||||
|
||||
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__)))
|
||||
_MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
|
||||
_MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
|
||||
_PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
|
||||
|
||||
|
||||
class GraphConfigUpsertRequest(BaseModel):
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
@@ -28,6 +32,7 @@ class GraphConfigUpsertRequest(BaseModel):
|
||||
prompt_dict: Dict[str, str] = Field(default_factory=dict)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class GraphConfigUpsertResponse(BaseModel):
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
@@ -36,6 +41,7 @@ class GraphConfigUpsertResponse(BaseModel):
|
||||
prompt_keys: List[str]
|
||||
api_key: str
|
||||
|
||||
|
||||
class GraphConfigReadResponse(BaseModel):
|
||||
graph_id: Optional[str] = Field(default=None)
|
||||
pipeline_id: str
|
||||
@@ -44,6 +50,7 @@ class GraphConfigReadResponse(BaseModel):
|
||||
prompt_dict: Dict[str, str]
|
||||
api_key: str = Field(default="")
|
||||
|
||||
|
||||
class GraphConfigListItem(BaseModel):
|
||||
graph_id: Optional[str] = Field(default=None)
|
||||
pipeline_id: str
|
||||
@@ -56,10 +63,12 @@ class GraphConfigListItem(BaseModel):
|
||||
created_at: Optional[str] = Field(default=None)
|
||||
updated_at: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class GraphConfigListResponse(BaseModel):
|
||||
items: List[GraphConfigListItem]
|
||||
count: int
|
||||
|
||||
|
||||
class PipelineCreateRequest(BaseModel):
|
||||
graph_id: str = Field(
|
||||
description="Graph key from GRAPH_BUILD_FNCS, e.g. routing or react"
|
||||
@@ -71,9 +80,8 @@ class PipelineCreateRequest(BaseModel):
|
||||
api_key: str
|
||||
entry_point: str = Field(default="fastapi_server/server_dashscope.py")
|
||||
llm_name: str = Field(default="qwen-plus")
|
||||
route_id: Optional[str] = Field(default=None)
|
||||
enabled: bool = Field(default=True)
|
||||
prompt_pipeline_id: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class PipelineCreateResponse(BaseModel):
|
||||
run_id: str
|
||||
@@ -87,12 +95,12 @@ class PipelineCreateResponse(BaseModel):
|
||||
auth_header_name: str
|
||||
auth_key_once: str
|
||||
auth_key_masked: str
|
||||
route_id: str
|
||||
enabled: bool
|
||||
config_file: str
|
||||
reload_required: bool
|
||||
registry_path: str
|
||||
|
||||
|
||||
class PipelineRunInfo(BaseModel):
|
||||
run_id: str
|
||||
pid: int
|
||||
@@ -104,29 +112,33 @@ class PipelineRunInfo(BaseModel):
|
||||
auth_type: str
|
||||
auth_header_name: str
|
||||
auth_key_masked: str
|
||||
route_id: str
|
||||
enabled: bool
|
||||
config_file: Optional[str] = Field(default=None)
|
||||
|
||||
|
||||
class PipelineListResponse(BaseModel):
|
||||
items: List[PipelineRunInfo]
|
||||
count: int
|
||||
|
||||
|
||||
class PipelineStopResponse(BaseModel):
|
||||
run_id: str
|
||||
status: str
|
||||
route_id: str
|
||||
pipeline_id: str
|
||||
enabled: bool
|
||||
reload_required: bool
|
||||
|
||||
|
||||
class McpConfigReadResponse(BaseModel):
|
||||
path: str
|
||||
raw_content: str
|
||||
tool_keys: List[str]
|
||||
|
||||
|
||||
class McpConfigUpdateRequest(BaseModel):
|
||||
raw_content: str
|
||||
|
||||
|
||||
class McpConfigUpdateResponse(BaseModel):
|
||||
status: str
|
||||
path: str
|
||||
@@ -196,12 +208,12 @@ def _read_pipeline_registry() -> Dict[str, Any]:
|
||||
if not osp.exists(_PIPELINE_REGISTRY_PATH):
|
||||
os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True)
|
||||
with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump({"routes": {}, "api_keys": {}}, f, indent=2)
|
||||
json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2)
|
||||
with open(_PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f:
|
||||
registry = json.load(f)
|
||||
routes = registry.get("routes")
|
||||
if not isinstance(routes, dict):
|
||||
raise ValueError("`routes` in pipeline registry must be an object")
|
||||
pipelines = registry.get("pipelines")
|
||||
if not isinstance(pipelines, dict):
|
||||
raise ValueError("`pipelines` in pipeline registry must be an object")
|
||||
return registry
|
||||
|
||||
|
||||
@@ -237,8 +249,11 @@ async def upsert_graph_config(body: GraphConfigUpsertRequest):
|
||||
api_key=(body.api_key or "").strip(),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/graph-configs", response_model=GraphConfigListResponse)
|
||||
async def list_graph_configs(pipeline_id: Optional[str] = None, graph_id: Optional[str] = None):
|
||||
async def list_graph_configs(
|
||||
pipeline_id: Optional[str] = None, graph_id: Optional[str] = None
|
||||
):
|
||||
try:
|
||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id, graph_id=graph_id)
|
||||
except Exception as e:
|
||||
@@ -247,10 +262,15 @@ async def list_graph_configs(pipeline_id: Optional[str] = None, graph_id: Option
|
||||
items = [GraphConfigListItem(**row) for row in rows]
|
||||
return GraphConfigListResponse(items=items, count=len(items))
|
||||
|
||||
@app.get("/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse)
|
||||
|
||||
@app.get(
|
||||
"/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse
|
||||
)
|
||||
async def get_default_graph_config(pipeline_id: str):
|
||||
try:
|
||||
prompt_dict, tool_keys = _db.get_config(pipeline_id=pipeline_id, prompt_set_id=None)
|
||||
prompt_dict, tool_keys = _db.get_config(
|
||||
pipeline_id=pipeline_id, prompt_set_id=None
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -279,11 +299,16 @@ async def get_default_graph_config(pipeline_id: str):
|
||||
api_key=(active.get("api_key") or ""),
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/graphs/{graph_id}/default-config", response_model=GraphConfigReadResponse)
|
||||
async def get_graph_default_config_by_graph(graph_id: str):
|
||||
return await get_default_graph_config(pipeline_id=graph_id)
|
||||
|
||||
@app.get("/v1/graph-configs/{pipeline_id}/{prompt_set_id}", response_model=GraphConfigReadResponse)
|
||||
|
||||
@app.get(
|
||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id}",
|
||||
response_model=GraphConfigReadResponse,
|
||||
)
|
||||
async def get_graph_config(pipeline_id: str, prompt_set_id: str):
|
||||
try:
|
||||
meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
||||
@@ -333,6 +358,7 @@ async def delete_graph_config(pipeline_id: str, prompt_set_id: str):
|
||||
async def available_graphs():
|
||||
return {"available_graphs": sorted(GRAPH_BUILD_FNCS.keys())}
|
||||
|
||||
|
||||
@app.get("/v1/tool-configs/mcp", response_model=McpConfigReadResponse)
|
||||
async def get_mcp_tool_config():
|
||||
try:
|
||||
@@ -367,6 +393,7 @@ async def update_mcp_tool_config(body: McpConfigUpdateRequest):
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/pipelines", response_model=PipelineListResponse)
|
||||
async def list_running_pipelines():
|
||||
try:
|
||||
@@ -377,24 +404,23 @@ async def list_running_pipelines():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
items: List[PipelineRunInfo] = []
|
||||
routes = registry.get("routes", {})
|
||||
for route_id, spec in sorted(routes.items()):
|
||||
pipelines = registry.get("pipelines", {})
|
||||
for pipeline_id, spec in sorted(pipelines.items()):
|
||||
if not isinstance(spec, dict):
|
||||
continue
|
||||
enabled = bool(spec.get("enabled", True))
|
||||
items.append(
|
||||
PipelineRunInfo(
|
||||
run_id=route_id,
|
||||
run_id=pipeline_id,
|
||||
pid=-1,
|
||||
graph_id=str(spec.get("graph_id") or route_id),
|
||||
pipeline_id=str(spec.get("prompt_pipeline_id") or route_id),
|
||||
graph_id=str(spec.get("graph_id") or pipeline_id),
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id="default",
|
||||
url=_DASHSCOPE_URL,
|
||||
port=-1,
|
||||
auth_type="bearer",
|
||||
auth_header_name="Authorization",
|
||||
auth_key_masked="",
|
||||
route_id=route_id,
|
||||
enabled=enabled,
|
||||
config_file=spec.get("config_file"),
|
||||
)
|
||||
@@ -411,50 +437,37 @@ async def create_pipeline(body: PipelineCreateRequest):
|
||||
detail=f"Unknown graph_id '{body.graph_id}'. Valid options: {sorted(GRAPH_BUILD_FNCS.keys())}",
|
||||
)
|
||||
|
||||
route_id = (body.route_id or body.pipeline_id).strip()
|
||||
if not route_id:
|
||||
raise HTTPException(status_code=400, detail="route_id or pipeline_id is required")
|
||||
prompt_pipeline_id = (body.prompt_pipeline_id or body.pipeline_id).strip()
|
||||
if not prompt_pipeline_id:
|
||||
raise HTTPException(status_code=400, detail="prompt_pipeline_id or pipeline_id is required")
|
||||
config_file = f"configs/pipelines/{route_id}.yml"
|
||||
pipeline_id = body.pipeline_id.strip()
|
||||
if not pipeline_id:
|
||||
raise HTTPException(status_code=400, detail="pipeline_id is required")
|
||||
config_file = f"configs/pipelines/{pipeline_id}.yml"
|
||||
config_abs_dir = osp.join(_PROJECT_ROOT, "configs", "pipelines")
|
||||
try:
|
||||
build_fn(
|
||||
pipeline_id=prompt_pipeline_id,
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set=body.prompt_set_id,
|
||||
tool_keys=body.tool_keys,
|
||||
api_key=body.api_key,
|
||||
llm_name=body.llm_name,
|
||||
pipeline_config_dir=config_abs_dir,
|
||||
)
|
||||
generated_config_file = f"configs/pipelines/{prompt_pipeline_id}.yml"
|
||||
if prompt_pipeline_id != route_id:
|
||||
# Keep runtime route_id and config_file aligned for lazy loading by route.
|
||||
src = osp.join(config_abs_dir, f"{prompt_pipeline_id}.yml")
|
||||
dst = osp.join(config_abs_dir, f"{route_id}.yml")
|
||||
if osp.exists(src):
|
||||
with open(src, "r", encoding="utf-8") as rf, open(dst, "w", encoding="utf-8") as wf:
|
||||
wf.write(rf.read())
|
||||
generated_config_file = config_file
|
||||
|
||||
update_pipeline_registry(
|
||||
pipeline_id=route_id,
|
||||
prompt_set=prompt_pipeline_id,
|
||||
pipeline_id=pipeline_id,
|
||||
graph_id=body.graph_id,
|
||||
config_file=generated_config_file,
|
||||
config_file=config_file,
|
||||
llm_name=body.llm_name,
|
||||
enabled=body.enabled,
|
||||
registry_f=_PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to register route: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}")
|
||||
|
||||
return PipelineCreateResponse(
|
||||
run_id=route_id,
|
||||
run_id=pipeline_id,
|
||||
pid=-1,
|
||||
graph_id=body.graph_id,
|
||||
pipeline_id=prompt_pipeline_id,
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=body.prompt_set_id,
|
||||
url=_DASHSCOPE_URL,
|
||||
port=-1,
|
||||
@@ -462,21 +475,23 @@ async def create_pipeline(body: PipelineCreateRequest):
|
||||
auth_header_name="Authorization",
|
||||
auth_key_once="",
|
||||
auth_key_masked="",
|
||||
route_id=route_id,
|
||||
enabled=body.enabled,
|
||||
config_file=config_file,
|
||||
reload_required=True,
|
||||
registry_path=_PIPELINE_REGISTRY_PATH,
|
||||
)
|
||||
|
||||
@app.delete("/v1/pipelines/{route_id}", response_model=PipelineStopResponse)
|
||||
async def stop_pipeline(route_id: str):
|
||||
|
||||
@app.delete("/v1/pipelines/{pipeline_id}", response_model=PipelineStopResponse)
|
||||
async def stop_pipeline(pipeline_id: str):
|
||||
try:
|
||||
registry = _read_pipeline_registry()
|
||||
routes = registry.get("routes", {})
|
||||
spec = routes.get(route_id)
|
||||
pipelines = registry.get("pipelines", {})
|
||||
spec = pipelines.get(pipeline_id)
|
||||
if not isinstance(spec, dict):
|
||||
raise HTTPException(status_code=404, detail=f"route_id '{route_id}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"pipeline_id '{pipeline_id}' not found"
|
||||
)
|
||||
spec["enabled"] = False
|
||||
_write_pipeline_registry(registry)
|
||||
except HTTPException:
|
||||
@@ -487,9 +502,9 @@ async def stop_pipeline(route_id: str):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return PipelineStopResponse(
|
||||
run_id=route_id,
|
||||
run_id=pipeline_id,
|
||||
status="disabled",
|
||||
route_id=route_id,
|
||||
pipeline_id=pipeline_id,
|
||||
enabled=False,
|
||||
reload_required=True,
|
||||
)
|
||||
|
||||
@@ -28,12 +28,16 @@ API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
REGISTRY_FILE = os.environ.get(
|
||||
"FAST_PIPELINE_REGISTRY_FILE",
|
||||
osp.join(osp.dirname(osp.dirname(osp.abspath(__file__))), "configs", "pipeline_registry.json"),
|
||||
osp.join(
|
||||
osp.dirname(osp.dirname(osp.abspath(__file__))),
|
||||
"configs",
|
||||
"pipeline_registry.json",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
PIPELINE_MANAGER = ServerPipelineManager(
|
||||
default_route_id=os.environ.get("FAST_DEFAULT_ROUTE_ID", os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default")),
|
||||
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
|
||||
default_config=pipeline_config,
|
||||
)
|
||||
PIPELINE_MANAGER.load_registry(REGISTRY_FILE)
|
||||
@@ -62,8 +66,10 @@ class DSApplicationCallRequest(BaseModel):
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
app = FastAPI(title="DashScope-Compatible Application API",
|
||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat")
|
||||
app = FastAPI(
|
||||
title="DashScope-Compatible Application API",
|
||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -74,7 +80,9 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
||||
def sse_chunks_from_stream(
|
||||
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||
):
|
||||
"""
|
||||
Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
@@ -115,7 +123,9 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
||||
async def sse_chunks_from_astream(
|
||||
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||
):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
@@ -202,22 +212,30 @@ async def _process_dashscope_request(
|
||||
thread_id = body_input.get("session_id") or req_session_id or "3"
|
||||
user_msg = _extract_user_message(messages)
|
||||
|
||||
route_id = PIPELINE_MANAGER.resolve_route_id(body=body, app_id=req_app_id, api_key=api_key)
|
||||
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(route_id)
|
||||
pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id(
|
||||
body=body, app_id=req_app_id, api_key=api_key
|
||||
)
|
||||
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id)
|
||||
|
||||
# Namespace thread ids to prevent memory collisions across pipelines.
|
||||
thread_id = f"{route_id}:{thread_id}"
|
||||
thread_id = f"{pipeline_id}:{thread_id}"
|
||||
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
chunk_generator = await selected_pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
chunk_generator = await selected_pipeline.achat(
|
||||
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=selected_model),
|
||||
sse_chunks_from_astream(
|
||||
chunk_generator, response_id=response_id, model=selected_model
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result_text = await selected_pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
result_text = await selected_pipeline.achat(
|
||||
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
@@ -232,9 +250,7 @@ async def _process_dashscope_request(
|
||||
"created": int(time.time()),
|
||||
"model": selected_model,
|
||||
},
|
||||
"route_id": route_id,
|
||||
# Backward compatibility: keep pipeline_id in response as the route id selector.
|
||||
"pipeline_id": route_id,
|
||||
"pipeline_id": pipeline_id,
|
||||
"is_end": True,
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
@@ -292,10 +308,13 @@ async def application_completion(
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "DashScope Application-compatible API", "endpoints": [
|
||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||
"/health",
|
||||
]}
|
||||
return {
|
||||
"message": "DashScope Application-compatible API",
|
||||
"endpoints": [
|
||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||
"/health",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@@ -310,5 +329,3 @@ if __name__ == "__main__":
|
||||
port=pipeline_config.port,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ from lang_agent.config.core_config import load_tyro_conf
|
||||
|
||||
|
||||
class ServerPipelineManager:
|
||||
"""Lazily load and cache multiple pipelines keyed by a client-facing route id."""
|
||||
"""Lazily load and cache multiple pipelines keyed by a client-facing pipeline id."""
|
||||
|
||||
def __init__(self, default_route_id: str, default_config: PipelineConfig):
|
||||
self.default_route_id = default_route_id
|
||||
def __init__(self, default_pipeline_id: str, default_config: PipelineConfig):
|
||||
self.default_pipeline_id = default_pipeline_id
|
||||
self.default_config = default_config
|
||||
self._route_specs: Dict[str, Dict[str, Any]] = {}
|
||||
self._pipeline_specs: Dict[str, Dict[str, Any]] = {}
|
||||
self._api_key_policy: Dict[str, Dict[str, Any]] = {}
|
||||
self._pipelines: Dict[str, Pipeline] = {}
|
||||
self._pipeline_llm: Dict[str, str] = {}
|
||||
@@ -36,34 +36,33 @@ class ServerPipelineManager:
|
||||
raise ValueError(f"pipeline registry file not found: {abs_path}")
|
||||
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
registry:dict = json.load(f)
|
||||
registry: dict = json.load(f)
|
||||
|
||||
routes = registry.get("routes")
|
||||
if routes is None:
|
||||
# Backward compatibility with initial schema.
|
||||
routes = registry.get("pipelines", {})
|
||||
if not isinstance(routes, dict):
|
||||
raise ValueError("`routes` in pipeline registry must be an object.")
|
||||
pipelines = registry.get("pipelines")
|
||||
if pipelines is None:
|
||||
raise ValueError("`pipelines` in pipeline registry must be an object.")
|
||||
|
||||
self._route_specs = {}
|
||||
for route_id, spec in routes.items():
|
||||
self._pipeline_specs = {}
|
||||
for pipeline_id, spec in pipelines.items():
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError(f"route spec for `{route_id}` must be an object.")
|
||||
self._route_specs[route_id] = {
|
||||
raise ValueError(
|
||||
f"pipeline spec for `{pipeline_id}` must be an object."
|
||||
)
|
||||
self._pipeline_specs[pipeline_id] = {
|
||||
"enabled": bool(spec.get("enabled", True)),
|
||||
"config_file": spec.get("config_file"),
|
||||
"overrides": spec.get("overrides", {}),
|
||||
# Explicitly separates routing id from prompt config pipeline_id.
|
||||
"prompt_pipeline_id": spec.get("prompt_pipeline_id"),
|
||||
}
|
||||
if not self._route_specs:
|
||||
raise ValueError("pipeline registry must define at least one route.")
|
||||
if not self._pipeline_specs:
|
||||
raise ValueError("pipeline registry must define at least one pipeline.")
|
||||
|
||||
api_key_policy = registry.get("api_keys", {})
|
||||
if api_key_policy and not isinstance(api_key_policy, dict):
|
||||
raise ValueError("`api_keys` in pipeline registry must be an object.")
|
||||
self._api_key_policy = api_key_policy
|
||||
logger.info(f"loaded pipeline registry: {abs_path}, routes={list(self._route_specs.keys())}")
|
||||
logger.info(
|
||||
f"loaded pipeline registry: {abs_path}, pipelines={list(self._pipeline_specs.keys())}"
|
||||
)
|
||||
|
||||
def _resolve_config_path(self, config_file: str) -> str:
|
||||
path = FsPath(config_file)
|
||||
@@ -74,48 +73,47 @@ class ServerPipelineManager:
|
||||
root = FsPath(__file__).resolve().parents[2]
|
||||
return str((root / path).resolve())
|
||||
|
||||
def _build_pipeline(self, route_id: str) -> Tuple[Pipeline, str]:
|
||||
spec = self._route_specs.get(route_id)
|
||||
def _build_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
|
||||
spec = self._pipeline_specs.get(pipeline_id)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
|
||||
)
|
||||
if not spec.get("enabled", True):
|
||||
raise HTTPException(status_code=403, detail=f"Route disabled: {route_id}")
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Pipeline disabled: {pipeline_id}"
|
||||
)
|
||||
|
||||
config_file = spec.get("config_file")
|
||||
overrides = spec.get("overrides", {})
|
||||
if config_file:
|
||||
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
|
||||
# Some legacy yaml configs deserialize to plain dicts instead of
|
||||
# InstantiateConfig dataclasses. Fall back to default config in that case.
|
||||
if hasattr(loaded_cfg, "setup"):
|
||||
cfg = loaded_cfg
|
||||
else:
|
||||
logger.warning(
|
||||
f"config_file for route `{route_id}` did not deserialize to config object; "
|
||||
"falling back to default config and applying route-level overrides."
|
||||
f"config_file for pipeline `{pipeline_id}` did not deserialize to config object; "
|
||||
"falling back to default config and applying pipeline-level overrides."
|
||||
)
|
||||
cfg = copy.deepcopy(self.default_config)
|
||||
else:
|
||||
# Build from default config + shallow overrides so new pipelines can be
|
||||
# added via registry without additional yaml files.
|
||||
cfg = copy.deepcopy(self.default_config)
|
||||
if not isinstance(overrides, dict):
|
||||
raise ValueError(f"route `overrides` for `{route_id}` must be an object.")
|
||||
raise ValueError(
|
||||
f"pipeline `overrides` for `{pipeline_id}` must be an object."
|
||||
)
|
||||
for key, value in overrides.items():
|
||||
if not hasattr(cfg, key):
|
||||
raise ValueError(f"unknown override field `{key}` for route `{route_id}`")
|
||||
raise ValueError(
|
||||
f"unknown override field `{key}` for pipeline `{pipeline_id}`"
|
||||
)
|
||||
setattr(cfg, key, value)
|
||||
|
||||
prompt_pipeline_id = spec.get("prompt_pipeline_id")
|
||||
if prompt_pipeline_id and (not isinstance(overrides, dict) or "pipeline_id" not in overrides):
|
||||
if hasattr(cfg, "pipeline_id"):
|
||||
cfg.pipeline_id = prompt_pipeline_id
|
||||
|
||||
p = cfg.setup()
|
||||
llm_name = getattr(cfg, "llm_name", "unknown-model")
|
||||
return p, llm_name
|
||||
|
||||
def _authorize(self, api_key: str, route_id: str) -> None:
|
||||
def _authorize(self, api_key: str, pipeline_id: str) -> None:
|
||||
if not self._api_key_policy:
|
||||
return
|
||||
|
||||
@@ -123,47 +121,46 @@ class ServerPipelineManager:
|
||||
if policy is None:
|
||||
return
|
||||
|
||||
allowed = policy.get("allowed_route_ids")
|
||||
if allowed is None:
|
||||
# Backward compatibility.
|
||||
allowed = policy.get("allowed_pipeline_ids")
|
||||
if allowed and route_id not in allowed:
|
||||
raise HTTPException(status_code=403, detail=f"route_id `{route_id}` is not allowed for this API key")
|
||||
allowed = policy.get("allowed_pipeline_ids")
|
||||
if allowed and pipeline_id not in allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"pipeline_id `{pipeline_id}` is not allowed for this API key",
|
||||
)
|
||||
|
||||
def resolve_route_id(self, body: Dict[str, Any], app_id: Optional[str], api_key: str) -> str:
|
||||
def resolve_pipeline_id(
|
||||
self, body: Dict[str, Any], app_id: Optional[str], api_key: str
|
||||
) -> str:
|
||||
body_input = body.get("input", {})
|
||||
route_id = (
|
||||
body.get("route_id")
|
||||
or (body_input.get("route_id") if isinstance(body_input, dict) else None)
|
||||
or body.get("pipeline_key")
|
||||
or (body_input.get("pipeline_key") if isinstance(body_input, dict) else None)
|
||||
# Backward compatibility: pipeline_id still accepted as route selector.
|
||||
or body.get("pipeline_id")
|
||||
pipeline_id = (
|
||||
body.get("pipeline_id")
|
||||
or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None)
|
||||
or app_id
|
||||
)
|
||||
|
||||
if not route_id:
|
||||
key_policy = self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
|
||||
route_id = key_policy.get("default_route_id")
|
||||
if not route_id:
|
||||
# Backward compatibility.
|
||||
route_id = key_policy.get("default_pipeline_id", self.default_route_id)
|
||||
if not pipeline_id:
|
||||
key_policy = (
|
||||
self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
|
||||
)
|
||||
pipeline_id = key_policy.get(
|
||||
"default_pipeline_id", self.default_pipeline_id
|
||||
)
|
||||
|
||||
if route_id not in self._route_specs:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}")
|
||||
if pipeline_id not in self._pipeline_specs:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
|
||||
)
|
||||
|
||||
self._authorize(api_key, route_id)
|
||||
return route_id
|
||||
self._authorize(api_key, pipeline_id)
|
||||
return pipeline_id
|
||||
|
||||
def get_pipeline(self, route_id: str) -> Tuple[Pipeline, str]:
|
||||
cached = self._pipelines.get(route_id)
|
||||
def get_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
|
||||
cached = self._pipelines.get(pipeline_id)
|
||||
if cached is not None:
|
||||
return cached, self._pipeline_llm[route_id]
|
||||
return cached, self._pipeline_llm[pipeline_id]
|
||||
|
||||
pipeline_obj, llm_name = self._build_pipeline(route_id)
|
||||
self._pipelines[route_id] = pipeline_obj
|
||||
self._pipeline_llm[route_id] = llm_name
|
||||
logger.info(f"lazy-loaded route_id={route_id} model={llm_name}")
|
||||
pipeline_obj, llm_name = self._build_pipeline(pipeline_id)
|
||||
self._pipelines[pipeline_id] = pipeline_obj
|
||||
self._pipeline_llm[pipeline_id] = llm_name
|
||||
logger.info(f"lazy-loaded pipeline_id={pipeline_id} model={llm_name}")
|
||||
return pipeline_obj, llm_name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user