more robust build_server_utils
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -6,10 +6,17 @@ import json
|
|||||||
|
|
||||||
from lang_agent.config.core_config import load_tyro_conf
|
from lang_agent.config.core_config import load_tyro_conf
|
||||||
|
|
||||||
|
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
_TY_BUILD_SCRIPT = osp.join(_PROJECT_ROOT, "lang_agent", "config", "ty_build_config.py")
|
||||||
|
|
||||||
|
|
||||||
def opt_to_config(save_path: str, *nargs):
|
def opt_to_config(save_path: str, *nargs):
|
||||||
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
||||||
subprocess.run(["python", "lang_agent/config/ty_build_config.py",
|
subprocess.run(
|
||||||
"--save-path", save_path, *nargs])
|
["python", _TY_BUILD_SCRIPT, "--save-path", save_path, *nargs],
|
||||||
|
check=True,
|
||||||
|
cwd=_PROJECT_ROOT,
|
||||||
|
)
|
||||||
|
|
||||||
def _build_and_load_pipeline_config(pipeline_id: str,
|
def _build_and_load_pipeline_config(pipeline_id: str,
|
||||||
pipeline_config_dir: str,
|
pipeline_config_dir: str,
|
||||||
@@ -23,20 +30,30 @@ def _build_and_load_pipeline_config(pipeline_id: str,
|
|||||||
|
|
||||||
def update_pipeline_registry(pipeline_id:str,
|
def update_pipeline_registry(pipeline_id:str,
|
||||||
prompt_set:str,
|
prompt_set:str,
|
||||||
|
graph_id: str,
|
||||||
|
config_file: str,
|
||||||
|
llm_name: str,
|
||||||
|
enabled: bool = True,
|
||||||
registry_f:str="configs/pipeline_registry.json"):
|
registry_f:str="configs/pipeline_registry.json"):
|
||||||
|
if not osp.isabs(registry_f):
|
||||||
|
registry_f = osp.join(_PROJECT_ROOT, registry_f)
|
||||||
|
os.makedirs(osp.dirname(registry_f), exist_ok=True)
|
||||||
|
if not osp.exists(registry_f):
|
||||||
|
with open(registry_f, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"routes": {}, "api_keys": {}}, f, indent=4)
|
||||||
|
|
||||||
with open(registry_f, "r") as f:
|
with open(registry_f, "r") as f:
|
||||||
registry = json.load(f)
|
registry = json.load(f)
|
||||||
|
|
||||||
if pipeline_id not in registry["routes"]:
|
routes: Dict[str, Dict[str, Any]] = registry.setdefault("routes", {})
|
||||||
registry["routes"][pipeline_id] = {
|
route = routes.setdefault(pipeline_id, {})
|
||||||
"enabled": True,
|
route["enabled"] = bool(enabled)
|
||||||
"config_file": None,
|
route["config_file"] = config_file
|
||||||
"prompt_pipeline_id": prompt_set,
|
route["prompt_pipeline_id"] = prompt_set
|
||||||
}
|
route["graph_id"] = graph_id
|
||||||
else:
|
route["overrides"] = {"llm_name": llm_name}
|
||||||
registry["routes"][pipeline_id]["prompt_pipeline_id"] = prompt_set
|
|
||||||
|
|
||||||
with open(registry_f, "w") as f:
|
with open(registry_f, "w", encoding="utf-8") as f:
|
||||||
json.dump(registry, f, indent=4)
|
json.dump(registry, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user