kimi magic

This commit is contained in:
2026-03-04 15:21:07 +08:00
parent 6ed33f3185
commit 91685d5bf7
3 changed files with 171 additions and 142 deletions

View File

@@ -28,12 +28,16 @@ API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
REGISTRY_FILE = os.environ.get(
"FAST_PIPELINE_REGISTRY_FILE",
osp.join(osp.dirname(osp.dirname(osp.abspath(__file__))), "configs", "pipeline_registry.json"),
osp.join(
osp.dirname(osp.dirname(osp.abspath(__file__))),
"configs",
"pipeline_registry.json",
),
)
PIPELINE_MANAGER = ServerPipelineManager(
default_route_id=os.environ.get("FAST_DEFAULT_ROUTE_ID", os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default")),
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
default_config=pipeline_config,
)
PIPELINE_MANAGER.load_registry(REGISTRY_FILE)
@@ -62,8 +66,10 @@ class DSApplicationCallRequest(BaseModel):
thread_id: Optional[str] = Field(default="3")
app = FastAPI(title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat")
app = FastAPI(
title="DashScope-Compatible Application API",
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
)
app.add_middleware(
CORSMiddleware,
@@ -74,7 +80,9 @@ app.add_middleware(
)
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen-flash"):
def sse_chunks_from_stream(
chunk_generator, response_id: str, model: str = "qwen-flash"
):
"""
Stream chunks from pipeline and format as SSE.
Accumulates text and sends incremental updates.
@@ -115,7 +123,9 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen
yield f"data: {json.dumps(final)}\n\n"
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"):
async def sse_chunks_from_astream(
chunk_generator, response_id: str, model: str = "qwen-flash"
):
"""
Async version: Stream chunks from pipeline and format as SSE.
Accumulates text and sends incremental updates.
@@ -202,22 +212,30 @@ async def _process_dashscope_request(
thread_id = body_input.get("session_id") or req_session_id or "3"
user_msg = _extract_user_message(messages)
route_id = PIPELINE_MANAGER.resolve_route_id(body=body, app_id=req_app_id, api_key=api_key)
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(route_id)
pipeline_id = PIPELINE_MANAGER.resolve_pipeline_id(
body=body, app_id=req_app_id, api_key=api_key
)
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(pipeline_id)
# Namespace thread ids to prevent memory collisions across pipelines.
thread_id = f"{route_id}:{thread_id}"
thread_id = f"{pipeline_id}:{thread_id}"
response_id = f"appcmpl-{os.urandom(12).hex()}"
if stream:
chunk_generator = await selected_pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
chunk_generator = await selected_pipeline.achat(
inp=user_msg, as_stream=True, thread_id=thread_id
)
return StreamingResponse(
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=selected_model),
sse_chunks_from_astream(
chunk_generator, response_id=response_id, model=selected_model
),
media_type="text/event-stream",
)
result_text = await selected_pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
result_text = await selected_pipeline.achat(
inp=user_msg, as_stream=False, thread_id=thread_id
)
if not isinstance(result_text, str):
result_text = str(result_text)
@@ -232,9 +250,7 @@ async def _process_dashscope_request(
"created": int(time.time()),
"model": selected_model,
},
"route_id": route_id,
# Backward compatibility: keep pipeline_id in response as the route id selector.
"pipeline_id": route_id,
"pipeline_id": pipeline_id,
"is_end": True,
}
return JSONResponse(content=data)
@@ -292,10 +308,13 @@ async def application_completion(
@app.get("/")
async def root():
return {"message": "DashScope Application-compatible API", "endpoints": [
"/v1/apps/{app_id}/sessions/{session_id}/responses",
"/health",
]}
return {
"message": "DashScope Application-compatible API",
"endpoints": [
"/v1/apps/{app_id}/sessions/{session_id}/responses",
"/health",
],
}
@app.get("/health")
@@ -310,5 +329,3 @@ if __name__ == "__main__":
port=pipeline_config.port,
reload=True,
)