diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index 614ae80..9e570b5 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -6,6 +6,7 @@ from lang_agent.config import KeyConfig from lang_agent.pipeline import Pipeline, PipelineConfig from langchain.chat_models import init_chat_model +from langchain_core.messages import BaseMessage, ToolMessage @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass @@ -21,12 +22,12 @@ class Validator: # NOTE: Need to register function here self.dict_corr_map = { - "Toxic Queries" : [self.Toxic_Queries_correct] + "dev_langagent" : [self.default_correct, self.val_tool_use] } # NOTE: Need to register function here self.dict_inp_map = { - "Toxic Queries" : self.Toxic_Queries_inp_parse + "dev_langagent" : self.default_inp_parse } @@ -39,7 +40,7 @@ class Validator: ) # NOTE: for every dataset; need one of these - def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool: + def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool: instructions = ( "Given an actual answer and an expected answer, determine whether" " the actual answer contains all of the information in the" @@ -48,7 +49,7 @@ class Validator: " otherwise. Do not include anything else in your response." ) actual_answer = outputs["output"][-1].content - expected_answer = reference_outputs["label"] + expected_answer = reference_outputs["answer"] user_msg = ( f"ACTUAL ANSWER: {actual_answer}" @@ -64,17 +65,42 @@ class Validator: return response.content.upper() == "CORRECT" + def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool: + tool_uses:List[str] = reference_outputs.get("tool_use") + if tool_uses is None: + return True + + tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)] + + # check if all tools are used + tool_used = [] + for ref_tool in tool_uses: + st_cond = False + ref_tool = ref_tool.lower() + for msg in tool_msgs: + st_cond = ref_tool in msg.name.lower() + if st_cond: + break + tool_used.append(st_cond) + + for cond in tool_used: + if not cond: + return False + + return True + + # NOTE: for every dataset; need one of these - def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline): + def default_inp_parse(self, inp, pipeline:Pipeline): inp = inp["text"] return pipeline.chat(inp, as_raw=True) def get_val_fnc(self, dataset_name:str)->List[Callable]: - return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct]) + return self.dict_corr_map.get(dataset_name, [self.default_correct]) def get_inp_fnc(self,dataset_name:str)->Callable: # return self.dict_inp_map[dataset_name] - return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse) \ No newline at end of file + return self.dict_inp_map.get(dataset_name, self.default_inp_parse) \ No newline at end of file