#!/usr/bin/env python3 """ Integration tests for the REST API server (server_rest.py) This script tests the REST API endpoints using FastAPI's TestClient. To run: conda activate lang python test_rest_api.py Or with pytest: conda activate lang pytest test_rest_api.py -v Tests cover: - Health check endpoints - API key authentication - Conversation creation - Chat endpoint (streaming and non-streaming) - Message creation (streaming and non-streaming) - Memory deletion (global and per-conversation) - Error handling Requirements: - pytest (for structured testing) - Or run directly as a script """ import os import sys import json from unittest.mock import AsyncMock, MagicMock, patch # Set up test environment before importing the server os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3" # Add project root to path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) try: from fastapi.testclient import TestClient from langgraph.checkpoint.memory import MemorySaver HAS_TEST_CLIENT = True except ImportError: HAS_TEST_CLIENT = False print("Warning: fastapi.testclient not available. Install with: pip install pytest httpx") print("Falling back to basic import test only.") from fastapi_server.server_rest import app def create_mock_pipeline(): """Create a mock Pipeline instance for testing.""" pipeline = MagicMock() # Mock async function that returns async generator for streaming or string for non-streaming async def mock_achat(inp, as_stream=False, thread_id="test"): if as_stream: # Return async generator for streaming async def gen(): chunks = ["Hello", " ", "world", "!"] for chunk in chunks: yield chunk return gen() else: # Return string for non-streaming return "Hello world!" async def mock_aclear_memory(): return None pipeline.achat = AsyncMock(side_effect=mock_achat) pipeline.aclear_memory = AsyncMock(return_value=None) # Mock graph with memory mock_graph = MagicMock() mock_memory = MagicMock(spec=MemorySaver) mock_memory.delete_thread = MagicMock() mock_graph.memory = mock_memory pipeline.graph = mock_graph return pipeline def test_health_endpoints(): """Test health check endpoints.""" print("\n=== Testing Health Endpoints ===") with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): client = TestClient(app) # Test root endpoint response = client.get("/") assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert "message" in data, "Root endpoint should return message" assert "endpoints" in data, "Root endpoint should return endpoints list" print("✓ Root endpoint works") # Test health endpoint response = client.get("/health") assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert data["status"] == "healthy", "Health endpoint should return healthy status" print("✓ Health endpoint works") def test_authentication(): """Test API key authentication.""" print("\n=== Testing Authentication ===") with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): client = TestClient(app) # Test missing auth header response = client.post("/v1/conversations") assert response.status_code == 401, f"Expected 401 for missing auth, got {response.status_code}" print("✓ Missing auth header returns 401") # Test invalid API key response = client.post( "/v1/conversations", headers={"Authorization": "Bearer invalid-key"} ) assert response.status_code == 401, f"Expected 401 for invalid key, got {response.status_code}" print("✓ Invalid API key returns 401") # Test valid API key response = client.post( "/v1/conversations", headers={"Authorization": "Bearer test-key-1"} ) assert response.status_code == 200, f"Expected 200 for valid key, got {response.status_code}" print("✓ Valid API key works") # Test API key without Bearer prefix response = client.post( "/v1/conversations", headers={"Authorization": "test-key-1"} ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" print("✓ API key without Bearer prefix works") def test_conversation_creation(): """Test conversation creation.""" print("\n=== Testing Conversation Creation ===") with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): client = TestClient(app) auth_headers = {"Authorization": "Bearer test-key-1"} response = client.post("/v1/conversations", headers=auth_headers) assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert "id" in data, "Response should contain id" assert "created_at" in data, "Response should contain created_at" assert data["id"].startswith("c_"), "Conversation ID should start with 'c_'" print(f"✓ Created conversation: {data['id']}") def test_chat_endpoint(): """Test chat endpoint.""" print("\n=== Testing Chat Endpoint ===") mock_pipeline = create_mock_pipeline() with patch("fastapi_server.server_rest.pipeline", mock_pipeline): client = TestClient(app) auth_headers = {"Authorization": "Bearer test-key-1"} # Test non-streaming chat response = client.post( "/v1/chat", headers=auth_headers, json={ "input": "Hello, how are you?", "stream": False } ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert "conversation_id" in data, "Response should contain conversation_id" assert "output" in data, "Response should contain output" assert data["output"] == "Hello world!", "Output should match expected" print(f"✓ Non-streaming chat works: {data['conversation_id']}") # Test chat with existing conversation_id conv_id = "c_test123" response = client.post( "/v1/chat", headers=auth_headers, json={ "input": "Hello", "conversation_id": conv_id, "stream": False } ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert data["conversation_id"] == conv_id, "Should use provided conversation_id" print(f"✓ Chat with existing conversation_id works") # Test streaming chat response = client.post( "/v1/chat", headers=auth_headers, json={ "input": "Hello", "stream": True } ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" assert "text/event-stream" in response.headers["content-type"], "Should return event-stream" # Parse streaming response lines = response.text.split("\n") data_lines = [line for line in lines if line.startswith("data: ")] assert len(data_lines) > 0, "Should have streaming data" # Parse first delta event first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix assert first_data["type"] == "delta", "First event should be delta" assert "conversation_id" in first_data, "Delta should contain conversation_id" assert "delta" in first_data, "Delta should contain delta field" print("✓ Streaming chat works") def test_message_endpoint(): """Test message creation endpoint.""" print("\n=== Testing Message Endpoint ===") mock_pipeline = create_mock_pipeline() with patch("fastapi_server.server_rest.pipeline", mock_pipeline): client = TestClient(app) auth_headers = {"Authorization": "Bearer test-key-1"} conv_id = "c_test123" # Test non-streaming message response = client.post( f"/v1/conversations/{conv_id}/messages", headers=auth_headers, json={ "role": "user", "content": "Hello, how are you?", "stream": False } ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert data["conversation_id"] == conv_id, "Should return correct conversation_id" assert "message" in data, "Response should contain message" assert data["message"]["role"] == "assistant", "Message role should be assistant" assert "content" in data["message"], "Message should contain content" print("✓ Non-streaming message creation works") # Test streaming message response = client.post( f"/v1/conversations/{conv_id}/messages", headers=auth_headers, json={ "role": "user", "content": "Hello", "stream": True } ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" assert "text/event-stream" in response.headers["content-type"], "Should return event-stream" print("✓ Streaming message creation works") # Test invalid role response = client.post( f"/v1/conversations/{conv_id}/messages", headers=auth_headers, json={ "role": "assistant", "content": "Hello", "stream": False } ) assert response.status_code == 400, f"Expected 400 for invalid role, got {response.status_code}" assert "Only role='user' is supported" in response.json()["detail"], "Should reject non-user role" print("✓ Invalid role rejection works") def test_memory_deletion(): """Test memory deletion endpoints.""" print("\n=== Testing Memory Deletion ===") mock_pipeline = create_mock_pipeline() with patch("fastapi_server.server_rest.pipeline", mock_pipeline): client = TestClient(app) auth_headers = {"Authorization": "Bearer test-key-1"} # Test delete all memory response = client.delete("/v1/memory", headers=auth_headers) assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert data["status"] == "success", "Should return success status" assert data["scope"] == "all", "Should indicate all scope" mock_pipeline.aclear_memory.assert_called_once() print("✓ Delete all memory works") # Test delete conversation memory conv_id = "c_test123" response = client.delete( f"/v1/conversations/{conv_id}/memory", headers=auth_headers ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" data = response.json() assert data["status"] == "success", "Should return success status" assert data["scope"] == "conversation", "Should indicate conversation scope" assert data["conversation_id"] == conv_id, "Should return conversation_id" mock_pipeline.graph.memory.delete_thread.assert_called() print("✓ Delete conversation memory works") # Test delete conversation memory with device_id format conv_id_with_device = "c_test123_device456" response = client.delete( f"/v1/conversations/{conv_id_with_device}/memory", headers=auth_headers ) assert response.status_code == 200, f"Expected 200, got {response.status_code}" # Should normalize to base thread_id mock_pipeline.graph.memory.delete_thread.assert_called_with("c_test123") print("✓ Delete conversation memory with device_id format works") # Test delete conversation memory when MemorySaver is not available mock_pipeline_no_mem = create_mock_pipeline() mock_pipeline_no_mem.graph.memory = None with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem): response = client.delete( f"/v1/conversations/{conv_id}/memory", headers=auth_headers ) assert response.status_code == 501, f"Expected 501, got {response.status_code}" data = response.json() assert data["status"] == "unsupported", "Should return unsupported status" print("✓ Unsupported memory deletion handling works") def test_error_handling(): """Test error handling.""" print("\n=== Testing Error Handling ===") with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()): client = TestClient(app) auth_headers = {"Authorization": "Bearer test-key-1"} # Test chat with missing input response = client.post( "/v1/chat", headers=auth_headers, json={"stream": False} ) assert response.status_code == 422, f"Expected 422 for validation error, got {response.status_code}" print("✓ Missing input validation works") # Test message with missing content response = client.post( "/v1/conversations/c_test123/messages", headers=auth_headers, json={"role": "user", "stream": False} ) assert response.status_code == 422, f"Expected 422 for validation error, got {response.status_code}" print("✓ Missing content validation works") def run_all_tests(): """Run all tests.""" if not HAS_TEST_CLIENT: print("Cannot run tests: fastapi.testclient not available") print("Install with: pip install pytest httpx") return False print("=" * 60) print("Running REST API Tests") print("=" * 60) try: test_health_endpoints() test_authentication() test_conversation_creation() test_chat_endpoint() test_message_endpoint() test_memory_deletion() test_error_handling() print("\n" + "=" * 60) print("✓ All tests passed!") print("=" * 60) return True except AssertionError as e: print(f"\n✗ Test failed: {e}") import traceback traceback.print_exc() return False except Exception as e: print(f"\n✗ Unexpected error: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": success = run_all_tests() sys.exit(0 if success else 1)