make this importable without tyro fking around

This commit is contained in:
2026-03-05 11:43:16 +08:00
parent 55b37cc611
commit a2890148f9

View File

@@ -1,9 +1,8 @@
from fastapi import FastAPI, HTTPException, Path, Request, Depends, Security from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request, Security
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import os import os
import os.path as osp import os.path as osp
import sys import sys
@@ -20,16 +19,28 @@ from lang_agent.pipeline import PipelineConfig
from lang_agent.components.server_pipeline_manager import ServerPipelineManager from lang_agent.components.server_pipeline_manager import ServerPipelineManager
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
# Load base config for route-level overrides (pipelines are lazy-loaded from registry) def _build_default_pipeline_config() -> PipelineConfig:
pipeline_config = tyro.cli(PipelineConfig) """
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}") Build import-time defaults without parsing CLI args.
This keeps module import safe for reuse by combined apps and tests.
"""
pipeline_config = PipelineConfig()
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
return pipeline_config
PIPELINE_MANAGER = ServerPipelineManager( def _build_pipeline_manager(base_config: PipelineConfig) -> ServerPipelineManager:
pipeline_manager = ServerPipelineManager(
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"), default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
default_config=pipeline_config, default_config=base_config,
) )
PIPELINE_MANAGER.load_registry(PIPELINE_REGISTRY_PATH) pipeline_manager.load_registry(PIPELINE_REGISTRY_PATH)
return pipeline_manager
pipeline_config = _build_default_pipeline_config()
PIPELINE_MANAGER = _build_pipeline_manager(pipeline_config)
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
@@ -55,20 +66,6 @@ class DSApplicationCallRequest(BaseModel):
thread_id: Optional[str] = Field(default="3") thread_id: Optional[str] = Field(default="3")
app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def sse_chunks_from_stream( def sse_chunks_from_stream(
chunk_generator, response_id: str, model: str = "qwen-flash" chunk_generator, response_id: str, model: str = "qwen-flash"
): ):
@@ -188,9 +185,10 @@ async def _process_dashscope_request(
app_id: Optional[str], app_id: Optional[str],
session_id: Optional[str], session_id: Optional[str],
api_key: str, api_key: str,
pipeline_manager: ServerPipelineManager,
): ):
try: try:
PIPELINE_MANAGER.refresh_registry_if_needed() pipeline_manager.refresh_registry_if_needed()
except Exception as e: except Exception as e:
logger.error(f"failed to refresh pipeline registry: {e}") logger.error(f"failed to refresh pipeline registry: {e}")
raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}") raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}")
@@ -207,10 +205,10 @@ async def _process_dashscope_request(
thread_id = body_input.get("session_id") or req_session_id or "3" thread_id = body_input.get("session_id") or req_session_id or "3"
user_msg = _extract_user_message(messages) user_msg = _extract_user_message(messages)
pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id( pipeline_id = pipeline_manager.resolve_pipeline_id(
body=body, app_id=req_app_id, api_key=api_key body=body, app_id=req_app_id, api_key=api_key
) )
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id) selected_pipeline, selected_model = pipeline_manager.get_pipeline(pipeline_id)
# Namespace thread ids to prevent memory collisions across pipelines. # Namespace thread ids to prevent memory collisions across pipelines.
thread_id = f"{pipeline_id}:{thread_id}" thread_id = f"{pipeline_id}:{thread_id}"
@@ -251,14 +249,21 @@ async def _process_dashscope_request(
return JSONResponse(content=data) return JSONResponse(content=data)
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses") def create_dashscope_router(
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses") pipeline_manager: Optional[ServerPipelineManager] = None,
async def application_responses( include_meta_routes: bool = True,
) -> APIRouter:
manager = pipeline_manager or PIPELINE_MANAGER
router = APIRouter()
@router.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
@router.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
async def application_responses(
request: Request, request: Request,
app_id: str = Path(...), app_id: str = Path(...),
session_id: str = Path(...), session_id: str = Path(...),
api_key: str = Depends(verify_api_key), api_key: str = Depends(verify_api_key),
): ):
try: try:
body = await request.json() body = await request.json()
return await _process_dashscope_request( return await _process_dashscope_request(
@@ -266,6 +271,7 @@ async def application_responses(
app_id=app_id, app_id=app_id,
session_id=session_id, session_id=session_id,
api_key=api_key, api_key=api_key,
pipeline_manager=manager,
) )
except HTTPException: except HTTPException:
@@ -274,17 +280,17 @@ async def application_responses(
logger.error(f"DashScope-compatible endpoint error: {e}") logger.error(f"DashScope-compatible endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and without session in path # without session in path.
@app.post("/apps/{app_id}/completion") @router.post("/apps/{app_id}/completion")
@app.post("/v1/apps/{app_id}/completion") @router.post("/v1/apps/{app_id}/completion")
@app.post("/api/apps/{app_id}/completion") @router.post("/api/apps/{app_id}/completion")
@app.post("/api/v1/apps/{app_id}/completion") @router.post("/api/v1/apps/{app_id}/completion")
async def application_completion( async def application_completion(
request: Request, request: Request,
app_id: str = Path(...), app_id: str = Path(...),
api_key: str = Depends(verify_api_key), api_key: str = Depends(verify_api_key),
): ):
try: try:
body = await request.json() body = await request.json()
return await _process_dashscope_request( return await _process_dashscope_request(
@@ -292,6 +298,7 @@ async def application_completion(
app_id=app_id, app_id=app_id,
session_id=None, session_id=None,
api_key=api_key, api_key=api_key,
pipeline_manager=manager,
) )
except HTTPException: except HTTPException:
@@ -300,9 +307,9 @@ async def application_completion(
logger.error(f"DashScope-compatible completion error: {e}") logger.error(f"DashScope-compatible completion error: {e}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
if include_meta_routes:
@app.get("/") @router.get("/")
async def root(): async def root():
return { return {
"message": "DashScope Application-compatible API", "message": "DashScope Application-compatible API",
"endpoints": [ "endpoints": [
@@ -311,16 +318,48 @@ async def root():
], ],
} }
@router.get("/health")
@app.get("/health") async def health():
async def health():
return {"status": "healthy"} return {"status": "healthy"}
return router
def create_dashscope_app(
pipeline_manager: Optional[ServerPipelineManager] = None,
) -> FastAPI:
dashscope_app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
dashscope_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
dashscope_app.include_router(
create_dashscope_router(
pipeline_manager=pipeline_manager,
include_meta_routes=True,
)
)
return dashscope_app
dashscope_router = create_dashscope_router(include_meta_routes=False)
app = create_dashscope_app()
if __name__ == "__main__": if __name__ == "__main__":
# CLI parsing is intentionally only in script mode to keep module import safe.
cli_pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with CLI pipeline config: \n{cli_pipeline_config}")
cli_pipeline_manager = _build_pipeline_manager(cli_pipeline_config)
uvicorn.run( uvicorn.run(
"server_dashscope:app", create_dashscope_app(pipeline_manager=cli_pipeline_manager),
host="0.0.0.0", host=cli_pipeline_config.host,
port=pipeline_config.port, port=cli_pipeline_config.port,
reload=True, reload=False,
) )