build graph
This commit is contained in:
@@ -21,6 +21,7 @@ from langchain.agents import create_agent
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,6 +88,8 @@ class ChattyToolNode:
|
||||
self.tool_done = False
|
||||
|
||||
self.populate_modules()
|
||||
self.build_graph()
|
||||
|
||||
|
||||
def populate_modules(self):
|
||||
self.chatty_llm = init_chat_model(model=self.config.llm_name,
|
||||
@@ -141,6 +144,22 @@ class ChattyToolNode:
|
||||
return {"messages": state["messages"] + chat_msgs + tool_msgs}
|
||||
|
||||
|
||||
def build_graph(self):
|
||||
builder = StateGraph(State)
|
||||
builder.add_node("chatty_tool_call", self._tool_node_call)
|
||||
builder.add_node("chatty_chat_call", self._chat_node_call)
|
||||
builder.add_node("chatty_handoff_node", self._handoff_node)
|
||||
|
||||
builder.add_edge(START, "chatty_tool_call")
|
||||
builder.add_edge(START, "chatty_chat_call")
|
||||
builder.add_edge("chatty_chat_call", "chatty_handoff_node")
|
||||
builder.add_edge("chatty_node_call", "chatty_handoff_node")
|
||||
builder.add_edge("chatty_handoff_node", END)
|
||||
|
||||
self.workflow = builder.compile()
|
||||
|
||||
|
||||
|
||||
tool_node_dict = {
|
||||
"tool_node" : ToolNodeConfig(),
|
||||
"chatty_tool_node" : ChattyToolNodeConfig()
|
||||
|
||||
Reference in New Issue
Block a user