diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 7f553f2..539b1c5 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -21,6 +21,7 @@ from langchain.agents import create_agent from langchain.chat_models import init_chat_model from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import StateGraph, START, END @dataclass @@ -87,6 +88,8 @@ class ChattyToolNode: self.tool_done = False self.populate_modules() + self.build_graph() + def populate_modules(self): self.chatty_llm = init_chat_model(model=self.config.llm_name, @@ -141,6 +144,22 @@ class ChattyToolNode: return {"messages": state["messages"] + chat_msgs + tool_msgs} + def build_graph(self): + builder = StateGraph(State) + builder.add_node("chatty_tool_call", self._tool_node_call) + builder.add_node("chatty_chat_call", self._chat_node_call) + builder.add_node("chatty_handoff_node", self._handoff_node) + + builder.add_edge(START, "chatty_tool_call") + builder.add_edge(START, "chatty_chat_call") + builder.add_edge("chatty_chat_call", "chatty_handoff_node") + builder.add_edge("chatty_node_call", "chatty_handoff_node") + builder.add_edge("chatty_handoff_node", END) + + self.workflow = builder.compile() + + + tool_node_dict = { "tool_node" : ToolNodeConfig(), "chatty_tool_node" : ChattyToolNodeConfig()