From dd842fca42c4b979a5168a9713d51e2e6358cf18 Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 6 Mar 2026 13:19:26 +0800 Subject: [PATCH] update tests --- tests/test_front_apis_registry.py | 190 +++++++++++++++++- tests/test_server_pipeline_manager_refresh.py | 1 + 2 files changed, 190 insertions(+), 1 deletion(-) diff --git a/tests/test_front_apis_registry.py b/tests/test_front_apis_registry.py index a081ee9..fab7555 100644 --- a/tests/test_front_apis_registry.py +++ b/tests/test_front_apis_registry.py @@ -1,13 +1,18 @@ import json import os from pathlib import Path +from datetime import datetime, timedelta, timezone +import importlib from fastapi.testclient import TestClient os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy") -import fastapi_server.front_apis as front_apis +try: + front_apis = importlib.import_module("lang_agent.fastapi_server.front_apis") +except ModuleNotFoundError: + front_apis = importlib.import_module("fastapi_server.front_apis") def _fake_build_fn( @@ -36,6 +41,98 @@ def _fake_build_fn( return {"path": str(out_file)} +class _FakeCursor: + def __init__(self, rows): + self._rows = rows + self._result = [] + self._last_sql = "" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, sql, params=None): + self._last_sql = sql + query = " ".join(sql.split()).lower() + params = params or () + + if "group by conversation_id, pipeline_id" in query: + pipeline_id = params[0] + limit = int(params[1]) + grouped = {} + for row in self._rows: + if row["pipeline_id"] != pipeline_id: + continue + conv_id = row["conversation_id"] + if conv_id not in grouped: + grouped[conv_id] = { + "conversation_id": conv_id, + "pipeline_id": row["pipeline_id"], + "message_count": 0, + "last_updated": row["created_at"], + } + grouped[conv_id]["message_count"] += 1 + if row["created_at"] > grouped[conv_id]["last_updated"]: + grouped[conv_id]["last_updated"] = row["created_at"] + values = sorted(grouped.values(), key=lambda x: x["last_updated"], reverse=True) + self._result = values[:limit] + return + + if "select 1 from messages" in query: + pipeline_id, conversation_id = params + found = any( + row["pipeline_id"] == pipeline_id + and row["conversation_id"] == conversation_id + for row in self._rows + ) + self._result = [{"exists": 1}] if found else [] + return + + if "order by sequence_number asc" in query: + pipeline_id, conversation_id = params + self._result = sorted( + [ + { + "message_type": row["message_type"], + "content": row["content"], + "sequence_number": row["sequence_number"], + "created_at": row["created_at"], + } + for row in self._rows + if row["pipeline_id"] == pipeline_id + and row["conversation_id"] == conversation_id + ], + key=lambda x: x["sequence_number"], + ) + return + + raise AssertionError(f"Unsupported SQL in test fake: {self._last_sql}") + + def fetchall(self): + return self._result + + def fetchone(self): + if not self._result: + return None + return self._result[0] + + +class _FakeConnection: + def __init__(self, rows): + self._rows = rows + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def cursor(self, row_factory=None): + return _FakeCursor(self._rows) + + def test_registry_route_lifecycle(monkeypatch, tmp_path): registry_path = tmp_path / "pipeline_registry.json" monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path)) @@ -137,3 +234,94 @@ def test_registry_api_key_policy_lifecycle(monkeypatch, tmp_path): assert delete_data["api_key"] == "sk-test-key" assert delete_data["status"] == "deleted" assert delete_data["reload_required"] is False + + +def test_pipeline_conversation_routes(monkeypatch): + now = datetime.now(timezone.utc) + rows = [ + { + "conversation_id": "agent-a:conv-1", + "pipeline_id": "agent-a", + "message_type": "human", + "content": "hello", + "sequence_number": 1, + "created_at": now - timedelta(seconds=30), + }, + { + "conversation_id": "agent-a:conv-1", + "pipeline_id": "agent-a", + "message_type": "ai", + "content": "hi there", + "sequence_number": 2, + "created_at": now - timedelta(seconds=20), + }, + { + "conversation_id": "agent-a:conv-2", + "pipeline_id": "agent-a", + "message_type": "human", + "content": "second thread", + "sequence_number": 1, + "created_at": now - timedelta(seconds=10), + }, + { + "conversation_id": "agent-b:conv-9", + "pipeline_id": "agent-b", + "message_type": "human", + "content": "other pipeline", + "sequence_number": 1, + "created_at": now - timedelta(seconds=5), + }, + ] + + monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy") + monkeypatch.setattr( + front_apis.psycopg, + "connect", + lambda _conn_str: _FakeConnection(rows), + ) + + client = TestClient(front_apis.app) + + list_resp = client.get("/v1/pipelines/agent-a/conversations") + assert list_resp.status_code == 200, list_resp.text + list_data = list_resp.json() + assert list_data["pipeline_id"] == "agent-a" + assert list_data["count"] == 2 + assert [item["conversation_id"] for item in list_data["items"]] == [ + "agent-a:conv-2", + "agent-a:conv-1", + ] + assert all(item["pipeline_id"] == "agent-a" for item in list_data["items"]) + + msg_resp = client.get("/v1/pipelines/agent-a/conversations/agent-a:conv-1/messages") + assert msg_resp.status_code == 200, msg_resp.text + msg_data = msg_resp.json() + assert msg_data["pipeline_id"] == "agent-a" + assert msg_data["conversation_id"] == "agent-a:conv-1" + assert msg_data["count"] == 2 + assert [item["message_type"] for item in msg_data["items"]] == ["human", "ai"] + assert [item["sequence_number"] for item in msg_data["items"]] == [1, 2] + + +def test_pipeline_conversation_messages_404(monkeypatch): + rows = [ + { + "conversation_id": "agent-b:conv-9", + "pipeline_id": "agent-b", + "message_type": "human", + "content": "other pipeline", + "sequence_number": 1, + "created_at": datetime.now(timezone.utc), + }, + ] + monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy") + monkeypatch.setattr( + front_apis.psycopg, + "connect", + lambda _conn_str: _FakeConnection(rows), + ) + + client = TestClient(front_apis.app) + resp = client.get("/v1/pipelines/agent-a/conversations/agent-b:conv-9/messages") + assert resp.status_code == 404, resp.text + assert "not found for pipeline 'agent-a'" in resp.json()["detail"] diff --git a/tests/test_server_pipeline_manager_refresh.py b/tests/test_server_pipeline_manager_refresh.py index 4003579..19b42ae 100644 --- a/tests/test_server_pipeline_manager_refresh.py +++ b/tests/test_server_pipeline_manager_refresh.py @@ -151,3 +151,4 @@ def test_refresh_registry_applies_disabled_state_immediately(tmp_path): assert exc_info.value.status_code == 403 +