use prompt store

This commit is contained in:
2026-02-10 10:54:58 +08:00
parent ede7199dfc
commit cb5b3afd05
6 changed files with 130 additions and 38 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type, TypedDict, Literal, Dict, List from typing import Type, TypedDict, Literal, Dict, List, Optional
import tyro import tyro
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
@@ -9,6 +9,7 @@ from langchain.chat_models import init_chat_model
from lang_agent.config import LLMKeyConfig from lang_agent.config import LLMKeyConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.graphs.graph_states import State from lang_agent.graphs.graph_states import State
from langchain.agents import create_agent from langchain.agents import create_agent
@@ -50,6 +51,12 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
class DualConfig(LLMKeyConfig): class DualConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda:Dual) _target: Type = field(default_factory=lambda:Dual)
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with hardcoded fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
from langchain.tools import tool from langchain.tools import tool
@@ -96,14 +103,23 @@ class Dual(GraphBase):
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools()) self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
# self.tool_agent = create_agent(self.tool_llm, [turn_lights]) # self.tool_agent = create_agent(self.tool_llm, [turn_lights])
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
hardcoded={
"sys_prompt": SYS_PROMPT,
"tool_sys_prompt": TOOL_SYS_PROMPT,
},
)
self.streamable_tags = [["dual_chat_llm"]] self.streamable_tags = [["dual_chat_llm"]]
def _chat_call(self, state:State): def _chat_call(self, state:State):
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state) return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
def _tool_call(self, state:State): def _tool_call(self, state:State):
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state) self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
return {} return {}
def _join(self, state:State): def _join(self, state:State):

View File

