validate tool use

This commit is contained in:
2025-10-27 16:39:20 +08:00
parent 880a573c42
commit 51ac83401b

View File

@@ -6,6 +6,7 @@ from lang_agent.config import KeyConfig
from lang_agent.pipeline import Pipeline, PipelineConfig
from langchain.chat_models import init_chat_model
from langchain_core.messages import BaseMessage, ToolMessage
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
@@ -21,12 +22,12 @@ class Validator:
# NOTE: Need to register function here
self.dict_corr_map = {
"Toxic Queries" : [self.Toxic_Queries_correct]
"dev_langagent" : [self.default_correct, self.val_tool_use]
}
# NOTE: Need to register function here
self.dict_inp_map = {
"Toxic Queries" : self.Toxic_Queries_inp_parse
"dev_langagent" : self.default_inp_parse
}
@@ -39,7 +40,7 @@ class Validator:
)
# NOTE: for every dataset; need one of these
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool:
def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
instructions = (
"Given an actual answer and an expected answer, determine whether"
" the actual answer contains all of the information in the"
@@ -48,7 +49,7 @@ class Validator:
" otherwise. Do not include anything else in your response."
)
actual_answer = outputs["output"][-1].content
expected_answer = reference_outputs["label"]
expected_answer = reference_outputs["answer"]
user_msg = (
f"ACTUAL ANSWER: {actual_answer}"
@@ -64,17 +65,42 @@ class Validator:
return response.content.upper() == "CORRECT"
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
tool_uses:List[str] = reference_outputs.get("tool_use")
if tool_uses is None:
return True
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
# check if all tools are used
tool_used = []
for ref_tool in tool_uses:
st_cond = False
ref_tool = ref_tool.lower()
for msg in tool_msgs:
st_cond = ref_tool in msg.name.lower()
if st_cond:
break
tool_used.append(st_cond)
for cond in tool_used:
if not cond:
return False
return True
# NOTE: for every dataset; need one of these
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline):
def default_inp_parse(self, inp, pipeline:Pipeline):
inp = inp["text"]
return pipeline.chat(inp, as_raw=True)
def get_val_fnc(self, dataset_name:str)->List[Callable]:
return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct])
return self.dict_corr_map.get(dataset_name, [self.default_correct])
def get_inp_fnc(self,dataset_name:str)->Callable:
# return self.dict_inp_map[dataset_name]
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse)
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)