#!/usr/bin/env python3 from __future__ import annotations import argparse import datetime as dt import glob import os import os.path as osp import sys from dataclasses import dataclass from typing import Dict, Iterable, List, Optional import commentjson import psycopg PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) if PROJECT_ROOT not in sys.path: sys.path.append(PROJECT_ROOT) from lang_agent.config import load_tyro_conf # noqa: E402 from lang_agent.config.db_config_manager import DBConfigManager # noqa: E402 @dataclass class MigrationPayload: config_path: str pipeline_id: str graph_id: str prompt_dict: Dict[str, str] tool_keys: List[str] api_key: Optional[str] def _infer_pipeline_id(pipeline_conf, config_path: str) -> str: candidates = [ getattr(pipeline_conf, "pipeline_id", None), getattr(getattr(pipeline_conf, "graph_config", None), "pipeline_id", None), ] for candidate in candidates: if candidate is None: continue value = str(candidate).strip() if value and value.lower() != "null": return value return osp.splitext(osp.basename(config_path))[0] def _infer_graph_id(graph_conf) -> str: if graph_conf is None: return "unknown" class_name = graph_conf.__class__.__name__.lower() if "routing" in class_name or class_name == "routeconfig": return "routing" if "react" in class_name: return "react" target = getattr(graph_conf, "_target", None) if target is not None: target_name = getattr(target, "__name__", str(target)).lower() if "routing" in target_name: return "routing" if "react" in target_name: return "react" return "unknown" def _extract_tool_keys(graph_conf) -> List[str]: if graph_conf is None: return [] tool_cfg = getattr(graph_conf, "tool_manager_config", None) client_cfg = getattr(tool_cfg, "client_tool_manager", None) keys = getattr(client_cfg, "tool_keys", None) if not keys: return [] out: List[str] = [] seen = set() for key in keys: cleaned = str(key).strip() if not cleaned or cleaned in seen: continue seen.add(cleaned) out.append(cleaned) return out def _load_prompt_dict(prompt_path: str, default_key: str = "sys_prompt") -> Dict[str, str]: if not prompt_path: return {} if not osp.exists(prompt_path): return {} if osp.isdir(prompt_path): prompt_files = sorted( p for p in glob.glob(osp.join(prompt_path, "*.txt")) if "optional" not in p ) out = {} for prompt_f in prompt_files: key = osp.splitext(osp.basename(prompt_f))[0] with open(prompt_f, "r", encoding="utf-8") as f: out[key] = f.read() return out if prompt_path.endswith(".json"): with open(prompt_path, "r", encoding="utf-8") as f: obj = commentjson.load(f) if not isinstance(obj, dict): return {} return {str(k): v if isinstance(v, str) else str(v) for k, v in obj.items()} if prompt_path.endswith(".txt"): with open(prompt_path, "r", encoding="utf-8") as f: return {default_key: f.read()} return {} def _extract_prompt_dict(graph_conf) -> Dict[str, str]: if graph_conf is None: return {} if hasattr(graph_conf, "sys_prompt_f"): return _load_prompt_dict(str(getattr(graph_conf, "sys_prompt_f")), "sys_prompt") if hasattr(graph_conf, "sys_promp_dir"): return _load_prompt_dict(str(getattr(graph_conf, "sys_promp_dir"))) return {} def _extract_tool_node_prompt_dict(graph_conf) -> Dict[str, str]: tool_node_conf = getattr(graph_conf, "tool_node_config", None) if tool_node_conf is None: return {} out: Dict[str, str] = {} if hasattr(tool_node_conf, "tool_prompt_f"): out.update( _load_prompt_dict(str(getattr(tool_node_conf, "tool_prompt_f")), "tool_prompt") ) if hasattr(tool_node_conf, "chatty_sys_prompt_f"): out.update( _load_prompt_dict( str(getattr(tool_node_conf, "chatty_sys_prompt_f")), "chatty_prompt" ) ) return out def _prompt_key_whitelist(graph_conf, graph_id: str) -> Optional[set]: if graph_id == "react": return {"sys_prompt"} if graph_id != "routing": return None allowed = {"route_prompt", "chat_prompt", "tool_prompt"} tool_node_conf = getattr(graph_conf, "tool_node_config", None) if tool_node_conf is None: return allowed cls_name = tool_node_conf.__class__.__name__.lower() target = getattr(tool_node_conf, "_target", None) target_name = getattr(target, "__name__", str(target)).lower() if target else "" if "chatty" in cls_name or "chatty" in target_name: allowed.add("chatty_prompt") return allowed def _collect_payload(config_path: str) -> MigrationPayload: conf = load_tyro_conf(config_path) graph_conf = getattr(conf, "graph_config", None) graph_id = _infer_graph_id(graph_conf) prompt_dict = _extract_prompt_dict(graph_conf) prompt_dict.update(_extract_tool_node_prompt_dict(graph_conf)) whitelist = _prompt_key_whitelist(graph_conf, graph_id) if whitelist is not None: prompt_dict = {k: v for k, v in prompt_dict.items() if k in whitelist} return MigrationPayload( config_path=config_path, pipeline_id=_infer_pipeline_id(conf, config_path), graph_id=graph_id, prompt_dict=prompt_dict, tool_keys=_extract_tool_keys(graph_conf), api_key=getattr(conf, "api_key", None), ) def _resolve_config_paths(config_dir: str, config_paths: Optional[Iterable[str]]) -> List[str]: if config_paths: resolved = [osp.abspath(path) for path in config_paths] else: pattern = osp.join(osp.abspath(config_dir), "*.yaml") resolved = sorted(glob.glob(pattern)) return [path for path in resolved if osp.exists(path)] def _ensure_prompt_set( conn: psycopg.Connection, pipeline_id: str, graph_id: str, set_name: str, description: str, ) -> str: with conn.cursor() as cur: cur.execute( """ SELECT id FROM prompt_sets WHERE pipeline_id = %s AND name = %s ORDER BY updated_at DESC, created_at DESC LIMIT 1 """, (pipeline_id, set_name), ) row = cur.fetchone() if row is not None: return str(row[0]) cur.execute( """ INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list) VALUES (%s, %s, %s, %s, false, '') RETURNING id """, (pipeline_id, graph_id, set_name, description), ) created = cur.fetchone() return str(created[0]) def _activate_prompt_set(conn: psycopg.Connection, pipeline_id: str, prompt_set_id: str) -> None: with conn.cursor() as cur: cur.execute( "UPDATE prompt_sets SET is_active = false, updated_at = now() WHERE pipeline_id = %s", (pipeline_id,), ) cur.execute( "UPDATE prompt_sets SET is_active = true, updated_at = now() WHERE id = %s", (prompt_set_id,), ) def _run_migration( payloads: List[MigrationPayload], set_name: str, description: str, dry_run: bool, activate: bool, ) -> None: for payload in payloads: print( f"[PLAN] pipeline={payload.pipeline_id} graph={payload.graph_id} " f"prompts={len(payload.prompt_dict)} tools={len(payload.tool_keys)} " f"config={payload.config_path}" ) if dry_run: continue manager = DBConfigManager() with psycopg.connect(manager.conn_str) as conn: prompt_set_id = _ensure_prompt_set( conn=conn, pipeline_id=payload.pipeline_id, graph_id=payload.graph_id, set_name=set_name, description=description, ) conn.commit() manager.set_config( pipeline_id=payload.pipeline_id, graph_id=payload.graph_id, prompt_set_id=prompt_set_id, tool_list=payload.tool_keys, prompt_dict=payload.prompt_dict, api_key=payload.api_key, ) if activate: _activate_prompt_set( conn=conn, pipeline_id=payload.pipeline_id, prompt_set_id=prompt_set_id, ) conn.commit() print( f"[DONE] pipeline={payload.pipeline_id} " f"prompt_set={prompt_set_id} activate={activate}" ) def main() -> None: date_str = dt.date.today().isoformat() parser = argparse.ArgumentParser( description="Import prompt definitions from pipeline YAML files into DB prompt_sets." ) parser.add_argument( "--config-dir", default=osp.join(PROJECT_ROOT, "configs", "pipelines"), help="Directory containing pipeline YAML files.", ) parser.add_argument( "--config", action="append", default=[], help="Specific pipeline config yaml path. Can be passed multiple times.", ) parser.add_argument( "--pipeline-id", action="append", default=[], help="If provided, only migrate these pipeline IDs (repeatable).", ) parser.add_argument( "--set-name", # default=f"migrated-{date_str}", default="default", help="Prompt set name to create/reuse under each pipeline.", ) parser.add_argument( "--description", default="Migrated from pipeline YAML prompt files", help="Prompt set description.", ) parser.add_argument( "--dry-run", action="store_true", help="Print what would be migrated without writing to DB.", ) parser.add_argument( "--activate", action="store_true", help="Mark imported set active for each migrated pipeline.", ) args = parser.parse_args() config_paths = _resolve_config_paths(args.config_dir, args.config) if not config_paths: raise SystemExit("No config files found. Provide --config or --config-dir.") requested_pipelines = {p.strip() for p in args.pipeline_id if p.strip()} payloads: List[MigrationPayload] = [] for config_path in config_paths: payload = _collect_payload(config_path) if requested_pipelines and payload.pipeline_id not in requested_pipelines: continue if not payload.prompt_dict: print(f"[SKIP] no prompts found for config={config_path}") continue payloads.append(payload) if not payloads: raise SystemExit("No pipelines matched with prompt content to migrate.") _run_migration( payloads=payloads, set_name=args.set_name, description=args.description, dry_run=args.dry_run, activate=args.activate, ) if __name__ == "__main__": main()