make toolnode configurable
This commit is contained in:
@@ -12,7 +12,7 @@ from lang_agent.config import KeyConfig
|
|||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.graphs.graph_states import State
|
from lang_agent.graphs.graph_states import State
|
||||||
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNode
|
from lang_agent.graphs.tool_nodes import AnnotatedToolNode, ToolNodeConfig
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
@@ -42,6 +42,8 @@ class RoutingConfig(KeyConfig):
|
|||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
|
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Route(BaseModel):
|
class Route(BaseModel):
|
||||||
@@ -126,7 +128,9 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||||
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
||||||
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
|
self.tool_node:GraphBase = self.config.tool_node_config.setup(tool_manager=tool_manager,
|
||||||
|
llm=self.llm,
|
||||||
|
memory=self.memory)
|
||||||
|
|
||||||
self._load_sys_prompts()
|
self._load_sys_prompts()
|
||||||
|
|
||||||
@@ -206,15 +210,8 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _tool_model_call(self, state:State):
|
def _tool_model_call(self, state:State):
|
||||||
inp = {"messages":[
|
out = self.tool_node.invoke(state)
|
||||||
SystemMessage(
|
return {"messages": out["messages"]}
|
||||||
self.prompt_dict["tool_prompt"]
|
|
||||||
),
|
|
||||||
*state["inp"][0]["messages"][1:]
|
|
||||||
]}, state["inp"][1]
|
|
||||||
|
|
||||||
out = self.tool_model.invoke(*inp)
|
|
||||||
return {"messages": out}
|
|
||||||
|
|
||||||
def _build_graph(self):
|
def _build_graph(self):
|
||||||
builder = StateGraph(State)
|
builder = StateGraph(State)
|
||||||
|
|||||||
Reference in New Issue
Block a user