make this importable without tyro fking around

This commit is contained in:
2026-03-05 11:43:16 +08:00
parent 55b37cc611
commit a2890148f9

View File

@@ -1,9 +1,8 @@
from fastapi import FastAPI, HTTPException, Path, Request, Depends, Security
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import APIKeyHeader
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import os
import os.path as osp
import sys
@@ -20,16 +19,28 @@ from lang_agent.pipeline import PipelineConfig
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
# Load base config for route-level overrides (pipelines are lazy-loaded from registry)
pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
def _build_default_pipeline_config() -> PipelineConfig:
"""
Build import-time defaults without parsing CLI args.
This keeps module import safe for reuse by combined apps and tests.
"""
pipeline_config = PipelineConfig()
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
return pipeline_config
PIPELINE_MANAGER = ServerPipelineManager(
def _build_pipeline_manager(base_config: PipelineConfig) -> ServerPipelineManager:
pipeline_manager = ServerPipelineManager(
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
default_config=pipeline_config,
)
PIPELINE_MANAGER.load_registry(PIPELINE_REGISTRY_PATH)
default_config=base_config,
)
pipeline_manager.load_registry(PIPELINE_REGISTRY_PATH)
return pipeline_manager
pipeline_config = _build_default_pipeline_config()
PIPELINE_MANAGER = _build_pipeline_manager(pipeline_config)
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
@@ -55,20 +66,6 @@ class DSApplicationCallRequest(BaseModel):
thread_id: Optional[str] = Field(default="3")
app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def sse_chunks_from_stream(
chunk_generator, response_id: str, model: str = "qwen-flash"
):
@@ -188,9 +185,10 @@ async def _process_dashscope_request(
app_id: Optional[str],
session_id: Optional[str],
api_key: str,
pipeline_manager: ServerPipelineManager,
):
try:
PIPELINE_MANAGER.refresh_registry_if_needed()
pipeline_manager.refresh_registry_if_needed()
except Exception as e:
logger.error(f"failed to refresh pipeline registry: {e}")
raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}")
@@ -207,10 +205,10 @@ async def _process_dashscope_request(
thread_id = body_input.get("session_id") or req_session_id or "3"
user_msg = _extract_user_message(messages)
pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id(
pipeline_id = pipeline_manager.resolve_pipeline_id(
body=body, app_id=req_app_id, api_key=api_key
)
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id)
selected_pipeline, selected_model = pipeline_manager.get_pipeline(pipeline_id)
# Namespace thread ids to prevent memory collisions across pipelines.
thread_id = f"{pipeline_id}:{thread_id}"
@@ -251,14 +249,21 @@ async def _process_dashscope_request(
return JSONResponse(content=data)
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
async def application_responses(
def create_dashscope_router(
pipeline_manager: Optional[ServerPipelineManager] = None,
include_meta_routes: bool = True,
) -> APIRouter:
manager = pipeline_manager or PIPELINE_MANAGER
router = APIRouter()
@router.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
@router.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
async def application_responses(
request: Request,
app_id: str = Path(...),
session_id: str = Path(...),
api_key: str = Depends(verify_api_key),
):
):
try:
body = await request.json()
return await _process_dashscope_request(
@@ -266,6 +271,7 @@ async def application_responses(
app_id=app_id,
session_id=session_id,
api_key=api_key,
pipeline_manager=manager,
)
except HTTPException:
@@ -274,17 +280,17 @@ async def application_responses(
logger.error(f"DashScope-compatible endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and without session in path
@app.post("/apps/{app_id}/completion")
@app.post("/v1/apps/{app_id}/completion")
@app.post("/api/apps/{app_id}/completion")
@app.post("/api/v1/apps/{app_id}/completion")
async def application_completion(
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and
# without session in path.
@router.post("/apps/{app_id}/completion")
@router.post("/v1/apps/{app_id}/completion")
@router.post("/api/apps/{app_id}/completion")
@router.post("/api/v1/apps/{app_id}/completion")
async def application_completion(
request: Request,
app_id: str = Path(...),
api_key: str = Depends(verify_api_key),
):
):
try:
body = await request.json()
return await _process_dashscope_request(
@@ -292,6 +298,7 @@ async def application_completion(
app_id=app_id,
session_id=None,
api_key=api_key,
pipeline_manager=manager,
)
except HTTPException:
@@ -300,9 +307,9 @@ async def application_completion(
logger.error(f"DashScope-compatible completion error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
if include_meta_routes:
@router.get("/")
async def root():
return {
"message": "DashScope Application-compatible API",
"endpoints": [
@@ -311,16 +318,48 @@ async def root():
],
}
@app.get("/health")
async def health():
@router.get("/health")
async def health():
return {"status": "healthy"}
return router
def create_dashscope_app(
pipeline_manager: Optional[ServerPipelineManager] = None,
) -> FastAPI:
dashscope_app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
dashscope_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
dashscope_app.include_router(
create_dashscope_router(
pipeline_manager=pipeline_manager,
include_meta_routes=True,
)
)
return dashscope_app
dashscope_router = create_dashscope_router(include_meta_routes=False)
app = create_dashscope_app()
if __name__ == "__main__":
# CLI parsing is intentionally only in script mode to keep module import safe.
cli_pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with CLI pipeline config: \n{cli_pipeline_config}")
cli_pipeline_manager = _build_pipeline_manager(cli_pipeline_config)
uvicorn.run(
"server_dashscope:app",
host="0.0.0.0",
port=pipeline_config.port,
reload=True,
create_dashscope_app(pipeline_manager=cli_pipeline_manager),
host=cli_pipeline_config.host,
port=cli_pipeline_config.port,
reload=False,
)