validate tool use

This commit is contained in:
2025-10-27 16:39:20 +08:00
parent 880a573c42
commit 51ac83401b

View File

@@ -6,6 +6,7 @@ from lang_agent.config import KeyConfig
from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.pipeline import Pipeline, PipelineConfig
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import BaseMessage, ToolMessage
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
@@ -21,12 +22,12 @@ class Validator:
# NOTE: Need to register function here # NOTE: Need to register function here
self.dict_corr_map = { self.dict_corr_map = {
"Toxic Queries" : [self.Toxic_Queries_correct] "dev_langagent" : [self.default_correct, self.val_tool_use]
} }
# NOTE: Need to register function here # NOTE: Need to register function here
self.dict_inp_map = { self.dict_inp_map = {
"Toxic Queries" : self.Toxic_Queries_inp_parse "dev_langagent" : self.default_inp_parse
} }
@@ -39,7 +40,7 @@ class Validator:
) )
# NOTE: for every dataset; need one of these # NOTE: for every dataset; need one of these
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool: def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
instructions = ( instructions = (
"Given an actual answer and an expected answer, determine whether" "Given an actual answer and an expected answer, determine whether"
" the actual answer contains all of the information in the" " the actual answer contains all of the information in the"
@@ -48,7 +49,7 @@ class Validator:
" otherwise. Do not include anything else in your response." " otherwise. Do not include anything else in your response."
) )
actual_answer = outputs["output"][-1].content actual_answer = outputs["output"][-1].content
expected_answer = reference_outputs["label"] expected_answer = reference_outputs["answer"]
user_msg = ( user_msg = (
f"ACTUAL ANSWER: {actual_answer}" f"ACTUAL ANSWER: {actual_answer}"
@@ -64,17 +65,42 @@ class Validator:
return response.content.upper() == "CORRECT" return response.content.upper() == "CORRECT"
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
tool_uses:List[str] = reference_outputs.get("tool_use")
if tool_uses is None:
return True
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
# check if all tools are used
tool_used = []
for ref_tool in tool_uses:
st_cond = False
ref_tool = ref_tool.lower()
for msg in tool_msgs:
st_cond = ref_tool in msg.name.lower()
if st_cond:
break
tool_used.append(st_cond)
for cond in tool_used:
if not cond:
return False
return True
# NOTE: for every dataset; need one of these # NOTE: for every dataset; need one of these
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline): def default_inp_parse(self, inp, pipeline:Pipeline):
inp = inp["text"] inp = inp["text"]
return pipeline.chat(inp, as_raw=True) return pipeline.chat(inp, as_raw=True)
def get_val_fnc(self, dataset_name:str)->List[Callable]: def get_val_fnc(self, dataset_name:str)->List[Callable]:
return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct]) return self.dict_corr_map.get(dataset_name, [self.default_correct])
def get_inp_fnc(self,dataset_name:str)->Callable: def get_inp_fnc(self,dataset_name:str)->Callable:
# return self.dict_inp_map[dataset_name] # return self.dict_inp_map[dataset_name]
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse) return self.dict_inp_map.get(dataset_name, self.default_inp_parse)