diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 1daab72..aed0852 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -6,11 +6,10 @@ import os.path as osp from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.tool_manager import ToolManager -from lang_agent.base import GraphBase +from lang_agent.base import ToolNodeBase from lang_agent.graphs.graph_states import State, ChattyToolState from lang_agent.utils import make_llm -from langchain_core.language_models import BaseChatModel from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain.agents import create_agent from langchain.chat_models import init_chat_model @@ -26,7 +25,7 @@ class ToolNodeConfig(InstantiateConfig): tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt") -class ToolNode(GraphBase): +class ToolNode(ToolNodeBase): def __init__(self, config: ToolNodeConfig, tool_manager:ToolManager, memory:MemorySaver): @@ -71,7 +70,7 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig): chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") -class ChattyToolNode(GraphBase): +class ChattyToolNode(ToolNodeBase): def __init__(self, config:ChattyToolNodeConfig, tool_manager:ToolManager, memory:MemorySaver): @@ -107,6 +106,9 @@ class ChattyToolNode(GraphBase): with open(self.config.tool_prompt_f, "r") as f: self.tool_sys_prompt = f.read() + + def get_streamable_tags(self): + return [["chatty_llm"], ["reit_llm"]] def invoke(self, state:State): self.tool_done = False