bug fixes + test
This commit is contained in:
@@ -92,7 +92,7 @@ class ChattyToolNode(GraphBase):
|
|||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
temperature=0)
|
temperature=0)
|
||||||
|
|
||||||
self.chatty_agent = create_agent(self.chatty_agent, [], checkpointer=self.mem)
|
self.chatty_agent = create_agent(self.chatty_llm, [], checkpointer=self.mem)
|
||||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_list_langchain_tools(), checkpointer=self.mem)
|
||||||
|
|
||||||
with open(self.config.chatty_sys_prompt_f, "r") as f:
|
with open(self.config.chatty_sys_prompt_f, "r") as f:
|
||||||
@@ -102,6 +102,8 @@ class ChattyToolNode(GraphBase):
|
|||||||
self.tool_sys_prompt = f.read()
|
self.tool_sys_prompt = f.read()
|
||||||
|
|
||||||
def invoke(self, state:State):
|
def invoke(self, state:State):
|
||||||
|
self.tool_done = False
|
||||||
|
|
||||||
inp = {"inp": state["inp"]}
|
inp = {"inp": state["inp"]}
|
||||||
out = self.workflow.invoke(inp)
|
out = self.workflow.invoke(inp)
|
||||||
chat_msgs = out.get("chatty_message")
|
chat_msgs = out.get("chatty_message")
|
||||||
@@ -119,6 +121,7 @@ class ChattyToolNode(GraphBase):
|
|||||||
|
|
||||||
out = self.tool_agent.invoke(*inp)
|
out = self.tool_agent.invoke(*inp)
|
||||||
|
|
||||||
|
self.tool_done = True
|
||||||
return {"tool_messages": out}
|
return {"tool_messages": out}
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +158,7 @@ class ChattyToolNode(GraphBase):
|
|||||||
builder.add_edge(START, "chatty_tool_call")
|
builder.add_edge(START, "chatty_tool_call")
|
||||||
builder.add_edge(START, "chatty_chat_call")
|
builder.add_edge(START, "chatty_chat_call")
|
||||||
builder.add_edge("chatty_chat_call", "chatty_handoff_node")
|
builder.add_edge("chatty_chat_call", "chatty_handoff_node")
|
||||||
builder.add_edge("chatty_node_call", "chatty_handoff_node")
|
builder.add_edge("chatty_tool_call", "chatty_handoff_node")
|
||||||
builder.add_edge("chatty_handoff_node", END)
|
builder.add_edge("chatty_handoff_node", END)
|
||||||
|
|
||||||
self.workflow = builder.compile()
|
self.workflow = builder.compile()
|
||||||
@@ -171,5 +174,35 @@ tool_node_union = tyro.extras.subcommand_type_from_defaults(tool_node_dict, pref
|
|||||||
AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]]
|
AnnotatedToolNode = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[tool_node_union]]
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tyro.cli(ToolNodeConfig)
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
from lang_agent.tool_manager import ToolManagerConfig
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
llm = init_chat_model(model="qwen-flash",
|
||||||
|
model_provider="openai",
|
||||||
|
api_key=os.environ.get("ALI_API_KEY"),
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
temperature=0)
|
||||||
|
mem = MemorySaver()
|
||||||
|
tool_manager = ToolManagerConfig().setup()
|
||||||
|
chatty_node:ChattyToolNode = ChattyToolNodeConfig().setup(tool_manager=tool_manager,
|
||||||
|
llm=llm,
|
||||||
|
memory=mem)
|
||||||
|
|
||||||
|
query = "use calculator to calculate 33*42"
|
||||||
|
input = {"inp" : ({"messages":[SystemMessage("you are a kind helper"), HumanMessage(query)]},
|
||||||
|
{"configurable": {"thread_id": '3'}})}
|
||||||
|
inp = input["inp"]
|
||||||
|
|
||||||
|
graph = chatty_node.workflow
|
||||||
|
|
||||||
|
for chunk, metadata in graph.stream(*inp, stream_mode="messages"):
|
||||||
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
print(chunk.content, end="", flush=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user