diff --git a/README.md b/README.md index 12bf47f..7619330 100644 --- a/README.md +++ b/README.md @@ -36,4 +36,16 @@ python scripts/start_mcp_server.py # update configs/ws_mcp_config.json with link from the command above python scripts/ws_start_register_tools.py +``` + +# Eval Dataset Format +see `scripts/make_eval_dataset.py` for example. Specific meaning of each entry: +```json +[ + { + "inputs": {"text": "用retrieve查询光予尘然后介绍"}, // model input + "outputs": {"answer": "光予尘茉莉绿茶为底", // reference answer + "tool_use": ["retrieve"]} // tool uses; assume model need to use all tools if more than 1 provided + } +] ``` \ No newline at end of file diff --git a/configs/route_sys_prompts.json b/configs/route_sys_prompts.json new file mode 100644 index 0000000..0e35c59 --- /dev/null +++ b/configs/route_sys_prompts.json @@ -0,0 +1,8 @@ +{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs + "route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input", + + // Prompt for tool branch + "tool_prompt" : "You must use tool to complete the possible task" + + // Optionally set chat_prompt to overwrite the system prompt from xiaozhi +} \ No newline at end of file diff --git a/lang_agent/config.py b/lang_agent/config.py index 285a743..dd9d307 100644 --- a/lang_agent/config.py +++ b/lang_agent/config.py @@ -53,6 +53,9 @@ class InstantiateConfig(PrintableConfig): with open(filename, 'w') as f: yaml.dump(self, f) logger.info(f"[yellow]config saved to: {filename}[/yellow]") + + def get_name(self): + return self.__class__.__name__ diff --git a/lang_agent/eval/evaluator.py b/lang_agent/eval/evaluator.py index 16b5e2a..b3ad69a 100644 --- a/lang_agent/eval/evaluator.py +++ b/lang_agent/eval/evaluator.py @@ -24,7 +24,7 @@ class EvaluatorConfig(InstantiateConfig): experiment_desc:str = "testing if this works or not" """describe the experiment""" - dataset_name:Literal["Toxic Queries"] = "Toxic Queries" + dataset_name:Literal["Toxic Queries"] = "dev_langagent" """name of the dataset to evaluate""" pipe_config: PipelineConfig = field(default_factory=PipelineConfig) diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index 47eeedb..a2a5098 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -1,11 +1,12 @@ from dataclasses import dataclass, field -from typing import Type, Literal +from typing import Type, Callable, List import tyro from lang_agent.config import KeyConfig from lang_agent.pipeline import Pipeline, PipelineConfig from langchain.chat_models import init_chat_model +from langchain_core.messages import BaseMessage, ToolMessage @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass @@ -21,12 +22,12 @@ class Validator: # NOTE: Need to register function here self.dict_corr_map = { - "Toxic Queries" : [self.Toxic_Queries_correct] + "dev_langagent" : [self.default_correct, self.val_tool_use] } # NOTE: Need to register function here self.dict_inp_map = { - "Toxic Queries" : self.Toxic_Queries_inp_parse + "dev_langagent" : self.default_inp_parse } @@ -39,7 +40,7 @@ class Validator: ) # NOTE: for every dataset; need one of these - def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool: + def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool: instructions = ( "Given an actual answer and an expected answer, determine whether" " the actual answer contains all of the information in the" @@ -48,7 +49,7 @@ class Validator: " otherwise. Do not include anything else in your response." ) actual_answer = outputs["output"][-1].content - expected_answer = reference_outputs["label"] + expected_answer = reference_outputs["answer"] user_msg = ( f"ACTUAL ANSWER: {actual_answer}" @@ -64,16 +65,38 @@ class Validator: return response.content.upper() == "CORRECT" + def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool: + tool_uses:List[str] = reference_outputs.get("tool_use") + if tool_uses is None: + return True + + tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)] + + # check if all tools are used + tool_used = [] + for ref_tool in tool_uses: + st_cond = False + ref_tool = ref_tool.lower() + for msg in tool_msgs: + st_cond = ref_tool in msg.name.lower() + if st_cond: + break + tool_used.append(st_cond) + + return sum(tool_used)/len(tool_uses) + + # NOTE: for every dataset; need one of these - def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline): + def default_inp_parse(self, inp, pipeline:Pipeline): inp = inp["text"] return pipeline.chat(inp, as_raw=True) - def get_val_fnc(self, dataset_name:str): - return self.dict_corr_map[dataset_name] + def get_val_fnc(self, dataset_name:str)->List[Callable]: + return self.dict_corr_map.get(dataset_name, [self.default_correct]) - def get_inp_fnc(self,dataset_name:str): - return self.dict_inp_map[dataset_name] \ No newline at end of file + def get_inp_fnc(self,dataset_name:str)->Callable: + # return self.dict_inp_map[dataset_name] + return self.dict_inp_map.get(dataset_name, self.default_inp_parse) \ No newline at end of file diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index ff46214..51de1d9 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -47,7 +47,7 @@ class ReactGraph(GraphBase): self.tool_manager:ToolManager = self.config.tool_manager_config.setup() memory = MemorySaver() - tools = self.tool_manager.get_langchain_tools() + tools = self.tool_manager.get_list_langchain_tools() self.agent = create_agent(self.llm, tools, checkpointer=memory) def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 85948c1..7dfaed9 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -7,6 +7,8 @@ from PIL import Image from io import BytesIO import matplotlib.pyplot as plt import jax +import os.path as osp +import commentjson from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig): base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" + sys_promp_json: str = None + "path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided" + tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) + def __post_init__(self): + super().__post_init__() + if self.sys_promp_json is None: + self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") + logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}") + + assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist." + class Route(BaseModel): step: Literal["chat", "order"] = Field( @@ -55,7 +68,11 @@ class State(TypedDict): class RoutingGraph(GraphBase): def __init__(self, config: RoutingConfig): self.config = config - self.chat_sys_msg = None + + # NOTE: tool that the chatbranch should have + self.chat_tool_names = ["retrieve", + "get_resources"] + self._build_modules() self.workflow = self._build_graph() @@ -87,24 +104,30 @@ class RoutingGraph(GraphBase): assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(kwargs) == 0, "due to inp assumptions" + def _get_chat_tools(self, man:ToolManager): + return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names] + def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name, - model_provider=self.config.llm_provider, - api_key=self.config.api_key, - base_url=self.config.base_url) - self.memory = MemorySaver() + model_provider=self.config.llm_provider, + api_key=self.config.api_key, + base_url=self.config.base_url) + self.memory = MemorySaver() # shared memory between the two branch self.router = self.llm.with_structured_output(Route) tool_manager:ToolManager = self.config.tool_manager_config.setup() - self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) - self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) + self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory) + self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory) + + with open(self.config.sys_promp_json , "r") as f: + self.prompt_dict:Dict[str, str] = commentjson.load(f) def _router_call(self, state:State): decision:Route = self.router.invoke( [ SystemMessage( - content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" + content=self.prompt_dict["route_prompt"] ), self._get_human_msg(state) ] @@ -137,6 +160,15 @@ class RoutingGraph(GraphBase): inp = state["messages"], state["inp"][1] else: inp = state["inp"] + + if self.prompt_dict.get("chat_prompt") is not None: + inp = {"messages":[ + SystemMessage( + self.prompt_dict["chat_prompt"] + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] + out = self.chat_model.invoke(*inp) return {"messages": out} @@ -145,9 +177,8 @@ class RoutingGraph(GraphBase): def _tool_model_call(self, state:State): inp = {"messages":[ SystemMessage( - "You must use tool to complete the possible task" + self.prompt_dict["tool_prompt"] ), - # self._get_human_msg(state) *state["inp"][0]["messages"][1:] ]}, state["inp"][1] diff --git a/lang_agent/tool_manager.py b/lang_agent/tool_manager.py index 0b21f0a..85bbff4 100644 --- a/lang_agent/tool_manager.py +++ b/lang_agent/tool_manager.py @@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool from lang_agent.config import InstantiateConfig, ToolConfig from lang_agent.base import LangToolBase -## import tool configs from lang_agent.rag.simple import SimpleRagConfig from lang_agent.dummy.calculator import CalculatorConfig from catering_end.lang_tool import CartToolConfig, CartTool -# from langchain.tools import StructuredTool from langchain_core.tools.structured import StructuredTool +import jax @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class ToolManagerConfig(InstantiateConfig): _target: Type = field(default_factory=lambda: ToolManager) - # tool configs here; + # tool configs here; MUST HAVE 'config' in name and must be dataclass rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) cart_config: CartToolConfig = field(default_factory=CartToolConfig) @@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable: return sync_wrapper + class ToolManager: def __init__(self, config:ToolManagerConfig): self.config = config @@ -91,29 +91,39 @@ class ToolManager: self.tool_fncs = [] tool_configs = self._get_tool_config() for tool_conf in tool_configs: + tool_name = tool_conf.get_name()[:-6] if tool_conf.use_tool: - logger.info(f"making tool:{tool_conf._target}") - self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup())) + logger.info(f"making tool:{tool_name}") + fnc_list = self._get_tool_fnc(tool_conf.setup()) + self.tool_fncs.extend(fnc_list) else: - logger.info(f"skipping tool:{tool_conf._target}") + logger.info(f"skipping tool:{tool_name}") + + self._build_langchain_tools() def get_tool_fncs(self): return self.tool_fncs - - def get_langchain_tools(self): - out = [] - for func in self.get_tool_fncs(): - if inspect.iscoroutinefunction(func): - out.append( - StructuredTool.from_function( - func=async_to_sync(func), - coroutine=func) - ) - else: - out.append( - StructuredTool.from_function(func=func) - ) + def get_tool_dict(self): + return self.tool_dict - return out \ No newline at end of file + + def fnc_to_structool(self, func): + if inspect.iscoroutinefunction(func): + return StructuredTool.from_function( + func=async_to_sync(func), + coroutine=func) + + else: + return StructuredTool.from_function(func=func) + + def _build_langchain_tools(self): + self.langchain_tools = [] + for func in self.get_tool_fncs(): + self.langchain_tools.append(self.fnc_to_structool(func)) + + return self.langchain_tools + + def get_list_langchain_tools(self)->List[StructuredTool]: + return self.langchain_tools \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c11109f..54f9588 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ dependencies = [ "langchain==1.0", "langchain_community", "langchain-openai", + "langchain_mcp_adapters", "httpx[socks]", "dashscope", "python-dotenv>=1.0.0", @@ -19,7 +20,8 @@ dependencies = [ "fastapi", "matplotlib", "Pillow", - "jax" + "jax", + "commentjson" ] [tool.setuptools.packages.find] diff --git a/scripts/demo_chat.py b/scripts/demo_chat.py index d3c2e04..02fb421 100644 --- a/scripts/demo_chat.py +++ b/scripts/demo_chat.py @@ -18,7 +18,8 @@ def main(conf:PipelineConfig): # response = pipeline.chat(user_input, as_stream=True) # print(f"回答: {response}") - out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True) + # out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True) + out = pipeline.chat("介绍一下自己", as_stream=True) # out = pipeline.chat("testing", as_stream=True) print("=========== final ==========") print(out) diff --git a/scripts/make_eval_dataset.py b/scripts/make_eval_dataset.py new file mode 100644 index 0000000..0106066 --- /dev/null +++ b/scripts/make_eval_dataset.py @@ -0,0 +1,31 @@ +from langsmith import Client +from loguru import logger + + +DATASET_NAME = "dev_langagent" + +examples = [ + { + "inputs": {"text": "用retrieve查询光予尘然后介绍"}, + "outputs": {"answer": "茉莉绿茶为底,清冽茶香中漫出玫珑蜜瓜的绵甜与凤梨的明亮果香,层次鲜活;顶部白柚茉莉泡沫轻盈漫过舌尖,带着微酸的清新感,让整体风味更显灵动", + "tool_use": ["retrieve"]} + }, + { + "inputs": {"text": "介绍一下自己"}, + "outputs": {"answer": "我是小盏,是一个点餐助手"} + } +] + +cli = Client() + +try: + dataset = cli.read_dataset(dataset_name=DATASET_NAME) + logger.info("read dataset") +except: + dataset = cli.create_dataset(dataset_name=DATASET_NAME) + logger.info("created dataset") + +cli.create_examples( + dataset_id=dataset.id, + examples=examples +) \ No newline at end of file