remove external test
This commit is contained in:
408
test_rest_api.py
408
test_rest_api.py
@@ -1,408 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Integration tests for the REST API server (server_rest.py)
|
|
||||||
|
|
||||||
This script tests the REST API endpoints using FastAPI's TestClient.
|
|
||||||
|
|
||||||
To run:
|
|
||||||
conda activate lang
|
|
||||||
python test_rest_api.py
|
|
||||||
|
|
||||||
Or with pytest:
|
|
||||||
conda activate lang
|
|
||||||
pytest test_rest_api.py -v
|
|
||||||
|
|
||||||
Tests cover:
|
|
||||||
- Health check endpoints
|
|
||||||
- API key authentication
|
|
||||||
- Conversation creation
|
|
||||||
- Chat endpoint (streaming and non-streaming)
|
|
||||||
- Message creation (streaming and non-streaming)
|
|
||||||
- Memory deletion (global and per-conversation)
|
|
||||||
- Error handling
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- pytest (for structured testing)
|
|
||||||
- Or run directly as a script
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
# Set up test environment before importing the server
|
|
||||||
os.environ["FAST_AUTH_KEYS"] = "test-key-1,test-key-2,test-key-3"
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
HAS_TEST_CLIENT = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TEST_CLIENT = False
|
|
||||||
print("Warning: fastapi.testclient not available. Install with: pip install pytest httpx")
|
|
||||||
print("Falling back to basic import test only.")
|
|
||||||
|
|
||||||
from fastapi_server.server_rest import app
|
|
||||||
|
|
||||||
|
|
||||||
def create_mock_pipeline():
|
|
||||||
"""Create a mock Pipeline instance for testing."""
|
|
||||||
pipeline = MagicMock()
|
|
||||||
|
|
||||||
# Mock async function that returns async generator for streaming or string for non-streaming
|
|
||||||
async def mock_achat(inp, as_stream=False, thread_id="test"):
|
|
||||||
if as_stream:
|
|
||||||
# Return async generator for streaming
|
|
||||||
async def gen():
|
|
||||||
chunks = ["Hello", " ", "world", "!"]
|
|
||||||
for chunk in chunks:
|
|
||||||
yield chunk
|
|
||||||
return gen()
|
|
||||||
else:
|
|
||||||
# Return string for non-streaming
|
|
||||||
return "Hello world!"
|
|
||||||
|
|
||||||
async def mock_aclear_memory():
|
|
||||||
return None
|
|
||||||
|
|
||||||
pipeline.achat = AsyncMock(side_effect=mock_achat)
|
|
||||||
pipeline.aclear_memory = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
# Mock graph with memory
|
|
||||||
mock_graph = MagicMock()
|
|
||||||
mock_memory = MagicMock(spec=MemorySaver)
|
|
||||||
mock_memory.delete_thread = MagicMock()
|
|
||||||
mock_graph.memory = mock_memory
|
|
||||||
pipeline.graph = mock_graph
|
|
||||||
|
|
||||||
return pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def test_health_endpoints():
|
|
||||||
"""Test health check endpoints."""
|
|
||||||
print("\n=== Testing Health Endpoints ===")
|
|
||||||
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()):
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
# Test root endpoint
|
|
||||||
response = client.get("/")
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert "message" in data, "Root endpoint should return message"
|
|
||||||
assert "endpoints" in data, "Root endpoint should return endpoints list"
|
|
||||||
print("✓ Root endpoint works")
|
|
||||||
|
|
||||||
# Test health endpoint
|
|
||||||
response = client.get("/health")
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert data["status"] == "healthy", "Health endpoint should return healthy status"
|
|
||||||
print("✓ Health endpoint works")
|
|
||||||
|
|
||||||
|
|
||||||
def test_authentication():
|
|
||||||
"""Test API key authentication."""
|
|
||||||
print("\n=== Testing Authentication ===")
|
|
||||||
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()):
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
# Test missing auth header
|
|
||||||
response = client.post("/v1/conversations")
|
|
||||||
assert response.status_code == 401, f"Expected 401 for missing auth, got {response.status_code}"
|
|
||||||
print("✓ Missing auth header returns 401")
|
|
||||||
|
|
||||||
# Test invalid API key
|
|
||||||
response = client.post(
|
|
||||||
"/v1/conversations",
|
|
||||||
headers={"Authorization": "Bearer invalid-key"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 401, f"Expected 401 for invalid key, got {response.status_code}"
|
|
||||||
print("✓ Invalid API key returns 401")
|
|
||||||
|
|
||||||
# Test valid API key
|
|
||||||
response = client.post(
|
|
||||||
"/v1/conversations",
|
|
||||||
headers={"Authorization": "Bearer test-key-1"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200 for valid key, got {response.status_code}"
|
|
||||||
print("✓ Valid API key works")
|
|
||||||
|
|
||||||
# Test API key without Bearer prefix
|
|
||||||
response = client.post(
|
|
||||||
"/v1/conversations",
|
|
||||||
headers={"Authorization": "test-key-1"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
print("✓ API key without Bearer prefix works")
|
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_creation():
|
|
||||||
"""Test conversation creation."""
|
|
||||||
print("\n=== Testing Conversation Creation ===")
|
|
||||||
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()):
|
|
||||||
client = TestClient(app)
|
|
||||||
auth_headers = {"Authorization": "Bearer test-key-1"}
|
|
||||||
|
|
||||||
response = client.post("/v1/conversations", headers=auth_headers)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert "id" in data, "Response should contain id"
|
|
||||||
assert "created_at" in data, "Response should contain created_at"
|
|
||||||
assert data["id"].startswith("c_"), "Conversation ID should start with 'c_'"
|
|
||||||
print(f"✓ Created conversation: {data['id']}")
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_endpoint():
|
|
||||||
"""Test chat endpoint."""
|
|
||||||
print("\n=== Testing Chat Endpoint ===")
|
|
||||||
|
|
||||||
mock_pipeline = create_mock_pipeline()
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
|
||||||
client = TestClient(app)
|
|
||||||
auth_headers = {"Authorization": "Bearer test-key-1"}
|
|
||||||
|
|
||||||
# Test non-streaming chat
|
|
||||||
response = client.post(
|
|
||||||
"/v1/chat",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"input": "Hello, how are you?",
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert "conversation_id" in data, "Response should contain conversation_id"
|
|
||||||
assert "output" in data, "Response should contain output"
|
|
||||||
assert data["output"] == "Hello world!", "Output should match expected"
|
|
||||||
print(f"✓ Non-streaming chat works: {data['conversation_id']}")
|
|
||||||
|
|
||||||
# Test chat with existing conversation_id
|
|
||||||
conv_id = "c_test123"
|
|
||||||
response = client.post(
|
|
||||||
"/v1/chat",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"input": "Hello",
|
|
||||||
"conversation_id": conv_id,
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert data["conversation_id"] == conv_id, "Should use provided conversation_id"
|
|
||||||
print(f"✓ Chat with existing conversation_id works")
|
|
||||||
|
|
||||||
# Test streaming chat
|
|
||||||
response = client.post(
|
|
||||||
"/v1/chat",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"input": "Hello",
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
assert "text/event-stream" in response.headers["content-type"], "Should return event-stream"
|
|
||||||
|
|
||||||
# Parse streaming response
|
|
||||||
lines = response.text.split("\n")
|
|
||||||
data_lines = [line for line in lines if line.startswith("data: ")]
|
|
||||||
assert len(data_lines) > 0, "Should have streaming data"
|
|
||||||
|
|
||||||
# Parse first delta event
|
|
||||||
first_data = json.loads(data_lines[0][6:]) # Remove "data: " prefix
|
|
||||||
assert first_data["type"] == "delta", "First event should be delta"
|
|
||||||
assert "conversation_id" in first_data, "Delta should contain conversation_id"
|
|
||||||
assert "delta" in first_data, "Delta should contain delta field"
|
|
||||||
print("✓ Streaming chat works")
|
|
||||||
|
|
||||||
|
|
||||||
def test_message_endpoint():
|
|
||||||
"""Test message creation endpoint."""
|
|
||||||
print("\n=== Testing Message Endpoint ===")
|
|
||||||
|
|
||||||
mock_pipeline = create_mock_pipeline()
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
|
||||||
client = TestClient(app)
|
|
||||||
auth_headers = {"Authorization": "Bearer test-key-1"}
|
|
||||||
conv_id = "c_test123"
|
|
||||||
|
|
||||||
# Test non-streaming message
|
|
||||||
response = client.post(
|
|
||||||
f"/v1/conversations/{conv_id}/messages",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello, how are you?",
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert data["conversation_id"] == conv_id, "Should return correct conversation_id"
|
|
||||||
assert "message" in data, "Response should contain message"
|
|
||||||
assert data["message"]["role"] == "assistant", "Message role should be assistant"
|
|
||||||
assert "content" in data["message"], "Message should contain content"
|
|
||||||
print("✓ Non-streaming message creation works")
|
|
||||||
|
|
||||||
# Test streaming message
|
|
||||||
response = client.post(
|
|
||||||
f"/v1/conversations/{conv_id}/messages",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello",
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
assert "text/event-stream" in response.headers["content-type"], "Should return event-stream"
|
|
||||||
print("✓ Streaming message creation works")
|
|
||||||
|
|
||||||
# Test invalid role
|
|
||||||
response = client.post(
|
|
||||||
f"/v1/conversations/{conv_id}/messages",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Hello",
|
|
||||||
"stream": False
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assert response.status_code == 400, f"Expected 400 for invalid role, got {response.status_code}"
|
|
||||||
assert "Only role='user' is supported" in response.json()["detail"], "Should reject non-user role"
|
|
||||||
print("✓ Invalid role rejection works")
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_deletion():
|
|
||||||
"""Test memory deletion endpoints."""
|
|
||||||
print("\n=== Testing Memory Deletion ===")
|
|
||||||
|
|
||||||
mock_pipeline = create_mock_pipeline()
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline):
|
|
||||||
client = TestClient(app)
|
|
||||||
auth_headers = {"Authorization": "Bearer test-key-1"}
|
|
||||||
|
|
||||||
# Test delete all memory
|
|
||||||
response = client.delete("/v1/memory", headers=auth_headers)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert data["status"] == "success", "Should return success status"
|
|
||||||
assert data["scope"] == "all", "Should indicate all scope"
|
|
||||||
mock_pipeline.aclear_memory.assert_called_once()
|
|
||||||
print("✓ Delete all memory works")
|
|
||||||
|
|
||||||
# Test delete conversation memory
|
|
||||||
conv_id = "c_test123"
|
|
||||||
response = client.delete(
|
|
||||||
f"/v1/conversations/{conv_id}/memory",
|
|
||||||
headers=auth_headers
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert data["status"] == "success", "Should return success status"
|
|
||||||
assert data["scope"] == "conversation", "Should indicate conversation scope"
|
|
||||||
assert data["conversation_id"] == conv_id, "Should return conversation_id"
|
|
||||||
mock_pipeline.graph.memory.delete_thread.assert_called()
|
|
||||||
print("✓ Delete conversation memory works")
|
|
||||||
|
|
||||||
# Test delete conversation memory with device_id format
|
|
||||||
conv_id_with_device = "c_test123_device456"
|
|
||||||
response = client.delete(
|
|
||||||
f"/v1/conversations/{conv_id_with_device}/memory",
|
|
||||||
headers=auth_headers
|
|
||||||
)
|
|
||||||
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
|
|
||||||
# Should normalize to base thread_id
|
|
||||||
mock_pipeline.graph.memory.delete_thread.assert_called_with("c_test123")
|
|
||||||
print("✓ Delete conversation memory with device_id format works")
|
|
||||||
|
|
||||||
# Test delete conversation memory when MemorySaver is not available
|
|
||||||
mock_pipeline_no_mem = create_mock_pipeline()
|
|
||||||
mock_pipeline_no_mem.graph.memory = None
|
|
||||||
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", mock_pipeline_no_mem):
|
|
||||||
response = client.delete(
|
|
||||||
f"/v1/conversations/{conv_id}/memory",
|
|
||||||
headers=auth_headers
|
|
||||||
)
|
|
||||||
assert response.status_code == 501, f"Expected 501, got {response.status_code}"
|
|
||||||
data = response.json()
|
|
||||||
assert data["status"] == "unsupported", "Should return unsupported status"
|
|
||||||
print("✓ Unsupported memory deletion handling works")
|
|
||||||
|
|
||||||
|
|
||||||
def test_error_handling():
|
|
||||||
"""Test error handling."""
|
|
||||||
print("\n=== Testing Error Handling ===")
|
|
||||||
|
|
||||||
with patch("fastapi_server.server_rest.pipeline", create_mock_pipeline()):
|
|
||||||
client = TestClient(app)
|
|
||||||
auth_headers = {"Authorization": "Bearer test-key-1"}
|
|
||||||
|
|
||||||
# Test chat with missing input
|
|
||||||
response = client.post(
|
|
||||||
"/v1/chat",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"stream": False}
|
|
||||||
)
|
|
||||||
assert response.status_code == 422, f"Expected 422 for validation error, got {response.status_code}"
|
|
||||||
print("✓ Missing input validation works")
|
|
||||||
|
|
||||||
# Test message with missing content
|
|
||||||
response = client.post(
|
|
||||||
"/v1/conversations/c_test123/messages",
|
|
||||||
headers=auth_headers,
|
|
||||||
json={"role": "user", "stream": False}
|
|
||||||
)
|
|
||||||
assert response.status_code == 422, f"Expected 422 for validation error, got {response.status_code}"
|
|
||||||
print("✓ Missing content validation works")
|
|
||||||
|
|
||||||
|
|
||||||
def run_all_tests():
|
|
||||||
"""Run all tests."""
|
|
||||||
if not HAS_TEST_CLIENT:
|
|
||||||
print("Cannot run tests: fastapi.testclient not available")
|
|
||||||
print("Install with: pip install pytest httpx")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Running REST API Tests")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_health_endpoints()
|
|
||||||
test_authentication()
|
|
||||||
test_conversation_creation()
|
|
||||||
test_chat_endpoint()
|
|
||||||
test_message_endpoint()
|
|
||||||
test_memory_deletion()
|
|
||||||
test_error_handling()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("✓ All tests passed!")
|
|
||||||
print("=" * 60)
|
|
||||||
return True
|
|
||||||
except AssertionError as e:
|
|
||||||
print(f"\n✗ Test failed: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n✗ Unexpected error: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = run_all_tests()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user