Merge pull request #1 from tangledup-ai/main

get update from. main
This commit is contained in:
jijiahao
2025-10-29 15:51:36 +08:00
committed by GitHub
11 changed files with 166 additions and 45 deletions

View File

@@ -37,3 +37,15 @@ python scripts/start_mcp_server.py
# update configs/ws_mcp_config.json with link from the command above # update configs/ws_mcp_config.json with link from the command above
python scripts/ws_start_register_tools.py python scripts/ws_start_register_tools.py
``` ```
# Eval Dataset Format
see `scripts/make_eval_dataset.py` for example. Specific meaning of each entry:
```json
[
{
"inputs": {"text": "用retrieve查询光予尘然后介绍"}, // model input
"outputs": {"answer": "光予尘茉莉绿茶为底", // reference answer
"tool_use": ["retrieve"]} // tool uses; assume model need to use all tools if more than 1 provided
}
]
```

View File

@@ -0,0 +1,8 @@
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
// Prompt for tool branch
"tool_prompt" : "You must use tool to complete the possible task"
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
}

View File

@@ -54,6 +54,9 @@ class InstantiateConfig(PrintableConfig):
yaml.dump(self, f) yaml.dump(self, f)
logger.info(f"[yellow]config saved to: {filename}[/yellow]") logger.info(f"[yellow]config saved to: {filename}[/yellow]")
def get_name(self):
return self.__class__.__name__

View File

@@ -24,7 +24,7 @@ class EvaluatorConfig(InstantiateConfig):
experiment_desc:str = "testing if this works or not" experiment_desc:str = "testing if this works or not"
"""describe the experiment""" """describe the experiment"""
dataset_name:Literal["Toxic Queries"] = "Toxic Queries" dataset_name:Literal["Toxic Queries"] = "dev_langagent"
"""name of the dataset to evaluate""" """name of the dataset to evaluate"""
pipe_config: PipelineConfig = field(default_factory=PipelineConfig) pipe_config: PipelineConfig = field(default_factory=PipelineConfig)

View File

@@ -1,11 +1,12 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type, Literal from typing import Type, Callable, List
import tyro import tyro
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.pipeline import Pipeline, PipelineConfig
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import BaseMessage, ToolMessage
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
@@ -21,12 +22,12 @@ class Validator:
# NOTE: Need to register function here # NOTE: Need to register function here
self.dict_corr_map = { self.dict_corr_map = {
"Toxic Queries" : [self.Toxic_Queries_correct] "dev_langagent" : [self.default_correct, self.val_tool_use]
} }
# NOTE: Need to register function here # NOTE: Need to register function here
self.dict_inp_map = { self.dict_inp_map = {
"Toxic Queries" : self.Toxic_Queries_inp_parse "dev_langagent" : self.default_inp_parse
} }
@@ -39,7 +40,7 @@ class Validator:
) )
# NOTE: for every dataset; need one of these # NOTE: for every dataset; need one of these
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool: def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
instructions = ( instructions = (
"Given an actual answer and an expected answer, determine whether" "Given an actual answer and an expected answer, determine whether"
" the actual answer contains all of the information in the" " the actual answer contains all of the information in the"
@@ -48,7 +49,7 @@ class Validator:
" otherwise. Do not include anything else in your response." " otherwise. Do not include anything else in your response."
) )
actual_answer = outputs["output"][-1].content actual_answer = outputs["output"][-1].content
expected_answer = reference_outputs["label"] expected_answer = reference_outputs["answer"]
user_msg = ( user_msg = (
f"ACTUAL ANSWER: {actual_answer}" f"ACTUAL ANSWER: {actual_answer}"
@@ -64,16 +65,38 @@ class Validator:
return response.content.upper() == "CORRECT" return response.content.upper() == "CORRECT"
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
tool_uses:List[str] = reference_outputs.get("tool_use")
if tool_uses is None:
return True
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
# check if all tools are used
tool_used = []
for ref_tool in tool_uses:
st_cond = False
ref_tool = ref_tool.lower()
for msg in tool_msgs:
st_cond = ref_tool in msg.name.lower()
if st_cond:
break
tool_used.append(st_cond)
return sum(tool_used)/len(tool_uses)
# NOTE: for every dataset; need one of these # NOTE: for every dataset; need one of these
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline): def default_inp_parse(self, inp, pipeline:Pipeline):
inp = inp["text"] inp = inp["text"]
return pipeline.chat(inp, as_raw=True) return pipeline.chat(inp, as_raw=True)
def get_val_fnc(self, dataset_name:str): def get_val_fnc(self, dataset_name:str)->List[Callable]:
return self.dict_corr_map[dataset_name] return self.dict_corr_map.get(dataset_name, [self.default_correct])
def get_inp_fnc(self,dataset_name:str): def get_inp_fnc(self,dataset_name:str)->Callable:
return self.dict_inp_map[dataset_name] # return self.dict_inp_map[dataset_name]
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)

