From 65a1705280c052894c7b7949ac16ebb9e8acfb9c Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 2 Mar 2026 18:14:24 +0800 Subject: [PATCH] pipeline_manager v1 --- fastapi_server/server_dashscope.py | 379 +++++++++++++++++++---------- 1 file changed, 246 insertions(+), 133 deletions(-) diff --git a/fastapi_server/server_dashscope.py b/fastapi_server/server_dashscope.py index 09877fb..76aae23 100644 --- a/fastapi_server/server_dashscope.py +++ b/fastapi_server/server_dashscope.py @@ -3,28 +3,172 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple +from pathlib import Path as FsPath import os +import os.path as osp import sys import time import json +import copy import uvicorn from loguru import logger import tyro # Ensure we can import from project root -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) from lang_agent.pipeline import Pipeline, PipelineConfig +from lang_agent.config.core_config import load_tyro_conf -# Initialize Pipeline once +# Initialize default pipeline once (used when no explicit pipeline id is provided) pipeline_config = tyro.cli(PipelineConfig) -logger.info(f"starting agent with pipeline: \n{pipeline_config}") -pipeline:Pipeline = pipeline_config.setup() +logger.info(f"starting agent with default pipeline: \n{pipeline_config}") +pipeline: Pipeline = pipeline_config.setup() # API Key Authentication API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True) VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) +REGISTRY_FILE = os.environ.get( + "FAST_PIPELINE_REGISTRY_FILE", + osp.join(osp.dirname(osp.dirname(osp.abspath(__file__))), "configs", "pipeline_registry.json"), +) + + +class PipelineManager: + """Lazily load and cache multiple pipelines keyed by a client-facing id.""" + + def __init__(self, default_pipeline_id: str, default_config: PipelineConfig, default_pipeline: Pipeline): + self.default_pipeline_id = default_pipeline_id + self.default_config = default_config + self._pipeline_specs: Dict[str, Dict[str, Any]] = {} + self._api_key_policy: Dict[str, Dict[str, Any]] = {} + self._pipelines: Dict[str, Pipeline] = {default_pipeline_id: default_pipeline} + self._pipeline_llm: Dict[str, str] = {default_pipeline_id: default_config.llm_name} + self._pipeline_specs[default_pipeline_id] = {"enabled": True, "config_file": None} + + def _resolve_registry_path(self, registry_path: str) -> str: + path = FsPath(registry_path) + if path.is_absolute(): + return str(path) + root = FsPath(osp.dirname(osp.dirname(osp.abspath(__file__)))) + return str((root / path).resolve()) + + def load_registry(self, registry_path: str) -> None: + abs_path = self._resolve_registry_path(registry_path) + if not osp.exists(abs_path): + logger.warning(f"pipeline registry file not found: {abs_path}. Using default pipeline only.") + return + + with open(abs_path, "r", encoding="utf-8") as f: + registry:dict = json.load(f) + + pipelines = registry.get("pipelines", {}) + if not isinstance(pipelines, dict): + raise ValueError("`pipelines` in pipeline registry must be an object.") + + for pipeline_id, spec in pipelines.items(): + if not isinstance(spec, dict): + raise ValueError(f"pipeline spec for `{pipeline_id}` must be an object.") + self._pipeline_specs[pipeline_id] = { + "enabled": bool(spec.get("enabled", True)), + "config_file": spec.get("config_file"), + "overrides": spec.get("overrides", {}), + } + + api_key_policy = registry.get("api_keys", {}) + if api_key_policy and not isinstance(api_key_policy, dict): + raise ValueError("`api_keys` in pipeline registry must be an object.") + self._api_key_policy = api_key_policy + logger.info(f"loaded pipeline registry: {abs_path}, pipelines={list(self._pipeline_specs.keys())}") + + def _resolve_config_path(self, config_file: str) -> str: + path = FsPath(config_file) + if path.is_absolute(): + return str(path) + root = FsPath(osp.dirname(osp.dirname(osp.abspath(__file__)))) + return str((root / path).resolve()) + + def _build_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]: + spec = self._pipeline_specs.get(pipeline_id) + if spec is None: + raise HTTPException(status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}") + if not spec.get("enabled", True): + raise HTTPException(status_code=403, detail=f"Pipeline disabled: {pipeline_id}") + + config_file = spec.get("config_file") + overrides = spec.get("overrides", {}) + if not config_file and not overrides: + # default pipeline + p = self._pipelines[self.default_pipeline_id] + llm_name = self._pipeline_llm[self.default_pipeline_id] + return p, llm_name + + if config_file: + cfg = load_tyro_conf(self._resolve_config_path(config_file)) + else: + # Build from default config + shallow overrides so new pipelines can be + # added via registry without additional yaml files. + cfg = copy.deepcopy(self.default_config) + if not isinstance(overrides, dict): + raise ValueError(f"pipeline `overrides` for `{pipeline_id}` must be an object.") + for key, value in overrides.items(): + if not hasattr(cfg, key): + raise ValueError(f"unknown override field `{key}` for pipeline `{pipeline_id}`") + setattr(cfg, key, value) + + p = cfg.setup() + llm_name = getattr(cfg, "llm_name", "unknown-model") + return p, llm_name + + def _authorize(self, api_key: str, pipeline_id: str) -> None: + if not self._api_key_policy: + return + + policy = self._api_key_policy.get(api_key) + if policy is None: + return + + allowed = policy.get("allowed_pipeline_ids") + if allowed and pipeline_id not in allowed: + raise HTTPException(status_code=403, detail=f"pipeline_id `{pipeline_id}` is not allowed for this API key") + + def resolve_pipeline_id(self, body: Dict[str, Any], app_id: Optional[str], api_key: str) -> str: + body_input = body.get("input", {}) + pipeline_id = ( + body.get("pipeline_id") + or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None) + or app_id + ) + + if not pipeline_id: + key_policy = self._api_key_policy.get(api_key, {}) if self._api_key_policy else {} + pipeline_id = key_policy.get("default_pipeline_id", self.default_pipeline_id) + + if pipeline_id not in self._pipeline_specs: + raise HTTPException(status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}") + + self._authorize(api_key, pipeline_id) + return pipeline_id + + def get_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]: + cached = self._pipelines.get(pipeline_id) + if cached is not None: + return cached, self._pipeline_llm[pipeline_id] + + pipeline_obj, llm_name = self._build_pipeline(pipeline_id) + self._pipelines[pipeline_id] = pipeline_obj + self._pipeline_llm[pipeline_id] = llm_name + logger.info(f"lazy-loaded pipeline_id={pipeline_id} model={llm_name}") + return pipeline_obj, llm_name + + +PIPELINE_MANAGER = PipelineManager( + default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"), + default_config=pipeline_config, + default_pipeline=pipeline, +) +PIPELINE_MANAGER.load_registry(REGISTRY_FILE) async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): @@ -143,81 +287,105 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str yield f"data: {json.dumps(final)}\n\n" +def _normalize_messages(body: Dict[str, Any]) -> List[Dict[str, Any]]: + messages = body.get("messages") + body_input = body.get("input", {}) + if messages is None and isinstance(body_input, dict): + messages = body_input.get("messages") + if messages is None and isinstance(body_input, dict): + prompt = body_input.get("prompt") + if isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + + if not messages: + raise HTTPException(status_code=400, detail="messages is required") + return messages + + +def _extract_user_message(messages: List[Dict[str, Any]]) -> str: + user_msg = None + for m in reversed(messages): + role = m.get("role") if isinstance(m, dict) else None + content = m.get("content") if isinstance(m, dict) else None + if role == "user" and content: + user_msg = content + break + if user_msg is None: + last = messages[-1] + user_msg = last.get("content") if isinstance(last, dict) else str(last) + return user_msg + + +async def _process_dashscope_request( + body: Dict[str, Any], + app_id: Optional[str], + session_id: Optional[str], + api_key: str, +): + req_app_id = app_id or body.get("app_id") + body_input = body.get("input", {}) if isinstance(body.get("input"), dict) else {} + req_session_id = session_id or body_input.get("session_id") + messages = _normalize_messages(body) + + stream = body.get("stream") + if stream is None: + stream = body.get("parameters", {}).get("stream", True) + + thread_id = body_input.get("session_id") or req_session_id or "3" + user_msg = _extract_user_message(messages) + + pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id(body=body, app_id=req_app_id, api_key=api_key) + selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id) + + # Namespace thread ids to prevent memory collisions across pipelines. + thread_id = f"{pipeline_id}:{thread_id}" + + response_id = f"appcmpl-{os.urandom(12).hex()}" + + if stream: + chunk_generator = await selected_pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) + return StreamingResponse( + sse_chunks_from_astream(chunk_generator, response_id=response_id, model=selected_model), + media_type="text/event-stream", + ) + + result_text = await selected_pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) + if not isinstance(result_text, str): + result_text = str(result_text) + + data = { + "request_id": response_id, + "code": 200, + "message": "OK", + "app_id": req_app_id, + "session_id": req_session_id, + "output": { + "text": result_text, + "created": int(time.time()), + "model": selected_model, + }, + "pipeline_id": pipeline_id, + "is_end": True, + } + return JSONResponse(content=data) + + @app.post("/v1/apps/{app_id}/sessions/{session_id}/responses") @app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses") async def application_responses( request: Request, app_id: str = Path(...), session_id: str = Path(...), - _: str = Depends(verify_api_key), + api_key: str = Depends(verify_api_key), ): try: body = await request.json() - - # Prefer path params - req_app_id = app_id or body.get("app_id") - req_session_id = session_id or body['input'].get("session_id") - - # Normalize messages - messages = body.get("messages") - if messages is None and isinstance(body.get("input"), dict): - messages = body.get("input", {}).get("messages") - if messages is None and isinstance(body.get("input"), dict): - prompt = body.get("input", {}).get("prompt") - if isinstance(prompt, str): - messages = [{"role": "user", "content": prompt}] - - if not messages: - raise HTTPException(status_code=400, detail="messages is required") - - # Determine stream flag - stream = body.get("stream") - if stream is None: - stream = body.get("parameters", {}).get("stream", True) - - thread_id = body['input'].get("session_id") - - # Extract latest user message - user_msg = None - for m in reversed(messages): - role = m.get("role") if isinstance(m, dict) else None - content = m.get("content") if isinstance(m, dict) else None - if role == "user" and content: - user_msg = content - break - if user_msg is None: - last = messages[-1] - user_msg = last.get("content") if isinstance(last, dict) else str(last) - - response_id = f"appcmpl-{os.urandom(12).hex()}" - - if stream: - # Use async streaming from pipeline - chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) - return StreamingResponse( - sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name), - media_type="text/event-stream", - ) - - # Non-streaming: get full result using async - result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) - if not isinstance(result_text, str): - result_text = str(result_text) - - data = { - "request_id": response_id, - "code": 200, - "message": "OK", - "app_id": req_app_id, - "session_id": req_session_id, - "output": { - "text": result_text, - "created": int(time.time()), - "model": pipeline_config.llm_name, - }, - "is_end": True, - } - return JSONResponse(content=data) + return await _process_dashscope_request( + body=body, + app_id=app_id, + session_id=session_id, + api_key=api_key, + ) except HTTPException: raise @@ -234,71 +402,16 @@ async def application_responses( async def application_completion( request: Request, app_id: str = Path(...), - _: str = Depends(verify_api_key), + api_key: str = Depends(verify_api_key), ): try: body = await request.json() - - req_session_id = body['input'].get("session_id") - - # Normalize messages - messages = body.get("messages") - if messages is None and isinstance(body.get("input"), dict): - messages = body.get("input", {}).get("messages") - if messages is None and isinstance(body.get("input"), dict): - prompt = body.get("input", {}).get("prompt") - if isinstance(prompt, str): - messages = [{"role": "user", "content": prompt}] - - if not messages: - raise HTTPException(status_code=400, detail="messages is required") - - stream = body.get("stream") - if stream is None: - stream = body.get("parameters", {}).get("stream", True) - - thread_id = body['input'].get("session_id") - - user_msg = None - for m in reversed(messages): - role = m.get("role") if isinstance(m, dict) else None - content = m.get("content") if isinstance(m, dict) else None - if role == "user" and content: - user_msg = content - break - if user_msg is None: - last = messages[-1] - user_msg = last.get("content") if isinstance(last, dict) else str(last) - - response_id = f"appcmpl-{os.urandom(12).hex()}" - - if stream: - # Use async streaming from pipeline - chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) - return StreamingResponse( - sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name), - media_type="text/event-stream", - ) - - # Non-streaming: get full result using async - result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) - if not isinstance(result_text, str): - result_text = str(result_text) - - data = { - "request_id": response_id, - "code": 200, - "message": "OK", - "app_id": app_id, - "session_id": req_session_id, - "output": { - "text": result_text, - "created": int(time.time()), - "model": pipeline_config.llm_name, - }, - "is_end": True, - } - return JSONResponse(content=data) + return await _process_dashscope_request( + body=body, + app_id=app_id, + session_id=None, + api_key=api_key, + ) except HTTPException: raise