better formatting

This commit is contained in:
2026-01-08 19:56:11 +08:00
parent e7bece0be0
commit 6d9f7fe5b9

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Type, Literal
from typing import Type, Literal, List
import tyro
from loguru import logger
import functools
@@ -14,6 +14,8 @@ from lang_agent.eval.validator import ValidatorConfig, Validator
from langsmith import Client
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage # NOTE: this is used in 'eval()
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class EvaluatorConfig(InstantiateConfig):
@@ -76,6 +78,7 @@ class Evaluator:
exp_save_f = f"{head_path}-{n_exp}.csv"
df = self.result.to_pandas()
df = self.format_result_df(df)
logger.info(f"saving experiment results to: {exp_save_f}")
df.to_csv(exp_save_f, index=False)
@@ -95,4 +98,21 @@ class Evaluator:
self.config.save_config(f"{head_path}-{n_exp}.yml")
def format_result_df(self, df:pd.DataFrame):
def map_fnc(out:List[BaseMessage]):
return out[-1].content
def extract_tool_out(out:List[BaseMessage]):
rev = out[::-1]
for msg in rev:
if isinstance(msg, ToolMessage):
return msg.content
return None
outs = df["outputs.output"]
df["outputs.output"] = outs.apply(map_fnc)
df["tool_out"] = outs.apply(extract_tool_out)
return df