diff --git a/fastapi_server/server_openai.py b/fastapi_server/server_openai.py new file mode 100644 index 0000000..b04c12b --- /dev/null +++ b/fastapi_server/server_openai.py @@ -0,0 +1,188 @@ +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from pydantic import BaseModel, Field +from typing import List, Optional, Union, Literal +import os +import sys +import time +import json +import uvicorn +from loguru import logger +import tyro + +# Ensure we can import from project root +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from lang_agent.pipeline import Pipeline, PipelineConfig + +# Initialize Pipeline once +pipeline_config = tyro.cli(PipelineConfig) +pipeline: Pipeline = pipeline_config.setup() + + +class OpenAIMessage(BaseModel): + role: str + content: str + + +class OpenAIChatCompletionRequest(BaseModel): + model: str = Field(default="gpt-3.5-turbo") + messages: List[OpenAIMessage] + stream: bool = Field(default=False) + temperature: Optional[float] = Field(default=1.0) + max_tokens: Optional[int] = Field(default=None) + # Optional overrides for pipeline behavior + thread_id: Optional[int] = Field(default=3) + + +app = FastAPI( + title="OpenAI-Compatible Chat API", + description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat" +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, created_time: int): + """ + Stream chunks from pipeline and format as OpenAI SSE. + """ + for chunk in chunk_generator: + if chunk: + data = { + "id": response_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": chunk + }, + "finish_reason": None + } + ] + } + yield f"data: {json.dumps(data)}\n\n" + + # Final message + final = { + "id": response_id, + "object": "chat.completion.chunk", + "created": created_time, + "model": model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop" + } + ] + } + yield f"data: {json.dumps(final)}\n\n" + yield "data: [DONE]\n\n" + + +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + try: + body = await request.json() + + messages = body.get("messages") + if not messages: + raise HTTPException(status_code=400, detail="messages is required") + + stream = body.get("stream", False) + model = body.get("model", "gpt-3.5-turbo") + thread_id = body.get("thread_id", 3) + + # Extract latest user message + user_msg = None + for m in reversed(messages): + role = m.get("role") if isinstance(m, dict) else None + content = m.get("content") if isinstance(m, dict) else None + if role == "user" and content: + user_msg = content + break + + if user_msg is None: + last = messages[-1] + user_msg = last.get("content") if isinstance(last, dict) else str(last) + + response_id = f"chatcmpl-{os.urandom(12).hex()}" + created_time = int(time.time()) + + if stream: + # Use actual streaming from pipeline + chunk_generator = pipeline.chat(inp=user_msg, as_stream=True, thread_id=thread_id) + return StreamingResponse( + sse_chunks_from_stream(chunk_generator, response_id=response_id, model=model, created_time=created_time), + media_type="text/event-stream", + ) + + # Non-streaming: get full result + result_text = pipeline.chat(inp=user_msg, as_stream=False, thread_id=thread_id) + if not isinstance(result_text, str): + result_text = str(result_text) + + data = { + "id": response_id, + "object": "chat.completion", + "created": created_time, + "model": model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": result_text + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + } + return JSONResponse(content=data) + + except HTTPException: + raise + except Exception as e: + logger.error(f"OpenAI-compatible endpoint error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/") +async def root(): + return { + "message": "OpenAI-compatible Chat API", + "endpoints": [ + "/v1/chat/completions", + "/health" + ] + } + + +@app.get("/health") +async def health(): + return {"status": "healthy"} + + +if __name__ == "__main__": + uvicorn.run( + "server_openai:app", + host="0.0.0.0", + port=8589, + reload=True, + ) diff --git a/fastapi_server/test_openai_client.py b/fastapi_server/test_openai_client.py new file mode 100644 index 0000000..3e91c3c --- /dev/null +++ b/fastapi_server/test_openai_client.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +Test for OpenAI-compatible API against server_openai.py + +Instructions: +- Start the OpenAI-compatible server first, e.g.: + python fastapi_server/server_openai.py --llm_name qwen-plus --llm_provider openai --base_url https://dashscope.aliyuncs.com/compatible-mode/v1 +- Or with uvicorn: + uvicorn fastapi_server.server_openai:app --host 0.0.0.0 --port 8589 --reload +- Set BASE_URL below to the server base URL you started. +""" +import os +from dotenv import load_dotenv +from loguru import logger + +TAG = __name__ + +load_dotenv() + +try: + from openai import OpenAI +except Exception as e: + print("openai package not found. Please install it: pip install openai") + raise + + +# <<< Paste your running FastAPI base url here >>> +BASE_URL = os.getenv("OPENAI_BASE_URL", "http://127.0.0.1:8589/v1") + +# Test configuration matching the server setup +# llm_name: "qwen-plus" +# llm_provider: "openai" +# base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + +# Test messages +messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "use calculator to calculate 1234*5641"}, +] + + +def test_streaming(): + """Test streaming chat completion""" + print("\n" + "="*60) + print("Testing STREAMING chat completion...") + print("="*60 + "\n") + + client = OpenAI( + base_url=BASE_URL, + api_key="test-key" # Dummy key for testing + ) + + try: + stream = client.chat.completions.create( + model="qwen-plus", # Using qwen-plus as configured + messages=messages, + stream=True + ) + + full_response = "" + for chunk in stream: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + full_response += content + print(content, end="", flush=True) + + print("\n\n" + "-"*60) + print(f"Full streaming response length: {len(full_response)}") + print("-"*60) + + return full_response + + except Exception as e: + logger.error(f"Streaming test error: {e}") + raise + + +def test_non_streaming(): + """Test non-streaming chat completion""" + print("\n" + "="*60) + print("Testing NON-STREAMING chat completion...") + print("="*60 + "\n") + + client = OpenAI( + base_url=BASE_URL, + api_key="test-key" # Dummy key for testing + ) + + try: + response = client.chat.completions.create( + model="qwen-plus", # Using qwen-plus as configured + messages=messages, + stream=False + ) + + content = response.choices[0].message.content + print(f"Response: {content}") + print("\n" + "-"*60) + print(f"Full non-streaming response length: {len(content)}") + print(f"Finish reason: {response.choices[0].finish_reason}") + print("-"*60) + + return content + + except Exception as e: + logger.error(f"Non-streaming test error: {e}") + raise + + +def main(): + print(f"\nUsing base_url = {BASE_URL}\n") + + # Test both streaming and non-streaming + streaming_result = test_streaming() + non_streaming_result = test_non_streaming() + + print("\n" + "="*60) + print("SUMMARY") + print("="*60) + print(f"Streaming response length: {len(streaming_result)}") + print(f"Non-streaming response length: {len(non_streaming_result)}") + print("\nBoth tests completed successfully!") + + +if __name__ == "__main__": + main()