Files
lang-agent/tests/test_migrate_yaml_prompts_to_db.py

114 lines
3.9 KiB
Python

import importlib.util
import sys
from pathlib import Path
from types import SimpleNamespace
def _load_module():
project_root = Path(__file__).resolve().parents[1]
script_path = project_root / "scripts" / "py_scripts" / "migrate_yaml_prompts_to_db.py"
spec = importlib.util.spec_from_file_location("migrate_yaml_prompts_to_db", script_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)
return module
def test_infer_pipeline_id_falls_back_to_filename():
module = _load_module()
conf = SimpleNamespace(
pipeline_id=None,
graph_config=SimpleNamespace(pipeline_id=None),
)
out = module._infer_pipeline_id(conf, "/tmp/blueberry.yaml")
assert out == "blueberry"
def test_extract_prompt_dict_for_react_txt(tmp_path):
module = _load_module()
prompt_f = tmp_path / "sys.txt"
prompt_f.write_text("hello react", encoding="utf-8")
graph_conf = SimpleNamespace(sys_prompt_f=str(prompt_f))
prompt_dict = module._extract_prompt_dict(graph_conf)
assert prompt_dict == {"sys_prompt": "hello react"}
def test_extract_prompt_dict_for_routing_dir(tmp_path):
module = _load_module()
(tmp_path / "route_prompt.txt").write_text("route", encoding="utf-8")
(tmp_path / "chat_prompt.txt").write_text("chat", encoding="utf-8")
graph_conf = SimpleNamespace(sys_promp_dir=str(tmp_path))
prompt_dict = module._extract_prompt_dict(graph_conf)
assert prompt_dict["route_prompt"] == "route"
assert prompt_dict["chat_prompt"] == "chat"
def test_collect_payload_routing_ignores_chatty_prompt_for_tool_node(tmp_path):
module = _load_module()
prompt_dir = tmp_path / "prompts"
prompt_dir.mkdir()
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
class RoutingConfig:
pass
class ToolNodeConfig:
pass
graph_conf = RoutingConfig()
graph_conf.sys_promp_dir = str(prompt_dir)
graph_conf.tool_node_config = ToolNodeConfig()
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
conf = SimpleNamespace(
pipeline_id=None,
api_key="sk",
graph_config=graph_conf,
)
module.load_tyro_conf = lambda _: conf
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
assert payload.pipeline_id == "xiaozhan"
assert set(payload.prompt_dict.keys()) == {"route_prompt", "chat_prompt", "tool_prompt"}
assert "chatty_prompt" not in payload.prompt_dict
def test_collect_payload_routing_includes_chatty_prompt_for_chatty_node(tmp_path):
module = _load_module()
prompt_dir = tmp_path / "prompts"
prompt_dir.mkdir()
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
class RoutingConfig:
pass
class ChattyToolNodeConfig:
pass
graph_conf = RoutingConfig()
graph_conf.sys_promp_dir = str(prompt_dir)
graph_conf.tool_node_config = ChattyToolNodeConfig()
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
graph_conf.tool_node_config.chatty_sys_prompt_f = str(
prompt_dir / "chatty_prompt.txt"
)
conf = SimpleNamespace(
pipeline_id="xiaozhan",
api_key="sk",
graph_config=graph_conf,
)
module.load_tyro_conf = lambda _: conf
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
assert payload.pipeline_id == "xiaozhan"
assert "chatty_prompt" in payload.prompt_dict