161 lines
5.4 KiB
Python
161 lines
5.4 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Type
|
|
import tyro
|
|
import os.path as osp
|
|
from loguru import logger
|
|
|
|
from lang_agent.config import KeyConfig
|
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
|
from lang_agent.base import GraphBase
|
|
from lang_agent.utils import tree_leaves
|
|
|
|
from langchain.chat_models import init_chat_model
|
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
|
from langchain.agents import create_agent
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
# NOTE: maybe make this into a base_graph_config?
|
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
|
@dataclass
|
|
class ReactGraphConfig(KeyConfig):
|
|
_target: Type = field(default_factory=lambda: ReactGraph)
|
|
|
|
llm_name: str = "qwen-plus"
|
|
"""name of llm"""
|
|
|
|
llm_provider:str = "openai"
|
|
"""provider of the llm"""
|
|
|
|
sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt")
|
|
"""path to system prompt"""
|
|
|
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
|
|
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
|
|
|
def __post_init__(self):
|
|
super().__post_init__()
|
|
err_msg = f"{self.sys_prompt_f} does not exist"
|
|
assert osp.exists(self.sys_prompt_f), err_msg
|
|
logger.info(f"will be loading react sys promtp from {self.sys_prompt_f}")
|
|
|
|
|
|
class ReactGraph(GraphBase):
|
|
def __init__(self, config: ReactGraphConfig):
|
|
self.config = config
|
|
|
|
self.populate_modules()
|
|
|
|
def populate_modules(self):
|
|
self.llm = init_chat_model(model=self.config.llm_name,
|
|
model_provider=self.config.llm_provider,
|
|
api_key=self.config.api_key,
|
|
base_url=self.config.base_url,
|
|
tags=["main_llm"])
|
|
|
|
|
|
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
|
self.memory = MemorySaver()
|
|
tools = self.tool_manager.get_langchain_tools()
|
|
self.agent = create_agent(self.llm, tools, checkpointer=self.memory)
|
|
|
|
with open(self.config.sys_prompt_f, "r") as f:
|
|
self.sys_prompt = f.read()
|
|
|
|
def _get_human_msg(self, *nargs):
|
|
msgs = nargs[0]["messages"]
|
|
|
|
candidate_hum_msg = None
|
|
for msg in msgs:
|
|
if isinstance(msg, HumanMessage):
|
|
candidate_hum_msg = msg
|
|
break
|
|
|
|
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
|
|
|
|
return candidate_hum_msg
|
|
|
|
def _prep_inp(self, *nargs):
|
|
assert len(nargs) == 2, "should have 2 arguements"
|
|
|
|
human_msg = self._get_human_msg(*nargs)
|
|
conf = nargs[1]
|
|
return {"messages":[SystemMessage(self.sys_prompt), human_msg]}, conf
|
|
|
|
|
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
"""
|
|
as_stream (bool): for debug only, gets the agent to print its thoughts
|
|
"""
|
|
nargs = self._prep_inp(*nargs)
|
|
if as_stream:
|
|
for step in self.agent.stream(*nargs, stream_mode="values", **kwargs):
|
|
step["messages"][-1].pretty_print()
|
|
out = step
|
|
else:
|
|
out = self.agent.invoke(*nargs, **kwargs)
|
|
|
|
msgs_list = tree_leaves(out)
|
|
|
|
for e in msgs_list:
|
|
if isinstance(e, BaseMessage):
|
|
e.pretty_print()
|
|
|
|
if as_raw:
|
|
return msgs_list
|
|
else:
|
|
return msgs_list[-1].content
|
|
|
|
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
"""
|
|
Async version of invoke using LangGraph's native async support.
|
|
as_stream (bool): for debug only, gets the agent to print its thoughts
|
|
"""
|
|
nargs = self._prep_inp(*nargs)
|
|
if as_stream:
|
|
async for step in self.agent.astream(*nargs, stream_mode="values", **kwargs):
|
|
step["messages"][-1].pretty_print()
|
|
out = step
|
|
else:
|
|
out = await self.agent.ainvoke(*nargs, **kwargs)
|
|
|
|
msgs_list = tree_leaves(out)
|
|
|
|
for e in msgs_list:
|
|
if isinstance(e, BaseMessage):
|
|
e.pretty_print()
|
|
|
|
if as_raw:
|
|
return msgs_list
|
|
else:
|
|
return msgs_list[-1].content
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from dotenv import load_dotenv
|
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
|
from langchain_core.messages.base import BaseMessageChunk
|
|
load_dotenv()
|
|
|
|
route:ReactGraph = ReactGraphConfig().setup()
|
|
graph = route.agent
|
|
|
|
nargs = {
|
|
"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
|
HumanMessage("use the calculator tool to calculate 92*55 and say the answer")]
|
|
},{"configurable": {"thread_id": "3"}}
|
|
|
|
out = route.invoke(*nargs)
|
|
assert 0
|
|
|
|
# for chunk, metadata in graph.stream({"inp": nargs}, stream_mode="messages"):
|
|
# node = metadata.get("langgraph_node")
|
|
# if node not in ("model"):
|
|
# print(node)
|
|
# continue # skip router or other intermediate nodes
|
|
|
|
# # Print only the final message content
|
|
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
|
# print(chunk.content, end="", flush=True)
|
|
|