use prompt store

This commit is contained in:
2026-02-10 10:54:58 +08:00
parent ede7199dfc
commit cb5b3afd05
6 changed files with 130 additions and 38 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Type, TypedDict, Literal, Dict, List
from typing import Type, TypedDict, Literal, Dict, List, Optional
import tyro
from pydantic import BaseModel, Field
from loguru import logger
@@ -9,6 +9,7 @@ from langchain.chat_models import init_chat_model
from lang_agent.config import LLMKeyConfig
from lang_agent.base import GraphBase
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.graphs.graph_states import State
from langchain.agents import create_agent
@@ -50,6 +51,12 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
class DualConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda:Dual)
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with hardcoded fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
from langchain.tools import tool
@@ -96,14 +103,23 @@ class Dual(GraphBase):
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
hardcoded={
"sys_prompt": SYS_PROMPT,
"tool_sys_prompt": TOOL_SYS_PROMPT,
},
)
self.streamable_tags = [["dual_chat_llm"]]
def _chat_call(self, state:State):
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state)
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
def _tool_call(self, state:State):
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state)
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
return {}
def _join(self, state:State):

View File

@@ -1,11 +1,12 @@
from dataclasses import dataclass, field
from typing import Type
from typing import Type, Optional
import tyro
import os.path as osp
from loguru import logger
from lang_agent.config import KeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase
from lang_agent.utils import tree_leaves
from lang_agent.graphs.graph_states import State
@@ -34,6 +35,12 @@ class ReactGraphConfig(KeyConfig):
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self):
@@ -65,8 +72,13 @@ class ReactGraph(GraphBase):
tools = self.tool_manager.get_langchain_tools()
self.agent = create_agent(self.llm, tools, checkpointer=self.memory)
with open(self.config.sys_prompt_f, "r") as f:
self.sys_prompt = f.read()
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.sys_prompt_f,
default_key="sys_prompt",
)
self.sys_prompt = self.prompt_store.get("sys_prompt")
def _agent_call(self, state:State):
if state.get("messages") is not None:

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator, Optional
import tyro
from pydantic import BaseModel, Field
from loguru import logger
@@ -10,6 +10,7 @@ import time
from lang_agent.config import LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.graphs.graph_states import State
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
@@ -41,6 +42,12 @@ class RoutingConfig(LLMKeyConfig):
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
@@ -83,32 +90,28 @@ class RoutingGraph(GraphBase):
tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory)
# Propagate pipeline_id and prompt_set_id to tool node config
if self.config.pipeline_id and hasattr(self.config.tool_node_config, 'pipeline_id'):
self.config.tool_node_config.pipeline_id = self.config.pipeline_id
if self.config.prompt_set_id and hasattr(self.config.tool_node_config, 'prompt_set_id'):
self.config.tool_node_config.prompt_set_id = self.config.prompt_set_id
self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
memory=self.memory)
self._load_sys_prompts()
def _load_sys_prompts(self):
if "json" in self.config.sys_promp_dir[-5:]:
logger.info("loading sys prompt from json")
with open(self.config.sys_promp_dir , "r") as f:
self.prompt_dict:Dict[str, str] = commentjson.load(f)
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.sys_promp_dir,
)
self.prompt_dict: Dict[str, str] = self.prompt_store.get_all()
elif osp.isdir(self.config.sys_promp_dir):
logger.info("loading sys_prompt from txt")
sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt"))
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
self.prompt_dict = {}
for sys_f in sys_fs:
key = osp.basename(sys_f).split(".")[0]
with open(sys_f, "r") as f:
self.prompt_dict[key] = f.read()
else:
err_msg = f"{self.config.sys_promp_dir} is not supported"
assert 0, err_msg
for k, _ in self.prompt_dict.items():
logger.info(f"loaded {k} system prompt")
for k in self.prompt_dict:
logger.info(f"loaded '{k}' system prompt")

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Optional
import tyro
import os.path as osp
import time
@@ -8,6 +8,7 @@ from loguru import logger
from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.components.tool_manager import ToolManager
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.components.reit_llm import ReitLLM
from lang_agent.base import ToolNodeBase
from lang_agent.graphs.graph_states import State, ChattyToolState
@@ -27,6 +28,12 @@ class ToolNodeConfig(InstantiateConfig):
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
class ToolNode(ToolNodeBase):
def __init__(self, config: ToolNodeConfig,
@@ -42,8 +49,13 @@ class ToolNode(ToolNodeBase):
self.llm = make_llm(tags=["tool_llm"])
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
with open(self.config.tool_prompt_f, "r") as f:
self.sys_prompt = f.read()
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.tool_prompt_f,
default_key="tool_prompt",
)
self.sys_prompt = self.prompt_store.get("tool_prompt")
def invoke(self, state:State):
inp = {"messages":[
@@ -88,6 +100,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
"""path to chatty system prompt"""
# pipeline_id and prompt_set_id are inherited from ToolNodeConfig
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
@@ -124,14 +138,31 @@ class ChattyToolNode(ToolNodeBase):
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
# self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
# Propagate pipeline_id and prompt_set_id to inner tool_node_conf
if self.config.pipeline_id and hasattr(self.config.tool_node_conf, 'pipeline_id'):
self.config.tool_node_conf.pipeline_id = self.config.pipeline_id
if self.config.prompt_set_id and hasattr(self.config.tool_node_conf, 'prompt_set_id'):
self.config.tool_node_conf.prompt_set_id = self.config.prompt_set_id
self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager,
memory=self.mem)
with open(self.config.chatty_sys_prompt_f, "r") as f:
self.chatty_sys_prompt = f.read()
with open(self.config.tool_prompt_f, "r") as f:
self.tool_sys_prompt = f.read()
self.chatty_prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.chatty_sys_prompt_f,
default_key="chatty_prompt",
)
self.chatty_sys_prompt = self.chatty_prompt_store.get("chatty_prompt")
self.tool_prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.tool_prompt_f,
default_key="tool_prompt",
)
self.tool_sys_prompt = self.tool_prompt_store.get("tool_prompt")
def get_streamable_tags(self):
return [["chatty_llm"], ["reit_llm"]]

