diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index a068913..e23f3a8 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -39,7 +39,6 @@ class Validator: api_key=self.config.api_key ) - # NOTE: for every dataset; need one of these def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool: instructions = ( "Given an actual answer and an expected answer, determine whether" @@ -72,7 +71,6 @@ class Validator: tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)] - # check if all tools are used tool_used = [] for ref_tool in tool_uses: st_cond = False @@ -85,9 +83,7 @@ class Validator: return sum(tool_used)/len(tool_uses) - - - # NOTE: for every dataset; need one of these + def default_inp_parse(self, inp, pipeline:Pipeline): inp = inp["text"]