better formatting

This commit is contained in:
2026-01-08 19:56:11 +08:00
parent e7bece0be0
commit 6d9f7fe5b9

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type, Literal from typing import Type, Literal, List
import tyro import tyro
from loguru import logger from loguru import logger
import functools import functools
@@ -14,6 +14,8 @@ from lang_agent.eval.validator import ValidatorConfig, Validator
from langsmith import Client from langsmith import Client
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage # NOTE: this is used in 'eval()
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class EvaluatorConfig(InstantiateConfig): class EvaluatorConfig(InstantiateConfig):
@@ -76,6 +78,7 @@ class Evaluator:
exp_save_f = f"{head_path}-{n_exp}.csv" exp_save_f = f"{head_path}-{n_exp}.csv"
df = self.result.to_pandas() df = self.result.to_pandas()
df = self.format_result_df(df)
logger.info(f"saving experiment results to: {exp_save_f}") logger.info(f"saving experiment results to: {exp_save_f}")
df.to_csv(exp_save_f, index=False) df.to_csv(exp_save_f, index=False)
@@ -95,4 +98,21 @@ class Evaluator:
self.config.save_config(f"{head_path}-{n_exp}.yml") self.config.save_config(f"{head_path}-{n_exp}.yml")
def format_result_df(self, df:pd.DataFrame):
def map_fnc(out:List[BaseMessage]):
return out[-1].content
def extract_tool_out(out:List[BaseMessage]):
rev = out[::-1]
for msg in rev:
if isinstance(msg, ToolMessage):
return msg.content
return None
outs = df["outputs.output"]
df["outputs.output"] = outs.apply(map_fnc)
df["tool_out"] = outs.apply(extract_tool_out)
return df