tool nodes
This commit is contained in:
154
lang_agent/graphs/tool_nodes.py
Normal file
154
lang_agent/graphs/tool_nodes.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations # optional but recommended
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lang_agent.graphs.routing import State# only imported for type hints
|
||||
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
||||
import tyro
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
|
||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||
from lang_agent.tool_manager import ToolManager
|
||||
from lang_agent.base import ToolNodeBase
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain.agents import create_agent
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolNodeConfig(InstantiateConfig):
|
||||
_target: Type = field(default_factory=lambda: ToolNode)
|
||||
|
||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||
|
||||
|
||||
class ToolNode(ToolNodeBase):
|
||||
def __init__(self, config: ToolNodeConfig,
|
||||
tool_manager:ToolManager,
|
||||
llm:BaseChatModel,
|
||||
memory:MemorySaver):
|
||||
self.config = config
|
||||
self.tool_manager = tool_manager
|
||||
self.llm = llm
|
||||
self.mem = memory
|
||||
|
||||
self.populate_modules()
|
||||
|
||||
def populate_modules(self):
|
||||
self.tool_agent = create_agent(self.llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||
with open(self.config.tool_prompt_f, "r") as f:
|
||||
self.sys_prompt = f.read()
|
||||
|
||||
def tool_node_call(self, state:State):
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
self.sys_prompt
|
||||
),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
|
||||
out = self.tool_agent.invoke(*inp)
|
||||
return {"messages": out}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
||||
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
||||
|
||||
llm_name: str = "qwen-plus"
|
||||
"""name of llm"""
|
||||
|
||||
llm_provider:str = "openai"
|
||||
"""provider of the llm"""
|
||||
|
||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||
|
||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||
|
||||
|
||||
class ChattyToolNode:
|
||||
def __init__(self, config:ChattyToolNodeConfig,
|
||||
tool_manager:ToolManager,
|
||||
llm:BaseChatModel,
|
||||
memory:MemorySaver):
|
||||
self.config = config
|
||||
self.tool_manager = tool_manager
|
||||
self.tool_llm = llm
|
||||
self.mem = memory
|
||||
self.tool_done = False
|
||||
|
||||
self.populate_modules()
|
||||
|
||||
def populate_modules(self):
|
||||
self.chatty_llm = init_chat_model(model=self.config.llm_name,
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0)
|
||||
|
||||
self.chatty_agent = create_agent(self.chatty_agent, [], checkpointer=self.mem)
|
||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||
|
||||
with open(self.config.chatty_sys_prompt_f, "r") as f:
|
||||
self.chatty_sys_prompt = f.read()
|
||||
|
||||
with open(self.config.tool_prompt_f, "r") as f:
|
||||
self.tool_sys_prompt = f.read()
|
||||
|
||||
|
||||
def _tool_node_call(self, state:State):
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
self.tool_sys_prompt
|
||||
),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
|
||||
out = self.tool_agent.invoke(*inp)
|
||||
|
||||
|
||||
return {"subgraph_states":{"tool_message": out}}
|
||||
|
||||
|
||||
def _chat_node_call(self, state:State):
|
||||
outs = []
|
||||
|
||||
while not self.tool_done:
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
self.chatty_sys_prompt
|
||||
),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
outs.append(self.chatty_agent.invoke(*inp))
|
||||
|
||||
return {"subgraph_states":{"chatty_message": outs}}
|
||||
|
||||
|
||||
def _handoff_node(self, state:State):
|
||||
chat_msgs = state.get("subgraph_states").get("chatty_message")
|
||||
tool_msgs = state.get("subgraph_states").get("tool_message")
|
||||
|
||||
return {"messages": state["messages"] + chat_msgs + tool_msgs}
|
||||
|
||||
|
||||
tool_node_dict = {
|
||||
"tool_node" : ToolNodeConfig(),
|
||||
"chatty_tool_node" : ChattyToolNodeConfig()
|
||||
}
|
||||
|
||||
tool_node_union = tyro.extras.subcommand_type_from_defaults(tool_node_dict, prefix_names=False)
|
||||
AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]]
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(ToolNodeConfig)
|
||||
|
||||
Reference in New Issue
Block a user