View File

@@ -6,7 +6,7 @@ Vision-enabled routing graph that:
"""
from dataclasses import dataclass, field
from typing import Type, TypedDict, List, Dict, Any, Tuple
from typing import Type, TypedDict, List, Dict, Any, Tuple, Optional
import tyro
import base64
import json
@@ -14,6 +14,7 @@ from loguru import logger
from lang_agent.config import LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.components.client_tool_manager import ClientToolManagerConfig
@@ -104,6 +105,12 @@ class VisionRoutingConfig(LLMKeyConfig):
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url for API"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with hardcoded fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
@@ -167,6 +174,17 @@ class VisionRoutingGraph(GraphBase):
# Create tool node for executing tools
self.tool_node = ToolNode(self.camera_tools)
# Build prompt store: DB (if pipeline_id) > hardcoded defaults
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
hardcoded={
"camera_decision_prompt": CAMERA_DECISION_PROMPT,
"vision_description_prompt": VISION_DESCRIPTION_PROMPT,
"conversation_prompt": CONVERSATION_PROMPT,
},
)
def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage:
"""Get user message from current invocation"""
msgs = state["inp"][0]["messages"]
@@ -180,7 +198,7 @@ class VisionRoutingGraph(GraphBase):
human_msg = self._get_human_msg(state)
messages = [
SystemMessage(content=CAMERA_DECISION_PROMPT),
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
human_msg
]
@@ -279,7 +297,7 @@ class VisionRoutingGraph(GraphBase):
)
messages = [
SystemMessage(content=VISION_DESCRIPTION_PROMPT),
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
vision_message
]
@@ -292,7 +310,7 @@ class VisionRoutingGraph(GraphBase):
human_msg = self._get_human_msg(state)
messages = [
SystemMessage(content=CONVERSATION_PROMPT),
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
human_msg
]