use prompt store
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator, Optional
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
@@ -10,6 +10,7 @@ import time
|
||||
|
||||
from lang_agent.config import LLMKeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.base import GraphBase, ToolNodeBase
|
||||
from lang_agent.graphs.graph_states import State
|
||||
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
||||
@@ -41,6 +42,12 @@ class RoutingConfig(LLMKeyConfig):
|
||||
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
||||
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with file fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
|
||||
@@ -83,32 +90,28 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||
self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory)
|
||||
|
||||
# Propagate pipeline_id and prompt_set_id to tool node config
|
||||
if self.config.pipeline_id and hasattr(self.config.tool_node_config, 'pipeline_id'):
|
||||
self.config.tool_node_config.pipeline_id = self.config.pipeline_id
|
||||
if self.config.prompt_set_id and hasattr(self.config.tool_node_config, 'prompt_set_id'):
|
||||
self.config.tool_node_config.prompt_set_id = self.config.prompt_set_id
|
||||
|
||||
self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
|
||||
memory=self.memory)
|
||||
|
||||
self._load_sys_prompts()
|
||||
|
||||
def _load_sys_prompts(self):
|
||||
if "json" in self.config.sys_promp_dir[-5:]:
|
||||
logger.info("loading sys prompt from json")
|
||||
with open(self.config.sys_promp_dir , "r") as f:
|
||||
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.sys_promp_dir,
|
||||
)
|
||||
self.prompt_dict: Dict[str, str] = self.prompt_store.get_all()
|
||||
|
||||
elif osp.isdir(self.config.sys_promp_dir):
|
||||
logger.info("loading sys_prompt from txt")
|
||||
sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt"))
|
||||
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
|
||||
self.prompt_dict = {}
|
||||
for sys_f in sys_fs:
|
||||
key = osp.basename(sys_f).split(".")[0]
|
||||
with open(sys_f, "r") as f:
|
||||
self.prompt_dict[key] = f.read()
|
||||
else:
|
||||
err_msg = f"{self.config.sys_promp_dir} is not supported"
|
||||
assert 0, err_msg
|
||||
|
||||
for k, _ in self.prompt_dict.items():
|
||||
logger.info(f"loaded {k} system prompt")
|
||||
for k in self.prompt_dict:
|
||||
logger.info(f"loaded '{k}' system prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user