get chat_branch tool
This commit is contained in:
@@ -68,7 +68,11 @@ class State(TypedDict):
|
||||
class RoutingGraph(GraphBase):
|
||||
def __init__(self, config: RoutingConfig):
|
||||
self.config = config
|
||||
self.chat_sys_msg = None
|
||||
|
||||
# NOTE: tool that the chatbranch should have
|
||||
self.chat_tool_names = ["retrieve",
|
||||
"get_resources"]
|
||||
|
||||
self._build_modules()
|
||||
|
||||
self.workflow = self._build_graph()
|
||||
@@ -100,16 +104,19 @@ class RoutingGraph(GraphBase):
|
||||
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||
assert len(kwargs) == 0, "due to inp assumptions"
|
||||
|
||||
def _get_chat_tools(self, man:ToolManager):
|
||||
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
|
||||
|
||||
def _build_modules(self):
|
||||
self.llm = init_chat_model(model=self.config.llm_name,
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url)
|
||||
self.memory = MemorySaver()
|
||||
self.memory = MemorySaver() # shared memory between the two branch
|
||||
self.router = self.llm.with_structured_output(Route)
|
||||
|
||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
|
||||
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
||||
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
|
||||
|
||||
with open(self.config.sys_promp_json , "r") as f:
|
||||
|
||||
Reference in New Issue
Block a user