Files
lang-agent/lang_agent/components/server_pipeline_manager.py
2026-03-04 15:21:07 +08:00

167 lines
6.4 KiB
Python

from fastapi import HTTPException
from typing import Any, Dict, Optional, Tuple
from pathlib import Path as FsPath
import os.path as osp
import json
import copy
from loguru import logger
from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config.core_config import load_tyro_conf
class ServerPipelineManager:
"""Lazily load and cache multiple pipelines keyed by a client-facing pipeline id."""
def __init__(self, default_pipeline_id: str, default_config: PipelineConfig):
self.default_pipeline_id = default_pipeline_id
self.default_config = default_config
self._pipeline_specs: Dict[str, Dict[str, Any]] = {}
self._api_key_policy: Dict[str, Dict[str, Any]] = {}
self._pipelines: Dict[str, Pipeline] = {}
self._pipeline_llm: Dict[str, str] = {}
def _resolve_registry_path(self, registry_path: str) -> str:
path = FsPath(registry_path)
if path.is_absolute():
return str(path)
# server_pipeline_manager.py is under <repo>/lang_agent/components/,
# so parents[2] is the repository root.
root = FsPath(__file__).resolve().parents[2]
return str((root / path).resolve())
def load_registry(self, registry_path: str) -> None:
abs_path = self._resolve_registry_path(registry_path)
if not osp.exists(abs_path):
raise ValueError(f"pipeline registry file not found: {abs_path}")
with open(abs_path, "r", encoding="utf-8") as f:
registry: dict = json.load(f)
pipelines = registry.get("pipelines")
if pipelines is None:
raise ValueError("`pipelines` in pipeline registry must be an object.")
self._pipeline_specs = {}
for pipeline_id, spec in pipelines.items():
if not isinstance(spec, dict):
raise ValueError(
f"pipeline spec for `{pipeline_id}` must be an object."
)
self._pipeline_specs[pipeline_id] = {
"enabled": bool(spec.get("enabled", True)),
"config_file": spec.get("config_file"),
"overrides": spec.get("overrides", {}),
}
if not self._pipeline_specs:
raise ValueError("pipeline registry must define at least one pipeline.")
api_key_policy = registry.get("api_keys", {})
if api_key_policy and not isinstance(api_key_policy, dict):
raise ValueError("`api_keys` in pipeline registry must be an object.")
self._api_key_policy = api_key_policy
logger.info(
f"loaded pipeline registry: {abs_path}, pipelines={list(self._pipeline_specs.keys())}"
)
def _resolve_config_path(self, config_file: str) -> str:
path = FsPath(config_file)
if path.is_absolute():
return str(path)
# Resolve relative config paths from repository root for consistency
# with docker-compose and tests.
root = FsPath(__file__).resolve().parents[2]
return str((root / path).resolve())
def _build_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
spec = self._pipeline_specs.get(pipeline_id)
if spec is None:
raise HTTPException(
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
)
if not spec.get("enabled", True):
raise HTTPException(
status_code=403, detail=f"Pipeline disabled: {pipeline_id}"
)
config_file = spec.get("config_file")
overrides = spec.get("overrides", {})
if config_file:
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
if hasattr(loaded_cfg, "setup"):
cfg = loaded_cfg
else:
logger.warning(
f"config_file for pipeline `{pipeline_id}` did not deserialize to config object; "
"falling back to default config and applying pipeline-level overrides."
)
cfg = copy.deepcopy(self.default_config)
else:
cfg = copy.deepcopy(self.default_config)
if not isinstance(overrides, dict):
raise ValueError(
f"pipeline `overrides` for `{pipeline_id}` must be an object."
)
for key, value in overrides.items():
if not hasattr(cfg, key):
raise ValueError(
f"unknown override field `{key}` for pipeline `{pipeline_id}`"
)
setattr(cfg, key, value)
p = cfg.setup()
llm_name = getattr(cfg, "llm_name", "unknown-model")
return p, llm_name
def _authorize(self, api_key: str, pipeline_id: str) -> None:
if not self._api_key_policy:
return
policy = self._api_key_policy.get(api_key)
if policy is None:
return
allowed = policy.get("allowed_pipeline_ids")
if allowed and pipeline_id not in allowed:
raise HTTPException(
status_code=403,
detail=f"pipeline_id `{pipeline_id}` is not allowed for this API key",
)
def resolve_pipeline_id(
self, body: Dict[str, Any], app_id: Optional[str], api_key: str
) -> str:
body_input = body.get("input", {})
pipeline_id = (
body.get("pipeline_id")
or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None)
or app_id
)
if not pipeline_id:
key_policy = (
self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
)
pipeline_id = key_policy.get(
"default_pipeline_id", self.default_pipeline_id
)
if pipeline_id not in self._pipeline_specs:
raise HTTPException(
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
)
self._authorize(api_key, pipeline_id)
return pipeline_id
def get_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
cached = self._pipelines.get(pipeline_id)
if cached is not None:
return cached, self._pipeline_llm[pipeline_id]
pipeline_obj, llm_name = self._build_pipeline(pipeline_id)
self._pipelines[pipeline_id] = pipeline_obj
self._pipeline_llm[pipeline_id] = llm_name
logger.info(f"lazy-loaded pipeline_id={pipeline_id} model={llm_name}")
return pipeline_obj, llm_name