load only pipeline_registry.json pipelines only

This commit is contained in:
2026-03-04 11:25:16 +08:00
parent 021b4d6ffb
commit 501f9954ce
2 changed files with 10 additions and 25 deletions

View File

@@ -4,13 +4,11 @@ from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path as FsPath
import os
import os.path as osp
import sys
import time
import json
import copy
import uvicorn
from loguru import logger
import tyro
@@ -18,14 +16,12 @@ import tyro
# Ensure we can import from project root
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config.core_config import load_tyro_conf
from lang_agent.pipeline import PipelineConfig
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
# Initialize default pipeline once (used when no explicit pipeline id is provided)
# Load base config for route-level overrides (pipelines are lazy-loaded from registry)
pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with default pipeline: \n{pipeline_config}")
pipeline: Pipeline = pipeline_config.setup()
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
# API Key Authentication
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
@@ -39,7 +35,6 @@ REGISTRY_FILE = os.environ.get(
PIPELINE_MANAGER = ServerPipelineManager(
default_route_id=os.environ.get("FAST_DEFAULT_ROUTE_ID", os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default")),
default_config=pipeline_config,
default_pipeline=pipeline,
)
PIPELINE_MANAGER.load_registry(REGISTRY_FILE)

View File

@@ -13,19 +13,13 @@ from lang_agent.config.core_config import load_tyro_conf
class ServerPipelineManager:
"""Lazily load and cache multiple pipelines keyed by a client-facing route id."""
def __init__(self, default_route_id: str, default_config: PipelineConfig, default_pipeline: Pipeline):
def __init__(self, default_route_id: str, default_config: PipelineConfig):
self.default_route_id = default_route_id
self.default_config = default_config
self._route_specs: Dict[str, Dict[str, Any]] = {}
self._api_key_policy: Dict[str, Dict[str, Any]] = {}
self._pipelines: Dict[str, Pipeline] = {default_route_id: default_pipeline}
self._pipeline_llm: Dict[str, str] = {default_route_id: default_config.llm_name}
self._route_specs[default_route_id] = {
"enabled": True,
"config_file": None,
"overrides": {},
"prompt_pipeline_id": None,
}
self._pipelines: Dict[str, Pipeline] = {}
self._pipeline_llm: Dict[str, str] = {}
def _resolve_registry_path(self, registry_path: str) -> str:
path = FsPath(registry_path)
@@ -39,8 +33,7 @@ class ServerPipelineManager:
def load_registry(self, registry_path: str) -> None:
abs_path = self._resolve_registry_path(registry_path)
if not osp.exists(abs_path):
logger.warning(f"pipeline registry file not found: {abs_path}. Using default pipeline only.")
return
raise ValueError(f"pipeline registry file not found: {abs_path}")
with open(abs_path, "r", encoding="utf-8") as f:
registry:dict = json.load(f)
@@ -52,6 +45,7 @@ class ServerPipelineManager:
if not isinstance(routes, dict):
raise ValueError("`routes` in pipeline registry must be an object.")
self._route_specs = {}
for route_id, spec in routes.items():
if not isinstance(spec, dict):
raise ValueError(f"route spec for `{route_id}` must be an object.")
@@ -62,6 +56,8 @@ class ServerPipelineManager:
# Explicitly separates routing id from prompt config pipeline_id.
"prompt_pipeline_id": spec.get("prompt_pipeline_id"),
}
if not self._route_specs:
raise ValueError("pipeline registry must define at least one route.")
api_key_policy = registry.get("api_keys", {})
if api_key_policy and not isinstance(api_key_policy, dict):
@@ -87,12 +83,6 @@ class ServerPipelineManager:
config_file = spec.get("config_file")
overrides = spec.get("overrides", {})
if not config_file and not overrides:
# default pipeline
p = self._pipelines[self.default_route_id]
llm_name = self._pipeline_llm[self.default_route_id]
return p, llm_name
if config_file:
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
# Some legacy yaml configs deserialize to plain dicts instead of