141 lines
4.9 KiB
Python
141 lines
4.9 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import Type, Literal, List
|
|
import tyro
|
|
from loguru import logger
|
|
import functools
|
|
import os
|
|
import os.path as osp
|
|
import glob
|
|
import pandas as pd
|
|
|
|
from lang_agent.config import InstantiateConfig
|
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
|
from lang_agent.eval.validator import ValidatorConfig, Validator
|
|
|
|
from langsmith import Client
|
|
|
|
from langchain_core.messages import ToolMessage, BaseMessage
|
|
|
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
|
@dataclass
|
|
class EvaluatorConfig(InstantiateConfig):
|
|
_target: Type = field(default_factory=lambda:Evaluator)
|
|
|
|
experiment_prefix:str = "simple test"
|
|
"""name of experiment"""
|
|
|
|
experiment_desc:str = "testing if this works or not"
|
|
"""describe the experiment"""
|
|
|
|
dataset_name:Literal["Toxic Queries"] = "QA_xiaozhan_sub"
|
|
"""name of the dataset to evaluate"""
|
|
|
|
log_dir:str = "logs"
|
|
|
|
pipe_config: PipelineConfig = field(default_factory=PipelineConfig)
|
|
|
|
validator_config: ValidatorConfig = field(default_factory=ValidatorConfig)
|
|
|
|
|
|
class Evaluator:
|
|
def __init__(self, config: EvaluatorConfig):
|
|
self.config = config
|
|
|
|
self.populate_modules()
|
|
|
|
def populate_modules(self):
|
|
logger.info("preparing to run experiment")
|
|
self.pipeline:Pipeline = self.config.pipe_config.setup()
|
|
self.cli = Client()
|
|
self.validator:Validator = self.config.validator_config.setup()
|
|
self.dataset = self.cli.read_dataset(dataset_name=self.config.dataset_name)
|
|
|
|
|
|
def evaluate(self):
|
|
logger.info("running experiment")
|
|
|
|
inp_fnc = self.validator.get_inp_fnc(self.config.dataset_name)
|
|
runnable = functools.partial(inp_fnc, pipeline=self.pipeline)
|
|
|
|
self.result = self.cli.evaluate(
|
|
runnable,
|
|
data=self.dataset.name,
|
|
evaluators=self.validator.get_val_fnc(self.config.dataset_name),
|
|
experiment_prefix=self.config.experiment_prefix,
|
|
description=self.config.experiment_desc,
|
|
max_concurrency=4,
|
|
upload_results=False
|
|
)
|
|
|
|
|
|
def save_results(self):
|
|
os.makedirs(self.config.log_dir, exist_ok=True)
|
|
|
|
assert hasattr(self, "result"), "NO RESULTS, run evaluate() before saving results"
|
|
|
|
head_path = osp.join(self.config.log_dir, f"{self.dataset.name}-{self.config.experiment_prefix}")
|
|
n_exp = len(glob.glob(f"{head_path}*"))
|
|
exp_save_f = f"{head_path}-{n_exp}.csv"
|
|
|
|
df = self.result.to_pandas()
|
|
df = self.format_result_df(df)
|
|
logger.info(f"saving experiment results to: {exp_save_f}")
|
|
df.to_csv(exp_save_f, index=False)
|
|
|
|
metric_col = [e for e in df.columns if "feedback" in e and not e.endswith(".comment")] + ["execution_time"]
|
|
|
|
df_curr_m = df[metric_col].mean().to_frame().T
|
|
df_curr_m.index = [f'{osp.basename(head_path)}-{n_exp}']
|
|
|
|
metric_f = osp.join(self.config.log_dir, "0_exp_metrics.csv") # start with 0 for first file in folder
|
|
if osp.exists(metric_f):
|
|
df_m = pd.read_csv(metric_f, index_col=0)
|
|
df_m = pd.concat([df_m, df_curr_m])
|
|
else:
|
|
df_m = df_curr_m
|
|
|
|
df_m.to_csv(metric_f)
|
|
|
|
self.config.save_config(f"{head_path}-{n_exp}.yaml")
|
|
|
|
def format_result_df(self, df:pd.DataFrame):
|
|
|
|
def map_fnc(out:List[BaseMessage]):
|
|
return out[-1].content
|
|
|
|
def extract_tool_out(out:List[BaseMessage]):
|
|
rev = out[::-1]
|
|
for msg in rev:
|
|
if isinstance(msg, ToolMessage):
|
|
return msg.content
|
|
|
|
return None
|
|
|
|
# outs = df["outputs.output"]
|
|
# df["outputs.output"] = outs.apply(map_fnc)
|
|
# df["tool_out"] = outs.apply(extract_tool_out)
|
|
|
|
# Extract comments from raw results if enabled
|
|
# if self.config.save_explanation:
|
|
comments_map = {} # example_id -> {evaluator_key: comment}
|
|
for res in self.result:
|
|
example_id = str(res.get("example").id) if res.get("example") else ""
|
|
eval_results = res.get("evaluation_results", {}).get("results", [])
|
|
for eval_res in eval_results:
|
|
key = eval_res.key if hasattr(eval_res, "key") else ""
|
|
comment = eval_res.comment if hasattr(eval_res, "comment") else ""
|
|
if comment and example_id:
|
|
if example_id not in comments_map:
|
|
comments_map[example_id] = {}
|
|
comments_map[example_id][key] = comment
|
|
|
|
# Add comment columns to DataFrame
|
|
if comments_map and "example_id" in df.columns:
|
|
evaluator_keys = set(k for v in comments_map.values() for k in v.keys())
|
|
for key in evaluator_keys:
|
|
col_name = f"feedback.{key}.comment"
|
|
df[col_name] = df["example_id"].apply(
|
|
lambda eid: comments_map.get(str(eid), {}).get(key, "")
|
|
)
|
|
|
|
return df |