update tests

This commit is contained in:
2026-03-06 13:19:26 +08:00
parent fc9f0f929d
commit dd842fca42
2 changed files with 190 additions and 1 deletions

View File

@@ -1,13 +1,18 @@
import json
import os
from pathlib import Path
from datetime import datetime, timedelta, timezone
import importlib
from fastapi.testclient import TestClient
os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
import fastapi_server.front_apis as front_apis
try:
front_apis = importlib.import_module("lang_agent.fastapi_server.front_apis")
except ModuleNotFoundError:
front_apis = importlib.import_module("fastapi_server.front_apis")
def _fake_build_fn(
@@ -36,6 +41,98 @@ def _fake_build_fn(
return {"path": str(out_file)}
class _FakeCursor:
def __init__(self, rows):
self._rows = rows
self._result = []
self._last_sql = ""
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql, params=None):
self._last_sql = sql
query = " ".join(sql.split()).lower()
params = params or ()
if "group by conversation_id, pipeline_id" in query:
pipeline_id = params[0]
limit = int(params[1])
grouped = {}
for row in self._rows:
if row["pipeline_id"] != pipeline_id:
continue
conv_id = row["conversation_id"]
if conv_id not in grouped:
grouped[conv_id] = {
"conversation_id": conv_id,
"pipeline_id": row["pipeline_id"],
"message_count": 0,
"last_updated": row["created_at"],
}
grouped[conv_id]["message_count"] += 1
if row["created_at"] > grouped[conv_id]["last_updated"]:
grouped[conv_id]["last_updated"] = row["created_at"]
values = sorted(grouped.values(), key=lambda x: x["last_updated"], reverse=True)
self._result = values[:limit]
return
if "select 1 from messages" in query:
pipeline_id, conversation_id = params
found = any(
row["pipeline_id"] == pipeline_id
and row["conversation_id"] == conversation_id
for row in self._rows
)
self._result = [{"exists": 1}] if found else []
return
if "order by sequence_number asc" in query:
pipeline_id, conversation_id = params
self._result = sorted(
[
{
"message_type": row["message_type"],
"content": row["content"],
"sequence_number": row["sequence_number"],
"created_at": row["created_at"],
}
for row in self._rows
if row["pipeline_id"] == pipeline_id
and row["conversation_id"] == conversation_id
],
key=lambda x: x["sequence_number"],
)
return
raise AssertionError(f"Unsupported SQL in test fake: {self._last_sql}")
def fetchall(self):
return self._result
def fetchone(self):
if not self._result:
return None
return self._result[0]
class _FakeConnection:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def cursor(self, row_factory=None):
return _FakeCursor(self._rows)
def test_registry_route_lifecycle(monkeypatch, tmp_path):
registry_path = tmp_path / "pipeline_registry.json"
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
@@ -137,3 +234,94 @@ def test_registry_api_key_policy_lifecycle(monkeypatch, tmp_path):
assert delete_data["api_key"] == "sk-test-key"
assert delete_data["status"] == "deleted"
assert delete_data["reload_required"] is False
def test_pipeline_conversation_routes(monkeypatch):
now = datetime.now(timezone.utc)
rows = [
{
"conversation_id": "agent-a:conv-1",
"pipeline_id": "agent-a",
"message_type": "human",
"content": "hello",
"sequence_number": 1,
"created_at": now - timedelta(seconds=30),
},
{
"conversation_id": "agent-a:conv-1",
"pipeline_id": "agent-a",
"message_type": "ai",
"content": "hi there",
"sequence_number": 2,
"created_at": now - timedelta(seconds=20),
},
{
"conversation_id": "agent-a:conv-2",
"pipeline_id": "agent-a",
"message_type": "human",
"content": "second thread",
"sequence_number": 1,
"created_at": now - timedelta(seconds=10),
},
{
"conversation_id": "agent-b:conv-9",
"pipeline_id": "agent-b",
"message_type": "human",
"content": "other pipeline",
"sequence_number": 1,
"created_at": now - timedelta(seconds=5),
},
]
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
monkeypatch.setattr(
front_apis.psycopg,
"connect",
lambda _conn_str: _FakeConnection(rows),
)
client = TestClient(front_apis.app)
list_resp = client.get("/v1/pipelines/agent-a/conversations")
assert list_resp.status_code == 200, list_resp.text
list_data = list_resp.json()
assert list_data["pipeline_id"] == "agent-a"
assert list_data["count"] == 2
assert [item["conversation_id"] for item in list_data["items"]] == [
"agent-a:conv-2",
"agent-a:conv-1",
]
assert all(item["pipeline_id"] == "agent-a" for item in list_data["items"])
msg_resp = client.get("/v1/pipelines/agent-a/conversations/agent-a:conv-1/messages")
assert msg_resp.status_code == 200, msg_resp.text
msg_data = msg_resp.json()
assert msg_data["pipeline_id"] == "agent-a"
assert msg_data["conversation_id"] == "agent-a:conv-1"
assert msg_data["count"] == 2
assert [item["message_type"] for item in msg_data["items"]] == ["human", "ai"]
assert [item["sequence_number"] for item in msg_data["items"]] == [1, 2]
def test_pipeline_conversation_messages_404(monkeypatch):
rows = [
{
"conversation_id": "agent-b:conv-9",
"pipeline_id": "agent-b",
"message_type": "human",
"content": "other pipeline",
"sequence_number": 1,
"created_at": datetime.now(timezone.utc),
},
]
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
monkeypatch.setattr(
front_apis.psycopg,
"connect",
lambda _conn_str: _FakeConnection(rows),
)
client = TestClient(front_apis.app)
resp = client.get("/v1/pipelines/agent-a/conversations/agent-b:conv-9/messages")
assert resp.status_code == 404, resp.text
assert "not found for pipeline 'agent-a'" in resp.json()["detail"]