From 6d9f7fe5b98bb3705d0e5a109feb1964aa44ffd5 Mon Sep 17 00:00:00 2001 From: goulustis Date: Thu, 8 Jan 2026 19:56:11 +0800 Subject: [PATCH] better formatting --- lang_agent/eval/evaluator.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/lang_agent/eval/evaluator.py b/lang_agent/eval/evaluator.py index 596e994..ccac461 100644 --- a/lang_agent/eval/evaluator.py +++ b/lang_agent/eval/evaluator.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Type, Literal +from typing import Type, Literal, List import tyro from loguru import logger import functools @@ -14,6 +14,8 @@ from lang_agent.eval.validator import ValidatorConfig, Validator from langsmith import Client +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage # NOTE: this is used in 'eval() + @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class EvaluatorConfig(InstantiateConfig): @@ -76,6 +78,7 @@ class Evaluator: exp_save_f = f"{head_path}-{n_exp}.csv" df = self.result.to_pandas() + df = self.format_result_df(df) logger.info(f"saving experiment results to: {exp_save_f}") df.to_csv(exp_save_f, index=False) @@ -94,5 +97,22 @@ class Evaluator: df_m.to_csv(metric_f) self.config.save_config(f"{head_path}-{n_exp}.yml") + + def format_result_df(self, df:pd.DataFrame): + def map_fnc(out:List[BaseMessage]): + return out[-1].content + def extract_tool_out(out:List[BaseMessage]): + rev = out[::-1] + for msg in rev: + if isinstance(msg, ToolMessage): + return msg.content + + return None + + outs = df["outputs.output"] + df["outputs.output"] = outs.apply(map_fnc) + df["tool_out"] = outs.apply(extract_tool_out) + + return df \ No newline at end of file