make this importable without tyro fking around

This commit is contained in:
2026-03-05 11:43:16 +08:00
parent 55b37cc611
commit a2890148f9

View File

@@ -1,9 +1,8 @@
from fastapi import FastAPI, HTTPException, Path, Request, Depends, Security from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request, Security
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
import os import os
import os.path as osp import os.path as osp
import sys import sys
@@ -20,16 +19,28 @@ from lang_agent.pipeline import PipelineConfig
from lang_agent.components.server_pipeline_manager import ServerPipelineManager from lang_agent.components.server_pipeline_manager import ServerPipelineManager
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
# Load base config for route-level overrides (pipelines are lazy-loaded from registry) def _build_default_pipeline_config() -> PipelineConfig:
pipeline_config = tyro.cli(PipelineConfig) """
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}") Build import-time defaults without parsing CLI args.
This keeps module import safe for reuse by combined apps and tests.
"""
pipeline_config = PipelineConfig()
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
return pipeline_config
PIPELINE_MANAGER = ServerPipelineManager( def _build_pipeline_manager(base_config: PipelineConfig) -> ServerPipelineManager:
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"), pipeline_manager = ServerPipelineManager(
default_config=pipeline_config, default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
) default_config=base_config,
PIPELINE_MANAGER.load_registry(PIPELINE_REGISTRY_PATH) )
pipeline_manager.load_registry(PIPELINE_REGISTRY_PATH)
return pipeline_manager
pipeline_config = _build_default_pipeline_config()
PIPELINE_MANAGER = _build_pipeline_manager(pipeline_config)
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
@@ -55,20 +66,6 @@ class DSApplicationCallRequest(BaseModel):
thread_id: Optional[str] = Field(default="3") thread_id: Optional[str] = Field(default="3")
app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def sse_chunks_from_stream( def sse_chunks_from_stream(
chunk_generator, response_id: str, model: str = "qwen-flash" chunk_generator, response_id: str, model: str = "qwen-flash"
): ):
@@ -188,9 +185,10 @@ async def _process_dashscope_request(
app_id: Optional[str], app_id: Optional[str],
session_id: Optional[str], session_id: Optional[str],
api_key: str, api_key: str,
pipeline_manager: ServerPipelineManager,
): ):
try: try:
PIPELINE_MANAGER.refresh_registry_if_needed() pipeline_manager.refresh_registry_if_needed()
except Exception as e: except Exception as e:
logger.error(f"failed to refresh pipeline registry: {e}") logger.error(f"failed to refresh pipeline registry: {e}")
raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}") raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}")
@@ -207,10 +205,10 @@ async def _process_dashscope_request(
thread_id = body_input.get("session_id") or req_session_id or "3" thread_id = body_input.get("session_id") or req_session_id or "3"
user_msg = _extract_user_message(messages) user_msg = _extract_user_message(messages)
pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id( pipeline_id = pipeline_manager.resolve_pipeline_id(
body=body, app_id=req_app_id, api_key=api_key body=body, app_id=req_app_id, api_key=api_key
) )
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id) selected_pipeline, selected_model = pipeline_manager.get_pipeline(pipeline_id)
# Namespace thread ids to prevent memory collisions across pipelines. # Namespace thread ids to prevent memory collisions across pipelines.
thread_id = f"{pipeline_id}:{thread_id}" thread_id = f"{pipeline_id}:{thread_id}"
@@ -251,76 +249,117 @@ async def _process_dashscope_request(
return JSONResponse(content=data) return JSONResponse(content=data)
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses") def create_dashscope_router(
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses") pipeline_manager: Optional[ServerPipelineManager] = None,
async def application_responses( include_meta_routes: bool = True,
request: Request, ) -> APIRouter:
app_id: str = Path(...), manager = pipeline_manager or PIPELINE_MANAGER
session_id: str = Path(...), router = APIRouter()
api_key: str = Depends(verify_api_key),
): @router.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
try: @router.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
body = await request.json() async def application_responses(
return await _process_dashscope_request( request: Request,
body=body, app_id: str = Path(...),
app_id=app_id, session_id: str = Path(...),
session_id=session_id, api_key: str = Depends(verify_api_key),
api_key=api_key, ):
try:
body = await request.json()
return await _process_dashscope_request(
body=body,
app_id=app_id,
session_id=session_id,
api_key=api_key,
pipeline_manager=manager,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"DashScope-compatible endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and
# without session in path.
@router.post("/apps/{app_id}/completion")
@router.post("/v1/apps/{app_id}/completion")
@router.post("/api/apps/{app_id}/completion")
@router.post("/api/v1/apps/{app_id}/completion")
async def application_completion(
request: Request,
app_id: str = Path(...),
api_key: str = Depends(verify_api_key),
):
try:
body = await request.json()
return await _process_dashscope_request(
body=body,
app_id=app_id,
session_id=None,
api_key=api_key,
pipeline_manager=manager,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"DashScope-compatible completion error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if include_meta_routes:
@router.get("/")
async def root():
return {
"message": "DashScope Application-compatible API",
"endpoints": [
"/v1/apps/{app_id}/sessions/{session_id}/responses",
"/health",
],
}
@router.get("/health")
async def health():
return {"status": "healthy"}
return router
def create_dashscope_app(
pipeline_manager: Optional[ServerPipelineManager] = None,
) -> FastAPI:
dashscope_app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
dashscope_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
dashscope_app.include_router(
create_dashscope_router(
pipeline_manager=pipeline_manager,
include_meta_routes=True,
) )
)
except HTTPException: return dashscope_app
raise
except Exception as e:
logger.error(f"DashScope-compatible endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and without session in path dashscope_router = create_dashscope_router(include_meta_routes=False)
@app.post("/apps/{app_id}/completion") app = create_dashscope_app()
@app.post("/v1/apps/{app_id}/completion")
@app.post("/api/apps/{app_id}/completion")
@app.post("/api/v1/apps/{app_id}/completion")
async def application_completion(
request: Request,
app_id: str = Path(...),
api_key: str = Depends(verify_api_key),
):
try:
body = await request.json()
return await _process_dashscope_request(
body=body,
app_id=app_id,
session_id=None,
api_key=api_key,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"DashScope-compatible completion error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {
"message": "DashScope Application-compatible API",
"endpoints": [
"/v1/apps/{app_id}/sessions/{session_id}/responses",
"/health",
],
}
@app.get("/health")
async def health():
return {"status": "healthy"}
if __name__ == "__main__": if __name__ == "__main__":
# CLI parsing is intentionally only in script mode to keep module import safe.
cli_pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with CLI pipeline config: \n{cli_pipeline_config}")
cli_pipeline_manager = _build_pipeline_manager(cli_pipeline_config)
uvicorn.run( uvicorn.run(
"server_dashscope:app", create_dashscope_app(pipeline_manager=cli_pipeline_manager),
host="0.0.0.0", host=cli_pipeline_config.host,
port=pipeline_config.port, port=cli_pipeline_config.port,
reload=True, reload=False,
) )