diff --git a/tests/test_server_rest.py b/tests/test_server_rest.py new file mode 100644 index 0000000..413a3ec --- /dev/null +++ b/tests/test_server_rest.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 +""" +Tests for the REST API server (server_rest.py) + +This test suite covers: +- Health check endpoints (GET /, GET /health) +- API key authentication (valid/invalid keys, Bearer format) +- Conversation creation (POST /v1/conversations) +- Chat endpoint (POST /v1/chat) - streaming and non-streaming +- Message creation (POST /v1/conversations/{id}/messages) - streaming and non-streaming +- Memory deletion (DELETE /v1/memory, DELETE /v1/conversations/{id}/memory) +- Edge cases and error handling + +Run with: pytest tests/test_server_rest.py -v +""" +import os +import pytest +import json +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi.testclient import TestClient +from langgraph.checkpoint.memory import MemorySaver + +# Set up test environment before importing the server +os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3" + +# Import after setting environment +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from fastapi_server.server_rest import app + + +@pytest.fixture +def mock_pipeline(): + """Create a mock Pipeline instance.""" + pipeline = MagicMock() + + # Mock async generator for streaming + async def mock_achat_stream(inp, as_stream=True, thread_id="test"): + chunks = ["Hello", " ", "world", "!"] + for chunk in chunks: + yield chunk + + # Mock async function that returns async generator for streaming or string for non-streaming + async def mock_achat(inp, as_stream=False, thread_id="test"): + if as_stream: + # Return async generator for streaming + async def gen(): + chunks = ["Hello", " ", "world", "!"] + for chunk in chunks: + yield chunk + return gen() + else: + # Return string for non-streaming + return "Hello world!" + + async def mock_aclear_memory(): + return None + + pipeline.achat = AsyncMock(side_effect=mock_achat) + pipeline.aclear_memory = AsyncMock(return_value=None) + + # Mock graph with memory + mock_graph = MagicMock() + mock_memory = MagicMock(spec=MemorySaver) + mock_memory.delete_thread = MagicMock() + mock_graph.memory = mock_memory + pipeline.graph = mock_graph + + return pipeline + + +@pytest.fixture +def client(mock_pipeline): + """Create a test client with mocked pipeline.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + with TestClient(app) as test_client: + yield test_client + + +@pytest.fixture +def auth_headers(): + """Return valid authentication headers.""" + return {"Authorization": "Bearer test-key-1"} + + +@pytest.fixture +def invalid_auth_headers(): + """Return invalid authentication headers.""" + return {"Authorization": "Bearer invalid-key"} + + +class TestHealthCheck: + """Tests for health check endpoint.""" + + def test_root_endpoint(self, client): + """Test root endpoint returns API information.""" + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert "message" in data + assert "endpoints" in data + assert isinstance(data["endpoints"], list) + + def test_health_endpoint(self, client): + """Test health endpoint returns healthy status.""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + +class TestAuthentication: + """Tests for API key authentication.""" + + def test_missing_auth_header(self, client): + """Test that missing auth header returns 401.""" + response = client.post("/v1/conversations") + assert response.status_code == 401 + + def test_invalid_api_key(self, client, invalid_auth_headers): + """Test that invalid API key returns 401.""" + response = client.post( + "/v1/conversations", + headers=invalid_auth_headers + ) + assert response.status_code == 401 + assert "Invalid API key" in response.json()["detail"] + + def test_valid_api_key_bearer_format(self, client, auth_headers): + """Test that valid API key with Bearer prefix works.""" + response = client.post( + "/v1/conversations", + headers=auth_headers + ) + assert response.status_code == 200 + + def test_valid_api_key_without_bearer(self, client): + """Test that valid API key without Bearer prefix works.""" + response = client.post( + "/v1/conversations", + headers={"Authorization": "test-key-1"} + ) + assert response.status_code == 200 + + def test_multiple_valid_keys(self, client): + """Test that any of the configured keys work.""" + for key in ["test-key-1", "test-key-2", "test-key-3"]: + response = client.post( + "/v1/conversations", + headers={"Authorization": f"Bearer {key}"} + ) + assert response.status_code == 200 + + +class TestConversationCreation: + """Tests for conversation creation endpoint.""" + + def test_create_conversation(self, client, auth_headers): + """Test creating a new conversation.""" + response = client.post( + "/v1/conversations", + headers=auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert "created_at" in data + assert data["id"].startswith("c_") + assert len(data["id"]) > 2 + + def test_conversation_id_format(self, client, auth_headers): + """Test that conversation IDs follow expected format.""" + response = client.post( + "/v1/conversations", + headers=auth_headers + ) + data = response.json() + conv_id = data["id"] + # Should start with "c_" and have hex characters + assert conv_id.startswith("c_") + assert len(conv_id) > 2 + + +class TestChatEndpoint: + """Tests for the /v1/chat endpoint.""" + + def test_chat_non_streaming(self, client, auth_headers, mock_pipeline): + """Test non-streaming chat request.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello, how are you?", + "stream": False + } + ) + assert response.status_code == 200 + data = response.json() + assert "conversation_id" in data + assert "output" in data + assert data["output"] == "Hello world!" + mock_pipeline.achat.assert_called_once() + call_kwargs = mock_pipeline.achat.call_args.kwargs + assert call_kwargs["inp"] == "Hello, how are you?" + assert call_kwargs["as_stream"] is False + + def test_chat_with_existing_conversation_id(self, client, auth_headers, mock_pipeline): + """Test chat with existing conversation ID.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + conv_id = "c_test123" + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello", + "conversation_id": conv_id, + "stream": False + } + ) + assert response.status_code == 200 + data = response.json() + assert data["conversation_id"] == conv_id + call_kwargs = mock_pipeline.achat.call_args.kwargs + assert call_kwargs["thread_id"] == conv_id + + def test_chat_creates_new_conversation_id(self, client, auth_headers, mock_pipeline): + """Test chat creates new conversation ID when not provided.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello", + "stream": False + } + ) + assert response.status_code == 200 + data = response.json() + assert "conversation_id" in data + assert data["conversation_id"].startswith("c_") + + def test_chat_streaming(self, client, auth_headers, mock_pipeline): + """Test streaming chat request.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "Hello", + "stream": True + } + ) + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + # Read streaming response + lines = response.text.split("\n") + data_lines = [line for line in lines if line.startswith("data: ")] + + # Should have delta events and a done event + assert len(data_lines) > 0 + + # Parse first delta event + first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix + assert first_data["type"] == "delta" + assert "conversation_id" in first_data + assert "delta" in first_data + + # Check that achat was called with as_stream=True + mock_pipeline.achat.assert_called_once() + call_kwargs = mock_pipeline.achat.call_args.kwargs + assert call_kwargs["as_stream"] is True + + +class TestMessageEndpoint: + """Tests for the /v1/conversations/{conversation_id}/messages endpoint.""" + + def test_create_message_non_streaming(self, client, auth_headers, mock_pipeline): + """Test creating a message (non-streaming).""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + conv_id = "c_test123" + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "user", + "content": "Hello, how are you?", + "stream": False + } + ) + assert response.status_code == 200 + data = response.json() + assert data["conversation_id"] == conv_id + assert "message" in data + assert data["message"]["role"] == "assistant" + assert "content" in data["message"] + mock_pipeline.achat.assert_called_once() + call_kwargs = mock_pipeline.achat.call_args.kwargs + assert call_kwargs["inp"] == "Hello, how are you?" + assert call_kwargs["thread_id"] == conv_id + + def test_create_message_streaming(self, client, auth_headers, mock_pipeline): + """Test creating a message (streaming).""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + conv_id = "c_test123" + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "user", + "content": "Hello", + "stream": True + } + ) + assert response.status_code == 200 + assert "text/event-stream" in response.headers["content-type"] + + # Verify achat was called with streaming + mock_pipeline.achat.assert_called_once() + call_kwargs = mock_pipeline.achat.call_args.kwargs + assert call_kwargs["as_stream"] is True + assert call_kwargs["thread_id"] == conv_id + + def test_create_message_invalid_role(self, client, auth_headers): + """Test that only 'user' role is accepted.""" + conv_id = "c_test123" + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "assistant", + "content": "Hello", + "stream": False + } + ) + assert response.status_code == 400 + assert "Only role='user' is supported" in response.json()["detail"] + + +class TestMemoryDeletion: + """Tests for memory deletion endpoints.""" + + def test_delete_all_memory(self, client, auth_headers, mock_pipeline): + """Test deleting all memory.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.delete( + "/v1/memory", + headers=auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert data["scope"] == "all" + mock_pipeline.aclear_memory.assert_called_once() + + def test_delete_all_memory_error_handling(self, client, auth_headers, mock_pipeline): + """Test error handling when deleting all memory fails.""" + mock_pipeline.aclear_memory = AsyncMock(side_effect=Exception("Memory error")) + + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.delete( + "/v1/memory", + headers=auth_headers + ) + assert response.status_code == 500 + assert "Memory error" in response.json()["detail"] + + def test_delete_conversation_memory(self, client, auth_headers, mock_pipeline): + """Test deleting memory for a specific conversation.""" + conv_id = "c_test123" + + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.delete( + f"/v1/conversations/{conv_id}/memory", + headers=auth_headers + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert data["scope"] == "conversation" + assert data["conversation_id"] == conv_id + # Verify delete_thread was called + mock_pipeline.graph.memory.delete_thread.assert_called_once() + + def test_delete_conversation_memory_with_device_id(self, client, auth_headers, mock_pipeline): + """Test deleting memory for conversation with device ID format.""" + conv_id = "c_test123_device456" + + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.delete( + f"/v1/conversations/{conv_id}/memory", + headers=auth_headers + ) + assert response.status_code == 200 + # Should normalize to base thread_id + mock_pipeline.graph.memory.delete_thread.assert_called_once_with("c_test123") + + def test_delete_conversation_memory_no_memory_saver(self, client, auth_headers): + """Test deleting conversation memory when MemorySaver is not available.""" + # Create a mock pipeline without MemorySaver + mock_pipeline_no_mem = MagicMock() + mock_pipeline_no_mem.graph = MagicMock() + mock_pipeline_no_mem.graph.memory = None + + with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem): + conv_id = "c_test123" + response = client.delete( + f"/v1/conversations/{conv_id}/memory", + headers=auth_headers + ) + assert response.status_code == 501 + data = response.json() + assert data["status"] == "unsupported" + assert "not supported" in data["message"].lower() + + +class TestEdgeCases: + """Tests for edge cases and error handling.""" + + def test_chat_empty_input(self, client, auth_headers, mock_pipeline): + """Test chat with empty input.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "input": "", + "stream": False + } + ) + # Should still process (validation would be in Pipeline) + assert response.status_code in [200, 400] + + def test_chat_missing_input(self, client, auth_headers): + """Test chat with missing input field.""" + response = client.post( + "/v1/chat", + headers=auth_headers, + json={ + "stream": False + } + ) + assert response.status_code == 422 # Validation error + + def test_message_missing_content(self, client, auth_headers): + """Test message creation with missing content.""" + conv_id = "c_test123" + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "user", + "stream": False + } + ) + assert response.status_code == 422 # Validation error + + def test_invalid_conversation_id_format(self, client, auth_headers, mock_pipeline): + """Test that various conversation ID formats are handled.""" + with patch("fastapi_server.server_rest.pipeline", mock_pipeline): + # Test with underscore (device_id format) + conv_id = "thread_123_device_456" + response = client.post( + f"/v1/conversations/{conv_id}/messages", + headers=auth_headers, + json={ + "role": "user", + "content": "Hello", + "stream": False + } + ) + # Should normalize thread_id (take first part before _) + assert response.status_code == 200 + call_kwargs = mock_pipeline.achat.call_args.kwargs + # The thread_id normalization happens in _normalize_thread_id + # but achat receives the full conversation_id + assert call_kwargs["thread_id"] == conv_id + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +