update tool_node config
This commit is contained in:
@@ -23,17 +23,11 @@ from langgraph.graph import StateGraph, START, END
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolNodeConfig(InstantiateConfig):
|
class ToolNodeConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ToolNode)
|
_target: Type = field(default_factory=lambda: ToolNode)
|
||||||
|
|
||||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
|
||||||
"""If set, load prompts from database (with file fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: Optional[str] = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(ToolNodeBase):
|
class ToolNode(ToolNodeBase):
|
||||||
def __init__(self, config: ToolNodeConfig,
|
def __init__(self, config: ToolNodeConfig,
|
||||||
@@ -46,7 +40,9 @@ class ToolNode(ToolNodeBase):
|
|||||||
self.populate_modules()
|
self.populate_modules()
|
||||||
|
|
||||||
def populate_modules(self):
|
def populate_modules(self):
|
||||||
self.llm = make_llm(tags=["tool_llm"])
|
self.llm = make_llm(model=self.config.llm_name,
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
tags=["tool_llm"])
|
||||||
|
|
||||||
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
|
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
|
||||||
self.prompt_store = build_prompt_store(
|
self.prompt_store = build_prompt_store(
|
||||||
@@ -85,14 +81,12 @@ class ToolNode(ToolNodeBase):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChattyToolNodeConfig(LLMNodeConfig, ToolNodeConfig):
|
class ChattyToolNodeConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
||||||
|
|
||||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||||
"""path to chatty system prompt"""
|
"""path to chatty system prompt"""
|
||||||
|
|
||||||
# pipeline_id and prompt_set_id are inherited from ToolNodeConfig
|
|
||||||
|
|
||||||
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user