From cbefea447b9be73429ddf120eb102da5792f5b6b Mon Sep 17 00:00:00 2001 From: goulustis Date: Thu, 29 Jan 2026 10:05:08 +0800 Subject: [PATCH] test rest api --- test_rest_api.py | 408 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) create mode 100644 test_rest_api.py diff --git a/test_rest_api.py b/test_rest_api.py new file mode 100644 index 0000000..b15115b --- /dev/null +++ b/test_rest_api.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +Integration tests for the REST API server (server_rest.py) + +This script tests the REST API endpoints using FastAPI's TestClient. + +To run: + conda activate lang + python test_rest_api.py + +Or with pytest: + conda activate lang + pytest test_rest_api.py -v + +Tests cover: +- Health check endpoints +- API key authentication +- Conversation creation +- Chat endpoint (streaming and non-streaming) +- Message creation (streaming and non-streaming) +- Memory deletion (global and per-conversation) +- Error handling + +Requirements: +- pytest (for structured testing) +- Or run directly as a script +""" +import os +import sys +import json +from unittest.mock import AsyncMock, MagicMock, patch + +# Set up test environment before importing the server +os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3" + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +try: + from fastapi.testclient import TestClient + from langgraph.checkpoint.memory import MemorySaver + HAS_TEST_CLIENT = True +except ImportError: + HAS_TEST_CLIENT = False + print("Warning: fastapi.testclient not available. Install with: pip install pytest httpx") + print("Falling back to basic import test only.") + +from fastapi_server.server_rest import app + + +def create_mock_pipeline(): + """Create a mock Pipeline instance for testing.""" + pipeline = MagicMock() + + # Mock async function that returns async generator for streaming or string for non-streaming + async def mock_achat(inp, as_stream=False, thread_id="test"): + if as_stream: + # Return async generator for streaming + async def gen(): + chunks = ["Hello", " ", "world", "!"] + for chunk in chunks: + yield chunk + return gen() + else: + # Return string for non-streaming + return "Hello world!" + + async def mock_aclear_memory(): + return None + + pipeline.achat = AsyncMock(side_effect=mock_achat) + pipeline.aclear_memory = AsyncMock(return_value=None) + + # Mock graph with memory + mock_graph = MagicMock() + mock_memory = MagicMock(spec=MemorySaver) + mock_memory.delete_thread = MagicMock() + mock_graph.memory = mock_memory + pipeline.graph = mock_graph + + return pipeline + + +def test_health_endpoints(): + """Test health check endpoints.""" + print("\n=== Testing Health Endpoints ===") + + with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): + client = TestClient(app) + + # Test root endpoint + response = client.get("/") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert "message" in data, "Root endpoint should return message" + assert "endpoints" in data, "Root endpoint should return endpoints list" + print("✓ Root endpoint works") + + # Test health endpoint + response = client.get("/health") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert data["status"] == "healthy", "Health endpoint should return healthy status" + print("✓ Health endpoint works") + + +def test_authentication(): + """Test API key authentication.""" + print("\n=== Testing Authentication ===") + + with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): + client = TestClient(app) + + # Test missing auth header + response = client.post("/v1/conversations") + assert response.status_code == 401, f"Expected 401 for missing auth, got {response.status_code}" + print("✓ Missing auth header returns 401") + + # Test invalid API key + response = client.post( + "/v1/conversations", + headers={"Authorization": "Bearer invalid-key"} + ) + assert response.status_code == 401, f"Expected 401 for invalid key, got {response.status_code}" + print("✓ Invalid API key returns 401") + + # Test valid API key + response = client.post( + "/v1/conversations", + headers={"Authorization": "Bearer test-key-1"} + ) + assert response.status_code == 200, f"Expected 200 for valid key, got {response.status_code}" + print("✓ Valid API key works") + + # Test API key without Bearer prefix + response = client.post( + "/v1/conversations", + headers={"Authorization": "test-key-1"} + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + print("✓ API key without Bearer prefix works") + + +def test_conversation_creation(): + """Test conversation creation.""" + print("\n=== Testing Conversation Creation ===") + + with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): + client = TestClient(app) + auth_headers = {"Authorization": "Bearer test-key-1"} + + response = client.post("/v1/conversations", headers=auth_headers) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert "id" in data, "Response should contain id" + assert "created_at" in data, "Response should contain created_at" + assert data["id"].startswith("c_"), "Conversation ID should start with 'c_'" + print(f"✓ Created conversation: {data['id']}") + + +def test_chat_endpoint(): + """Test chat endpoint.""" + print("\n=== Testing Chat Endpoint ===") + + mock_pipeline = create_mock_pipeline() + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + client = TestClient(app) + auth_headers = {"Authorization": "Bearer test-key-1"} + + # Test non-streaming chat + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello, how are you?", + "stream": False + } + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert "conversation_id" in data, "Response should contain conversation_id" + assert "output" in data, "Response should contain output" + assert data["output"] == "Hello world!", "Output should match expected" + print(f"✓ Non-streaming chat works: {data['conversation_id']}") + + # Test chat with existing conversation_id + conv_id = "c_test123" + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello", + "conversation_id": conv_id, + "stream": False + } + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert data["conversation_id"] == conv_id, "Should use provided conversation_id" + print(f"✓ Chat with existing conversation_id works") + + # Test streaming chat + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello", + "stream": True + } + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + assert "text/event-stream" in response.headers["content-type"], "Should return event-stream" + + # Parse streaming response + lines = response.text.split("\n") + data_lines = [line for line in lines if line.startswith("data: ")] + assert len(data_lines) > 0, "Should have streaming data" + + # Parse first delta event + first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix + assert first_data["type"] == "delta", "First event should be delta" + assert "conversation_id" in first_data, "Delta should contain conversation_id" + assert "delta" in first_data, "Delta should contain delta field" + print("✓ Streaming chat works") + + +def test_message_endpoint(): + """Test message creation endpoint.""" + print("\n=== Testing Message Endpoint ===") + + mock_pipeline = create_mock_pipeline() + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + client = TestClient(app) + auth_headers = {"Authorization": "Bearer test-key-1"} + conv_id = "c_test123" + + # Test non-streaming message + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "user", + "content": "Hello, how are you?", + "stream": False + } + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert data["conversation_id"] == conv_id, "Should return correct conversation_id" + assert "message" in data, "Response should contain message" + assert data["message"]["role"] == "assistant", "Message role should be assistant" + assert "content" in data["message"], "Message should contain content" + print("✓ Non-streaming message creation works") + + # Test streaming message + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "user", + "content": "Hello", + "stream": True + } + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + assert "text/event-stream" in response.headers["content-type"], "Should return event-stream" + print("✓ Streaming message creation works") + + # Test invalid role + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "assistant", + "content": "Hello", + "stream": False + } + ) + assert response.status_code == 400, f"Expected 400 for invalid role, got {response.status_code}" + assert "Only role='user' is supported" in response.json()["detail"], "Should reject non-user role" + print("✓ Invalid role rejection works") + + +def test_memory_deletion(): + """Test memory deletion endpoints.""" + print("\n=== Testing Memory Deletion ===") + + mock_pipeline = create_mock_pipeline() + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + client = TestClient(app) + auth_headers = {"Authorization": "Bearer test-key-1"} + + # Test delete all memory + response = client.delete("/v1/memory", headers=auth_headers) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert data["status"] == "success", "Should return success status" + assert data["scope"] == "all", "Should indicate all scope" + mock_pipeline.aclear_memory.assert_called_once() + print("✓ Delete all memory works") + + # Test delete conversation memory + conv_id = "c_test123" + response = client.delete( + f"/v1/conversations/{conv_id}/memory", + headers=auth_headers + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + data = response.json() + assert data["status"] == "success", "Should return success status" + assert data["scope"] == "conversation", "Should indicate conversation scope" + assert data["conversation_id"] == conv_id, "Should return conversation_id" + mock_pipeline.graph.memory.delete_thread.assert_called() + print("✓ Delete conversation memory works") + + # Test delete conversation memory with device_id format + conv_id_with_device = "c_test123_device456" + response = client.delete( + f"/v1/conversations/{conv_id_with_device}/memory", + headers=auth_headers + ) + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + # Should normalize to base thread_id + mock_pipeline.graph.memory.delete_thread.assert_called_with("c_test123") + print("✓ Delete conversation memory with device_id format works") + + # Test delete conversation memory when MemorySaver is not available + mock_pipeline_no_mem = create_mock_pipeline() + mock_pipeline_no_mem.graph.memory = None + + with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem): + response = client.delete( + f"/v1/conversations/{conv_id}/memory", + headers=auth_headers + ) + assert response.status_code == 501, f"Expected 501, got {response.status_code}" + data = response.json() + assert data["status"] == "unsupported", "Should return unsupported status" + print("✓ Unsupported memory deletion handling works") + + +def test_error_handling(): + """Test error handling.""" + print("\n=== Testing Error Handling ===") + + with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): + client = TestClient(app) + auth_headers = {"Authorization": "Bearer test-key-1"} + + # Test chat with missing input + response = client.post( + "/v1/chat", + headers=auth_headers, + json={"stream": False} + ) + assert response.status_code == 422, f"Expected 422 for validation error, got {response.status_code}" + print("✓ Missing input validation works") + + # Test message with missing content + response = client.post( + "/v1/conversations/c_test123/messages", + headers=auth_headers, + json={"role": "user", "stream": False} + ) + assert response.status_code == 422, f"Expected 422 for validation error, got {response.status_code}" + print("✓ Missing content validation works") + + +def run_all_tests(): + """Run all tests.""" + if not HAS_TEST_CLIENT: + print("Cannot run tests: fastapi.testclient not available") + print("Install with: pip install pytest httpx") + return False + + print("=" * 60) + print("Running REST API Tests") + print("=" * 60) + + try: + test_health_endpoints() + test_authentication() + test_conversation_creation() + test_chat_endpoint() + test_message_endpoint() + test_memory_deletion() + test_error_handling() + + print("\n" + "=" * 60) + print("✓ All tests passed!") + print("=" * 60) + return True + except AssertionError as e: + print(f"\n✗ Test failed: {e}") + import traceback + traceback.print_exc() + return False + except Exception as e: + print(f"\n✗ Unexpected error: {e}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) +