rest_api tests

This commit is contained in:
2026-01-28 15:03:36 +08:00
parent 122016efba
commit b7fecae3f4

484
tests/test_server_rest.py Normal file
View File

@@ -0,0 +1,484 @@
#!/usr/bin/env python3
"""
Tests for the REST API server (server_rest.py)
This test suite covers:
- Health check endpoints (GET /, GET /health)
- API key authentication (valid/invalid keys, Bearer format)
- Conversation creation (POST /v1/conversations)
- Chat endpoint (POST /v1/chat) - streaming and non-streaming
- Message creation (POST /v1/conversations/{id}/messages) - streaming and non-streaming
- Memory deletion (DELETE /v1/memory, DELETE /v1/conversations/{id}/memory)
- Edge cases and error handling
Run with: pytest tests/test_server_rest.py -v
"""
import os
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from langgraph.checkpoint.memory import MemorySaver
# Set up test environment before importing the server
os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3"
# Import after setting environment
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi_server.server_rest import app
@pytest.fixture
def mock_pipeline():
"""Create a mock Pipeline instance."""
pipeline = MagicMock()
# Mock async generator for streaming
async def mock_achat_stream(inp, as_stream=True, thread_id="test"):
chunks = ["Hello", " ", "world", "!"]
for chunk in chunks:
yield chunk
# Mock async function that returns async generator for streaming or string for non-streaming
async def mock_achat(inp, as_stream=False, thread_id="test"):
if as_stream:
# Return async generator for streaming
async def gen():
chunks = ["Hello", " ", "world", "!"]
for chunk in chunks:
yield chunk
return gen()
else:
# Return string for non-streaming
return "Hello world!"
async def mock_aclear_memory():
return None
pipeline.achat = AsyncMock(side_effect=mock_achat)
pipeline.aclear_memory = AsyncMock(return_value=None)
# Mock graph with memory
mock_graph = MagicMock()
mock_memory = MagicMock(spec=MemorySaver)
mock_memory.delete_thread = MagicMock()
mock_graph.memory = mock_memory
pipeline.graph = mock_graph
return pipeline
@pytest.fixture
def client(mock_pipeline):
"""Create a test client with mocked pipeline."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
with TestClient(app) as test_client:
yield test_client
@pytest.fixture
def auth_headers():
"""Return valid authentication headers."""
return {"Authorization": "Bearer test-key-1"}
@pytest.fixture
def invalid_auth_headers():
"""Return invalid authentication headers."""
return {"Authorization": "Bearer invalid-key"}
class TestHealthCheck:
"""Tests for health check endpoint."""
def test_root_endpoint(self, client):
"""Test root endpoint returns API information."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "message" in data
assert "endpoints" in data
assert isinstance(data["endpoints"], list)
def test_health_endpoint(self, client):
"""Test health endpoint returns healthy status."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
class TestAuthentication:
"""Tests for API key authentication."""
def test_missing_auth_header(self, client):
"""Test that missing auth header returns 401."""
response = client.post("/v1/conversations")
assert response.status_code == 401
def test_invalid_api_key(self, client, invalid_auth_headers):
"""Test that invalid API key returns 401."""
response = client.post(
"/v1/conversations",
headers=invalid_auth_headers
)
assert response.status_code == 401
assert "Invalid API key" in response.json()["detail"]
def test_valid_api_key_bearer_format(self, client, auth_headers):
"""Test that valid API key with Bearer prefix works."""
response = client.post(
"/v1/conversations",
headers=auth_headers
)
assert response.status_code == 200
def test_valid_api_key_without_bearer(self, client):
"""Test that valid API key without Bearer prefix works."""
response = client.post(
"/v1/conversations",
headers={"Authorization": "test-key-1"}
)
assert response.status_code == 200
def test_multiple_valid_keys(self, client):
"""Test that any of the configured keys work."""
for key in ["test-key-1", "test-key-2", "test-key-3"]:
response = client.post(
"/v1/conversations",
headers={"Authorization": f"Bearer {key}"}
)
assert response.status_code == 200
class TestConversationCreation:
"""Tests for conversation creation endpoint."""
def test_create_conversation(self, client, auth_headers):
"""Test creating a new conversation."""
response = client.post(
"/v1/conversations",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert "created_at" in data
assert data["id"].startswith("c_")
assert len(data["id"]) > 2
def test_conversation_id_format(self, client, auth_headers):
"""Test that conversation IDs follow expected format."""
response = client.post(
"/v1/conversations",
headers=auth_headers
)
data = response.json()
conv_id = data["id"]
# Should start with "c_" and have hex characters
assert conv_id.startswith("c_")
assert len(conv_id) > 2
class TestChatEndpoint:
"""Tests for the /v1/chat endpoint."""
def test_chat_non_streaming(self, client, auth_headers, mock_pipeline):
"""Test non-streaming chat request."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.post(
"/v1/chat",
headers=auth_headers,
json={
"input": "Hello, how are you?",
"stream": False
}
)
assert response.status_code == 200
data = response.json()
assert "conversation_id" in data
assert "output" in data
assert data["output"] == "Hello world!"
mock_pipeline.achat.assert_called_once()
call_kwargs = mock_pipeline.achat.call_args.kwargs
assert call_kwargs["inp"] == "Hello, how are you?"
assert call_kwargs["as_stream"] is False
def test_chat_with_existing_conversation_id(self, client, auth_headers, mock_pipeline):
"""Test chat with existing conversation ID."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
conv_id = "c_test123"
response = client.post(
"/v1/chat",
headers=auth_headers,
json={
"input": "Hello",
"conversation_id": conv_id,
"stream": False
}
)
assert response.status_code == 200
data = response.json()
assert data["conversation_id"] == conv_id
call_kwargs = mock_pipeline.achat.call_args.kwargs
assert call_kwargs["thread_id"] == conv_id
def test_chat_creates_new_conversation_id(self, client, auth_headers, mock_pipeline):
"""Test chat creates new conversation ID when not provided."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.post(
"/v1/chat",
headers=auth_headers,
json={
"input": "Hello",
"stream": False
}
)
assert response.status_code == 200
data = response.json()
assert "conversation_id" in data
assert data["conversation_id"].startswith("c_")
def test_chat_streaming(self, client, auth_headers, mock_pipeline):
"""Test streaming chat request."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.post(
"/v1/chat",
headers=auth_headers,
json={
"input": "Hello",
"stream": True
}
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
# Read streaming response
lines = response.text.split("\n")
data_lines = [line for line in lines if line.startswith("data: ")]
# Should have delta events and a done event
assert len(data_lines) > 0
# Parse first delta event
first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix
assert first_data["type"] == "delta"
assert "conversation_id" in first_data
assert "delta" in first_data
# Check that achat was called with as_stream=True
mock_pipeline.achat.assert_called_once()
call_kwargs = mock_pipeline.achat.call_args.kwargs
assert call_kwargs["as_stream"] is True
class TestMessageEndpoint:
"""Tests for the /v1/conversations/{conversation_id}/messages endpoint."""
def test_create_message_non_streaming(self, client, auth_headers, mock_pipeline):
"""Test creating a message (non-streaming)."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
conv_id = "c_test123"
response = client.post(
f"/v1/conversations/{conv_id}/messages",
headers=auth_headers,
json={
"role": "user",
"content": "Hello, how are you?",
"stream": False
}
)
assert response.status_code == 200
data = response.json()
assert data["conversation_id"] == conv_id
assert "message" in data
assert data["message"]["role"] == "assistant"
assert "content" in data["message"]
mock_pipeline.achat.assert_called_once()
call_kwargs = mock_pipeline.achat.call_args.kwargs
assert call_kwargs["inp"] == "Hello, how are you?"
assert call_kwargs["thread_id"] == conv_id
def test_create_message_streaming(self, client, auth_headers, mock_pipeline):
"""Test creating a message (streaming)."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
conv_id = "c_test123"
response = client.post(
f"/v1/conversations/{conv_id}/messages",
headers=auth_headers,
json={
"role": "user",
"content": "Hello",
"stream": True
}
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
# Verify achat was called with streaming
mock_pipeline.achat.assert_called_once()
call_kwargs = mock_pipeline.achat.call_args.kwargs
assert call_kwargs["as_stream"] is True
assert call_kwargs["thread_id"] == conv_id
def test_create_message_invalid_role(self, client, auth_headers):
"""Test that only 'user' role is accepted."""
conv_id = "c_test123"
response = client.post(
f"/v1/conversations/{conv_id}/messages",
headers=auth_headers,
json={
"role": "assistant",
"content": "Hello",
"stream": False
}
)
assert response.status_code == 400
assert "Only role='user' is supported" in response.json()["detail"]
class TestMemoryDeletion:
"""Tests for memory deletion endpoints."""
def test_delete_all_memory(self, client, auth_headers, mock_pipeline):
"""Test deleting all memory."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.delete(
"/v1/memory",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["scope"] == "all"
mock_pipeline.aclear_memory.assert_called_once()
def test_delete_all_memory_error_handling(self, client, auth_headers, mock_pipeline):
"""Test error handling when deleting all memory fails."""
mock_pipeline.aclear_memory = AsyncMock(side_effect=Exception("Memory error"))
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.delete(
"/v1/memory",
headers=auth_headers
)
assert response.status_code == 500
assert "Memory error" in response.json()["detail"]
def test_delete_conversation_memory(self, client, auth_headers, mock_pipeline):
"""Test deleting memory for a specific conversation."""
conv_id = "c_test123"
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.delete(
f"/v1/conversations/{conv_id}/memory",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["scope"] == "conversation"
assert data["conversation_id"] == conv_id
# Verify delete_thread was called
mock_pipeline.graph.memory.delete_thread.assert_called_once()
def test_delete_conversation_memory_with_device_id(self, client, auth_headers, mock_pipeline):
"""Test deleting memory for conversation with device ID format."""
conv_id = "c_test123_device456"
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.delete(
f"/v1/conversations/{conv_id}/memory",
headers=auth_headers
)
assert response.status_code == 200
# Should normalize to base thread_id
mock_pipeline.graph.memory.delete_thread.assert_called_once_with("c_test123")
def test_delete_conversation_memory_no_memory_saver(self, client, auth_headers):
"""Test deleting conversation memory when MemorySaver is not available."""
# Create a mock pipeline without MemorySaver
mock_pipeline_no_mem = MagicMock()
mock_pipeline_no_mem.graph = MagicMock()
mock_pipeline_no_mem.graph.memory = None
with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem):
conv_id = "c_test123"
response = client.delete(
f"/v1/conversations/{conv_id}/memory",
headers=auth_headers
)
assert response.status_code == 501
data = response.json()
assert data["status"] == "unsupported"
assert "not supported" in data["message"].lower()
class TestEdgeCases:
"""Tests for edge cases and error handling."""
def test_chat_empty_input(self, client, auth_headers, mock_pipeline):
"""Test chat with empty input."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
response = client.post(
"/v1/chat",
headers=auth_headers,
json={
"input": "",
"stream": False
}
)
# Should still process (validation would be in Pipeline)
assert response.status_code in [200, 400]
def test_chat_missing_input(self, client, auth_headers):
"""Test chat with missing input field."""
response = client.post(
"/v1/chat",
headers=auth_headers,
json={
"stream": False
}
)
assert response.status_code == 422 # Validation error
def test_message_missing_content(self, client, auth_headers):
"""Test message creation with missing content."""
conv_id = "c_test123"
response = client.post(
f"/v1/conversations/{conv_id}/messages",
headers=auth_headers,
json={
"role": "user",
"stream": False
}
)
assert response.status_code == 422 # Validation error
def test_invalid_conversation_id_format(self, client, auth_headers, mock_pipeline):
"""Test that various conversation ID formats are handled."""
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
# Test with underscore (device_id format)
conv_id = "thread_123_device_456"
response = client.post(
f"/v1/conversations/{conv_id}/messages",
headers=auth_headers,
json={
"role": "user",
"content": "Hello",
"stream": False
}
)
# Should normalize thread_id (take first part before _)
assert response.status_code == 200
call_kwargs = mock_pipeline.achat.call_args.kwargs
# The thread_id normalization happens in _normalize_thread_id
# but achat receives the full conversation_id
assert call_kwargs["thread_id"] == conv_id
if __name__ == "__main__":
pytest.main([__file__, "-v"])