View File

@@ -47,7 +47,7 @@ class ReactGraph(GraphBase):
self.tool_manager:ToolManager = self.config.tool_manager_config.setup() self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
memory = MemorySaver() memory = MemorySaver()
tools = self.tool_manager.get_langchain_tools() tools = self.tool_manager.get_list_langchain_tools()
self.agent = create_agent(self.llm, tools, checkpointer=memory) self.agent = create_agent(self.llm, tools, checkpointer=memory)
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):

View File

@@ -7,6 +7,8 @@ from PIL import Image
from io import BytesIO from io import BytesIO
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import jax import jax
import os.path as osp
import commentjson
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider""" """base url; could be used to overwrite the baseurl in llm provider"""
sys_promp_json: str = None
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self):
super().__post_init__()
if self.sys_promp_json is None:
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
class Route(BaseModel): class Route(BaseModel):
step: Literal["chat", "order"] = Field( step: Literal["chat", "order"] = Field(
@@ -55,7 +68,11 @@ class State(TypedDict):
class RoutingGraph(GraphBase): class RoutingGraph(GraphBase):
def __init__(self, config: RoutingConfig): def __init__(self, config: RoutingConfig):
self.config = config self.config = config
self.chat_sys_msg = None
# NOTE: tool that the chatbranch should have
self.chat_tool_names = ["retrieve",
"get_resources"]
self._build_modules() self._build_modules()
self.workflow = self._build_graph() self.workflow = self._build_graph()
@@ -87,24 +104,30 @@ class RoutingGraph(GraphBase):
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions" assert len(kwargs) == 0, "due to inp assumptions"
def _get_chat_tools(self, man:ToolManager):
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
def _build_modules(self): def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider, model_provider=self.config.llm_provider,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url) base_url=self.config.base_url)
self.memory = MemorySaver() self.memory = MemorySaver() # shared memory between the two branch
self.router = self.llm.with_structured_output(Route) self.router = self.llm.with_structured_output(Route)
tool_manager:ToolManager = self.config.tool_manager_config.setup() tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
with open(self.config.sys_promp_json , "r") as f:
self.prompt_dict:Dict[str, str] = commentjson.load(f)
def _router_call(self, state:State): def _router_call(self, state:State):
decision:Route = self.router.invoke( decision:Route = self.router.invoke(
[ [
SystemMessage( SystemMessage(
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" content=self.prompt_dict["route_prompt"]
), ),
self._get_human_msg(state) self._get_human_msg(state)
] ]
@@ -138,6 +161,15 @@ class RoutingGraph(GraphBase):
else: else:
inp = state["inp"] inp = state["inp"]
if self.prompt_dict.get("chat_prompt") is not None:
inp = {"messages":[
SystemMessage(
self.prompt_dict["chat_prompt"]
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = self.chat_model.invoke(*inp) out = self.chat_model.invoke(*inp)
return {"messages": out} return {"messages": out}
@@ -145,9 +177,8 @@ class RoutingGraph(GraphBase):
def _tool_model_call(self, state:State): def _tool_model_call(self, state:State):
inp = {"messages":[ inp = {"messages":[
SystemMessage( SystemMessage(
"You must use tool to complete the possible task" self.prompt_dict["tool_prompt"]
), ),
# self._get_human_msg(state)
*state["inp"][0]["messages"][1:] *state["inp"][0]["messages"][1:]
]}, state["inp"][1] ]}, state["inp"][1]

View File

@@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool
from lang_agent.config import InstantiateConfig, ToolConfig from lang_agent.config import InstantiateConfig, ToolConfig
from lang_agent.base import LangToolBase from lang_agent.base import LangToolBase
## import tool configs
from lang_agent.rag.simple import SimpleRagConfig from lang_agent.rag.simple import SimpleRagConfig
from lang_agent.dummy.calculator import CalculatorConfig from lang_agent.dummy.calculator import CalculatorConfig
from catering_end.lang_tool import CartToolConfig, CartTool from catering_end.lang_tool import CartToolConfig, CartTool
# from langchain.tools import StructuredTool
from langchain_core.tools.structured import StructuredTool from langchain_core.tools.structured import StructuredTool
import jax
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class ToolManagerConfig(InstantiateConfig): class ToolManagerConfig(InstantiateConfig):
_target: Type = field(default_factory=lambda: ToolManager) _target: Type = field(default_factory=lambda: ToolManager)
# tool configs here; # tool configs here; MUST HAVE 'config' in name and must be dataclass
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig) rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
cart_config: CartToolConfig = field(default_factory=CartToolConfig) cart_config: CartToolConfig = field(default_factory=CartToolConfig)
@@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable:
return sync_wrapper return sync_wrapper
class ToolManager: class ToolManager:
def __init__(self, config:ToolManagerConfig): def __init__(self, config:ToolManagerConfig):
self.config = config self.config = config
@@ -91,29 +91,39 @@ class ToolManager:
self.tool_fncs = [] self.tool_fncs = []
tool_configs = self._get_tool_config() tool_configs = self._get_tool_config()
for tool_conf in tool_configs: for tool_conf in tool_configs:
tool_name = tool_conf.get_name()[:-6]
if tool_conf.use_tool: if tool_conf.use_tool:
logger.info(f"making tool:{tool_conf._target}") logger.info(f"making tool:{tool_name}")
self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup())) fnc_list = self._get_tool_fnc(tool_conf.setup())
self.tool_fncs.extend(fnc_list)
else: else:
logger.info(f"skipping tool:{tool_conf._target}") logger.info(f"skipping tool:{tool_name}")
self._build_langchain_tools()
def get_tool_fncs(self): def get_tool_fncs(self):
return self.tool_fncs return self.tool_fncs
def get_tool_dict(self):
return self.tool_dict
def get_langchain_tools(self):
out = [] def fnc_to_structool(self, func):
for func in self.get_tool_fncs():
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
out.append( return StructuredTool.from_function(
StructuredTool.from_function(
func=async_to_sync(func), func=async_to_sync(func),
coroutine=func) coroutine=func)
)
else:
out.append(
StructuredTool.from_function(func=func)
)
return out else:
return StructuredTool.from_function(func=func)
def _build_langchain_tools(self):
self.langchain_tools = []
for func in self.get_tool_fncs():
self.langchain_tools.append(self.fnc_to_structool(func))
return self.langchain_tools
def get_list_langchain_tools(self)->List[StructuredTool]:
return self.langchain_tools

View File

@@ -5,6 +5,7 @@ dependencies = [
"langchain==1.0", "langchain==1.0",
"langchain_community", "langchain_community",
"langchain-openai", "langchain-openai",
"langchain_mcp_adapters",
"httpx[socks]", "httpx[socks]",
"dashscope", "dashscope",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
@@ -19,7 +20,8 @@ dependencies = [
"fastapi", "fastapi",
"matplotlib", "matplotlib",
"Pillow", "Pillow",
"jax" "jax",
"commentjson"
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]

View File

@@ -18,7 +18,8 @@ def main(conf:PipelineConfig):
# response = pipeline.chat(user_input, as_stream=True) # response = pipeline.chat(user_input, as_stream=True)
# print(f"回答: {response}") # print(f"回答: {response}")
out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True) # out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
out = pipeline.chat("介绍一下自己", as_stream=True)
# out = pipeline.chat("testing", as_stream=True) # out = pipeline.chat("testing", as_stream=True)
print("=========== final ==========") print("=========== final ==========")
print(out) print(out)

View File

@@ -0,0 +1,31 @@
from langsmith import Client
from loguru import logger
DATASET_NAME = "dev_langagent"
examples = [
{
"inputs": {"text": "用retrieve查询光予尘然后介绍"},
"outputs": {"answer": "茉莉绿茶为底,清冽茶香中漫出玫珑蜜瓜的绵甜与凤梨的明亮果香,层次鲜活;顶部白柚茉莉泡沫轻盈漫过舌尖,带着微酸的清新感,让整体风味更显灵动",
"tool_use": ["retrieve"]}
},
{
"inputs": {"text": "介绍一下自己"},
"outputs": {"answer": "我是小盏,是一个点餐助手"}
}
]
cli = Client()
try:
dataset = cli.read_dataset(dataset_name=DATASET_NAME)
logger.info("read dataset")
except:
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
logger.info("created dataset")
cli.create_examples(
dataset_id=dataset.id,
examples=examples
)