use prompt store

This commit is contained in:
2026-02-10 10:54:58 +08:00
parent ede7199dfc
commit cb5b3afd05
6 changed files with 130 additions and 38 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator, Optional
import tyro
from pydantic import BaseModel, Field
from loguru import logger
@@ -10,6 +10,7 @@ import time
from lang_agent.config import LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.graphs.graph_states import State
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
@@ -41,6 +42,12 @@ class RoutingConfig(LLMKeyConfig):
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
@@ -83,32 +90,28 @@ class RoutingGraph(GraphBase):
tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory)
# Propagate pipeline_id and prompt_set_id to tool node config
if self.config.pipeline_id and hasattr(self.config.tool_node_config, 'pipeline_id'):
self.config.tool_node_config.pipeline_id = self.config.pipeline_id
if self.config.prompt_set_id and hasattr(self.config.tool_node_config, 'prompt_set_id'):
self.config.tool_node_config.prompt_set_id = self.config.prompt_set_id
self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
memory=self.memory)
self._load_sys_prompts()
def _load_sys_prompts(self):
if "json" in self.config.sys_promp_dir[-5:]:
logger.info("loading sys prompt from json")
with open(self.config.sys_promp_dir , "r") as f:
self.prompt_dict:Dict[str, str] = commentjson.load(f)
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.sys_promp_dir,
)
self.prompt_dict: Dict[str, str] = self.prompt_store.get_all()
elif osp.isdir(self.config.sys_promp_dir):
logger.info("loading sys_prompt from txt")
sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt"))
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
self.prompt_dict = {}
for sys_f in sys_fs:
key = osp.basename(sys_f).split(".")[0]
with open(sys_f, "r") as f:
self.prompt_dict[key] = f.read()
else:
err_msg = f"{self.config.sys_promp_dir} is not supported"
assert 0, err_msg
for k, _ in self.prompt_dict.items():
logger.info(f"loaded {k} system prompt")
for k in self.prompt_dict:
logger.info(f"loaded '{k}' system prompt")