get chat_branch tool

This commit is contained in:
2025-10-29 13:48:22 +08:00
parent a8a16a5363
commit 7765cebefa

View File

@@ -68,7 +68,11 @@ class State(TypedDict):
class RoutingGraph(GraphBase):
def __init__(self, config: RoutingConfig):
self.config = config
self.chat_sys_msg = None
# NOTE: tool that the chatbranch should have
self.chat_tool_names = ["retrieve",
"get_resources"]
self._build_modules()
self.workflow = self._build_graph()
@@ -100,16 +104,19 @@ class RoutingGraph(GraphBase):
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def _get_chat_tools(self, man:ToolManager):
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
self.memory = MemorySaver()
self.memory = MemorySaver() # shared memory between the two branch
self.router = self.llm.with_structured_output(Route)
tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
with open(self.config.sys_promp_json , "r") as f: