diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index ed4c7f9..91985f8 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -12,7 +12,7 @@ from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase from lang_agent.graphs.graph_states import State -from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNode +from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage @@ -42,6 +42,8 @@ class RoutingConfig(KeyConfig): tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) + tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig) + class Route(BaseModel): @@ -126,7 +128,9 @@ class RoutingGraph(GraphBase): tool_manager:ToolManager = self.config.tool_manager_config.setup() self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory) - self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) + self.tool_node:GraphBase = self.config.tool_node_config.setup(tool_manager=tool_manager, + llm=self.llm, + memory=self.memory) self._load_sys_prompts() @@ -206,15 +210,8 @@ class RoutingGraph(GraphBase): def _tool_model_call(self, state:State): - inp = {"messages":[ - SystemMessage( - self.prompt_dict["tool_prompt"] - ), - *state["inp"][0]["messages"][1:] - ]}, state["inp"][1] - - out = self.tool_model.invoke(*inp) - return {"messages": out} + out = self.tool_node.invoke(state) + return {"messages": out["messages"]} def _build_graph(self): builder = StateGraph(State)