From 7e23d5c0562f6883cc168c3f3f2b6992254a67ce Mon Sep 17 00:00:00 2001 From: goulustis Date: Thu, 5 Mar 2026 17:17:10 +0800 Subject: [PATCH] yaml to sql migration script --- .../py_scripts/migrate_yaml_prompts_to_db.py | 364 ++++++++++++++++++ tests/test_migrate_yaml_prompts_to_db.py | 113 ++++++ 2 files changed, 477 insertions(+) create mode 100644 scripts/py_scripts/migrate_yaml_prompts_to_db.py create mode 100644 tests/test_migrate_yaml_prompts_to_db.py diff --git a/scripts/py_scripts/migrate_yaml_prompts_to_db.py b/scripts/py_scripts/migrate_yaml_prompts_to_db.py new file mode 100644 index 0000000..0f5d6e3 --- /dev/null +++ b/scripts/py_scripts/migrate_yaml_prompts_to_db.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import datetime as dt +import glob +import os +import os.path as osp +import sys +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional + +import commentjson +import psycopg + + +PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) +if PROJECT_ROOT not in sys.path: + sys.path.append(PROJECT_ROOT) + +from lang_agent.config import load_tyro_conf # noqa: E402 +from lang_agent.config.db_config_manager import DBConfigManager # noqa: E402 + + +@dataclass +class MigrationPayload: + config_path: str + pipeline_id: str + graph_id: str + prompt_dict: Dict[str, str] + tool_keys: List[str] + api_key: Optional[str] + + +def _infer_pipeline_id(pipeline_conf, config_path: str) -> str: + candidates = [ + getattr(pipeline_conf, "pipeline_id", None), + getattr(getattr(pipeline_conf, "graph_config", None), "pipeline_id", None), + ] + for candidate in candidates: + if candidate is None: + continue + value = str(candidate).strip() + if value and value.lower() != "null": + return value + return osp.splitext(osp.basename(config_path))[0] + + +def _infer_graph_id(graph_conf) -> str: + if graph_conf is None: + return "unknown" + class_name = graph_conf.__class__.__name__.lower() + if "routing" in class_name or class_name == "routeconfig": + return "routing" + if "react" in class_name: + return "react" + + target = getattr(graph_conf, "_target", None) + if target is not None: + target_name = getattr(target, "__name__", str(target)).lower() + if "routing" in target_name: + return "routing" + if "react" in target_name: + return "react" + return "unknown" + + +def _extract_tool_keys(graph_conf) -> List[str]: + if graph_conf is None: + return [] + tool_cfg = getattr(graph_conf, "tool_manager_config", None) + client_cfg = getattr(tool_cfg, "client_tool_manager", None) + keys = getattr(client_cfg, "tool_keys", None) + if not keys: + return [] + out: List[str] = [] + seen = set() + for key in keys: + cleaned = str(key).strip() + if not cleaned or cleaned in seen: + continue + seen.add(cleaned) + out.append(cleaned) + return out + + +def _load_prompt_dict(prompt_path: str, default_key: str = "sys_prompt") -> Dict[str, str]: + if not prompt_path: + return {} + if not osp.exists(prompt_path): + return {} + + if osp.isdir(prompt_path): + prompt_files = sorted( + p for p in glob.glob(osp.join(prompt_path, "*.txt")) if "optional" not in p + ) + out = {} + for prompt_f in prompt_files: + key = osp.splitext(osp.basename(prompt_f))[0] + with open(prompt_f, "r", encoding="utf-8") as f: + out[key] = f.read() + return out + + if prompt_path.endswith(".json"): + with open(prompt_path, "r", encoding="utf-8") as f: + obj = commentjson.load(f) + if not isinstance(obj, dict): + return {} + return {str(k): v if isinstance(v, str) else str(v) for k, v in obj.items()} + + if prompt_path.endswith(".txt"): + with open(prompt_path, "r", encoding="utf-8") as f: + return {default_key: f.read()} + + return {} + + +def _extract_prompt_dict(graph_conf) -> Dict[str, str]: + if graph_conf is None: + return {} + if hasattr(graph_conf, "sys_prompt_f"): + return _load_prompt_dict(str(getattr(graph_conf, "sys_prompt_f")), "sys_prompt") + if hasattr(graph_conf, "sys_promp_dir"): + return _load_prompt_dict(str(getattr(graph_conf, "sys_promp_dir"))) + return {} + + +def _extract_tool_node_prompt_dict(graph_conf) -> Dict[str, str]: + tool_node_conf = getattr(graph_conf, "tool_node_config", None) + if tool_node_conf is None: + return {} + + out: Dict[str, str] = {} + if hasattr(tool_node_conf, "tool_prompt_f"): + out.update( + _load_prompt_dict(str(getattr(tool_node_conf, "tool_prompt_f")), "tool_prompt") + ) + if hasattr(tool_node_conf, "chatty_sys_prompt_f"): + out.update( + _load_prompt_dict( + str(getattr(tool_node_conf, "chatty_sys_prompt_f")), "chatty_prompt" + ) + ) + return out + + +def _prompt_key_whitelist(graph_conf, graph_id: str) -> Optional[set]: + if graph_id == "react": + return {"sys_prompt"} + if graph_id != "routing": + return None + + allowed = {"route_prompt", "chat_prompt", "tool_prompt"} + tool_node_conf = getattr(graph_conf, "tool_node_config", None) + if tool_node_conf is None: + return allowed + + cls_name = tool_node_conf.__class__.__name__.lower() + target = getattr(tool_node_conf, "_target", None) + target_name = getattr(target, "__name__", str(target)).lower() if target else "" + if "chatty" in cls_name or "chatty" in target_name: + allowed.add("chatty_prompt") + return allowed + + +def _collect_payload(config_path: str) -> MigrationPayload: + conf = load_tyro_conf(config_path) + graph_conf = getattr(conf, "graph_config", None) + graph_id = _infer_graph_id(graph_conf) + prompt_dict = _extract_prompt_dict(graph_conf) + prompt_dict.update(_extract_tool_node_prompt_dict(graph_conf)) + whitelist = _prompt_key_whitelist(graph_conf, graph_id) + if whitelist is not None: + prompt_dict = {k: v for k, v in prompt_dict.items() if k in whitelist} + return MigrationPayload( + config_path=config_path, + pipeline_id=_infer_pipeline_id(conf, config_path), + graph_id=graph_id, + prompt_dict=prompt_dict, + tool_keys=_extract_tool_keys(graph_conf), + api_key=getattr(conf, "api_key", None), + ) + + +def _resolve_config_paths(config_dir: str, config_paths: Optional[Iterable[str]]) -> List[str]: + if config_paths: + resolved = [osp.abspath(path) for path in config_paths] + else: + pattern = osp.join(osp.abspath(config_dir), "*.yaml") + resolved = sorted(glob.glob(pattern)) + return [path for path in resolved if osp.exists(path)] + + +def _ensure_prompt_set( + conn: psycopg.Connection, + pipeline_id: str, + graph_id: str, + set_name: str, + description: str, +) -> str: + with conn.cursor() as cur: + cur.execute( + """ + SELECT id FROM prompt_sets + WHERE pipeline_id = %s AND name = %s + ORDER BY updated_at DESC, created_at DESC + LIMIT 1 + """, + (pipeline_id, set_name), + ) + row = cur.fetchone() + if row is not None: + return str(row[0]) + + cur.execute( + """ + INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list) + VALUES (%s, %s, %s, %s, false, '') + RETURNING id + """, + (pipeline_id, graph_id, set_name, description), + ) + created = cur.fetchone() + return str(created[0]) + + +def _activate_prompt_set(conn: psycopg.Connection, pipeline_id: str, prompt_set_id: str) -> None: + with conn.cursor() as cur: + cur.execute( + "UPDATE prompt_sets SET is_active = false, updated_at = now() WHERE pipeline_id = %s", + (pipeline_id,), + ) + cur.execute( + "UPDATE prompt_sets SET is_active = true, updated_at = now() WHERE id = %s", + (prompt_set_id,), + ) + + +def _run_migration( + payloads: List[MigrationPayload], + set_name: str, + description: str, + dry_run: bool, + activate: bool, +) -> None: + for payload in payloads: + print( + f"[PLAN] pipeline={payload.pipeline_id} graph={payload.graph_id} " + f"prompts={len(payload.prompt_dict)} tools={len(payload.tool_keys)} " + f"config={payload.config_path}" + ) + if dry_run: + continue + + manager = DBConfigManager() + with psycopg.connect(manager.conn_str) as conn: + prompt_set_id = _ensure_prompt_set( + conn=conn, + pipeline_id=payload.pipeline_id, + graph_id=payload.graph_id, + set_name=set_name, + description=description, + ) + conn.commit() + + manager.set_config( + pipeline_id=payload.pipeline_id, + graph_id=payload.graph_id, + prompt_set_id=prompt_set_id, + tool_list=payload.tool_keys, + prompt_dict=payload.prompt_dict, + api_key=payload.api_key, + ) + + if activate: + _activate_prompt_set( + conn=conn, + pipeline_id=payload.pipeline_id, + prompt_set_id=prompt_set_id, + ) + conn.commit() + + print( + f"[DONE] pipeline={payload.pipeline_id} " + f"prompt_set={prompt_set_id} activate={activate}" + ) + + +def main() -> None: + date_str = dt.date.today().isoformat() + parser = argparse.ArgumentParser( + description="Import prompt definitions from pipeline YAML files into DB prompt_sets." + ) + parser.add_argument( + "--config-dir", + default=osp.join(PROJECT_ROOT, "configs", "pipelines"), + help="Directory containing pipeline YAML files.", + ) + parser.add_argument( + "--config", + action="append", + default=[], + help="Specific pipeline config yaml path. Can be passed multiple times.", + ) + parser.add_argument( + "--pipeline-id", + action="append", + default=[], + help="If provided, only migrate these pipeline IDs (repeatable).", + ) + parser.add_argument( + "--set-name", + # default=f"migrated-{date_str}", + default="default", + help="Prompt set name to create/reuse under each pipeline.", + ) + parser.add_argument( + "--description", + default="Migrated from pipeline YAML prompt files", + help="Prompt set description.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be migrated without writing to DB.", + ) + parser.add_argument( + "--activate", + action="store_true", + help="Mark imported set active for each migrated pipeline.", + ) + args = parser.parse_args() + + config_paths = _resolve_config_paths(args.config_dir, args.config) + if not config_paths: + raise SystemExit("No config files found. Provide --config or --config-dir.") + + requested_pipelines = {p.strip() for p in args.pipeline_id if p.strip()} + + payloads: List[MigrationPayload] = [] + for config_path in config_paths: + payload = _collect_payload(config_path) + if requested_pipelines and payload.pipeline_id not in requested_pipelines: + continue + if not payload.prompt_dict: + print(f"[SKIP] no prompts found for config={config_path}") + continue + payloads.append(payload) + + if not payloads: + raise SystemExit("No pipelines matched with prompt content to migrate.") + + _run_migration( + payloads=payloads, + set_name=args.set_name, + description=args.description, + dry_run=args.dry_run, + activate=args.activate, + ) + + +if __name__ == "__main__": + main() + diff --git a/tests/test_migrate_yaml_prompts_to_db.py b/tests/test_migrate_yaml_prompts_to_db.py new file mode 100644 index 0000000..8208094 --- /dev/null +++ b/tests/test_migrate_yaml_prompts_to_db.py @@ -0,0 +1,113 @@ +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace + + +def _load_module(): + project_root = Path(__file__).resolve().parents[1] + script_path = project_root / "scripts" / "py_scripts" / "migrate_yaml_prompts_to_db.py" + spec = importlib.util.spec_from_file_location("migrate_yaml_prompts_to_db", script_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_infer_pipeline_id_falls_back_to_filename(): + module = _load_module() + conf = SimpleNamespace( + pipeline_id=None, + graph_config=SimpleNamespace(pipeline_id=None), + ) + out = module._infer_pipeline_id(conf, "/tmp/blueberry.yaml") + assert out == "blueberry" + + +def test_extract_prompt_dict_for_react_txt(tmp_path): + module = _load_module() + prompt_f = tmp_path / "sys.txt" + prompt_f.write_text("hello react", encoding="utf-8") + graph_conf = SimpleNamespace(sys_prompt_f=str(prompt_f)) + prompt_dict = module._extract_prompt_dict(graph_conf) + assert prompt_dict == {"sys_prompt": "hello react"} + + +def test_extract_prompt_dict_for_routing_dir(tmp_path): + module = _load_module() + (tmp_path / "route_prompt.txt").write_text("route", encoding="utf-8") + (tmp_path / "chat_prompt.txt").write_text("chat", encoding="utf-8") + graph_conf = SimpleNamespace(sys_promp_dir=str(tmp_path)) + prompt_dict = module._extract_prompt_dict(graph_conf) + assert prompt_dict["route_prompt"] == "route" + assert prompt_dict["chat_prompt"] == "chat" + + +def test_collect_payload_routing_ignores_chatty_prompt_for_tool_node(tmp_path): + module = _load_module() + prompt_dir = tmp_path / "prompts" + prompt_dir.mkdir() + (prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8") + (prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8") + (prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8") + (prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8") + + class RoutingConfig: + pass + + class ToolNodeConfig: + pass + + graph_conf = RoutingConfig() + graph_conf.sys_promp_dir = str(prompt_dir) + graph_conf.tool_node_config = ToolNodeConfig() + graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt") + + conf = SimpleNamespace( + pipeline_id=None, + api_key="sk", + graph_config=graph_conf, + ) + + module.load_tyro_conf = lambda _: conf + payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml")) + assert payload.pipeline_id == "xiaozhan" + assert set(payload.prompt_dict.keys()) == {"route_prompt", "chat_prompt", "tool_prompt"} + assert "chatty_prompt" not in payload.prompt_dict + + +def test_collect_payload_routing_includes_chatty_prompt_for_chatty_node(tmp_path): + module = _load_module() + prompt_dir = tmp_path / "prompts" + prompt_dir.mkdir() + (prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8") + (prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8") + (prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8") + (prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8") + + class RoutingConfig: + pass + + class ChattyToolNodeConfig: + pass + + graph_conf = RoutingConfig() + graph_conf.sys_promp_dir = str(prompt_dir) + graph_conf.tool_node_config = ChattyToolNodeConfig() + graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt") + graph_conf.tool_node_config.chatty_sys_prompt_f = str( + prompt_dir / "chatty_prompt.txt" + ) + + conf = SimpleNamespace( + pipeline_id="xiaozhan", + api_key="sk", + graph_config=graph_conf, + ) + + module.load_tyro_conf = lambda _: conf + payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml")) + assert payload.pipeline_id == "xiaozhan" + assert "chatty_prompt" in payload.prompt_dict +