diff --git a/lang_agent/front_api/build_server_utils.py b/lang_agent/front_api/build_server_utils.py index d7ce585..5ca7e4b 100644 --- a/lang_agent/front_api/build_server_utils.py +++ b/lang_agent/front_api/build_server_utils.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List import os import os.path as osp import subprocess @@ -6,10 +6,17 @@ import json from lang_agent.config.core_config import load_tyro_conf -def opt_to_config(save_path:str, *nargs): +_PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) +_TY_BUILD_SCRIPT = osp.join(_PROJECT_ROOT, "lang_agent", "config", "ty_build_config.py") + + +def opt_to_config(save_path: str, *nargs): os.makedirs(osp.dirname(save_path), exist_ok=True) - subprocess.run(["python", "lang_agent/config/ty_build_config.py", - "--save-path", save_path, *nargs]) + subprocess.run( + ["python", _TY_BUILD_SCRIPT, "--save-path", save_path, *nargs], + check=True, + cwd=_PROJECT_ROOT, + ) def _build_and_load_pipeline_config(pipeline_id: str, pipeline_config_dir: str, @@ -23,20 +30,30 @@ def _build_and_load_pipeline_config(pipeline_id: str, def update_pipeline_registry(pipeline_id:str, prompt_set:str, + graph_id: str, + config_file: str, + llm_name: str, + enabled: bool = True, registry_f:str="configs/pipeline_registry.json"): + if not osp.isabs(registry_f): + registry_f = osp.join(_PROJECT_ROOT, registry_f) + os.makedirs(osp.dirname(registry_f), exist_ok=True) + if not osp.exists(registry_f): + with open(registry_f, "w", encoding="utf-8") as f: + json.dump({"routes": {}, "api_keys": {}}, f, indent=4) + with open(registry_f, "r") as f: registry = json.load(f) - - if pipeline_id not in registry["routes"]: - registry["routes"][pipeline_id] = { - "enabled": True, - "config_file": None, - "prompt_pipeline_id": prompt_set, - } - else: - registry["routes"][pipeline_id]["prompt_pipeline_id"] = prompt_set - with open(registry_f, "w") as f: + routes: Dict[str, Dict[str, Any]] = registry.setdefault("routes", {}) + route = routes.setdefault(pipeline_id, {}) + route["enabled"] = bool(enabled) + route["config_file"] = config_file + route["prompt_pipeline_id"] = prompt_set + route["graph_id"] = graph_id + route["overrides"] = {"llm_name": llm_name} + + with open(registry_f, "w", encoding="utf-8") as f: json.dump(registry, f, indent=4)