use prompt store
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, TypedDict, Literal, Dict, List
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Optional
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
@@ -9,6 +9,7 @@ from langchain.chat_models import init_chat_model
|
||||
from lang_agent.config import LLMKeyConfig
|
||||
from lang_agent.base import GraphBase
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.graphs.graph_states import State
|
||||
|
||||
from langchain.agents import create_agent
|
||||
@@ -50,6 +51,12 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
|
||||
class DualConfig(LLMKeyConfig):
|
||||
_target: Type = field(default_factory=lambda:Dual)
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with hardcoded fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
from langchain.tools import tool
|
||||
@@ -96,14 +103,23 @@ class Dual(GraphBase):
|
||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
|
||||
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
|
||||
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
hardcoded={
|
||||
"sys_prompt": SYS_PROMPT,
|
||||
"tool_sys_prompt": TOOL_SYS_PROMPT,
|
||||
},
|
||||
)
|
||||
|
||||
self.streamable_tags = [["dual_chat_llm"]]
|
||||
|
||||
|
||||
def _chat_call(self, state:State):
|
||||
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state)
|
||||
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
|
||||
|
||||
def _tool_call(self, state:State):
|
||||
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state)
|
||||
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
|
||||
return {}
|
||||
|
||||
def _join(self, state:State):
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
import tyro
|
||||
import os.path as osp
|
||||
from loguru import logger
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.base import GraphBase
|
||||
from lang_agent.utils import tree_leaves
|
||||
from lang_agent.graphs.graph_states import State
|
||||
@@ -34,6 +35,12 @@ class ReactGraphConfig(KeyConfig):
|
||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with file fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -65,8 +72,13 @@ class ReactGraph(GraphBase):
|
||||
tools = self.tool_manager.get_langchain_tools()
|
||||
self.agent = create_agent(self.llm, tools, checkpointer=self.memory)
|
||||
|
||||
with open(self.config.sys_prompt_f, "r") as f:
|
||||
self.sys_prompt = f.read()
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.sys_prompt_f,
|
||||
default_key="sys_prompt",
|
||||
)
|
||||
self.sys_prompt = self.prompt_store.get("sys_prompt")
|
||||
|
||||
def _agent_call(self, state:State):
|
||||
if state.get("messages") is not None:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterator, Optional
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
@@ -10,6 +10,7 @@ import time
|
||||
|
||||
from lang_agent.config import LLMKeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.base import GraphBase, ToolNodeBase
|
||||
from lang_agent.graphs.graph_states import State
|
||||
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
||||
@@ -41,6 +42,12 @@ class RoutingConfig(LLMKeyConfig):
|
||||
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
||||
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with file fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
|
||||
@@ -83,32 +90,28 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||
self.chat_model = create_agent(self.chat_llm, [], checkpointer=self.memory)
|
||||
|
||||
# Propagate pipeline_id and prompt_set_id to tool node config
|
||||
if self.config.pipeline_id and hasattr(self.config.tool_node_config, 'pipeline_id'):
|
||||
self.config.tool_node_config.pipeline_id = self.config.pipeline_id
|
||||
if self.config.prompt_set_id and hasattr(self.config.tool_node_config, 'prompt_set_id'):
|
||||
self.config.tool_node_config.prompt_set_id = self.config.prompt_set_id
|
||||
|
||||
self.tool_node:ToolNodeBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
|
||||
memory=self.memory)
|
||||
|
||||
self._load_sys_prompts()
|
||||
|
||||
def _load_sys_prompts(self):
|
||||
if "json" in self.config.sys_promp_dir[-5:]:
|
||||
logger.info("loading sys prompt from json")
|
||||
with open(self.config.sys_promp_dir , "r") as f:
|
||||
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.sys_promp_dir,
|
||||
)
|
||||
self.prompt_dict: Dict[str, str] = self.prompt_store.get_all()
|
||||
|
||||
elif osp.isdir(self.config.sys_promp_dir):
|
||||
logger.info("loading sys_prompt from txt")
|
||||
sys_fs = glob.glob(osp.join(self.config.sys_promp_dir, "*.txt"))
|
||||
sys_fs = sorted([e for e in sys_fs if not ("optional" in e)])
|
||||
self.prompt_dict = {}
|
||||
for sys_f in sys_fs:
|
||||
key = osp.basename(sys_f).split(".")[0]
|
||||
with open(sys_f, "r") as f:
|
||||
self.prompt_dict[key] = f.read()
|
||||
else:
|
||||
err_msg = f"{self.config.sys_promp_dir} is not supported"
|
||||
assert 0, err_msg
|
||||
|
||||
for k, _ in self.prompt_dict.items():
|
||||
logger.info(f"loaded {k} system prompt")
|
||||
for k in self.prompt_dict:
|
||||
logger.info(f"loaded '{k}' system prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Optional
|
||||
import tyro
|
||||
import os.path as osp
|
||||
import time
|
||||
@@ -8,6 +8,7 @@ from loguru import logger
|
||||
|
||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.components.reit_llm import ReitLLM
|
||||
from lang_agent.base import ToolNodeBase
|
||||
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||
@@ -27,6 +28,12 @@ class ToolNodeConfig(InstantiateConfig):
|
||||
|
||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with file fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
|
||||
class ToolNode(ToolNodeBase):
|
||||
def __init__(self, config: ToolNodeConfig,
|
||||
@@ -42,8 +49,13 @@ class ToolNode(ToolNodeBase):
|
||||
self.llm = make_llm(tags=["tool_llm"])
|
||||
|
||||
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
|
||||
with open(self.config.tool_prompt_f, "r") as f:
|
||||
self.sys_prompt = f.read()
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.tool_prompt_f,
|
||||
default_key="tool_prompt",
|
||||
)
|
||||
self.sys_prompt = self.prompt_store.get("tool_prompt")
|
||||
|
||||
def invoke(self, state:State):
|
||||
inp = {"messages":[
|
||||
@@ -88,6 +100,8 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||
"""path to chatty system prompt"""
|
||||
|
||||
# pipeline_id and prompt_set_id are inherited from ToolNodeConfig
|
||||
|
||||
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
||||
|
||||
|
||||
@@ -124,14 +138,31 @@ class ChattyToolNode(ToolNodeBase):
|
||||
|
||||
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
|
||||
# self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||
|
||||
# Propagate pipeline_id and prompt_set_id to inner tool_node_conf
|
||||
if self.config.pipeline_id and hasattr(self.config.tool_node_conf, 'pipeline_id'):
|
||||
self.config.tool_node_conf.pipeline_id = self.config.pipeline_id
|
||||
if self.config.prompt_set_id and hasattr(self.config.tool_node_conf, 'prompt_set_id'):
|
||||
self.config.tool_node_conf.prompt_set_id = self.config.prompt_set_id
|
||||
|
||||
self.tool_agent = self.config.tool_node_conf.setup(tool_manager=self.tool_manager,
|
||||
memory=self.mem)
|
||||
|
||||
with open(self.config.chatty_sys_prompt_f, "r") as f:
|
||||
self.chatty_sys_prompt = f.read()
|
||||
|
||||
with open(self.config.tool_prompt_f, "r") as f:
|
||||
self.tool_sys_prompt = f.read()
|
||||
self.chatty_prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.chatty_sys_prompt_f,
|
||||
default_key="chatty_prompt",
|
||||
)
|
||||
self.chatty_sys_prompt = self.chatty_prompt_store.get("chatty_prompt")
|
||||
|
||||
self.tool_prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.tool_prompt_f,
|
||||
default_key="tool_prompt",
|
||||
)
|
||||
self.tool_sys_prompt = self.tool_prompt_store.get("tool_prompt")
|
||||
|
||||
def get_streamable_tags(self):
|
||||
return [["chatty_llm"], ["reit_llm"]]
|
||||
|
||||
@@ -6,7 +6,7 @@ Vision-enabled routing graph that:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, TypedDict, List, Dict, Any, Tuple
|
||||
from typing import Type, TypedDict, List, Dict, Any, Tuple, Optional
|
||||
import tyro
|
||||
import base64
|
||||
import json
|
||||
@@ -14,6 +14,7 @@ from loguru import logger
|
||||
|
||||
from lang_agent.config import LLMKeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.base import GraphBase, ToolNodeBase
|
||||
from lang_agent.components.client_tool_manager import ClientToolManagerConfig
|
||||
|
||||
@@ -104,6 +105,12 @@ class VisionRoutingConfig(LLMKeyConfig):
|
||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
"""base url for API"""
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with hardcoded fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
|
||||
|
||||
|
||||
@@ -167,6 +174,17 @@ class VisionRoutingGraph(GraphBase):
|
||||
# Create tool node for executing tools
|
||||
self.tool_node = ToolNode(self.camera_tools)
|
||||
|
||||
# Build prompt store: DB (if pipeline_id) > hardcoded defaults
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
hardcoded={
|
||||
"camera_decision_prompt": CAMERA_DECISION_PROMPT,
|
||||
"vision_description_prompt": VISION_DESCRIPTION_PROMPT,
|
||||
"conversation_prompt": CONVERSATION_PROMPT,
|
||||
},
|
||||
)
|
||||
|
||||
def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage:
|
||||
"""Get user message from current invocation"""
|
||||
msgs = state["inp"][0]["messages"]
|
||||
@@ -180,7 +198,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=CAMERA_DECISION_PROMPT),
|
||||
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
|
||||
human_msg
|
||||
]
|
||||
|
||||
@@ -279,7 +297,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=VISION_DESCRIPTION_PROMPT),
|
||||
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
|
||||
vision_message
|
||||
]
|
||||
|
||||
@@ -292,7 +310,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=CONVERSATION_PROMPT),
|
||||
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
|
||||
human_msg
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user