yaml to sql migration script
This commit is contained in:
113
tests/test_migrate_yaml_prompts_to_db.py
Normal file
113
tests/test_migrate_yaml_prompts_to_db.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def _load_module():
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
script_path = project_root / "scripts" / "py_scripts" / "migrate_yaml_prompts_to_db.py"
|
||||
spec = importlib.util.spec_from_file_location("migrate_yaml_prompts_to_db", script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_infer_pipeline_id_falls_back_to_filename():
|
||||
module = _load_module()
|
||||
conf = SimpleNamespace(
|
||||
pipeline_id=None,
|
||||
graph_config=SimpleNamespace(pipeline_id=None),
|
||||
)
|
||||
out = module._infer_pipeline_id(conf, "/tmp/blueberry.yaml")
|
||||
assert out == "blueberry"
|
||||
|
||||
|
||||
def test_extract_prompt_dict_for_react_txt(tmp_path):
|
||||
module = _load_module()
|
||||
prompt_f = tmp_path / "sys.txt"
|
||||
prompt_f.write_text("hello react", encoding="utf-8")
|
||||
graph_conf = SimpleNamespace(sys_prompt_f=str(prompt_f))
|
||||
prompt_dict = module._extract_prompt_dict(graph_conf)
|
||||
assert prompt_dict == {"sys_prompt": "hello react"}
|
||||
|
||||
|
||||
def test_extract_prompt_dict_for_routing_dir(tmp_path):
|
||||
module = _load_module()
|
||||
(tmp_path / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||
(tmp_path / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||
graph_conf = SimpleNamespace(sys_promp_dir=str(tmp_path))
|
||||
prompt_dict = module._extract_prompt_dict(graph_conf)
|
||||
assert prompt_dict["route_prompt"] == "route"
|
||||
assert prompt_dict["chat_prompt"] == "chat"
|
||||
|
||||
|
||||
def test_collect_payload_routing_ignores_chatty_prompt_for_tool_node(tmp_path):
|
||||
module = _load_module()
|
||||
prompt_dir = tmp_path / "prompts"
|
||||
prompt_dir.mkdir()
|
||||
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
|
||||
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
|
||||
|
||||
class RoutingConfig:
|
||||
pass
|
||||
|
||||
class ToolNodeConfig:
|
||||
pass
|
||||
|
||||
graph_conf = RoutingConfig()
|
||||
graph_conf.sys_promp_dir = str(prompt_dir)
|
||||
graph_conf.tool_node_config = ToolNodeConfig()
|
||||
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
|
||||
|
||||
conf = SimpleNamespace(
|
||||
pipeline_id=None,
|
||||
api_key="sk",
|
||||
graph_config=graph_conf,
|
||||
)
|
||||
|
||||
module.load_tyro_conf = lambda _: conf
|
||||
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
|
||||
assert payload.pipeline_id == "xiaozhan"
|
||||
assert set(payload.prompt_dict.keys()) == {"route_prompt", "chat_prompt", "tool_prompt"}
|
||||
assert "chatty_prompt" not in payload.prompt_dict
|
||||
|
||||
|
||||
def test_collect_payload_routing_includes_chatty_prompt_for_chatty_node(tmp_path):
|
||||
module = _load_module()
|
||||
prompt_dir = tmp_path / "prompts"
|
||||
prompt_dir.mkdir()
|
||||
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
|
||||
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
|
||||
|
||||
class RoutingConfig:
|
||||
pass
|
||||
|
||||
class ChattyToolNodeConfig:
|
||||
pass
|
||||
|
||||
graph_conf = RoutingConfig()
|
||||
graph_conf.sys_promp_dir = str(prompt_dir)
|
||||
graph_conf.tool_node_config = ChattyToolNodeConfig()
|
||||
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
|
||||
graph_conf.tool_node_config.chatty_sys_prompt_f = str(
|
||||
prompt_dir / "chatty_prompt.txt"
|
||||
)
|
||||
|
||||
conf = SimpleNamespace(
|
||||
pipeline_id="xiaozhan",
|
||||
api_key="sk",
|
||||
graph_config=graph_conf,
|
||||
)
|
||||
|
||||
module.load_tyro_conf = lambda _: conf
|
||||
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
|
||||
assert payload.pipeline_id == "xiaozhan"
|
||||
assert "chatty_prompt" in payload.prompt_dict
|
||||
|
||||
Reference in New Issue
Block a user