rest_api tests
This commit is contained in:
484
tests/test_server_rest.py
Normal file
484
tests/test_server_rest.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for the REST API server (server_rest.py)
|
||||||
|
|
||||||
|
This test suite covers:
|
||||||
|
- Health check endpoints (GET /, GET /health)
|
||||||
|
- API key authentication (valid/invalid keys, Bearer format)
|
||||||
|
- Conversation creation (POST /v1/conversations)
|
||||||
|
- Chat endpoint (POST /v1/chat) - streaming and non-streaming
|
||||||
|
- Message creation (POST /v1/conversations/{id}/messages) - streaming and non-streaming
|
||||||
|
- Memory deletion (DELETE /v1/memory, DELETE /v1/conversations/{id}/memory)
|
||||||
|
- Edge cases and error handling
|
||||||
|
|
||||||
|
Run with: pytest tests/test_server_rest.py -v
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
|
# Set up test environment before importing the server
|
||||||
|
os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3"
|
||||||
|
|
||||||
|
# Import after setting environment
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from fastapi_server.server_rest import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pipeline():
|
||||||
|
"""Create a mock Pipeline instance."""
|
||||||
|
pipeline = MagicMock()
|
||||||
|
|
||||||
|
# Mock async generator for streaming
|
||||||
|
async def mock_achat_stream(inp, as_stream=True, thread_id="test"):
|
||||||
|
chunks = ["Hello", " ", "world", "!"]
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Mock async function that returns async generator for streaming or string for non-streaming
|
||||||
|
async def mock_achat(inp, as_stream=False, thread_id="test"):
|
||||||
|
if as_stream:
|
||||||
|
# Return async generator for streaming
|
||||||
|
async def gen():
|
||||||
|
chunks = ["Hello", " ", "world", "!"]
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
return gen()
|
||||||
|
else:
|
||||||
|
# Return string for non-streaming
|
||||||
|
return "Hello world!"
|
||||||
|
|
||||||
|
async def mock_aclear_memory():
|
||||||
|
return None
|
||||||
|
|
||||||
|
pipeline.achat = AsyncMock(side_effect=mock_achat)
|
||||||
|
pipeline.aclear_memory = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
# Mock graph with memory
|
||||||
|
mock_graph = MagicMock()
|
||||||
|
mock_memory = MagicMock(spec=MemorySaver)
|
||||||
|
mock_memory.delete_thread = MagicMock()
|
||||||
|
mock_graph.memory = mock_memory
|
||||||
|
pipeline.graph = mock_graph
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_pipeline):
|
||||||
|
"""Create a test client with mocked pipeline."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
with TestClient(app) as test_client:
|
||||||
|
yield test_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers():
|
||||||
|
"""Return valid authentication headers."""
|
||||||
|
return {"Authorization": "Bearer test-key-1"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def invalid_auth_headers():
|
||||||
|
"""Return invalid authentication headers."""
|
||||||
|
return {"Authorization": "Bearer invalid-key"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Tests for health check endpoint."""
|
||||||
|
|
||||||
|
def test_root_endpoint(self, client):
|
||||||
|
"""Test root endpoint returns API information."""
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "message" in data
|
||||||
|
assert "endpoints" in data
|
||||||
|
assert isinstance(data["endpoints"], list)
|
||||||
|
|
||||||
|
def test_health_endpoint(self, client):
|
||||||
|
"""Test health endpoint returns healthy status."""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthentication:
|
||||||
|
"""Tests for API key authentication."""
|
||||||
|
|
||||||
|
def test_missing_auth_header(self, client):
|
||||||
|
"""Test that missing auth header returns 401."""
|
||||||
|
response = client.post("/v1/conversations")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_invalid_api_key(self, client, invalid_auth_headers):
|
||||||
|
"""Test that invalid API key returns 401."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/conversations",
|
||||||
|
headers=invalid_auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Invalid API key" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_valid_api_key_bearer_format(self, client, auth_headers):
|
||||||
|
"""Test that valid API key with Bearer prefix works."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/conversations",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_valid_api_key_without_bearer(self, client):
|
||||||
|
"""Test that valid API key without Bearer prefix works."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/conversations",
|
||||||
|
headers={"Authorization": "test-key-1"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_multiple_valid_keys(self, client):
|
||||||
|
"""Test that any of the configured keys work."""
|
||||||
|
for key in ["test-key-1", "test-key-2", "test-key-3"]:
|
||||||
|
response = client.post(
|
||||||
|
"/v1/conversations",
|
||||||
|
headers={"Authorization": f"Bearer {key}"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestConversationCreation:
|
||||||
|
"""Tests for conversation creation endpoint."""
|
||||||
|
|
||||||
|
def test_create_conversation(self, client, auth_headers):
|
||||||
|
"""Test creating a new conversation."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/conversations",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
assert data["id"].startswith("c_")
|
||||||
|
assert len(data["id"]) > 2
|
||||||
|
|
||||||
|
def test_conversation_id_format(self, client, auth_headers):
|
||||||
|
"""Test that conversation IDs follow expected format."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/conversations",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
data = response.json()
|
||||||
|
conv_id = data["id"]
|
||||||
|
# Should start with "c_" and have hex characters
|
||||||
|
assert conv_id.startswith("c_")
|
||||||
|
assert len(conv_id) > 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatEndpoint:
|
||||||
|
"""Tests for the /v1/chat endpoint."""
|
||||||
|
|
||||||
|
def test_chat_non_streaming(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test non-streaming chat request."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"input": "Hello, how are you?",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "conversation_id" in data
|
||||||
|
assert "output" in data
|
||||||
|
assert data["output"] == "Hello world!"
|
||||||
|
mock_pipeline.achat.assert_called_once()
|
||||||
|
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||||
|
assert call_kwargs["inp"] == "Hello, how are you?"
|
||||||
|
assert call_kwargs["as_stream"] is False
|
||||||
|
|
||||||
|
def test_chat_with_existing_conversation_id(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test chat with existing conversation ID."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
conv_id = "c_test123"
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"input": "Hello",
|
||||||
|
"conversation_id": conv_id,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["conversation_id"] == conv_id
|
||||||
|
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||||
|
assert call_kwargs["thread_id"] == conv_id
|
||||||
|
|
||||||
|
def test_chat_creates_new_conversation_id(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test chat creates new conversation ID when not provided."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"input": "Hello",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "conversation_id" in data
|
||||||
|
assert data["conversation_id"].startswith("c_")
|
||||||
|
|
||||||
|
def test_chat_streaming(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test streaming chat request."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"input": "Hello",
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "text/event-stream" in response.headers["content-type"]
|
||||||
|
|
||||||
|
# Read streaming response
|
||||||
|
lines = response.text.split("\n")
|
||||||
|
data_lines = [line for line in lines if line.startswith("data: ")]
|
||||||
|
|
||||||
|
# Should have delta events and a done event
|
||||||
|
assert len(data_lines) > 0
|
||||||
|
|
||||||
|
# Parse first delta event
|
||||||
|
first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix
|
||||||
|
assert first_data["type"] == "delta"
|
||||||
|
assert "conversation_id" in first_data
|
||||||
|
assert "delta" in first_data
|
||||||
|
|
||||||
|
# Check that achat was called with as_stream=True
|
||||||
|
mock_pipeline.achat.assert_called_once()
|
||||||
|
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||||
|
assert call_kwargs["as_stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageEndpoint:
|
||||||
|
"""Tests for the /v1/conversations/{conversation_id}/messages endpoint."""
|
||||||
|
|
||||||
|
def test_create_message_non_streaming(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test creating a message (non-streaming)."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
conv_id = "c_test123"
|
||||||
|
response = client.post(
|
||||||
|
f"/v1/conversations/{conv_id}/messages",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, how are you?",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["conversation_id"] == conv_id
|
||||||
|
assert "message" in data
|
||||||
|
assert data["message"]["role"] == "assistant"
|
||||||
|
assert "content" in data["message"]
|
||||||
|
mock_pipeline.achat.assert_called_once()
|
||||||
|
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||||
|
assert call_kwargs["inp"] == "Hello, how are you?"
|
||||||
|
assert call_kwargs["thread_id"] == conv_id
|
||||||
|
|
||||||
|
def test_create_message_streaming(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test creating a message (streaming)."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
conv_id = "c_test123"
|
||||||
|
response = client.post(
|
||||||
|
f"/v1/conversations/{conv_id}/messages",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello",
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "text/event-stream" in response.headers["content-type"]
|
||||||
|
|
||||||
|
# Verify achat was called with streaming
|
||||||
|
mock_pipeline.achat.assert_called_once()
|
||||||
|
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||||
|
assert call_kwargs["as_stream"] is True
|
||||||
|
assert call_kwargs["thread_id"] == conv_id
|
||||||
|
|
||||||
|
def test_create_message_invalid_role(self, client, auth_headers):
|
||||||
|
"""Test that only 'user' role is accepted."""
|
||||||
|
conv_id = "c_test123"
|
||||||
|
response = client.post(
|
||||||
|
f"/v1/conversations/{conv_id}/messages",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Only role='user' is supported" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryDeletion:
|
||||||
|
"""Tests for memory deletion endpoints."""
|
||||||
|
|
||||||
|
def test_delete_all_memory(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test deleting all memory."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.delete(
|
||||||
|
"/v1/memory",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "success"
|
||||||
|
assert data["scope"] == "all"
|
||||||
|
mock_pipeline.aclear_memory.assert_called_once()
|
||||||
|
|
||||||
|
def test_delete_all_memory_error_handling(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test error handling when deleting all memory fails."""
|
||||||
|
mock_pipeline.aclear_memory = AsyncMock(side_effect=Exception("Memory error"))
|
||||||
|
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.delete(
|
||||||
|
"/v1/memory",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 500
|
||||||
|
assert "Memory error" in response.json()["detail"]
|
||||||
|
|
||||||
|
def test_delete_conversation_memory(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test deleting memory for a specific conversation."""
|
||||||
|
conv_id = "c_test123"
|
||||||
|
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.delete(
|
||||||
|
f"/v1/conversations/{conv_id}/memory",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "success"
|
||||||
|
assert data["scope"] == "conversation"
|
||||||
|
assert data["conversation_id"] == conv_id
|
||||||
|
# Verify delete_thread was called
|
||||||
|
mock_pipeline.graph.memory.delete_thread.assert_called_once()
|
||||||
|
|
||||||
|
def test_delete_conversation_memory_with_device_id(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test deleting memory for conversation with device ID format."""
|
||||||
|
conv_id = "c_test123_device456"
|
||||||
|
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.delete(
|
||||||
|
f"/v1/conversations/{conv_id}/memory",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Should normalize to base thread_id
|
||||||
|
mock_pipeline.graph.memory.delete_thread.assert_called_once_with("c_test123")
|
||||||
|
|
||||||
|
def test_delete_conversation_memory_no_memory_saver(self, client, auth_headers):
|
||||||
|
"""Test deleting conversation memory when MemorySaver is not available."""
|
||||||
|
# Create a mock pipeline without MemorySaver
|
||||||
|
mock_pipeline_no_mem = MagicMock()
|
||||||
|
mock_pipeline_no_mem.graph = MagicMock()
|
||||||
|
mock_pipeline_no_mem.graph.memory = None
|
||||||
|
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem):
|
||||||
|
conv_id = "c_test123"
|
||||||
|
response = client.delete(
|
||||||
|
f"/v1/conversations/{conv_id}/memory",
|
||||||
|
headers=auth_headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 501
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "unsupported"
|
||||||
|
assert "not supported" in data["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases and error handling."""
|
||||||
|
|
||||||
|
def test_chat_empty_input(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test chat with empty input."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"input": "",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Should still process (validation would be in Pipeline)
|
||||||
|
assert response.status_code in [200, 400]
|
||||||
|
|
||||||
|
def test_chat_missing_input(self, client, auth_headers):
|
||||||
|
"""Test chat with missing input field."""
|
||||||
|
response = client.post(
|
||||||
|
"/v1/chat",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
def test_message_missing_content(self, client, auth_headers):
|
||||||
|
"""Test message creation with missing content."""
|
||||||
|
conv_id = "c_test123"
|
||||||
|
response = client.post(
|
||||||
|
f"/v1/conversations/{conv_id}/messages",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"role": "user",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
def test_invalid_conversation_id_format(self, client, auth_headers, mock_pipeline):
|
||||||
|
"""Test that various conversation ID formats are handled."""
|
||||||
|
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
||||||
|
# Test with underscore (device_id format)
|
||||||
|
conv_id = "thread_123_device_456"
|
||||||
|
response = client.post(
|
||||||
|
f"/v1/conversations/{conv_id}/messages",
|
||||||
|
headers=auth_headers,
|
||||||
|
json={
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello",
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Should normalize thread_id (take first part before _)
|
||||||
|
assert response.status_code == 200
|
||||||
|
call_kwargs = mock_pipeline.achat.call_args.kwargs
|
||||||
|
# The thread_id normalization happens in _normalize_thread_id
|
||||||
|
# but achat receives the full conversation_id
|
||||||
|
assert call_kwargs["thread_id"] == conv_id
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
|
|
||||||
Reference in New Issue
Block a user