update tests

This commit is contained in:
2026-03-06 13:19:26 +08:00
parent fc9f0f929d
commit dd842fca42
2 changed files with 190 additions and 1 deletions

View File

@@ -1,13 +1,18 @@
import json import json
import os import os
from pathlib import Path from pathlib import Path
from datetime import datetime, timedelta, timezone
import importlib
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy") os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
import fastapi_server.front_apis as front_apis try:
front_apis = importlib.import_module("lang_agent.fastapi_server.front_apis")
except ModuleNotFoundError:
front_apis = importlib.import_module("fastapi_server.front_apis")
def _fake_build_fn( def _fake_build_fn(
@@ -36,6 +41,98 @@ def _fake_build_fn(
return {"path": str(out_file)} return {"path": str(out_file)}
class _FakeCursor:
def __init__(self, rows):
self._rows = rows
self._result = []
self._last_sql = ""
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, sql, params=None):
self._last_sql = sql
query = " ".join(sql.split()).lower()
params = params or ()
if "group by conversation_id, pipeline_id" in query:
pipeline_id = params[0]
limit = int(params[1])
grouped = {}
for row in self._rows:
if row["pipeline_id"] != pipeline_id:
continue
conv_id = row["conversation_id"]
if conv_id not in grouped:
grouped[conv_id] = {
"conversation_id": conv_id,
"pipeline_id": row["pipeline_id"],
"message_count": 0,
"last_updated": row["created_at"],
}
grouped[conv_id]["message_count"] += 1
if row["created_at"] > grouped[conv_id]["last_updated"]:
grouped[conv_id]["last_updated"] = row["created_at"]
values = sorted(grouped.values(), key=lambda x: x["last_updated"], reverse=True)
self._result = values[:limit]
return
if "select 1 from messages" in query:
pipeline_id, conversation_id = params
found = any(
row["pipeline_id"] == pipeline_id
and row["conversation_id"] == conversation_id
for row in self._rows
)
self._result = [{"exists": 1}] if found else []
return
if "order by sequence_number asc" in query:
pipeline_id, conversation_id = params
self._result = sorted(
[
{
"message_type": row["message_type"],
"content": row["content"],
"sequence_number": row["sequence_number"],
"created_at": row["created_at"],
}
for row in self._rows
if row["pipeline_id"] == pipeline_id
and row["conversation_id"] == conversation_id
],
key=lambda x: x["sequence_number"],
)
return
raise AssertionError(f"Unsupported SQL in test fake: {self._last_sql}")
def fetchall(self):
return self._result
def fetchone(self):
if not self._result:
return None
return self._result[0]
class _FakeConnection:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def cursor(self, row_factory=None):
return _FakeCursor(self._rows)
def test_registry_route_lifecycle(monkeypatch, tmp_path): def test_registry_route_lifecycle(monkeypatch, tmp_path):
registry_path = tmp_path / "pipeline_registry.json" registry_path = tmp_path / "pipeline_registry.json"
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path)) monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
@@ -137,3 +234,94 @@ def test_registry_api_key_policy_lifecycle(monkeypatch, tmp_path):
assert delete_data["api_key"] == "sk-test-key" assert delete_data["api_key"] == "sk-test-key"
assert delete_data["status"] == "deleted" assert delete_data["status"] == "deleted"
assert delete_data["reload_required"] is False assert delete_data["reload_required"] is False
def test_pipeline_conversation_routes(monkeypatch):
now = datetime.now(timezone.utc)
rows = [
{
"conversation_id": "agent-a:conv-1",
"pipeline_id": "agent-a",
"message_type": "human",
"content": "hello",
"sequence_number": 1,
"created_at": now - timedelta(seconds=30),
},
{
"conversation_id": "agent-a:conv-1",
"pipeline_id": "agent-a",
"message_type": "ai",
"content": "hi there",
"sequence_number": 2,
"created_at": now - timedelta(seconds=20),
},
{
"conversation_id": "agent-a:conv-2",
"pipeline_id": "agent-a",
"message_type": "human",
"content": "second thread",
"sequence_number": 1,
"created_at": now - timedelta(seconds=10),
},
{
"conversation_id": "agent-b:conv-9",
"pipeline_id": "agent-b",
"message_type": "human",
"content": "other pipeline",
"sequence_number": 1,
"created_at": now - timedelta(seconds=5),
},
]
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
monkeypatch.setattr(
front_apis.psycopg,
"connect",
lambda _conn_str: _FakeConnection(rows),
)
client = TestClient(front_apis.app)
list_resp = client.get("/v1/pipelines/agent-a/conversations")
assert list_resp.status_code == 200, list_resp.text
list_data = list_resp.json()
assert list_data["pipeline_id"] == "agent-a"
assert list_data["count"] == 2
assert [item["conversation_id"] for item in list_data["items"]] == [
"agent-a:conv-2",
"agent-a:conv-1",
]
assert all(item["pipeline_id"] == "agent-a" for item in list_data["items"])
msg_resp = client.get("/v1/pipelines/agent-a/conversations/agent-a:conv-1/messages")
assert msg_resp.status_code == 200, msg_resp.text
msg_data = msg_resp.json()
assert msg_data["pipeline_id"] == "agent-a"
assert msg_data["conversation_id"] == "agent-a:conv-1"
assert msg_data["count"] == 2
assert [item["message_type"] for item in msg_data["items"]] == ["human", "ai"]
assert [item["sequence_number"] for item in msg_data["items"]] == [1, 2]
def test_pipeline_conversation_messages_404(monkeypatch):
rows = [
{
"conversation_id": "agent-b:conv-9",
"pipeline_id": "agent-b",
"message_type": "human",
"content": "other pipeline",
"sequence_number": 1,
"created_at": datetime.now(timezone.utc),
},
]
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
monkeypatch.setattr(
front_apis.psycopg,
"connect",
lambda _conn_str: _FakeConnection(rows),
)
client = TestClient(front_apis.app)
resp = client.get("/v1/pipelines/agent-a/conversations/agent-b:conv-9/messages")
assert resp.status_code == 404, resp.text
assert "not found for pipeline 'agent-a'" in resp.json()["detail"]

View File

@@ -151,3 +151,4 @@ def test_refresh_registry_applies_disabled_state_immediately(tmp_path):
assert exc_info.value.status_code == 403 assert exc_info.value.status_code == 403