use chatty tool state instead
This commit is contained in:
@@ -1,17 +1,16 @@
|
|||||||
from dataclasses import dataclass, field, is_dataclass
|
from dataclasses import dataclass, field, is_dataclass
|
||||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
||||||
import tyro
|
import tyro
|
||||||
import os
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager
|
from lang_agent.tool_manager import ToolManager
|
||||||
from lang_agent.base import ToolNodeBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.graphs.graph_states import State
|
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import SystemMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
@@ -26,7 +25,7 @@ class ToolNodeConfig(InstantiateConfig):
|
|||||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(ToolNodeBase):
|
class ToolNode(GraphBase):
|
||||||
def __init__(self, config: ToolNodeConfig,
|
def __init__(self, config: ToolNodeConfig,
|
||||||
tool_manager:ToolManager,
|
tool_manager:ToolManager,
|
||||||
llm:BaseChatModel,
|
llm:BaseChatModel,
|
||||||
@@ -43,7 +42,7 @@ class ToolNode(ToolNodeBase):
|
|||||||
with open(self.config.tool_prompt_f, "r") as f:
|
with open(self.config.tool_prompt_f, "r") as f:
|
||||||
self.sys_prompt = f.read()
|
self.sys_prompt = f.read()
|
||||||
|
|
||||||
def tool_node_call(self, state:State):
|
def invoke(self, state:State):
|
||||||
inp = {"messages":[
|
inp = {"messages":[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
self.sys_prompt
|
self.sys_prompt
|
||||||
@@ -71,7 +70,7 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
|||||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||||
|
|
||||||
|
|
||||||
class ChattyToolNode:
|
class ChattyToolNode(GraphBase):
|
||||||
def __init__(self, config:ChattyToolNodeConfig,
|
def __init__(self, config:ChattyToolNodeConfig,
|
||||||
tool_manager:ToolManager,
|
tool_manager:ToolManager,
|
||||||
llm:BaseChatModel,
|
llm:BaseChatModel,
|
||||||
@@ -102,8 +101,15 @@ class ChattyToolNode:
|
|||||||
with open(self.config.tool_prompt_f, "r") as f:
|
with open(self.config.tool_prompt_f, "r") as f:
|
||||||
self.tool_sys_prompt = f.read()
|
self.tool_sys_prompt = f.read()
|
||||||
|
|
||||||
|
def invoke(self, state:State):
|
||||||
|
inp = {"inp": state["inp"]}
|
||||||
|
out = self.workflow.invoke(inp)
|
||||||
|
chat_msgs = out.get("chatty_message")
|
||||||
|
tool_msgs = out.get("tool_message")
|
||||||
|
|
||||||
|
return {"messages": state["messages"] + chat_msgs + tool_msgs}
|
||||||
|
|
||||||
def _tool_node_call(self, state:State):
|
def _tool_node_call(self, state:ChattyToolState):
|
||||||
inp = {"messages":[
|
inp = {"messages":[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
self.tool_sys_prompt
|
self.tool_sys_prompt
|
||||||
@@ -113,11 +119,10 @@ class ChattyToolNode:
|
|||||||
|
|
||||||
out = self.tool_agent.invoke(*inp)
|
out = self.tool_agent.invoke(*inp)
|
||||||
|
|
||||||
|
return {"tool_messages": out}
|
||||||
return {"subgraph_states":{"tool_message": out}}
|
|
||||||
|
|
||||||
|
|
||||||
def _chat_node_call(self, state:State):
|
def _chat_node_call(self, state:ChattyToolState):
|
||||||
outs = []
|
outs = []
|
||||||
|
|
||||||
while not self.tool_done:
|
while not self.tool_done:
|
||||||
@@ -129,18 +134,20 @@ class ChattyToolNode:
|
|||||||
]}, state["inp"][1]
|
]}, state["inp"][1]
|
||||||
outs.append(self.chatty_agent.invoke(*inp))
|
outs.append(self.chatty_agent.invoke(*inp))
|
||||||
|
|
||||||
return {"subgraph_states":{"chatty_message": outs}}
|
return {"chatty_message": outs}
|
||||||
|
|
||||||
|
|
||||||
def _handoff_node(self, state:State):
|
def _handoff_node(self, state:ChattyToolState):
|
||||||
chat_msgs = state.get("subgraph_states").get("chatty_message")
|
# NOTE: this exist to have both results
|
||||||
tool_msgs = state.get("subgraph_states").get("tool_message")
|
# chat_msgs = state.get("chatty_message")
|
||||||
|
# tool_msgs = state.get("tool_message")
|
||||||
|
|
||||||
return {"messages": state["messages"] + chat_msgs + tool_msgs}
|
# return {"messages": chat_msgs + tool_msgs}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def build_graph(self):
|
def build_graph(self):
|
||||||
builder = StateGraph(State)
|
builder = StateGraph(ChattyToolState)
|
||||||
builder.add_node("chatty_tool_call", self._tool_node_call)
|
builder.add_node("chatty_tool_call", self._tool_node_call)
|
||||||
builder.add_node("chatty_chat_call", self._chat_node_call)
|
builder.add_node("chatty_chat_call", self._chat_node_call)
|
||||||
builder.add_node("chatty_handoff_node", self._handoff_node)
|
builder.add_node("chatty_handoff_node", self._handoff_node)
|
||||||
|
|||||||
Reference in New Issue
Block a user