diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index 8d608c2..95905d4 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -40,19 +40,21 @@ class Validator: api_key=self.config.api_key ) - def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool: + def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> dict: instructions = ( "Given an actual answer and an expected answer, determine whether" " the actual answer contains all of the information in the" - " expected answer. Respond with 'CORRECT' if the actual answer" - " does contain all of the expected information and 'INCORRECT'" - " otherwise. Do not include anything else in your response." + " expected answer. First provide your reasoning, then respond with" + " your final judgment.\n\n" + "Format your response EXACTLY as follows:\n" + "EXPLANATION: \n" + "JUDGMENT: " ) actual_answer = outputs["output"][-1].content expected_answer = reference_outputs["answer"] if expected_answer is None: - return True + return {"score": True, "comment": "No expected answer provided, auto-pass."} user_msg = ( f"ACTUAL ANSWER: {actual_answer}" @@ -66,7 +68,24 @@ class Validator: ] ) - return response.content.upper() == "CORRECT" + response_text = response.content + + # Parse the explanation and judgment from the response + explanation = "" + is_correct = False + + if "EXPLANATION:" in response_text: + parts = response_text.split("JUDGMENT:") + explanation = parts[0].replace("EXPLANATION:", "").strip() + if len(parts) > 1: + judgment = parts[1].strip().upper() + is_correct = "CORRECT" in judgment and "INCORRECT" not in judgment + else: + # Fallback: check if response contains CORRECT/INCORRECT + explanation = response_text + is_correct = "CORRECT" in response_text.upper() and "INCORRECT" not in response_text.upper() + + return {"score": is_correct, "comment": explanation} def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->float: tool_uses:List[str] = reference_outputs.get("tool_use")