rest_api tests
This commit is contained in:
484
tests/test_server_rest.py
Normal file
484
tests/test_server_rest.py
Normal file
@@ -0,0 +1,484 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for the REST API server (server_rest.py)
|
||||
|
||||
This test suite covers:
|
||||
- Health check endpoints (GET /, GET /health)
|
||||
- API key authentication (valid/invalid keys, Bearer format)
|
||||
- Conversation creation (POST /v1/conversations)
|
||||
- Chat endpoint (POST /v1/chat) - streaming and non-streaming
|
||||
- Message creation (POST /v1/conversations/{id}/messages) - streaming and non-streaming
|
||||
- Memory deletion (DELETE /v1/memory, DELETE /v1/conversations/{id}/memory)
|
||||
- Edge cases and error handling
|
||||
|
||||
Run with: pytest tests/test_server_rest.py -v
|
||||
"""
|
||||
import os
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
# Set up test environment before importing the server
|
||||
os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3"
|
||||
|
||||
# Import after setting environment
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from fastapi_server.server_rest import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline():
|
||||
"""Create a mock Pipeline instance."""
|
||||
pipeline = MagicMock()
|
||||
|
||||
# Mock async generator for streaming
|
||||
async def mock_achat_stream(inp, as_stream=True, thread_id="test"):
|
||||
chunks = ["Hello", " ", "world", "!"]
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
# Mock async function that returns async generator for streaming or string for non-streaming
|
||||
async def mock_achat(inp, as_stream=False, thread_id="test"):
|
||||
if as_stream:
|
||||
# Return async generator for streaming
|
||||
async def gen():
|
||||
chunks = ["Hello", " ", "world", "!"]
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
return gen()
|
||||
else:
|
||||
# Return string for non-streaming
|
||||
return "Hello world!"
|
||||
|
||||
async def mock_aclear_memory():
|
||||
return None
|
||||
|
||||
pipeline.achat = AsyncMock(side_effect=mock_achat)
|
||||
pipeline.aclear_memory = AsyncMock(return_value=None)
|
||||
|
||||
# Mock graph with memory
|
||||
mock_graph = MagicMock()
|
||||
mock_memory = MagicMock(spec=MemorySaver)
|
||||
mock_memory.delete_thread = MagicMock()
|
||||
mock_graph.memory = mock_memory
|
||||
pipeline.graph = mock_graph
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_pipeline):
|
||||
"""Create a test client with mocked pipeline."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
with TestClient(app) as test_client:
|
||||
yield test_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
"""Return valid authentication headers."""
|
||||
return {"Authorization": "Bearer test-key-1"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_auth_headers():
|
||||
"""Return invalid authentication headers."""
|
||||
return {"Authorization": "Bearer invalid-key"}
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
def test_root_endpoint(self, client):
|
||||
"""Test root endpoint returns API information."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "endpoints" in data
|
||||
assert isinstance(data["endpoints"], list)
|
||||
|
||||
def test_health_endpoint(self, client):
|
||||
"""Test health endpoint returns healthy status."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
|
||||
|
||||
class TestAuthentication:
|
||||
"""Tests for API key authentication."""
|
||||
|
||||
def test_missing_auth_header(self, client):
|
||||
"""Test that missing auth header returns 401."""
|
||||
response = client.post("/v1/conversations")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_invalid_api_key(self, client, invalid_auth_headers):
|
||||
"""Test that invalid API key returns 401."""
|
||||
response = client.post(
|
||||
"/v1/conversations",
|
||||
headers=invalid_auth_headers
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert "Invalid API key" in response.json()["detail"]
|
||||
|
||||
def test_valid_api_key_bearer_format(self, client, auth_headers):
|
||||
"""Test that valid API key with Bearer prefix works."""
|
||||
response = client.post(
|
||||
"/v1/conversations",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_valid_api_key_without_bearer(self, client):
|
||||
"""Test that valid API key without Bearer prefix works."""
|
||||
response = client.post(
|
||||
"/v1/conversations",
|
||||
headers={"Authorization": "test-key-1"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_multiple_valid_keys(self, client):
|
||||
"""Test that any of the configured keys work."""
|
||||
for key in ["test-key-1", "test-key-2", "test-key-3"]:
|
||||
response = client.post(
|
||||
"/v1/conversations",
|
||||
headers={"Authorization": f"Bearer {key}"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestConversationCreation:
|
||||
"""Tests for conversation creation endpoint."""
|
||||
|
||||
def test_create_conversation(self, client, auth_headers):
|
||||
"""Test creating a new conversation."""
|
||||
response = client.post(
|
||||
"/v1/conversations",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
assert data["id"].startswith("c_")
|
||||
assert len(data["id"]) > 2
|
||||
|
||||
def test_conversation_id_format(self, client, auth_headers):
|
||||
"""Test that conversation IDs follow expected format."""
|
||||
response = client.post(
|
||||
"/v1/conversations",
|
||||
headers=auth_headers
|
||||
)
|
||||
data = response.json()
|
||||
conv_id = data["id"]
|
||||
# Should start with "c_" and have hex characters
|
||||
assert conv_id.startswith("c_")
|
||||
assert len(conv_id) > 2
|
||||
|
||||
|
||||
class TestChatEndpoint:
|
||||
"""Tests for the /v1/chat endpoint."""
|
||||
|
||||
def test_chat_non_streaming(self, client, auth_headers, mock_pipeline):
|
||||
"""Test non-streaming chat request."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.post(
|
||||
"/v1/chat",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"input": "Hello, how are you?",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "conversation_id" in data
|
||||
assert "output" in data
|
||||
assert data["output"] == "Hello world!"
|
||||
mock_pipeline.achat.assert_called_once()
|
||||
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||
assert call_kwargs["inp"] == "Hello, how are you?"
|
||||
assert call_kwargs["as_stream"] is False
|
||||
|
||||
def test_chat_with_existing_conversation_id(self, client, auth_headers, mock_pipeline):
|
||||
"""Test chat with existing conversation ID."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
conv_id = "c_test123"
|
||||
response = client.post(
|
||||
"/v1/chat",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"input": "Hello",
|
||||
"conversation_id": conv_id,
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["conversation_id"] == conv_id
|
||||
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||
assert call_kwargs["thread_id"] == conv_id
|
||||
|
||||
def test_chat_creates_new_conversation_id(self, client, auth_headers, mock_pipeline):
|
||||
"""Test chat creates new conversation ID when not provided."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.post(
|
||||
"/v1/chat",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"input": "Hello",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "conversation_id" in data
|
||||
assert data["conversation_id"].startswith("c_")
|
||||
|
||||
def test_chat_streaming(self, client, auth_headers, mock_pipeline):
|
||||
"""Test streaming chat request."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.post(
|
||||
"/v1/chat",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"input": "Hello",
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
|
||||
# Read streaming response
|
||||
lines = response.text.split("\n")
|
||||
data_lines = [line for line in lines if line.startswith("data: ")]
|
||||
|
||||
# Should have delta events and a done event
|
||||
assert len(data_lines) > 0
|
||||
|
||||
# Parse first delta event
|
||||
first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix
|
||||
assert first_data["type"] == "delta"
|
||||
assert "conversation_id" in first_data
|
||||
assert "delta" in first_data
|
||||
|
||||
# Check that achat was called with as_stream=True
|
||||
mock_pipeline.achat.assert_called_once()
|
||||
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||
assert call_kwargs["as_stream"] is True
|
||||
|
||||
|
||||
class TestMessageEndpoint:
|
||||
"""Tests for the /v1/conversations/{conversation_id}/messages endpoint."""
|
||||
|
||||
def test_create_message_non_streaming(self, client, auth_headers, mock_pipeline):
|
||||
"""Test creating a message (non-streaming)."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
conv_id = "c_test123"
|
||||
response = client.post(
|
||||
f"/v1/conversations/{conv_id}/messages",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["conversation_id"] == conv_id
|
||||
assert "message" in data
|
||||
assert data["message"]["role"] == "assistant"
|
||||
assert "content" in data["message"]
|
||||
mock_pipeline.achat.assert_called_once()
|
||||
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||
assert call_kwargs["inp"] == "Hello, how are you?"
|
||||
assert call_kwargs["thread_id"] == conv_id
|
||||
|
||||
def test_create_message_streaming(self, client, auth_headers, mock_pipeline):
|
||||
"""Test creating a message (streaming)."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
conv_id = "c_test123"
|
||||
response = client.post(
|
||||
f"/v1/conversations/{conv_id}/messages",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello",
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "text/event-stream" in response.headers["content-type"]
|
||||
|
||||
# Verify achat was called with streaming
|
||||
mock_pipeline.achat.assert_called_once()
|
||||
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||
assert call_kwargs["as_stream"] is True
|
||||
assert call_kwargs["thread_id"] == conv_id
|
||||
|
||||
def test_create_message_invalid_role(self, client, auth_headers):
|
||||
"""Test that only 'user' role is accepted."""
|
||||
conv_id = "c_test123"
|
||||
response = client.post(
|
||||
f"/v1/conversations/{conv_id}/messages",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"role": "assistant",
|
||||
"content": "Hello",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Only role='user' is supported" in response.json()["detail"]
|
||||
|
||||
|
||||
class TestMemoryDeletion:
|
||||
"""Tests for memory deletion endpoints."""
|
||||
|
||||
def test_delete_all_memory(self, client, auth_headers, mock_pipeline):
|
||||
"""Test deleting all memory."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.delete(
|
||||
"/v1/memory",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
assert data["scope"] == "all"
|
||||
mock_pipeline.aclear_memory.assert_called_once()
|
||||
|
||||
def test_delete_all_memory_error_handling(self, client, auth_headers, mock_pipeline):
|
||||
"""Test error handling when deleting all memory fails."""
|
||||
mock_pipeline.aclear_memory = AsyncMock(side_effect=Exception("Memory error"))
|
||||
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.delete(
|
||||
"/v1/memory",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 500
|
||||
assert "Memory error" in response.json()["detail"]
|
||||
|
||||
def test_delete_conversation_memory(self, client, auth_headers, mock_pipeline):
|
||||
"""Test deleting memory for a specific conversation."""
|
||||
conv_id = "c_test123"
|
||||
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.delete(
|
||||
f"/v1/conversations/{conv_id}/memory",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
assert data["scope"] == "conversation"
|
||||
assert data["conversation_id"] == conv_id
|
||||
# Verify delete_thread was called
|
||||
mock_pipeline.graph.memory.delete_thread.assert_called_once()
|
||||
|
||||
def test_delete_conversation_memory_with_device_id(self, client, auth_headers, mock_pipeline):
|
||||
"""Test deleting memory for conversation with device ID format."""
|
||||
conv_id = "c_test123_device456"
|
||||
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.delete(
|
||||
f"/v1/conversations/{conv_id}/memory",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Should normalize to base thread_id
|
||||
mock_pipeline.graph.memory.delete_thread.assert_called_once_with("c_test123")
|
||||
|
||||
def test_delete_conversation_memory_no_memory_saver(self, client, auth_headers):
|
||||
"""Test deleting conversation memory when MemorySaver is not available."""
|
||||
# Create a mock pipeline without MemorySaver
|
||||
mock_pipeline_no_mem = MagicMock()
|
||||
mock_pipeline_no_mem.graph = MagicMock()
|
||||
mock_pipeline_no_mem.graph.memory = None
|
||||
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem):
|
||||
conv_id = "c_test123"
|
||||
response = client.delete(
|
||||
f"/v1/conversations/{conv_id}/memory",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 501
|
||||
data = response.json()
|
||||
assert data["status"] == "unsupported"
|
||||
assert "not supported" in data["message"].lower()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases and error handling."""
|
||||
|
||||
def test_chat_empty_input(self, client, auth_headers, mock_pipeline):
|
||||
"""Test chat with empty input."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
response = client.post(
|
||||
"/v1/chat",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"input": "",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
# Should still process (validation would be in Pipeline)
|
||||
assert response.status_code in [200, 400]
|
||||
|
||||
def test_chat_missing_input(self, client, auth_headers):
|
||||
"""Test chat with missing input field."""
|
||||
response = client.post(
|
||||
"/v1/chat",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_message_missing_content(self, client, auth_headers):
|
||||
"""Test message creation with missing content."""
|
||||
conv_id = "c_test123"
|
||||
response = client.post(
|
||||
f"/v1/conversations/{conv_id}/messages",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"role": "user",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_invalid_conversation_id_format(self, client, auth_headers, mock_pipeline):
|
||||
"""Test that various conversation ID formats are handled."""
|
||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||
# Test with underscore (device_id format)
|
||||
conv_id = "thread_123_device_456"
|
||||
response = client.post(
|
||||
f"/v1/conversations/{conv_id}/messages",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
# Should normalize thread_id (take first part before _)
|
||||
assert response.status_code == 200
|
||||
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||
# The thread_id normalization happens in _normalize_thread_id
|
||||
# but achat receives the full conversation_id
|
||||
assert call_kwargs["thread_id"] == conv_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user