Files
lang-agent/fastapi_server/server.py

315 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import FastAPI, HTTPException, Depends, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
import os
import sys
import time
import uvicorn
import httpx
import openai
import json
from loguru import logger
# 添加父目录到系统路径以便导入lang_agent模块
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lang_agent.pipeline import Pipeline, PipelineConfig
# 定义OpenAI格式的请求模型
class ChatMessage(BaseModel):
role: str = Field(..., description="消息角色,可以是 'system', 'user', 'assistant'")
content: str = Field(..., description="消息内容")
class ChatCompletionRequest(BaseModel):
model: str = Field(default="qwen-flash", description="模型名称")
messages: List[ChatMessage] = Field(..., description="对话消息列表")
temperature: Optional[float] = Field(default=0.7, description="采样温度")
max_tokens: Optional[int] = Field(default=500, description="最大生成token数")
stream: Optional[bool] = Field(default=False, description="是否流式返回")
thread_id: Optional[int] = Field(default=3, description="线程ID用于多轮对话")
llm_provider: Optional[str] = Field(default="openai", description="LLM提供商")
base_url: Optional[str] = Field(default="https://dashscope.aliyuncs.com/compatible-mode/v1", description="LLM API基础URL")
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ChatCompletionResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
usage: Optional[ChatCompletionResponseUsage] = None
# OpenAI客户端包装类
class OpenAIClientWrapper:
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
timeout: float = 60.0,
model_name: str = "qwen-flash",
max_tokens: int = 500,
temperature: float = 0.7,
top_p: float = 1.0,
frequency_penalty: float = 0.0,
):
"""
初始化OpenAI客户端包装器
Args:
api_key: API密钥如果为None则从环境变量OPENAI_API_KEY获取
base_url: API基础URL如果为None则从环境变量OPENAI_BASE_URL获取
timeout: 请求超时时间(秒)
model_name: 默认模型名称
max_tokens: 默认最大token数
temperature: 默认采样温度
top_p: 默认top_p参数
frequency_penalty: 默认频率惩罚
"""
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "")
self.base_url = base_url or os.getenv("OPENAI_BASE_URL", None)
self.timeout = timeout
self.model_name = model_name
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.client = openai.OpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=httpx.Timeout(self.timeout)
)
def response(self, session_id: str, dialogue: List[Dict[str, str]], **kwargs):
"""
生成聊天响应(流式)
Args:
session_id: 会话ID
dialogue: 对话消息列表,格式为 [{"role": "user", "content": "..."}, ...]
**kwargs: 额外的参数可以覆盖默认的max_tokens, temperature, top_p, frequency_penalty
Returns:
OpenAI流式响应对象
"""
try:
responses = self.client.chat.completions.create(
model=self.model_name,
messages=dialogue,
stream=True,
max_tokens=kwargs.get("max_tokens", self.max_tokens),
temperature=kwargs.get("temperature", self.temperature),
top_p=kwargs.get("top_p", self.top_p),
frequency_penalty=kwargs.get("frequency_penalty", self.frequency_penalty),
)
return responses
except Exception as e:
logger.error(f"OpenAI客户端响应错误: {str(e)}")
raise
# 初始化FastAPI应用
app = FastAPI(title="Lang Agent Chat API", description="使用OpenAI格式调用pipeline.invoke的聊天API")
# 设置API密钥
API_KEY = "123tangledup-ai"
# 创建安全方案
security = HTTPBearer()
# 验证API密钥的依赖项
# async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)):
# if credentials.credentials != API_KEY:
# raise HTTPException(
# status_code=401,
# detail="无效的API密钥",
# headers={"WWW-Authenticate": "Bearer"},
# )
# return credentials
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 初始化Pipeline
pipeline_config = PipelineConfig()
pipeline_config.llm_name = "qwen-flash"
pipeline_config.llm_provider = "openai"
pipeline_config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
pipeline = Pipeline(pipeline_config)
# 初始化OpenAI客户端包装器可选用于直接调用OpenAI API
openai_client = OpenAIClientWrapper(
api_key=os.getenv("OPENAI_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
timeout=60.0,
model_name="qwen-flash",
max_tokens=500,
temperature=0.7,
top_p=1.0,
frequency_penalty=0.0,
)
def generate_streaming_chunks(full_text: str, response_id: str, model: str, chunk_size: int = 10):
"""
Generate streaming chunks from non-streaming result
"""
created_time = int(time.time())
# Stream content chunks
for i in range(0, len(full_text), chunk_size):
chunk = full_text[i:i + chunk_size]
if chunk:
chunk_data = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None
}
]
}
yield f"data: {json.dumps(chunk_data)}\n\n"
# Send final chunk with finish_reason
final_chunk = {
"id": response_id,
"object": "chat.completion.chunk",
"created": created_time,
"model": model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/chat/completions")
async def chat_completions(
request: ChatCompletionRequest#,
# credentials: HTTPAuthorizationCredentials = Depends(verify_api_key)
):
"""
使用OpenAI格式的聊天完成API
"""
try:
# 提取用户消息
user_message = None
system_message = None
# TODO: wrap this sht as human and system message
for message in request.messages:
if message.role == "user":
user_message = message.content
elif message.role == "system" or message.role == "assistant":
system_message = message.content
if not user_message:
raise HTTPException(status_code=400, detail="缺少用户消息")
# 调用pipeline的chat方法 (always get non-streaming result)
response_content = pipeline.chat(
inp=user_message,
as_stream=False, # Always get full result, then chunk it if streaming
thread_id=request.thread_id
)
# Ensure response_content is a string
if not isinstance(response_content, str):
response_content = str(response_content)
logger.info(f"Pipeline response - Length: {len(response_content)}, Content: {repr(response_content[:200])}")
if len(response_content) == 0:
logger.warning("Pipeline returned empty response!")
response_id = f"chatcmpl-{os.urandom(12).hex()}"
# If streaming requested, return streaming response
if request.stream:
return StreamingResponse(
generate_streaming_chunks(
full_text=response_content,
response_id=response_id,
model=request.model,
chunk_size=10
),
media_type="text/event-stream"
)
# Otherwise return normal response
response = ChatCompletionResponse(
id=response_id,
created=int(time.time()),
model=request.model,
choices=[
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response_content),
finish_reason="stop"
)
]
)
return response
except Exception as e:
logger.error(f"处理聊天请求时出错: {str(e)}")
raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")
@app.get("/")
async def root():
"""
根路径返回API信息
"""
return {
"message": "Lang Agent Chat API",
"version": "1.0.0",
"description": "使用OpenAI格式调用pipeline.invoke的聊天API",
"authentication": "Bearer Token (API Key)",
"endpoints": {
"/v1/chat/completions": "POST - 聊天完成接口兼容OpenAI格式需要API密钥验证",
"/": "GET - API信息",
"/health": "GET - 健康检查接口"
}
}
@app.get("/health")
async def health_check():
"""
健康检查接口
"""
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(
"server:app",
host="0.0.0.0",
port=8488,
reload=True
)