working chatty tool node
This commit is contained in:
@@ -8,6 +8,7 @@ from lang_agent.config import InstantiateConfig, KeyConfig
|
|||||||
from lang_agent.tool_manager import ToolManager
|
from lang_agent.tool_manager import ToolManager
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.graphs.graph_states import State, ChattyToolState
|
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||||
|
from lang_agent.utils import make_llm
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||||
@@ -28,16 +29,16 @@ class ToolNodeConfig(InstantiateConfig):
|
|||||||
class ToolNode(GraphBase):
|
class ToolNode(GraphBase):
|
||||||
def __init__(self, config: ToolNodeConfig,
|
def __init__(self, config: ToolNodeConfig,
|
||||||
tool_manager:ToolManager,
|
tool_manager:ToolManager,
|
||||||
llm:BaseChatModel,
|
|
||||||
memory:MemorySaver):
|
memory:MemorySaver):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tool_manager = tool_manager
|
self.tool_manager = tool_manager
|
||||||
self.llm = llm
|
|
||||||
self.mem = memory
|
self.mem = memory
|
||||||
|
|
||||||
self.populate_modules()
|
self.populate_modules()
|
||||||
|
|
||||||
def populate_modules(self):
|
def populate_modules(self):
|
||||||
|
self.llm = make_llm(tags=["tool_llm"])
|
||||||
|
|
||||||
self.tool_agent = create_agent(self.llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
self.tool_agent = create_agent(self.llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||||
with open(self.config.tool_prompt_f, "r") as f:
|
with open(self.config.tool_prompt_f, "r") as f:
|
||||||
self.sys_prompt = f.read()
|
self.sys_prompt = f.read()
|
||||||
@@ -73,11 +74,9 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
|||||||
class ChattyToolNode(GraphBase):
|
class ChattyToolNode(GraphBase):
|
||||||
def __init__(self, config:ChattyToolNodeConfig,
|
def __init__(self, config:ChattyToolNodeConfig,
|
||||||
tool_manager:ToolManager,
|
tool_manager:ToolManager,
|
||||||
llm:BaseChatModel,
|
|
||||||
memory:MemorySaver):
|
memory:MemorySaver):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tool_manager = tool_manager
|
self.tool_manager = tool_manager
|
||||||
self.tool_llm = llm
|
|
||||||
self.mem = memory
|
self.mem = memory
|
||||||
self.tool_done = False
|
self.tool_done = False
|
||||||
|
|
||||||
@@ -90,7 +89,15 @@ class ChattyToolNode(GraphBase):
|
|||||||
model_provider=self.config.llm_provider,
|
model_provider=self.config.llm_provider,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
temperature=0)
|
temperature=0,
|
||||||
|
tags=["chatty_llm"])
|
||||||
|
self.tool_llm = init_chat_model(model=self.config.llm_name,
|
||||||
|
model_provider=self.config.llm_provider,
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
base_url=self.config.base_url,
|
||||||
|
temperature=0,
|
||||||
|
tags=["tool_llm"])
|
||||||
|
self.reit_llm = make_llm(model="qwen-flash", tags=["reit_llm"])
|
||||||
|
|
||||||
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
|
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
|
||||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||||
@@ -141,11 +148,15 @@ class ChattyToolNode(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _handoff_node(self, state:ChattyToolState):
|
def _handoff_node(self, state:ChattyToolState):
|
||||||
# NOTE: this exist to have both results
|
# NOTE: This exists just to stream the thing correctly
|
||||||
# chat_msgs = state.get("chatty_message")
|
tool_msgs = state.get("tool_messages")["messages"]
|
||||||
# tool_msgs = state.get("tool_message")
|
inp = [
|
||||||
|
SystemMessage(
|
||||||
# return {"messages": chat_msgs + tool_msgs}
|
"do nothing and repeat the last message"
|
||||||
|
),
|
||||||
|
tool_msgs[-1].content
|
||||||
|
]
|
||||||
|
self.reit_llm.invoke(inp)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@@ -174,36 +185,29 @@ tool_node_union = tyro.extras.subcommand_type_from_defaults(tool_node_dict, pref
|
|||||||
AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]]
|
AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from langchain.chat_models import init_chat_model
|
|
||||||
from langchain_core.messages.base import BaseMessageChunk
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
from lang_agent.tool_manager import ToolManagerConfig
|
from lang_agent.tool_manager import ToolManagerConfig
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
llm = init_chat_model(model="qwen-flash",
|
|
||||||
model_provider="openai",
|
|
||||||
api_key=os.environ.get("ALI_API_KEY"),
|
|
||||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
||||||
temperature=0)
|
|
||||||
mem = MemorySaver()
|
mem = MemorySaver()
|
||||||
tool_manager = ToolManagerConfig().setup()
|
tool_manager = ToolManagerConfig().setup()
|
||||||
chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager,
|
chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager,
|
||||||
llm=llm,
|
|
||||||
memory=mem)
|
memory=mem)
|
||||||
|
|
||||||
query = "use calculator to calculate 33*42"
|
query = "use calculator to calculate 33*42"
|
||||||
input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]},
|
input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]},
|
||||||
{"configurable": {"thread_id": '3'}})}
|
{"configurable": {"thread_id": '3'}})}
|
||||||
inp = input
|
inp = input
|
||||||
|
|
||||||
graph = chatty_node.workflow
|
graph = chatty_node.workflow
|
||||||
|
|
||||||
# for chunk, metadata in graph.stream(inp, stream_mode="messages"):
|
for chunk, metadata in graph.stream(inp, stream_mode="messages"):
|
||||||
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
tags = metadata.get("tags")
|
||||||
# print(chunk.content, end="", flush=True)
|
if not (tags in [["chatty_llm"], ["reit_llm"]]):
|
||||||
out = graph.invoke(inp)
|
continue
|
||||||
assert 0
|
|
||||||
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
print(chunk.content, end="", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user