Files
lang-agent/lang_agent/graphs/deepagents_qt.py

96 lines
3.8 KiB
Python

from dataclasses import dataclass, field
from typing import Type, Literal
import tyro
import os.path as osp
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from deepagents import create_deep_agent
from lang_agent.utils import make_llm
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.graphs.graph_states import State
from lang_agent.config import LLMNodeConfig
from lang_agent.base import GraphBase
# from lang_agent.fs_bkends import StateBk, StateBkConfig, LocalShell, LocalShellConfig, DaytonaSandboxBk, DaytonaSandboxConfig
from lang_agent.fs_bkends import BaseFilesystemBackend, StateBkConfig, AnnotatedStateBk
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class DeepAgentConfig(LLMNodeConfig):
_target: Type = field(default_factory=lambda : DeepAgent)
sys_prompt_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "deepagent.txt")
"""path to system prompt"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
# file_backend_config: StateBkConfig = field(default_factory=StateBkConfig)
# file_backend_config: LocalShellConfig = field(default_factory=LocalShellConfig)
file_backend_config: AnnotatedStateBk = field(default_factory=StateBkConfig)
def __post_init__(self):
super().__post_init__()
assert osp.exists(self.sys_prompt_f), "prompt path does not exist"
class DeepAgent(GraphBase):
def __init__(self, config:DeepAgentConfig):
self.config = config
self._build_modules()
self.workflow = self._build_graph()
def _build_modules(self):
llm = make_llm(self.config.llm_name,
self.config.llm_provider,
api_key=self.config.api_key,
tags=["main_llm"])
self.tool_man: ToolManager = self.config.tool_manager_config.setup()
self.file_backend: BaseFilesystemBackend = self.config.file_backend_config.setup()
bkend_agent_params = self.file_backend.get_deepagent_params()
self.mem = MemorySaver()
self.deep_agent = create_deep_agent(model=llm,
tools=self.tool_man.get_langchain_tools(),
backend=self.file_backend.get_backend(),
checkpointer=self.mem,
**bkend_agent_params)
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
file_path=self.config.sys_prompt_f,
default_key="sys_prompt",
)
self.sys_prompt = self.prompt_store.get("sys_prompt")
def _agent_call(self, state:State):
msg_dict = {"messages":[
SystemMessage(
self.sys_prompt
),
*self._get_inp_msgs(state)
]}
msg_dict.update(self.file_backend.get_inf_inp())
inp = msg_dict, state["inp"][1]
out = self.deep_agent.invoke(*inp)
return {"messages": out["messages"]}
def _build_graph(self):
builder = StateGraph(State)
builder.add_node("agent_call", self._agent_call)
builder.add_edge(START, "agent_call")
builder.add_edge("agent_call", END)
return builder.compile()
if __name__ == "__main__":
config = DeepAgentConfig()
deepagent = DeepAgent(config)
deepagent.workflow.invoke({"inp": ({"messages":[SystemMessage("you are a helpful bot enhanced with skills")]}, {"configurable": {"thread_id": '3'}})})