12
README.md
12
README.md
@@ -37,3 +37,15 @@ python scripts/start_mcp_server.py
|
||||
# update configs/ws_mcp_config.json with link from the command above
|
||||
python scripts/ws_start_register_tools.py
|
||||
```
|
||||
|
||||
# Eval Dataset Format
|
||||
see `scripts/make_eval_dataset.py` for example. Specific meaning of each entry:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"inputs": {"text": "用retrieve查询光予尘然后介绍"}, // model input
|
||||
"outputs": {"answer": "光予尘茉莉绿茶为底", // reference answer
|
||||
"tool_use": ["retrieve"]} // tool uses; assume model need to use all tools if more than 1 provided
|
||||
}
|
||||
]
|
||||
```
|
||||
8
configs/route_sys_prompts.json
Normal file
8
configs/route_sys_prompts.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
|
||||
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
|
||||
|
||||
// Prompt for tool branch
|
||||
"tool_prompt" : "You must use tool to complete the possible task"
|
||||
|
||||
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
|
||||
}
|
||||
@@ -54,6 +54,9 @@ class InstantiateConfig(PrintableConfig):
|
||||
yaml.dump(self, f)
|
||||
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
||||
|
||||
def get_name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ class EvaluatorConfig(InstantiateConfig):
|
||||
experiment_desc:str = "testing if this works or not"
|
||||
"""describe the experiment"""
|
||||
|
||||
dataset_name:Literal["Toxic Queries"] = "Toxic Queries"
|
||||
dataset_name:Literal["Toxic Queries"] = "dev_langagent"
|
||||
"""name of the dataset to evaluate"""
|
||||
|
||||
pipe_config: PipelineConfig = field(default_factory=PipelineConfig)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, Literal
|
||||
from typing import Type, Callable, List
|
||||
import tyro
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import BaseMessage, ToolMessage
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
@@ -21,12 +22,12 @@ class Validator:
|
||||
|
||||
# NOTE: Need to register function here
|
||||
self.dict_corr_map = {
|
||||
"Toxic Queries" : [self.Toxic_Queries_correct]
|
||||
"dev_langagent" : [self.default_correct, self.val_tool_use]
|
||||
}
|
||||
|
||||
# NOTE: Need to register function here
|
||||
self.dict_inp_map = {
|
||||
"Toxic Queries" : self.Toxic_Queries_inp_parse
|
||||
"dev_langagent" : self.default_inp_parse
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +40,7 @@ class Validator:
|
||||
)
|
||||
|
||||
# NOTE: for every dataset; need one of these
|
||||
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool:
|
||||
def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
|
||||
instructions = (
|
||||
"Given an actual answer and an expected answer, determine whether"
|
||||
" the actual answer contains all of the information in the"
|
||||
@@ -48,7 +49,7 @@ class Validator:
|
||||
" otherwise. Do not include anything else in your response."
|
||||
)
|
||||
actual_answer = outputs["output"][-1].content
|
||||
expected_answer = reference_outputs["label"]
|
||||
expected_answer = reference_outputs["answer"]
|
||||
|
||||
user_msg = (
|
||||
f"ACTUAL ANSWER: {actual_answer}"
|
||||
@@ -64,16 +65,38 @@ class Validator:
|
||||
|
||||
return response.content.upper() == "CORRECT"
|
||||
|
||||
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
|
||||
tool_uses:List[str] = reference_outputs.get("tool_use")
|
||||
if tool_uses is None:
|
||||
return True
|
||||
|
||||
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
|
||||
|
||||
# check if all tools are used
|
||||
tool_used = []
|
||||
for ref_tool in tool_uses:
|
||||
st_cond = False
|
||||
ref_tool = ref_tool.lower()
|
||||
for msg in tool_msgs:
|
||||
st_cond = ref_tool in msg.name.lower()
|
||||
if st_cond:
|
||||
break
|
||||
tool_used.append(st_cond)
|
||||
|
||||
return sum(tool_used)/len(tool_uses)
|
||||
|
||||
|
||||
|
||||
# NOTE: for every dataset; need one of these
|
||||
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline):
|
||||
def default_inp_parse(self, inp, pipeline:Pipeline):
|
||||
inp = inp["text"]
|
||||
return pipeline.chat(inp, as_raw=True)
|
||||
|
||||
|
||||
def get_val_fnc(self, dataset_name:str):
|
||||
return self.dict_corr_map[dataset_name]
|
||||
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
||||
return self.dict_corr_map.get(dataset_name, [self.default_correct])
|
||||
|
||||
|
||||
def get_inp_fnc(self,dataset_name:str):
|
||||
return self.dict_inp_map[dataset_name]
|
||||
def get_inp_fnc(self,dataset_name:str)->Callable:
|
||||
# return self.dict_inp_map[dataset_name]
|
||||
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)
|
||||
@@ -47,7 +47,7 @@ class ReactGraph(GraphBase):
|
||||
|
||||
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||
memory = MemorySaver()
|
||||
tools = self.tool_manager.get_langchain_tools()
|
||||
tools = self.tool_manager.get_list_langchain_tools()
|
||||
self.agent = create_agent(self.llm, tools, checkpointer=memory)
|
||||
|
||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||
|
||||
@@ -7,6 +7,8 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
import matplotlib.pyplot as plt
|
||||
import jax
|
||||
import os.path as osp
|
||||
import commentjson
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
|
||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||
|
||||
sys_promp_json: str = None
|
||||
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.sys_promp_json is None:
|
||||
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
||||
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
||||
|
||||
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
||||
|
||||
|
||||
class Route(BaseModel):
|
||||
step: Literal["chat", "order"] = Field(
|
||||
@@ -55,7 +68,11 @@ class State(TypedDict):
|
||||
class RoutingGraph(GraphBase):
|
||||
def __init__(self, config: RoutingConfig):
|
||||
self.config = config
|
||||
self.chat_sys_msg = None
|
||||
|
||||
# NOTE: tool that the chatbranch should have
|
||||
self.chat_tool_names = ["retrieve",
|
||||
"get_resources"]
|
||||
|
||||
self._build_modules()
|
||||
|
||||
self.workflow = self._build_graph()
|
||||
@@ -87,24 +104,30 @@ class RoutingGraph(GraphBase):
|
||||
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||
assert len(kwargs) == 0, "due to inp assumptions"
|
||||
|
||||
def _get_chat_tools(self, man:ToolManager):
|
||||
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
|
||||
|
||||
def _build_modules(self):
|
||||
self.llm = init_chat_model(model=self.config.llm_name,
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url)
|
||||
self.memory = MemorySaver()
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url)
|
||||
self.memory = MemorySaver() # shared memory between the two branch
|
||||
self.router = self.llm.with_structured_output(Route)
|
||||
|
||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
|
||||
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
|
||||
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
||||
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
|
||||
|
||||
with open(self.config.sys_promp_json , "r") as f:
|
||||
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
||||
|
||||
|
||||
def _router_call(self, state:State):
|
||||
decision:Route = self.router.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
||||
content=self.prompt_dict["route_prompt"]
|
||||
),
|
||||
self._get_human_msg(state)
|
||||
]
|
||||
@@ -138,6 +161,15 @@ class RoutingGraph(GraphBase):
|
||||
else:
|
||||
inp = state["inp"]
|
||||
|
||||
if self.prompt_dict.get("chat_prompt") is not None:
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
self.prompt_dict["chat_prompt"]
|
||||
),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
|
||||
|
||||
out = self.chat_model.invoke(*inp)
|
||||
return {"messages": out}
|
||||
|
||||
@@ -145,9 +177,8 @@ class RoutingGraph(GraphBase):
|
||||
def _tool_model_call(self, state:State):
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
"You must use tool to complete the possible task"
|
||||
self.prompt_dict["tool_prompt"]
|
||||
),
|
||||
# self._get_human_msg(state)
|
||||
*state["inp"][0]["messages"][1:]
|
||||
]}, state["inp"][1]
|
||||
|
||||
|
||||
@@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool
|
||||
from lang_agent.config import InstantiateConfig, ToolConfig
|
||||
from lang_agent.base import LangToolBase
|
||||
|
||||
## import tool configs
|
||||
from lang_agent.rag.simple import SimpleRagConfig
|
||||
from lang_agent.dummy.calculator import CalculatorConfig
|
||||
from catering_end.lang_tool import CartToolConfig, CartTool
|
||||
|
||||
# from langchain.tools import StructuredTool
|
||||
from langchain_core.tools.structured import StructuredTool
|
||||
import jax
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class ToolManagerConfig(InstantiateConfig):
|
||||
_target: Type = field(default_factory=lambda: ToolManager)
|
||||
|
||||
# tool configs here;
|
||||
# tool configs here; MUST HAVE 'config' in name and must be dataclass
|
||||
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
||||
|
||||
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||
@@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable:
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
|
||||
class ToolManager:
|
||||
def __init__(self, config:ToolManagerConfig):
|
||||
self.config = config
|
||||
@@ -91,29 +91,39 @@ class ToolManager:
|
||||
self.tool_fncs = []
|
||||
tool_configs = self._get_tool_config()
|
||||
for tool_conf in tool_configs:
|
||||
tool_name = tool_conf.get_name()[:-6]
|
||||
if tool_conf.use_tool:
|
||||
logger.info(f"making tool:{tool_conf._target}")
|
||||
self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup()))
|
||||
logger.info(f"making tool:{tool_name}")
|
||||
fnc_list = self._get_tool_fnc(tool_conf.setup())
|
||||
self.tool_fncs.extend(fnc_list)
|
||||
else:
|
||||
logger.info(f"skipping tool:{tool_conf._target}")
|
||||
logger.info(f"skipping tool:{tool_name}")
|
||||
|
||||
self._build_langchain_tools()
|
||||
|
||||
|
||||
def get_tool_fncs(self):
|
||||
return self.tool_fncs
|
||||
|
||||
def get_tool_dict(self):
|
||||
return self.tool_dict
|
||||
|
||||
def get_langchain_tools(self):
|
||||
out = []
|
||||
|
||||
def fnc_to_structool(self, func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return StructuredTool.from_function(
|
||||
func=async_to_sync(func),
|
||||
coroutine=func)
|
||||
|
||||
else:
|
||||
return StructuredTool.from_function(func=func)
|
||||
|
||||
def _build_langchain_tools(self):
|
||||
self.langchain_tools = []
|
||||
for func in self.get_tool_fncs():
|
||||
if inspect.iscoroutinefunction(func):
|
||||
out.append(
|
||||
StructuredTool.from_function(
|
||||
func=async_to_sync(func),
|
||||
coroutine=func)
|
||||
)
|
||||
else:
|
||||
out.append(
|
||||
StructuredTool.from_function(func=func)
|
||||
)
|
||||
self.langchain_tools.append(self.fnc_to_structool(func))
|
||||
|
||||
return out
|
||||
return self.langchain_tools
|
||||
|
||||
def get_list_langchain_tools(self)->List[StructuredTool]:
|
||||
return self.langchain_tools
|
||||
@@ -5,6 +5,7 @@ dependencies = [
|
||||
"langchain==1.0",
|
||||
"langchain_community",
|
||||
"langchain-openai",
|
||||
"langchain_mcp_adapters",
|
||||
"httpx[socks]",
|
||||
"dashscope",
|
||||
"python-dotenv>=1.0.0",
|
||||
@@ -19,7 +20,8 @@ dependencies = [
|
||||
"fastapi",
|
||||
"matplotlib",
|
||||
"Pillow",
|
||||
"jax"
|
||||
"jax",
|
||||
"commentjson"
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -18,7 +18,8 @@ def main(conf:PipelineConfig):
|
||||
# response = pipeline.chat(user_input, as_stream=True)
|
||||
# print(f"回答: {response}")
|
||||
|
||||
out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
|
||||
# out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
|
||||
out = pipeline.chat("介绍一下自己", as_stream=True)
|
||||
# out = pipeline.chat("testing", as_stream=True)
|
||||
print("=========== final ==========")
|
||||
print(out)
|
||||
|
||||
31
scripts/make_eval_dataset.py
Normal file
31
scripts/make_eval_dataset.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from langsmith import Client
|
||||
from loguru import logger
|
||||
|
||||
|
||||
DATASET_NAME = "dev_langagent"
|
||||
|
||||
examples = [
|
||||
{
|
||||
"inputs": {"text": "用retrieve查询光予尘然后介绍"},
|
||||
"outputs": {"answer": "茉莉绿茶为底,清冽茶香中漫出玫珑蜜瓜的绵甜与凤梨的明亮果香,层次鲜活;顶部白柚茉莉泡沫轻盈漫过舌尖,带着微酸的清新感,让整体风味更显灵动",
|
||||
"tool_use": ["retrieve"]}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下自己"},
|
||||
"outputs": {"answer": "我是小盏,是一个点餐助手"}
|
||||
}
|
||||
]
|
||||
|
||||
cli = Client()
|
||||
|
||||
try:
|
||||
dataset = cli.read_dataset(dataset_name=DATASET_NAME)
|
||||
logger.info("read dataset")
|
||||
except:
|
||||
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
|
||||
logger.info("created dataset")
|
||||
|
||||
cli.create_examples(
|
||||
dataset_id=dataset.id,
|
||||
examples=examples
|
||||
)
|
||||
Reference in New Issue
Block a user