Merge pull request #1 from tangledup-ai/main

get update from. main
This commit is contained in:
jijiahao
2025-10-29 15:51:36 +08:00
committed by GitHub
11 changed files with 166 additions and 45 deletions

View File

@@ -37,3 +37,15 @@ python scripts/start_mcp_server.py
# update configs/ws_mcp_config.json with link from the command above
python scripts/ws_start_register_tools.py
```
# Eval Dataset Format
see `scripts/make_eval_dataset.py` for example. Specific meaning of each entry:
```json
[
{
"inputs": {"text": "用retrieve查询光予尘然后介绍"}, // model input
"outputs": {"answer": "光予尘茉莉绿茶为底", // reference answer
"tool_use": ["retrieve"]} // tool uses; assume model need to use all tools if more than 1 provided
}
]
```

View File

@@ -0,0 +1,8 @@
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
// Prompt for tool branch
"tool_prompt" : "You must use tool to complete the possible task"
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
}

View File

@@ -54,6 +54,9 @@ class InstantiateConfig(PrintableConfig):
yaml.dump(self, f)
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
def get_name(self):
return self.__class__.__name__

View File

@@ -24,7 +24,7 @@ class EvaluatorConfig(InstantiateConfig):
experiment_desc:str = "testing if this works or not"
"""describe the experiment"""
dataset_name:Literal["Toxic Queries"] = "Toxic Queries"
dataset_name:Literal["Toxic Queries"] = "dev_langagent"
"""name of the dataset to evaluate"""
pipe_config: PipelineConfig = field(default_factory=PipelineConfig)

View File

@@ -1,11 +1,12 @@
from dataclasses import dataclass, field
from typing import Type, Literal
from typing import Type, Callable, List
import tyro
from lang_agent.config import KeyConfig
from lang_agent.pipeline import Pipeline, PipelineConfig
from langchain.chat_models import init_chat_model
from langchain_core.messages import BaseMessage, ToolMessage
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
@@ -21,12 +22,12 @@ class Validator:
# NOTE: Need to register function here
self.dict_corr_map = {
"Toxic Queries" : [self.Toxic_Queries_correct]
"dev_langagent" : [self.default_correct, self.val_tool_use]
}
# NOTE: Need to register function here
self.dict_inp_map = {
"Toxic Queries" : self.Toxic_Queries_inp_parse
"dev_langagent" : self.default_inp_parse
}
@@ -39,7 +40,7 @@ class Validator:
)
# NOTE: for every dataset; need one of these
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool:
def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
instructions = (
"Given an actual answer and an expected answer, determine whether"
" the actual answer contains all of the information in the"
@@ -48,7 +49,7 @@ class Validator:
" otherwise. Do not include anything else in your response."
)
actual_answer = outputs["output"][-1].content
expected_answer = reference_outputs["label"]
expected_answer = reference_outputs["answer"]
user_msg = (
f"ACTUAL ANSWER: {actual_answer}"
@@ -64,16 +65,38 @@ class Validator:
return response.content.upper() == "CORRECT"
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
tool_uses:List[str] = reference_outputs.get("tool_use")
if tool_uses is None:
return True
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
# check if all tools are used
tool_used = []
for ref_tool in tool_uses:
st_cond = False
ref_tool = ref_tool.lower()
for msg in tool_msgs:
st_cond = ref_tool in msg.name.lower()
if st_cond:
break
tool_used.append(st_cond)
return sum(tool_used)/len(tool_uses)
# NOTE: for every dataset; need one of these
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline):
def default_inp_parse(self, inp, pipeline:Pipeline):
inp = inp["text"]
return pipeline.chat(inp, as_raw=True)
def get_val_fnc(self, dataset_name:str):
return self.dict_corr_map[dataset_name]
def get_val_fnc(self, dataset_name:str)->List[Callable]:
return self.dict_corr_map.get(dataset_name, [self.default_correct])
def get_inp_fnc(self,dataset_name:str):
return self.dict_inp_map[dataset_name]
def get_inp_fnc(self,dataset_name:str)->Callable:
# return self.dict_inp_map[dataset_name]
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)

View File

@@ -47,7 +47,7 @@ class ReactGraph(GraphBase):
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
memory = MemorySaver()
tools = self.tool_manager.get_langchain_tools()
tools = self.tool_manager.get_list_langchain_tools()
self.agent = create_agent(self.llm, tools, checkpointer=memory)
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):

View File

@@ -7,6 +7,8 @@ from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import jax
import os.path as osp
import commentjson
from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
sys_promp_json: str = None
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
def __post_init__(self):
super().__post_init__()
if self.sys_promp_json is None:
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
class Route(BaseModel):
step: Literal["chat", "order"] = Field(
@@ -55,7 +68,11 @@ class State(TypedDict):
class RoutingGraph(GraphBase):
def __init__(self, config: RoutingConfig):
self.config = config
self.chat_sys_msg = None
# NOTE: tool that the chatbranch should have
self.chat_tool_names = ["retrieve",
"get_resources"]
self._build_modules()
self.workflow = self._build_graph()
@@ -87,24 +104,30 @@ class RoutingGraph(GraphBase):
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def _get_chat_tools(self, man:ToolManager):
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
self.memory = MemorySaver()
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
self.memory = MemorySaver() # shared memory between the two branch
self.router = self.llm.with_structured_output(Route)
tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
with open(self.config.sys_promp_json , "r") as f:
self.prompt_dict:Dict[str, str] = commentjson.load(f)
def _router_call(self, state:State):
decision:Route = self.router.invoke(
[
SystemMessage(
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
content=self.prompt_dict["route_prompt"]
),
self._get_human_msg(state)
]
@@ -138,6 +161,15 @@ class RoutingGraph(GraphBase):
else:
inp = state["inp"]
if self.prompt_dict.get("chat_prompt") is not None:
inp = {"messages":[
SystemMessage(
self.prompt_dict["chat_prompt"]
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = self.chat_model.invoke(*inp)
return {"messages": out}
@@ -145,9 +177,8 @@ class RoutingGraph(GraphBase):
def _tool_model_call(self, state:State):
inp = {"messages":[
SystemMessage(
"You must use tool to complete the possible task"
self.prompt_dict["tool_prompt"]
),
# self._get_human_msg(state)
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]

View File

@@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool
from lang_agent.config import InstantiateConfig, ToolConfig
from lang_agent.base import LangToolBase
## import tool configs
from lang_agent.rag.simple import SimpleRagConfig
from lang_agent.dummy.calculator import CalculatorConfig
from catering_end.lang_tool import CartToolConfig, CartTool
# from langchain.tools import StructuredTool
from langchain_core.tools.structured import StructuredTool
import jax
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class ToolManagerConfig(InstantiateConfig):
_target: Type = field(default_factory=lambda: ToolManager)
# tool configs here;
# tool configs here; MUST HAVE 'config' in name and must be dataclass
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
@@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable:
return sync_wrapper
class ToolManager:
def __init__(self, config:ToolManagerConfig):
self.config = config
@@ -91,29 +91,39 @@ class ToolManager:
self.tool_fncs = []
tool_configs = self._get_tool_config()
for tool_conf in tool_configs:
tool_name = tool_conf.get_name()[:-6]
if tool_conf.use_tool:
logger.info(f"making tool:{tool_conf._target}")
self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup()))
logger.info(f"making tool:{tool_name}")
fnc_list = self._get_tool_fnc(tool_conf.setup())
self.tool_fncs.extend(fnc_list)
else:
logger.info(f"skipping tool:{tool_conf._target}")
logger.info(f"skipping tool:{tool_name}")
self._build_langchain_tools()
def get_tool_fncs(self):
return self.tool_fncs
def get_tool_dict(self):
return self.tool_dict
def get_langchain_tools(self):
out = []
def fnc_to_structool(self, func):
if inspect.iscoroutinefunction(func):
return StructuredTool.from_function(
func=async_to_sync(func),
coroutine=func)
else:
return StructuredTool.from_function(func=func)
def _build_langchain_tools(self):
self.langchain_tools = []
for func in self.get_tool_fncs():
if inspect.iscoroutinefunction(func):
out.append(
StructuredTool.from_function(
func=async_to_sync(func),
coroutine=func)
)
else:
out.append(
StructuredTool.from_function(func=func)
)
self.langchain_tools.append(self.fnc_to_structool(func))
return out
return self.langchain_tools
def get_list_langchain_tools(self)->List[StructuredTool]:
return self.langchain_tools

View File

@@ -5,6 +5,7 @@ dependencies = [
"langchain==1.0",
"langchain_community",
"langchain-openai",
"langchain_mcp_adapters",
"httpx[socks]",
"dashscope",
"python-dotenv>=1.0.0",
@@ -19,7 +20,8 @@ dependencies = [
"fastapi",
"matplotlib",
"Pillow",
"jax"
"jax",
"commentjson"
]
[tool.setuptools.packages.find]

View File

@@ -18,7 +18,8 @@ def main(conf:PipelineConfig):
# response = pipeline.chat(user_input, as_stream=True)
# print(f"回答: {response}")
out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
# out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
out = pipeline.chat("介绍一下自己", as_stream=True)
# out = pipeline.chat("testing", as_stream=True)
print("=========== final ==========")
print(out)

View File

@@ -0,0 +1,31 @@
from langsmith import Client
from loguru import logger
DATASET_NAME = "dev_langagent"
examples = [
{
"inputs": {"text": "用retrieve查询光予尘然后介绍"},
"outputs": {"answer": "茉莉绿茶为底,清冽茶香中漫出玫珑蜜瓜的绵甜与凤梨的明亮果香,层次鲜活;顶部白柚茉莉泡沫轻盈漫过舌尖,带着微酸的清新感,让整体风味更显灵动",
"tool_use": ["retrieve"]}
},
{
"inputs": {"text": "介绍一下自己"},
"outputs": {"answer": "我是小盏,是一个点餐助手"}
}
]
cli = Client()
try:
dataset = cli.read_dataset(dataset_name=DATASET_NAME)
logger.info("read dataset")
except:
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
logger.info("created dataset")
cli.create_examples(
dataset_id=dataset.id,
examples=examples
)