From 880a573c428b5018f32a2552d107c58b08a6dd34 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 15:21:12 +0800 Subject: [PATCH 01/18] return default if not specified --- lang_agent/eval/validator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index 47eeedb..614ae80 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Type, Literal +from typing import Type, Callable, List import tyro from lang_agent.config import KeyConfig @@ -71,9 +71,10 @@ class Validator: return pipeline.chat(inp, as_raw=True) - def get_val_fnc(self, dataset_name:str): - return self.dict_corr_map[dataset_name] + def get_val_fnc(self, dataset_name:str)->List[Callable]: + return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct]) - def get_inp_fnc(self,dataset_name:str): - return self.dict_inp_map[dataset_name] \ No newline at end of file + def get_inp_fnc(self,dataset_name:str)->Callable: + # return self.dict_inp_map[dataset_name] + return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse) \ No newline at end of file From 51ac83401bdd1b66972dbfa75681cce3076f89d6 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 16:39:20 +0800 Subject: [PATCH 02/18] validate tool use --- lang_agent/eval/validator.py | 40 +++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index 614ae80..9e570b5 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -6,6 +6,7 @@ from lang_agent.config import KeyConfig from lang_agent.pipeline import Pipeline, PipelineConfig from langchain.chat_models import init_chat_model +from langchain_core.messages import BaseMessage, ToolMessage @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass @@ -21,12 +22,12 @@ class Validator: # NOTE: Need to register function here self.dict_corr_map = { - "Toxic Queries" : [self.Toxic_Queries_correct] + "dev_langagent" : [self.default_correct, self.val_tool_use] } # NOTE: Need to register function here self.dict_inp_map = { - "Toxic Queries" : self.Toxic_Queries_inp_parse + "dev_langagent" : self.default_inp_parse } @@ -39,7 +40,7 @@ class Validator: ) # NOTE: for every dataset; need one of these - def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool: + def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool: instructions = ( "Given an actual answer and an expected answer, determine whether" " the actual answer contains all of the information in the" @@ -48,7 +49,7 @@ class Validator: " otherwise. Do not include anything else in your response." ) actual_answer = outputs["output"][-1].content - expected_answer = reference_outputs["label"] + expected_answer = reference_outputs["answer"] user_msg = ( f"ACTUAL ANSWER: {actual_answer}" @@ -64,17 +65,42 @@ class Validator: return response.content.upper() == "CORRECT" + def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool: + tool_uses:List[str] = reference_outputs.get("tool_use") + if tool_uses is None: + return True + + tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)] + + # check if all tools are used + tool_used = [] + for ref_tool in tool_uses: + st_cond = False + ref_tool = ref_tool.lower() + for msg in tool_msgs: + st_cond = ref_tool in msg.name.lower() + if st_cond: + break + tool_used.append(st_cond) + + for cond in tool_used: + if not cond: + return False + + return True + + # NOTE: for every dataset; need one of these - def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline): + def default_inp_parse(self, inp, pipeline:Pipeline): inp = inp["text"] return pipeline.chat(inp, as_raw=True) def get_val_fnc(self, dataset_name:str)->List[Callable]: - return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct]) + return self.dict_corr_map.get(dataset_name, [self.default_correct]) def get_inp_fnc(self,dataset_name:str)->Callable: # return self.dict_inp_map[dataset_name] - return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse) \ No newline at end of file + return self.dict_inp_map.get(dataset_name, self.default_inp_parse) \ No newline at end of file From 35a01ac18f28a51a2f3b489dc98684132a2ed75d Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 16:39:54 +0800 Subject: [PATCH 03/18] make dataset --- scripts/make_eval_dataset.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 scripts/make_eval_dataset.py diff --git a/scripts/make_eval_dataset.py b/scripts/make_eval_dataset.py new file mode 100644 index 0000000..0106066 --- /dev/null +++ b/scripts/make_eval_dataset.py @@ -0,0 +1,31 @@ +from langsmith import Client +from loguru import logger + + +DATASET_NAME = "dev_langagent" + +examples = [ + { + "inputs": {"text": "用retrieve查询光予尘然后介绍"}, + "outputs": {"answer": "茉莉绿茶为底,清冽茶香中漫出玫珑蜜瓜的绵甜与凤梨的明亮果香,层次鲜活;顶部白柚茉莉泡沫轻盈漫过舌尖,带着微酸的清新感,让整体风味更显灵动", + "tool_use": ["retrieve"]} + }, + { + "inputs": {"text": "介绍一下自己"}, + "outputs": {"answer": "我是小盏,是一个点餐助手"} + } +] + +cli = Client() + +try: + dataset = cli.read_dataset(dataset_name=DATASET_NAME) + logger.info("read dataset") +except: + dataset = cli.create_dataset(dataset_name=DATASET_NAME) + logger.info("created dataset") + +cli.create_examples( + dataset_id=dataset.id, + examples=examples +) \ No newline at end of file From a6d3fe8c7691c234a8d72387c4a12e295fd2f82e Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 16:45:45 +0800 Subject: [PATCH 04/18] update with example instruct --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 12bf47f..7619330 100644 --- a/README.md +++ b/README.md @@ -36,4 +36,16 @@ python scripts/start_mcp_server.py # update configs/ws_mcp_config.json with link from the command above python scripts/ws_start_register_tools.py +``` + +# Eval Dataset Format +see `scripts/make_eval_dataset.py` for example. Specific meaning of each entry: +```json +[ + { + "inputs": {"text": "用retrieve查询光予尘然后介绍"}, // model input + "outputs": {"answer": "光予尘茉莉绿茶为底", // reference answer + "tool_use": ["retrieve"]} // tool uses; assume model need to use all tools if more than 1 provided + } +] ``` \ No newline at end of file From e9d13878c99d2b21b20acd75fc02895d7ebe0250 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 16:45:57 +0800 Subject: [PATCH 05/18] change to simple datset --- lang_agent/eval/evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lang_agent/eval/evaluator.py b/lang_agent/eval/evaluator.py index 16b5e2a..b3ad69a 100644 --- a/lang_agent/eval/evaluator.py +++ b/lang_agent/eval/evaluator.py @@ -24,7 +24,7 @@ class EvaluatorConfig(InstantiateConfig): experiment_desc:str = "testing if this works or not" """describe the experiment""" - dataset_name:Literal["Toxic Queries"] = "Toxic Queries" + dataset_name:Literal["Toxic Queries"] = "dev_langagent" """name of the dataset to evaluate""" pipe_config: PipelineConfig = field(default_factory=PipelineConfig) From 5822c4a5721025a657d9399ff528a08f3a8a183b Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 17:09:43 +0800 Subject: [PATCH 06/18] add dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c11109f..2e001e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ dependencies = [ "langchain==1.0", "langchain_community", "langchain-openai", + "langchain_mcp_adapters", "httpx[socks]", "dashscope", "python-dotenv>=1.0.0", From 4bab1bad0b49fab30e9d567cbd6fa0e0ae241d75 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 17:34:49 +0800 Subject: [PATCH 07/18] configure system prompt with json --- configs/route_sys_prompts.json | 8 +++++++ lang_agent/graphs/routing.py | 39 ++++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 6 deletions(-) create mode 100644 configs/route_sys_prompts.json diff --git a/configs/route_sys_prompts.json b/configs/route_sys_prompts.json new file mode 100644 index 0000000..0e35c59 --- /dev/null +++ b/configs/route_sys_prompts.json @@ -0,0 +1,8 @@ +{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs + "route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input", + + // Prompt for tool branch + "tool_prompt" : "You must use tool to complete the possible task" + + // Optionally set chat_prompt to overwrite the system prompt from xiaozhi +} \ No newline at end of file diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 85948c1..592a5ba 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -7,6 +7,8 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt import jax +import os.path as osp +import commentjson from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig): base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" + sys_promp_json: str = None + "path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided" + tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) + def __post_init__(self): + super().__post_init__() + if self.sys_promp_json is None: + self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") + logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}") + + assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist." + class Route(BaseModel): step: Literal["chat", "order"] = Field( @@ -89,9 +102,9 @@ class RoutingGraph(GraphBase): def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name, - model_provider=self.config.llm_provider, - api_key=self.config.api_key, - base_url=self.config.base_url) + model_provider=self.config.llm_provider, + api_key=self.config.api_key, + base_url=self.config.base_url) self.memory = MemorySaver() self.router = self.llm.with_structured_output(Route) @@ -99,12 +112,16 @@ class RoutingGraph(GraphBase): self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) + with open(self.config.sys_promp_json , "r") as f: + self.prompt_dict = commentjson.load(f) + def _router_call(self, state:State): decision:Route = self.router.invoke( [ SystemMessage( - content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" + # content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" + content=self.prompt_dict.get("route_prompt") ), self._get_human_msg(state) ] @@ -137,6 +154,16 @@ class RoutingGraph(GraphBase): inp = state["messages"], state["inp"][1] else: inp = state["inp"] + + if self.prompt_dict.get("chat_prompt") is not None: + inp = {"messages":[ + SystemMessage( + # "You must use tool to complete the possible task" + self.prompt_dict["chat_prompt"] + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] + out = self.chat_model.invoke(*inp) return {"messages": out} @@ -145,9 +172,9 @@ class RoutingGraph(GraphBase): def _tool_model_call(self, state:State): inp = {"messages":[ SystemMessage( - "You must use tool to complete the possible task" + # "You must use tool to complete the possible task" + self.prompt_dict["tool_prompt"] ), - # self._get_human_msg(state) *state["inp"][0]["messages"][1:] ]}, state["inp"][1] From 31a64401fdb62c8505c107d7ef930f429a79a58b Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 17:34:54 +0800 Subject: [PATCH 08/18] update req --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2e001e2..54f9588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ dependencies = [ "fastapi", "matplotlib", "Pillow", - "jax" + "jax", + "commentjson" ] [tool.setuptools.packages.find] From 3f563c8bf61c790a175da4d5e1e602f995983968 Mon Sep 17 00:00:00 2001 From: goulustis Date: Mon, 27 Oct 2025 17:35:05 +0800 Subject: [PATCH 09/18] simple --- scripts/demo_chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/demo_chat.py b/scripts/demo_chat.py index d3c2e04..02fb421 100644 --- a/scripts/demo_chat.py +++ b/scripts/demo_chat.py @@ -18,7 +18,8 @@ def main(conf:PipelineConfig): # response = pipeline.chat(user_input, as_stream=True) # print(f"回答: {response}") - out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True) + # out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True) + out = pipeline.chat("介绍一下自己", as_stream=True) # out = pipeline.chat("testing", as_stream=True) print("=========== final ==========") print(out) From 666e0c4d23db9e3ab1d99a79b6c81452e3593538 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 11:52:10 +0800 Subject: [PATCH 10/18] function rename --- lang_agent/graphs/react.py | 2 +- lang_agent/graphs/routing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index ff46214..51de1d9 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -47,7 +47,7 @@ class ReactGraph(GraphBase): self.tool_manager:ToolManager = self.config.tool_manager_config.setup() memory = MemorySaver() - tools = self.tool_manager.get_langchain_tools() + tools = self.tool_manager.get_list_langchain_tools() self.agent = create_agent(self.llm, tools, checkpointer=memory) def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 592a5ba..ac8cbbf 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -110,7 +110,7 @@ class RoutingGraph(GraphBase): tool_manager:ToolManager = self.config.tool_manager_config.setup() self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) - self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) + self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) with open(self.config.sys_promp_json , "r") as f: self.prompt_dict = commentjson.load(f) From 8f6b181ff834c6cb134b36e0a3641e4145fc75e2 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 11:52:40 +0800 Subject: [PATCH 11/18] make tool_dict for better usage --- lang_agent/tool_manager.py | 47 +++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index 0b21f0a..12ae52b 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool from lang_agent.config import InstantiateConfig, ToolConfig from lang_agent.base import LangToolBase -## import tool configs from lang_agent.rag.simple import SimpleRagConfig from lang_agent.dummy.calculator import CalculatorConfig from catering_end.lang_tool import CartToolConfig, CartTool -# from langchain.tools import StructuredTool from langchain_core.tools.structured import StructuredTool +import jax @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class ToolManagerConfig(InstantiateConfig): _target: Type = field(default_factory=lambda: ToolManager) - # tool configs here; + # tool configs here; MUST HAVE 'config' in name and must be dataclass rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) cart_config: CartToolConfig = field(default_factory=CartToolConfig) @@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable: return sync_wrapper + class ToolManager: def __init__(self, config:ToolManagerConfig): self.config = config @@ -89,31 +89,42 @@ class ToolManager: """instantiate all object with tools""" self.tool_fncs = [] + self.tool_dict = {} tool_configs = self._get_tool_config() for tool_conf in tool_configs: + tool_name = tool_conf.get_name()[:-6] if tool_conf.use_tool: - logger.info(f"making tool:{tool_conf._target}") - self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup())) + logger.info(f"making tool:{tool_name}") + fnc_list = self._get_tool_fnc(tool_conf.setup()) + self.tool_fncs.extend(fnc_list) + self.tool_dict[tool_name] = fnc_list else: - logger.info(f"skipping tool:{tool_conf._target}") + logger.info(f"skipping tool:{tool_name}") def get_tool_fncs(self): return self.tool_fncs + def get_tool_dict(self): + return self.tool_dict + + + def fnc_to_structool(self, func): + if inspect.iscoroutinefunction(func): + return StructuredTool.from_function( + func=async_to_sync(func), + coroutine=func) + + else: + return StructuredTool.from_function(func=func) + - def get_langchain_tools(self): + def get_list_langchain_tools(self): out = [] for func in self.get_tool_fncs(): - if inspect.iscoroutinefunction(func): - out.append( - StructuredTool.from_function( - func=async_to_sync(func), - coroutine=func) - ) - else: - out.append( - StructuredTool.from_function(func=func) - ) + out.append(self.fnc_to_structool(func)) - return out \ No newline at end of file + return out + + def get_dict_langchain_tools(self): + return jax.tree_util.tree_map(self.fnc_to_structool, self.tool_dict) From c0cd0149a14d18538ab0cf53841ac8aed0c53084 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 11:52:49 +0800 Subject: [PATCH 12/18] get name --- lang_agent/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lang_agent/config.py b/lang_agent/config.py index 285a743..dd9d307 100644 --- a/lang_agent/config.py +++ b/lang_agent/config.py @@ -53,6 +53,9 @@ class InstantiateConfig(PrintableConfig): with open(filename, 'w') as f: yaml.dump(self, f) logger.info(f"[yellow]config saved to: {filename}[/yellow]") + + def get_name(self): + return self.__class__.__name__ From a661e3a724c541e7502171d70cefafd44ee40ec8 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 13:33:18 +0800 Subject: [PATCH 13/18] build langchain tools --- lang_agent/tool_manager.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index 12ae52b..83c8d14 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -89,7 +89,6 @@ class ToolManager: """instantiate all object with tools""" self.tool_fncs = [] - self.tool_dict = {} tool_configs = self._get_tool_config() for tool_conf in tool_configs: tool_name = tool_conf.get_name()[:-6] @@ -97,9 +96,10 @@ class ToolManager: logger.info(f"making tool:{tool_name}") fnc_list = self._get_tool_fnc(tool_conf.setup()) self.tool_fncs.extend(fnc_list) - self.tool_dict[tool_name] = fnc_list else: logger.info(f"skipping tool:{tool_name}") + + self._build_langchain_tools() def get_tool_fncs(self): @@ -118,13 +118,12 @@ class ToolManager: else: return StructuredTool.from_function(func=func) + def _build_langchain_tools(self): + self.langchain_tools = [] + for func in self.get_tool_fncs(): + self.langchain_tools.append(self.fnc_to_structool(func)) + + return self.langchain_tools def get_list_langchain_tools(self): - out = [] - for func in self.get_tool_fncs(): - out.append(self.fnc_to_structool(func)) - - return out - - def get_dict_langchain_tools(self): - return jax.tree_util.tree_map(self.fnc_to_structool, self.tool_dict) + return self.langchain_tools \ No newline at end of file From a8a16a5363f6a328f5a27b0469a7cc0a780a708b Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 13:37:30 +0800 Subject: [PATCH 14/18] remove comments --- lang_agent/graphs/routing.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index ac8cbbf..ff58d1f 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -120,8 +120,7 @@ class RoutingGraph(GraphBase): decision:Route = self.router.invoke( [ SystemMessage( - # content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" - content=self.prompt_dict.get("route_prompt") + content=self.prompt_dict["route_prompt"] ), self._get_human_msg(state) ] @@ -157,12 +156,11 @@ class RoutingGraph(GraphBase): if self.prompt_dict.get("chat_prompt") is not None: inp = {"messages":[ - SystemMessage( - # "You must use tool to complete the possible task" - self.prompt_dict["chat_prompt"] - ), - *state["inp"][0]["messages"][1:] - ]}, state["inp"][1] + SystemMessage( + self.prompt_dict["chat_prompt"] + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] out = self.chat_model.invoke(*inp) @@ -172,7 +170,6 @@ class RoutingGraph(GraphBase): def _tool_model_call(self, state:State): inp = {"messages":[ SystemMessage( - # "You must use tool to complete the possible task" self.prompt_dict["tool_prompt"] ), *state["inp"][0]["messages"][1:] From 7765cebefa58367ef004e0afe3f1db2776cb4ebd Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 13:48:22 +0800 Subject: [PATCH 15/18] get chat_branch tool --- lang_agent/graphs/routing.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index ff58d1f..d039c92 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -68,7 +68,11 @@ class State(TypedDict): class RoutingGraph(GraphBase): def __init__(self, config: RoutingConfig): self.config = config - self.chat_sys_msg = None + + # NOTE: tool that the chatbranch should have + self.chat_tool_names = ["retrieve", + "get_resources"] + self._build_modules() self.workflow = self._build_graph() @@ -100,16 +104,19 @@ class RoutingGraph(GraphBase): assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(kwargs) == 0, "due to inp assumptions" + def _get_chat_tools(self, man:ToolManager): + return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names] + def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name, model_provider=self.config.llm_provider, api_key=self.config.api_key, base_url=self.config.base_url) - self.memory = MemorySaver() + self.memory = MemorySaver() # shared memory between the two branch self.router = self.llm.with_structured_output(Route) tool_manager:ToolManager = self.config.tool_manager_config.setup() - self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) + self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) with open(self.config.sys_promp_json , "r") as f: From 6aa1a362c1d7054f1a4108bc278300be6ee262ed Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 13:48:30 +0800 Subject: [PATCH 16/18] add typing --- lang_agent/tool_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index 83c8d14..85bbff4 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -125,5 +125,5 @@ class ToolManager: return self.langchain_tools - def get_list_langchain_tools(self): + def get_list_langchain_tools(self)->List[StructuredTool]: return self.langchain_tools \ No newline at end of file From 23535e3b8f894f250ea953935e3b9fb28202a941 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 14:09:14 +0800 Subject: [PATCH 17/18] add key type --- lang_agent/graphs/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index d039c92..7dfaed9 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -120,7 +120,7 @@ class RoutingGraph(GraphBase): self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) with open(self.config.sys_promp_json , "r") as f: - self.prompt_dict = commentjson.load(f) + self.prompt_dict:Dict[str, str] = commentjson.load(f) def _router_call(self, state:State): From 703e42929369148e9ec41c93e39ed5affa394dd0 Mon Sep 17 00:00:00 2001 From: goulustis Date: Wed, 29 Oct 2025 14:42:12 +0800 Subject: [PATCH 18/18] change metric --- lang_agent/eval/validator.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index 9e570b5..a2a5098 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -83,11 +83,7 @@ class Validator: break tool_used.append(st_cond) - for cond in tool_used: - if not cond: - return False - - return True + return sum(tool_used)/len(tool_uses)