better formatting
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, Literal
|
||||
from typing import Type, Literal, List
|
||||
import tyro
|
||||
from loguru import logger
|
||||
import functools
|
||||
@@ -14,6 +14,8 @@ from lang_agent.eval.validator import ValidatorConfig, Validator
|
||||
|
||||
from langsmith import Client
|
||||
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage # NOTE: this is used in 'eval()
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class EvaluatorConfig(InstantiateConfig):
|
||||
@@ -76,6 +78,7 @@ class Evaluator:
|
||||
exp_save_f = f"{head_path}-{n_exp}.csv"
|
||||
|
||||
df = self.result.to_pandas()
|
||||
df = self.format_result_df(df)
|
||||
logger.info(f"saving experiment results to: {exp_save_f}")
|
||||
df.to_csv(exp_save_f, index=False)
|
||||
|
||||
@@ -95,4 +98,21 @@ class Evaluator:
|
||||
|
||||
self.config.save_config(f"{head_path}-{n_exp}.yml")
|
||||
|
||||
def format_result_df(self, df:pd.DataFrame):
|
||||
|
||||
def map_fnc(out:List[BaseMessage]):
|
||||
return out[-1].content
|
||||
|
||||
def extract_tool_out(out:List[BaseMessage]):
|
||||
rev = out[::-1]
|
||||
for msg in rev:
|
||||
if isinstance(msg, ToolMessage):
|
||||
return msg.content
|
||||
|
||||
return None
|
||||
|
||||
outs = df["outputs.output"]
|
||||
df["outputs.output"] = outs.apply(map_fnc)
|
||||
df["tool_out"] = outs.apply(extract_tool_out)
|
||||
|
||||
return df
|
||||
Reference in New Issue
Block a user