@@ -1,11 +1,12 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type from typing import Type, Optional
import tyro import tyro
import os.path as osp import os.path as osp
from loguru import logger from loguru import logger
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.utils import tree_leaves from lang_agent.utils import tree_leaves
from lang_agent.graphs.graph_states import State from lang_agent.graphs.graph_states import State
@@ -34,6 +35,12 @@ class ReactGraphConfig(KeyConfig):
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider""" """base url; could be used to overwrite the baseurl in llm provider"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self): def __post_init__(self):
@@ -65,8 +72,13 @@ class ReactGraph(GraphBase):
tools = self.tool_manager.get_langchain_tools() tools = self.tool_manager.get_langchain_tools()
self.agent = create_agent(self.llm, tools, checkpointer=self.memory) self.agent = create_agent(self.llm, tools, checkpointer=self.memory)
with open(self.config.sys_prompt_f, "r") as f: self.prompt_store = build_prompt_store(
self.sys_prompt = f.read() pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.sys_prompt_f,
default_key="sys_prompt",
)
self.sys_prompt = self.prompt_store.get("sys_prompt")
def _agent_call(self, state:State): def _agent_call(self, state:State):
if state.get("messages") is not None: if state.get("messages") is not None:

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator, Optional
import tyro import tyro
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
@@ -10,6 +10,7 @@ import time
from lang_agent.config import LLMKeyConfig from lang_agent.config import LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase, ToolNodeBase from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.graphs.graph_states import State from lang_agent.graphs.graph_states import State
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
@@ -41,6 +42,12 @@ class RoutingConfig(LLMKeyConfig):
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts") sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided""" """path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig) tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
@@ -83,32 +90,28 @@ class RoutingGraph(GraphBase):
tool_manager:ToolManager = self.config.tool_manager_config.setup() tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory) self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory)
# Propagate pipeline_id and prompt_set_id to tool node config
if self.config.pipeline_id and hasattr(self.config.tool_node_config, 'pipeline_id'):
self.config.tool_node_config.pipeline_id = self.config.pipeline_id
if self.config.prompt_set_id and hasattr(self.config.tool_node_config, 'prompt_set_id'):
self.config.tool_node_config.prompt_set_id = self.config.prompt_set_id
self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager, self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
memory=self.memory) memory=self.memory)
self._load_sys_prompts() self._load_sys_prompts()
def _load_sys_prompts(self): def _load_sys_prompts(self):
if "json" in self.config.sys_promp_dir[-5:]: self.prompt_store = build_prompt_store(
logger.info("loading sys prompt from json") pipeline_id=self.config.pipeline_id,
with open(self.config.sys_promp_dir , "r") as f: prompt_set_id=self.config.prompt_set_id,
self.prompt_dict:Dict[str, str] = commentjson.load(f) file_path=self.config.sys_promp_dir,
)
self.prompt_dict: Dict[str, str] = self.prompt_store.get_all()
elif osp.isdir(self.config.sys_promp_dir): for k in self.prompt_dict:
logger.info("loading sys_prompt from txt") logger.info(f"loaded '{k}' system prompt")
sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt"))
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
self.prompt_dict = {}
for sys_f in sys_fs:
key = osp.basename(sys_f).split(".")[0]
with open(sys_f, "r") as f:
self.prompt_dict[key] = f.read()
else:
err_msg = f"{self.config.sys_promp_dir} is not supported"
assert 0, err_msg
for k, _ in self.prompt_dict.items():
logger.info(f"loaded {k} system prompt")

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple from typing import Type, TypedDict, Literal, Dict, List, Tuple, Optional
import tyro import tyro
import os.path as osp import os.path as osp
import time import time
@@ -8,6 +8,7 @@ from loguru import logger
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.components.tool_manager import ToolManager from lang_agent.components.tool_manager import ToolManager
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.components.reit_llm import ReitLLM from lang_agent.components.reit_llm import ReitLLM
from lang_agent.base import ToolNodeBase from lang_agent.base import ToolNodeBase
from lang_agent.graphs.graph_states import State, ChattyToolState from lang_agent.graphs.graph_states import State, ChattyToolState
@@ -27,6 +28,12 @@ class ToolNodeConfig(InstantiateConfig):
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt") tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
class ToolNode(ToolNodeBase): class ToolNode(ToolNodeBase):
def __init__(self, config: ToolNodeConfig, def __init__(self, config: ToolNodeConfig,
@@ -42,8 +49,13 @@ class ToolNode(ToolNodeBase):
self.llm = make_llm(tags=["tool_llm"]) self.llm = make_llm(tags=["tool_llm"])
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem) self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
with open(self.config.tool_prompt_f, "r") as f: self.prompt_store = build_prompt_store(
self.sys_prompt = f.read() pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.tool_prompt_f,
default_key="tool_prompt",
)
self.sys_prompt = self.prompt_store.get("tool_prompt")
def invoke(self, state:State): def invoke(self, state:State):
inp = {"messages":[ inp = {"messages":[
@@ -88,6 +100,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
"""path to chatty system prompt""" """path to chatty system prompt"""
# pipeline_id and prompt_set_id are inherited from ToolNodeConfig
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig) tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
@@ -124,14 +138,31 @@ class ChattyToolNode(ToolNodeBase):
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem) self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
# self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem) # self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
# Propagate pipeline_id and prompt_set_id to inner tool_node_conf
if self.config.pipeline_id and hasattr(self.config.tool_node_conf, 'pipeline_id'):
self.config.tool_node_conf.pipeline_id = self.config.pipeline_id
if self.config.prompt_set_id and hasattr(self.config.tool_node_conf, 'prompt_set_id'):
self.config.tool_node_conf.prompt_set_id = self.config.prompt_set_id
self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager, self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager,
memory=self.mem) memory=self.mem)
with open(self.config.chatty_sys_prompt_f, "r") as f: self.chatty_prompt_store = build_prompt_store(
self.chatty_sys_prompt = f.read() pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.chatty_sys_prompt_f,
default_key="chatty_prompt",
)
self.chatty_sys_prompt = self.chatty_prompt_store.get("chatty_prompt")
with open(self.config.tool_prompt_f, "r") as f: self.tool_prompt_store = build_prompt_store(
self.tool_sys_prompt = f.read() pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.tool_prompt_f,
default_key="tool_prompt",
)
self.tool_sys_prompt = self.tool_prompt_store.get("tool_prompt")
def get_streamable_tags(self): def get_streamable_tags(self):
return [["chatty_llm"], ["reit_llm"]] return [["chatty_llm"], ["reit_llm"]]

View File

@@ -6,7 +6,7 @@ Vision-enabled routing graph that:
""" """
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type, TypedDict, List, Dict, Any, Tuple from typing import Type, TypedDict, List, Dict, Any, Tuple, Optional
import tyro import tyro
import base64 import base64
import json import json
@@ -14,6 +14,7 @@ from loguru import logger
from lang_agent.config import LLMKeyConfig from lang_agent.config import LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase, ToolNodeBase from lang_agent.base import GraphBase, ToolNodeBase
from lang_agent.components.client_tool_manager import ClientToolManagerConfig from lang_agent.components.client_tool_manager import ClientToolManagerConfig
@@ -104,6 +105,12 @@ class VisionRoutingConfig(LLMKeyConfig):
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1" base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url for API""" """base url for API"""
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with hardcoded fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
@@ -167,6 +174,17 @@ class VisionRoutingGraph(GraphBase):
# Create tool node for executing tools # Create tool node for executing tools
self.tool_node = ToolNode(self.camera_tools) self.tool_node = ToolNode(self.camera_tools)
# Build prompt store: DB (if pipeline_id) > hardcoded defaults
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
hardcoded={
"camera_decision_prompt": CAMERA_DECISION_PROMPT,
"vision_description_prompt": VISION_DESCRIPTION_PROMPT,
"conversation_prompt": CONVERSATION_PROMPT,
},
)
def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage: def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage:
"""Get user message from current invocation""" """Get user message from current invocation"""
msgs = state["inp"][0]["messages"] msgs = state["inp"][0]["messages"]
@@ -180,7 +198,7 @@ class VisionRoutingGraph(GraphBase):
human_msg = self._get_human_msg(state) human_msg = self._get_human_msg(state)
messages = [ messages = [
SystemMessage(content=CAMERA_DECISION_PROMPT), SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
human_msg human_msg
] ]
@@ -279,7 +297,7 @@ class VisionRoutingGraph(GraphBase):
) )
messages = [ messages = [
SystemMessage(content=VISION_DESCRIPTION_PROMPT), SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
vision_message vision_message
] ]
@@ -292,7 +310,7 @@ class VisionRoutingGraph(GraphBase):
human_msg = self._get_human_msg(state) human_msg = self._get_human_msg(state)
messages = [ messages = [
SystemMessage(content=CONVERSATION_PROMPT), SystemMessage(content=self.prompt_store.get("conversation_prompt")),
human_msg human_msg
] ]

View File

@@ -73,6 +73,12 @@ class PipelineConfig(KeyConfig):
port:int = 23 port:int = 23
"""what is my port""" """what is my port"""
pipeline_id: str = None
"""If set, load prompts from database (with file fallback)"""
prompt_set_id: str = None
"""If set, load from this specific prompt set instead of the active one"""
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig) graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
@@ -97,6 +103,12 @@ class Pipeline:
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
self.config.graph_config.api_key = self.config.api_key self.config.graph_config.api_key = self.config.api_key
# Propagate pipeline_id and prompt_set_id to graph config for DB prompt loading
if self.config.pipeline_id is not None and hasattr(self.config.graph_config, 'pipeline_id'):
self.config.graph_config.pipeline_id = self.config.pipeline_id
if self.config.prompt_set_id is not None and hasattr(self.config.graph_config, 'prompt_set_id'):
self.config.graph_config.prompt_set_id = self.config.prompt_set_id
self.graph:GraphBase = self.config.graph_config.setup() self.graph:GraphBase = self.config.graph_config.setup()
def show_graph(self): def show_graph(self):