diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index ff58d1f..d039c92 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -68,7 +68,11 @@ class State(TypedDict): class RoutingGraph(GraphBase): def __init__(self, config: RoutingConfig): self.config = config - self.chat_sys_msg = None + + # NOTE: tool that the chatbranch should have + self.chat_tool_names = ["retrieve", + "get_resources"] + self._build_modules() self.workflow = self._build_graph() @@ -100,16 +104,19 @@ class RoutingGraph(GraphBase): assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(kwargs) == 0, "due to inp assumptions" + def _get_chat_tools(self, man:ToolManager): + return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names] + def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name, model_provider=self.config.llm_provider, api_key=self.config.api_key, base_url=self.config.base_url) - self.memory = MemorySaver() + self.memory = MemorySaver() # shared memory between the two branch self.router = self.llm.with_structured_output(Route) tool_manager:ToolManager = self.config.tool_manager_config.setup() - self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) + self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) with open(self.config.sys_promp_json , "r") as f: