use chatty tool state instead

This commit is contained in:
2025-11-21 14:48:00 +08:00
parent 57fdd420db
commit 51835a4ced

View File

@@ -1,17 +1,16 @@
from dataclasses import dataclass, field, is_dataclass from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro import tyro
import os
import os.path as osp import os.path as osp
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.tool_manager import ToolManager from lang_agent.tool_manager import ToolManager
from lang_agent.base import ToolNodeBase from lang_agent.base import GraphBase
from lang_agent.graphs.graph_states import State from lang_agent.graphs.graph_states import State, ChattyToolState
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
@@ -26,7 +25,7 @@ class ToolNodeConfig(InstantiateConfig):
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt") tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
class ToolNode(ToolNodeBase): class ToolNode(GraphBase):
def __init__(self, config: ToolNodeConfig, def __init__(self, config: ToolNodeConfig,
tool_manager:ToolManager, tool_manager:ToolManager,
llm:BaseChatModel, llm:BaseChatModel,
@@ -43,7 +42,7 @@ class ToolNode(ToolNodeBase):
with open(self.config.tool_prompt_f, "r") as f: with open(self.config.tool_prompt_f, "r") as f:
self.sys_prompt = f.read() self.sys_prompt = f.read()
def tool_node_call(self, state:State): def invoke(self, state:State):
inp = {"messages":[ inp = {"messages":[
SystemMessage( SystemMessage(
self.sys_prompt self.sys_prompt
@@ -71,7 +70,7 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
class ChattyToolNode: class ChattyToolNode(GraphBase):
def __init__(self, config:ChattyToolNodeConfig, def __init__(self, config:ChattyToolNodeConfig,
tool_manager:ToolManager, tool_manager:ToolManager,
llm:BaseChatModel, llm:BaseChatModel,
@@ -102,8 +101,15 @@ class ChattyToolNode:
with open(self.config.tool_prompt_f, "r") as f: with open(self.config.tool_prompt_f, "r") as f:
self.tool_sys_prompt = f.read() self.tool_sys_prompt = f.read()
def invoke(self, state:State):
inp = {"inp": state["inp"]}
out = self.workflow.invoke(inp)
chat_msgs = out.get("chatty_message")
tool_msgs = out.get("tool_message")
return {"messages": state["messages"] + chat_msgs + tool_msgs}
def _tool_node_call(self, state:State): def _tool_node_call(self, state:ChattyToolState):
inp = {"messages":[ inp = {"messages":[
SystemMessage( SystemMessage(
self.tool_sys_prompt self.tool_sys_prompt
@@ -113,11 +119,10 @@ class ChattyToolNode:
out = self.tool_agent.invoke(*inp) out = self.tool_agent.invoke(*inp)
return {"tool_messages": out}
return {"subgraph_states":{"tool_message": out}}
def _chat_node_call(self, state:State): def _chat_node_call(self, state:ChattyToolState):
outs = [] outs = []
while not self.tool_done: while not self.tool_done:
@@ -129,18 +134,20 @@ class ChattyToolNode:
]}, state["inp"][1] ]}, state["inp"][1]
outs.append(self.chatty_agent.invoke(*inp)) outs.append(self.chatty_agent.invoke(*inp))
return {"subgraph_states":{"chatty_message": outs}} return {"chatty_message": outs}
def _handoff_node(self, state:State): def _handoff_node(self, state:ChattyToolState):
chat_msgs = state.get("subgraph_states").get("chatty_message") # NOTE: this exist to have both results
tool_msgs = state.get("subgraph_states").get("tool_message") # chat_msgs = state.get("chatty_message")
# tool_msgs = state.get("tool_message")
return {"messages": state["messages"] + chat_msgs + tool_msgs} # return {"messages": chat_msgs + tool_msgs}
return {}
def build_graph(self): def build_graph(self):
builder = StateGraph(State) builder = StateGraph(ChattyToolState)
builder.add_node("chatty_tool_call", self._tool_node_call) builder.add_node("chatty_tool_call", self._tool_node_call)
builder.add_node("chatty_chat_call", self._chat_node_call) builder.add_node("chatty_chat_call", self._chat_node_call)
builder.add_node("chatty_handoff_node", self._handoff_node) builder.add_node("chatty_handoff_node", self._handoff_node)