12
README.md
12
README.md
@@ -36,4 +36,16 @@ python scripts/start_mcp_server.py
|
|||||||
|
|
||||||
# update configs/ws_mcp_config.json with link from the command above
|
# update configs/ws_mcp_config.json with link from the command above
|
||||||
python scripts/ws_start_register_tools.py
|
python scripts/ws_start_register_tools.py
|
||||||
|
```
|
||||||
|
|
||||||
|
# Eval Dataset Format
|
||||||
|
see `scripts/make_eval_dataset.py` for example. Specific meaning of each entry:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"inputs": {"text": "用retrieve查询光予尘然后介绍"}, // model input
|
||||||
|
"outputs": {"answer": "光予尘茉莉绿茶为底", // reference answer
|
||||||
|
"tool_use": ["retrieve"]} // tool uses; assume model need to use all tools if more than 1 provided
|
||||||
|
}
|
||||||
|
]
|
||||||
```
|
```
|
||||||
8
configs/route_sys_prompts.json
Normal file
8
configs/route_sys_prompts.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{ // Prompt for router; manditory to say return as json for route_prompt; if not, there will be bugs
|
||||||
|
"route_prompt" : "Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input",
|
||||||
|
|
||||||
|
// Prompt for tool branch
|
||||||
|
"tool_prompt" : "You must use tool to complete the possible task"
|
||||||
|
|
||||||
|
// Optionally set chat_prompt to overwrite the system prompt from xiaozhi
|
||||||
|
}
|
||||||
@@ -53,6 +53,9 @@ class InstantiateConfig(PrintableConfig):
|
|||||||
with open(filename, 'w') as f:
|
with open(filename, 'w') as f:
|
||||||
yaml.dump(self, f)
|
yaml.dump(self, f)
|
||||||
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class EvaluatorConfig(InstantiateConfig):
|
|||||||
experiment_desc:str = "testing if this works or not"
|
experiment_desc:str = "testing if this works or not"
|
||||||
"""describe the experiment"""
|
"""describe the experiment"""
|
||||||
|
|
||||||
dataset_name:Literal["Toxic Queries"] = "Toxic Queries"
|
dataset_name:Literal["Toxic Queries"] = "dev_langagent"
|
||||||
"""name of the dataset to evaluate"""
|
"""name of the dataset to evaluate"""
|
||||||
|
|
||||||
pipe_config: PipelineConfig = field(default_factory=PipelineConfig)
|
pipe_config: PipelineConfig = field(default_factory=PipelineConfig)
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Type, Literal
|
from typing import Type, Callable, List
|
||||||
import tyro
|
import tyro
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_core.messages import BaseMessage, ToolMessage
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -21,12 +22,12 @@ class Validator:
|
|||||||
|
|
||||||
# NOTE: Need to register function here
|
# NOTE: Need to register function here
|
||||||
self.dict_corr_map = {
|
self.dict_corr_map = {
|
||||||
"Toxic Queries" : [self.Toxic_Queries_correct]
|
"dev_langagent" : [self.default_correct, self.val_tool_use]
|
||||||
}
|
}
|
||||||
|
|
||||||
# NOTE: Need to register function here
|
# NOTE: Need to register function here
|
||||||
self.dict_inp_map = {
|
self.dict_inp_map = {
|
||||||
"Toxic Queries" : self.Toxic_Queries_inp_parse
|
"dev_langagent" : self.default_inp_parse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ class Validator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: for every dataset; need one of these
|
# NOTE: for every dataset; need one of these
|
||||||
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool:
|
def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
|
||||||
instructions = (
|
instructions = (
|
||||||
"Given an actual answer and an expected answer, determine whether"
|
"Given an actual answer and an expected answer, determine whether"
|
||||||
" the actual answer contains all of the information in the"
|
" the actual answer contains all of the information in the"
|
||||||
@@ -48,7 +49,7 @@ class Validator:
|
|||||||
" otherwise. Do not include anything else in your response."
|
" otherwise. Do not include anything else in your response."
|
||||||
)
|
)
|
||||||
actual_answer = outputs["output"][-1].content
|
actual_answer = outputs["output"][-1].content
|
||||||
expected_answer = reference_outputs["label"]
|
expected_answer = reference_outputs["answer"]
|
||||||
|
|
||||||
user_msg = (
|
user_msg = (
|
||||||
f"ACTUAL ANSWER: {actual_answer}"
|
f"ACTUAL ANSWER: {actual_answer}"
|
||||||
@@ -64,16 +65,38 @@ class Validator:
|
|||||||
|
|
||||||
return response.content.upper() == "CORRECT"
|
return response.content.upper() == "CORRECT"
|
||||||
|
|
||||||
|
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
|
||||||
|
tool_uses:List[str] = reference_outputs.get("tool_use")
|
||||||
|
if tool_uses is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
|
||||||
|
|
||||||
|
# check if all tools are used
|
||||||
|
tool_used = []
|
||||||
|
for ref_tool in tool_uses:
|
||||||
|
st_cond = False
|
||||||
|
ref_tool = ref_tool.lower()
|
||||||
|
for msg in tool_msgs:
|
||||||
|
st_cond = ref_tool in msg.name.lower()
|
||||||
|
if st_cond:
|
||||||
|
break
|
||||||
|
tool_used.append(st_cond)
|
||||||
|
|
||||||
|
return sum(tool_used)/len(tool_uses)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: for every dataset; need one of these
|
# NOTE: for every dataset; need one of these
|
||||||
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline):
|
def default_inp_parse(self, inp, pipeline:Pipeline):
|
||||||
inp = inp["text"]
|
inp = inp["text"]
|
||||||
return pipeline.chat(inp, as_raw=True)
|
return pipeline.chat(inp, as_raw=True)
|
||||||
|
|
||||||
|
|
||||||
def get_val_fnc(self, dataset_name:str):
|
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
||||||
return self.dict_corr_map[dataset_name]
|
return self.dict_corr_map.get(dataset_name, [self.default_correct])
|
||||||
|
|
||||||
|
|
||||||
def get_inp_fnc(self,dataset_name:str):
|
def get_inp_fnc(self,dataset_name:str)->Callable:
|
||||||
return self.dict_inp_map[dataset_name]
|
# return self.dict_inp_map[dataset_name]
|
||||||
|
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)
|
||||||
@@ -47,7 +47,7 @@ class ReactGraph(GraphBase):
|
|||||||
|
|
||||||
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
self.tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||||
memory = MemorySaver()
|
memory = MemorySaver()
|
||||||
tools = self.tool_manager.get_langchain_tools()
|
tools = self.tool_manager.get_list_langchain_tools()
|
||||||
self.agent = create_agent(self.llm, tools, checkpointer=memory)
|
self.agent = create_agent(self.llm, tools, checkpointer=memory)
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ from PIL import Image
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import jax
|
import jax
|
||||||
|
import os.path as osp
|
||||||
|
import commentjson
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
@@ -34,9 +36,20 @@ class RoutingConfig(KeyConfig):
|
|||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||||
|
|
||||||
|
sys_promp_json: str = None
|
||||||
|
"path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.sys_promp_json is None:
|
||||||
|
self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json")
|
||||||
|
logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}")
|
||||||
|
|
||||||
|
assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist."
|
||||||
|
|
||||||
|
|
||||||
class Route(BaseModel):
|
class Route(BaseModel):
|
||||||
step: Literal["chat", "order"] = Field(
|
step: Literal["chat", "order"] = Field(
|
||||||
@@ -55,7 +68,11 @@ class State(TypedDict):
|
|||||||
class RoutingGraph(GraphBase):
|
class RoutingGraph(GraphBase):
|
||||||
def __init__(self, config: RoutingConfig):
|
def __init__(self, config: RoutingConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.chat_sys_msg = None
|
|
||||||
|
# NOTE: tool that the chatbranch should have
|
||||||
|
self.chat_tool_names = ["retrieve",
|
||||||
|
"get_resources"]
|
||||||
|
|
||||||
self._build_modules()
|
self._build_modules()
|
||||||
|
|
||||||
self.workflow = self._build_graph()
|
self.workflow = self._build_graph()
|
||||||
@@ -87,24 +104,30 @@ class RoutingGraph(GraphBase):
|
|||||||
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||||
assert len(kwargs) == 0, "due to inp assumptions"
|
assert len(kwargs) == 0, "due to inp assumptions"
|
||||||
|
|
||||||
|
def _get_chat_tools(self, man:ToolManager):
|
||||||
|
return [lang_tool for lang_tool in man.get_list_langchain_tools() if lang_tool.name in self.chat_tool_names]
|
||||||
|
|
||||||
def _build_modules(self):
|
def _build_modules(self):
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
model_provider=self.config.llm_provider,
|
model_provider=self.config.llm_provider,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url)
|
base_url=self.config.base_url)
|
||||||
self.memory = MemorySaver()
|
self.memory = MemorySaver() # shared memory between the two branch
|
||||||
self.router = self.llm.with_structured_output(Route)
|
self.router = self.llm.with_structured_output(Route)
|
||||||
|
|
||||||
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
tool_manager:ToolManager = self.config.tool_manager_config.setup()
|
||||||
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
|
self.chat_model = create_agent(self.llm, self._get_chat_tools(tool_manager), checkpointer=self.memory)
|
||||||
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
|
self.tool_model = create_agent(self.llm, tool_manager.get_list_langchain_tools(), checkpointer=self.memory)
|
||||||
|
|
||||||
|
with open(self.config.sys_promp_json , "r") as f:
|
||||||
|
self.prompt_dict:Dict[str, str] = commentjson.load(f)
|
||||||
|
|
||||||
|
|
||||||
def _router_call(self, state:State):
|
def _router_call(self, state:State):
|
||||||
decision:Route = self.router.invoke(
|
decision:Route = self.router.invoke(
|
||||||
[
|
[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
content=self.prompt_dict["route_prompt"]
|
||||||
),
|
),
|
||||||
self._get_human_msg(state)
|
self._get_human_msg(state)
|
||||||
]
|
]
|
||||||
@@ -137,6 +160,15 @@ class RoutingGraph(GraphBase):
|
|||||||
inp = state["messages"], state["inp"][1]
|
inp = state["messages"], state["inp"][1]
|
||||||
else:
|
else:
|
||||||
inp = state["inp"]
|
inp = state["inp"]
|
||||||
|
|
||||||
|
if self.prompt_dict.get("chat_prompt") is not None:
|
||||||
|
inp = {"messages":[
|
||||||
|
SystemMessage(
|
||||||
|
self.prompt_dict["chat_prompt"]
|
||||||
|
),
|
||||||
|
*state["inp"][0]["messages"][1:]
|
||||||
|
]}, state["inp"][1]
|
||||||
|
|
||||||
|
|
||||||
out = self.chat_model.invoke(*inp)
|
out = self.chat_model.invoke(*inp)
|
||||||
return {"messages": out}
|
return {"messages": out}
|
||||||
@@ -145,9 +177,8 @@ class RoutingGraph(GraphBase):
|
|||||||
def _tool_model_call(self, state:State):
|
def _tool_model_call(self, state:State):
|
||||||
inp = {"messages":[
|
inp = {"messages":[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
"You must use tool to complete the possible task"
|
self.prompt_dict["tool_prompt"]
|
||||||
),
|
),
|
||||||
# self._get_human_msg(state)
|
|
||||||
*state["inp"][0]["messages"][1:]
|
*state["inp"][0]["messages"][1:]
|
||||||
]}, state["inp"][1]
|
]}, state["inp"][1]
|
||||||
|
|
||||||
|
|||||||
@@ -11,20 +11,19 @@ from fastmcp.tools.tool import FunctionTool
|
|||||||
from lang_agent.config import InstantiateConfig, ToolConfig
|
from lang_agent.config import InstantiateConfig, ToolConfig
|
||||||
from lang_agent.base import LangToolBase
|
from lang_agent.base import LangToolBase
|
||||||
|
|
||||||
## import tool configs
|
|
||||||
from lang_agent.rag.simple import SimpleRagConfig
|
from lang_agent.rag.simple import SimpleRagConfig
|
||||||
from lang_agent.dummy.calculator import CalculatorConfig
|
from lang_agent.dummy.calculator import CalculatorConfig
|
||||||
from catering_end.lang_tool import CartToolConfig, CartTool
|
from catering_end.lang_tool import CartToolConfig, CartTool
|
||||||
|
|
||||||
# from langchain.tools import StructuredTool
|
|
||||||
from langchain_core.tools.structured import StructuredTool
|
from langchain_core.tools.structured import StructuredTool
|
||||||
|
import jax
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolManagerConfig(InstantiateConfig):
|
class ToolManagerConfig(InstantiateConfig):
|
||||||
_target: Type = field(default_factory=lambda: ToolManager)
|
_target: Type = field(default_factory=lambda: ToolManager)
|
||||||
|
|
||||||
# tool configs here;
|
# tool configs here; MUST HAVE 'config' in name and must be dataclass
|
||||||
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
rag_config: SimpleRagConfig = field(default_factory=SimpleRagConfig)
|
||||||
|
|
||||||
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
cart_config: CartToolConfig = field(default_factory=CartToolConfig)
|
||||||
@@ -59,6 +58,7 @@ def async_to_sync(async_func: Callable) -> Callable:
|
|||||||
|
|
||||||
return sync_wrapper
|
return sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
def __init__(self, config:ToolManagerConfig):
|
def __init__(self, config:ToolManagerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -91,29 +91,39 @@ class ToolManager:
|
|||||||
self.tool_fncs = []
|
self.tool_fncs = []
|
||||||
tool_configs = self._get_tool_config()
|
tool_configs = self._get_tool_config()
|
||||||
for tool_conf in tool_configs:
|
for tool_conf in tool_configs:
|
||||||
|
tool_name = tool_conf.get_name()[:-6]
|
||||||
if tool_conf.use_tool:
|
if tool_conf.use_tool:
|
||||||
logger.info(f"making tool:{tool_conf._target}")
|
logger.info(f"making tool:{tool_name}")
|
||||||
self.tool_fncs.extend(self._get_tool_fnc(tool_conf.setup()))
|
fnc_list = self._get_tool_fnc(tool_conf.setup())
|
||||||
|
self.tool_fncs.extend(fnc_list)
|
||||||
else:
|
else:
|
||||||
logger.info(f"skipping tool:{tool_conf._target}")
|
logger.info(f"skipping tool:{tool_name}")
|
||||||
|
|
||||||
|
self._build_langchain_tools()
|
||||||
|
|
||||||
|
|
||||||
def get_tool_fncs(self):
|
def get_tool_fncs(self):
|
||||||
return self.tool_fncs
|
return self.tool_fncs
|
||||||
|
|
||||||
|
def get_tool_dict(self):
|
||||||
def get_langchain_tools(self):
|
return self.tool_dict
|
||||||
out = []
|
|
||||||
for func in self.get_tool_fncs():
|
|
||||||
if inspect.iscoroutinefunction(func):
|
|
||||||
out.append(
|
|
||||||
StructuredTool.from_function(
|
|
||||||
func=async_to_sync(func),
|
|
||||||
coroutine=func)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out.append(
|
|
||||||
StructuredTool.from_function(func=func)
|
|
||||||
)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
def fnc_to_structool(self, func):
|
||||||
|
if inspect.iscoroutinefunction(func):
|
||||||
|
return StructuredTool.from_function(
|
||||||
|
func=async_to_sync(func),
|
||||||
|
coroutine=func)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return StructuredTool.from_function(func=func)
|
||||||
|
|
||||||
|
def _build_langchain_tools(self):
|
||||||
|
self.langchain_tools = []
|
||||||
|
for func in self.get_tool_fncs():
|
||||||
|
self.langchain_tools.append(self.fnc_to_structool(func))
|
||||||
|
|
||||||
|
return self.langchain_tools
|
||||||
|
|
||||||
|
def get_list_langchain_tools(self)->List[StructuredTool]:
|
||||||
|
return self.langchain_tools
|
||||||
@@ -5,6 +5,7 @@ dependencies = [
|
|||||||
"langchain==1.0",
|
"langchain==1.0",
|
||||||
"langchain_community",
|
"langchain_community",
|
||||||
"langchain-openai",
|
"langchain-openai",
|
||||||
|
"langchain_mcp_adapters",
|
||||||
"httpx[socks]",
|
"httpx[socks]",
|
||||||
"dashscope",
|
"dashscope",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
@@ -19,7 +20,8 @@ dependencies = [
|
|||||||
"fastapi",
|
"fastapi",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"jax"
|
"jax",
|
||||||
|
"commentjson"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ def main(conf:PipelineConfig):
|
|||||||
# response = pipeline.chat(user_input, as_stream=True)
|
# response = pipeline.chat(user_input, as_stream=True)
|
||||||
# print(f"回答: {response}")
|
# print(f"回答: {response}")
|
||||||
|
|
||||||
out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
|
# out = pipeline.chat("用工具算6856854-416846等于多少;然后解释它是怎么算出来的", as_stream=True)
|
||||||
|
out = pipeline.chat("介绍一下自己", as_stream=True)
|
||||||
# out = pipeline.chat("testing", as_stream=True)
|
# out = pipeline.chat("testing", as_stream=True)
|
||||||
print("=========== final ==========")
|
print("=========== final ==========")
|
||||||
print(out)
|
print(out)
|
||||||
|
|||||||
31
scripts/make_eval_dataset.py
Normal file
31
scripts/make_eval_dataset.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from langsmith import Client
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_NAME = "dev_langagent"
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"inputs": {"text": "用retrieve查询光予尘然后介绍"},
|
||||||
|
"outputs": {"answer": "茉莉绿茶为底,清冽茶香中漫出玫珑蜜瓜的绵甜与凤梨的明亮果香,层次鲜活;顶部白柚茉莉泡沫轻盈漫过舌尖,带着微酸的清新感,让整体风味更显灵动",
|
||||||
|
"tool_use": ["retrieve"]}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍一下自己"},
|
||||||
|
"outputs": {"answer": "我是小盏,是一个点餐助手"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
cli = Client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = cli.read_dataset(dataset_name=DATASET_NAME)
|
||||||
|
logger.info("read dataset")
|
||||||
|
except:
|
||||||
|
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
|
||||||
|
logger.info("created dataset")
|
||||||
|
|
||||||
|
cli.create_examples(
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
examples=examples
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user