update tool_node config
This commit is contained in:
@@ -23,17 +23,11 @@ from langgraph.graph import StateGraph, START, END
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolNodeConfig(InstantiateConfig):
|
||||
class ToolNodeConfig(LLMNodeConfig):
|
||||
_target: Type = field(default_factory=lambda: ToolNode)
|
||||
|
||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with file fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
|
||||
class ToolNode(ToolNodeBase):
|
||||
def __init__(self, config: ToolNodeConfig,
|
||||
@@ -46,7 +40,9 @@ class ToolNode(ToolNodeBase):
|
||||
self.populate_modules()
|
||||
|
||||
def populate_modules(self):
|
||||
self.llm = make_llm(tags=["tool_llm"])
|
||||
self.llm = make_llm(model=self.config.llm_name,
|
||||
api_key=self.config.api_key,
|
||||
tags=["tool_llm"])
|
||||
|
||||
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
|
||||
self.prompt_store = build_prompt_store(
|
||||
@@ -85,14 +81,12 @@ class ToolNode(ToolNodeBase):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChattyToolNodeConfig(LLMNodeConfig, ToolNodeConfig):
|
||||
class ChattyToolNodeConfig(LLMNodeConfig):
|
||||
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
||||
|
||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||
"""path to chatty system prompt"""
|
||||
|
||||
# pipeline_id and prompt_set_id are inherited from ToolNodeConfig
|
||||
|
||||
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user