diff --git a/lang_agent/graphs/dual_path.py b/lang_agent/graphs/dual_path.py index 62642c3..2c415ad 100644 --- a/lang_agent/graphs/dual_path.py +++ b/lang_agent/graphs/dual_path.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Type, TypedDict, Literal, Dict, List +from typing import Type, TypedDict, Literal, Dict, List, Optional import tyro from pydantic import BaseModel, Field from loguru import logger @@ -9,6 +9,7 @@ from langchain.chat_models import init_chat_model from lang_agent.config import LLMKeyConfig from lang_agent.base import GraphBase from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig +from lang_agent.components.prompt_store import build_prompt_store from lang_agent.graphs.graph_states import State from langchain.agents import create_agent @@ -50,6 +51,12 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜 class DualConfig(LLMKeyConfig): _target: Type = field(default_factory=lambda:Dual) + pipeline_id: Optional[str] = None + """If set, load prompts from database (with hardcoded fallback)""" + + prompt_set_id: Optional[str] = None + """If set, load from this specific prompt set instead of the active one""" + tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) from langchain.tools import tool @@ -96,14 +103,23 @@ class Dual(GraphBase): self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools()) # self.tool_agent = create_agent(self.tool_llm, [turn_lights]) + self.prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + hardcoded={ + "sys_prompt": SYS_PROMPT, + "tool_sys_prompt": TOOL_SYS_PROMPT, + }, + ) + self.streamable_tags = [["dual_chat_llm"]] def _chat_call(self, state:State): - return self._agent_call_template(SYS_PROMPT, self.chat_agent, state) + return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state) def _tool_call(self, state:State): - self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state) + self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state) return {} def _join(self, state:State): diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index c4e2a6b..5efb7a8 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,11 +1,12 @@ from dataclasses import dataclass, field -from typing import Type +from typing import Type, Optional import tyro import os.path as osp from loguru import logger from lang_agent.config import KeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig +from lang_agent.components.prompt_store import build_prompt_store from lang_agent.base import GraphBase from lang_agent.utils import tree_leaves from lang_agent.graphs.graph_states import State @@ -34,6 +35,12 @@ class ReactGraphConfig(KeyConfig): base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" + pipeline_id: Optional[str] = None + """If set, load prompts from database (with file fallback)""" + + prompt_set_id: Optional[str] = None + """If set, load from this specific prompt set instead of the active one""" + tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) def __post_init__(self): @@ -65,8 +72,13 @@ class ReactGraph(GraphBase): tools = self.tool_manager.get_langchain_tools() self.agent = create_agent(self.llm, tools, checkpointer=self.memory) - with open(self.config.sys_prompt_f, "r") as f: - self.sys_prompt = f.read() + self.prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + file_path=self.config.sys_prompt_f, + default_key="sys_prompt", + ) + self.sys_prompt = self.prompt_store.get("sys_prompt") def _agent_call(self, state:State): if state.get("messages") is not None: diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 93ffd4b..f3bfa74 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, is_dataclass -from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator +from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator, Optional import tyro from pydantic import BaseModel, Field from loguru import logger @@ -10,6 +10,7 @@ import time from lang_agent.config import LLMKeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig +from lang_agent.components.prompt_store import build_prompt_store from lang_agent.base import GraphBase, ToolNodeBase from lang_agent.graphs.graph_states import State from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig @@ -41,6 +42,12 @@ class RoutingConfig(LLMKeyConfig): sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts") """path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided""" + pipeline_id: Optional[str] = None + """If set, load prompts from database (with file fallback)""" + + prompt_set_id: Optional[str] = None + """If set, load from this specific prompt set instead of the active one""" + tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig) @@ -83,32 +90,28 @@ class RoutingGraph(GraphBase): tool_manager:ToolManager = self.config.tool_manager_config.setup() self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory) + + # Propagate pipeline_id and prompt_set_id to tool node config + if self.config.pipeline_id and hasattr(self.config.tool_node_config, 'pipeline_id'): + self.config.tool_node_config.pipeline_id = self.config.pipeline_id + if self.config.prompt_set_id and hasattr(self.config.tool_node_config, 'prompt_set_id'): + self.config.tool_node_config.prompt_set_id = self.config.prompt_set_id + self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager, memory=self.memory) self._load_sys_prompts() def _load_sys_prompts(self): - if "json" in self.config.sys_promp_dir[-5:]: - logger.info("loading sys prompt from json") - with open(self.config.sys_promp_dir , "r") as f: - self.prompt_dict:Dict[str, str] = commentjson.load(f) + self.prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + file_path=self.config.sys_promp_dir, + ) + self.prompt_dict: Dict[str, str] = self.prompt_store.get_all() - elif osp.isdir(self.config.sys_promp_dir): - logger.info("loading sys_prompt from txt") - sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt")) - sys_fs = sorted([e for e in sys_fs if not ("optional" in e)]) - self.prompt_dict = {} - for sys_f in sys_fs: - key = osp.basename(sys_f).split(".")[0] - with open(sys_f, "r") as f: - self.prompt_dict[key] = f.read() - else: - err_msg = f"{self.config.sys_promp_dir} is not supported" - assert 0, err_msg - - for k, _ in self.prompt_dict.items(): - logger.info(f"loaded {k} system prompt") + for k in self.prompt_dict: + logger.info(f"loaded '{k}' system prompt") diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 4f2248a..8a5ac30 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, is_dataclass -from typing import Type, TypedDict, Literal, Dict, List, Tuple +from typing import Type, TypedDict, Literal, Dict, List, Tuple, Optional import tyro import os.path as osp import time @@ -8,6 +8,7 @@ from loguru import logger from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.components.tool_manager import ToolManager +from lang_agent.components.prompt_store import build_prompt_store from lang_agent.components.reit_llm import ReitLLM from lang_agent.base import ToolNodeBase from lang_agent.graphs.graph_states import State, ChattyToolState @@ -27,6 +28,12 @@ class ToolNodeConfig(InstantiateConfig): tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt") + pipeline_id: Optional[str] = None + """If set, load prompts from database (with file fallback)""" + + prompt_set_id: Optional[str] = None + """If set, load from this specific prompt set instead of the active one""" + class ToolNode(ToolNodeBase): def __init__(self, config: ToolNodeConfig, @@ -42,8 +49,13 @@ class ToolNode(ToolNodeBase): self.llm = make_llm(tags=["tool_llm"]) self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem) - with open(self.config.tool_prompt_f, "r") as f: - self.sys_prompt = f.read() + self.prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + file_path=self.config.tool_prompt_f, + default_key="tool_prompt", + ) + self.sys_prompt = self.prompt_store.get("tool_prompt") def invoke(self, state:State): inp = {"messages":[ @@ -88,6 +100,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig): chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") """path to chatty system prompt""" + # pipeline_id and prompt_set_id are inherited from ToolNodeConfig + tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig) @@ -124,14 +138,31 @@ class ChattyToolNode(ToolNodeBase): self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem) # self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) + + # Propagate pipeline_id and prompt_set_id to inner tool_node_conf + if self.config.pipeline_id and hasattr(self.config.tool_node_conf, 'pipeline_id'): + self.config.tool_node_conf.pipeline_id = self.config.pipeline_id + if self.config.prompt_set_id and hasattr(self.config.tool_node_conf, 'prompt_set_id'): + self.config.tool_node_conf.prompt_set_id = self.config.prompt_set_id + self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager, memory=self.mem) - with open(self.config.chatty_sys_prompt_f, "r") as f: - self.chatty_sys_prompt = f.read() - - with open(self.config.tool_prompt_f, "r") as f: - self.tool_sys_prompt = f.read() + self.chatty_prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + file_path=self.config.chatty_sys_prompt_f, + default_key="chatty_prompt", + ) + self.chatty_sys_prompt = self.chatty_prompt_store.get("chatty_prompt") + + self.tool_prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + file_path=self.config.tool_prompt_f, + default_key="tool_prompt", + ) + self.tool_sys_prompt = self.tool_prompt_store.get("tool_prompt") def get_streamable_tags(self): return [["chatty_llm"], ["reit_llm"]] diff --git a/lang_agent/graphs/vision_routing.py b/lang_agent/graphs/vision_routing.py index 33d4e0e..9c0ff17 100644 --- a/lang_agent/graphs/vision_routing.py +++ b/lang_agent/graphs/vision_routing.py @@ -6,7 +6,7 @@ Vision-enabled routing graph that: """ from dataclasses import dataclass, field -from typing import Type, TypedDict, List, Dict, Any, Tuple +from typing import Type, TypedDict, List, Dict, Any, Tuple, Optional import tyro import base64 import json @@ -14,6 +14,7 @@ from loguru import logger from lang_agent.config import LLMKeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig +from lang_agent.components.prompt_store import build_prompt_store from lang_agent.base import GraphBase, ToolNodeBase from lang_agent.components.client_tool_manager import ClientToolManagerConfig @@ -104,6 +105,12 @@ class VisionRoutingConfig(LLMKeyConfig): base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url for API""" + pipeline_id: Optional[str] = None + """If set, load prompts from database (with hardcoded fallback)""" + + prompt_set_id: Optional[str] = None + """If set, load from this specific prompt set instead of the active one""" + tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig) @@ -167,6 +174,17 @@ class VisionRoutingGraph(GraphBase): # Create tool node for executing tools self.tool_node = ToolNode(self.camera_tools) + # Build prompt store: DB (if pipeline_id) > hardcoded defaults + self.prompt_store = build_prompt_store( + pipeline_id=self.config.pipeline_id, + prompt_set_id=self.config.prompt_set_id, + hardcoded={ + "camera_decision_prompt": CAMERA_DECISION_PROMPT, + "vision_description_prompt": VISION_DESCRIPTION_PROMPT, + "conversation_prompt": CONVERSATION_PROMPT, + }, + ) + def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage: """Get user message from current invocation""" msgs = state["inp"][0]["messages"] @@ -180,7 +198,7 @@ class VisionRoutingGraph(GraphBase): human_msg = self._get_human_msg(state) messages = [ - SystemMessage(content=CAMERA_DECISION_PROMPT), + SystemMessage(content=self.prompt_store.get("camera_decision_prompt")), human_msg ] @@ -279,7 +297,7 @@ class VisionRoutingGraph(GraphBase): ) messages = [ - SystemMessage(content=VISION_DESCRIPTION_PROMPT), + SystemMessage(content=self.prompt_store.get("vision_description_prompt")), vision_message ] @@ -292,7 +310,7 @@ class VisionRoutingGraph(GraphBase): human_msg = self._get_human_msg(state) messages = [ - SystemMessage(content=CONVERSATION_PROMPT), + SystemMessage(content=self.prompt_store.get("conversation_prompt")), human_msg ] diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index fd64f07..d5a00b5 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -73,6 +73,12 @@ class PipelineConfig(KeyConfig): port:int = 23 """what is my port""" + pipeline_id: str = None + """If set, load prompts from database (with file fallback)""" + + prompt_set_id: str = None + """If set, load from this specific prompt set instead of the active one""" + # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) graph_config: AnnotatedGraph = field(default_factory=RoutingConfig) @@ -97,6 +103,12 @@ class Pipeline: self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url self.config.graph_config.api_key = self.config.api_key + # Propagate pipeline_id and prompt_set_id to graph config for DB prompt loading + if self.config.pipeline_id is not None and hasattr(self.config.graph_config, 'pipeline_id'): + self.config.graph_config.pipeline_id = self.config.pipeline_id + if self.config.prompt_set_id is not None and hasattr(self.config.graph_config, 'prompt_set_id'): + self.config.graph_config.prompt_set_id = self.config.prompt_set_id + self.graph:GraphBase = self.config.graph_config.setup() def show_graph(self):