diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index fa4487f..06b8907 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -1,17 +1,16 @@ from dataclasses import dataclass, field, is_dataclass from typing import Type, TypedDict, Literal, Dict, List, Tuple import tyro -import os import os.path as osp from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.tool_manager import ToolManager -from lang_agent.base import ToolNodeBase -from lang_agent.graphs.graph_states import State +from lang_agent.base import GraphBase +from lang_agent.graphs.graph_states import State, ChattyToolState from langchain_core.language_models import BaseChatModel -from langchain_core.messages import SystemMessage +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain.agents import create_agent from langchain.chat_models import init_chat_model @@ -26,7 +25,7 @@ class ToolNodeConfig(InstantiateConfig): tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt") -class ToolNode(ToolNodeBase): +class ToolNode(GraphBase): def __init__(self, config: ToolNodeConfig, tool_manager:ToolManager, llm:BaseChatModel, @@ -43,7 +42,7 @@ class ToolNode(ToolNodeBase): with open(self.config.tool_prompt_f, "r") as f: self.sys_prompt = f.read() - def tool_node_call(self, state:State): + def invoke(self, state:State): inp = {"messages":[ SystemMessage( self.sys_prompt @@ -71,7 +70,7 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig): chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") -class ChattyToolNode: +class ChattyToolNode(GraphBase): def __init__(self, config:ChattyToolNodeConfig, tool_manager:ToolManager, llm:BaseChatModel, @@ -102,8 +101,15 @@ class ChattyToolNode: with open(self.config.tool_prompt_f, "r") as f: self.tool_sys_prompt = f.read() + def invoke(self, state:State): + inp = {"inp": state["inp"]} + out = self.workflow.invoke(inp) + chat_msgs = out.get("chatty_message") + tool_msgs = out.get("tool_message") + + return {"messages": state["messages"] + chat_msgs + tool_msgs} - def _tool_node_call(self, state:State): + def _tool_node_call(self, state:ChattyToolState): inp = {"messages":[ SystemMessage( self.tool_sys_prompt @@ -113,11 +119,10 @@ class ChattyToolNode: out = self.tool_agent.invoke(*inp) - - return {"subgraph_states":{"tool_message": out}} + return {"tool_messages": out} - def _chat_node_call(self, state:State): + def _chat_node_call(self, state:ChattyToolState): outs = [] while not self.tool_done: @@ -129,18 +134,20 @@ class ChattyToolNode: ]}, state["inp"][1] outs.append(self.chatty_agent.invoke(*inp)) - return {"subgraph_states":{"chatty_message": outs}} + return {"chatty_message": outs} - def _handoff_node(self, state:State): - chat_msgs = state.get("subgraph_states").get("chatty_message") - tool_msgs = state.get("subgraph_states").get("tool_message") + def _handoff_node(self, state:ChattyToolState): + # NOTE: this exist to have both results + # chat_msgs = state.get("chatty_message") + # tool_msgs = state.get("tool_message") - return {"messages": state["messages"] + chat_msgs + tool_msgs} + # return {"messages": chat_msgs + tool_msgs} + return {} def build_graph(self): - builder = StateGraph(State) + builder = StateGraph(ChattyToolState) builder.add_node("chatty_tool_call", self._tool_node_call) builder.add_node("chatty_chat_call", self._chat_node_call) builder.add_node("chatty_handoff_node", self._handoff_node)