from dataclasses import dataclass, field from typing import Type, Literal, List import tyro from loguru import logger import functools import os import os.path as osp import glob import pandas as pd from lang_agent.config import InstantiateConfig from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.eval.validator import ValidatorConfig, Validator from langsmith import Client from langchain_core.messages import ToolMessage, BaseMessage @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class EvaluatorConfig(InstantiateConfig): _target: Type = field(default_factory=lambda:Evaluator) experiment_prefix:str = "simple test" """name of experiment""" experiment_desc:str = "testing if this works or not" """describe the experiment""" dataset_name:Literal["Toxic Queries"] = "QA_xiaozhan_sub" """name of the dataset to evaluate""" log_dir:str = "logs" pipe_config: PipelineConfig = field(default_factory=PipelineConfig) validator_config: ValidatorConfig = field(default_factory=ValidatorConfig) class Evaluator: def __init__(self, config: EvaluatorConfig): self.config = config self.populate_modules() def populate_modules(self): logger.info("preparing to run experiment") self.pipeline:Pipeline = self.config.pipe_config.setup() self.cli = Client() self.validator:Validator = self.config.validator_config.setup() self.dataset = self.cli.read_dataset(dataset_name=self.config.dataset_name) def evaluate(self): logger.info("running experiment") inp_fnc = self.validator.get_inp_fnc(self.config.dataset_name) runnable = functools.partial(inp_fnc, pipeline=self.pipeline) self.result = self.cli.evaluate( runnable, data=self.dataset.name, evaluators=self.validator.get_val_fnc(self.config.dataset_name), experiment_prefix=self.config.experiment_prefix, description=self.config.experiment_desc, max_concurrency=4, upload_results=False ) def save_results(self): os.makedirs(self.config.log_dir, exist_ok=True) assert hasattr(self, "result"), "NO RESULTS, run evaluate() before saving results" head_path = osp.join(self.config.log_dir, f"{self.dataset.name}-{self.config.experiment_prefix}") n_exp = len(glob.glob(f"{head_path}*")) exp_save_f = f"{head_path}-{n_exp}.csv" df = self.result.to_pandas() df = self.format_result_df(df) logger.info(f"saving experiment results to: {exp_save_f}") df.to_csv(exp_save_f, index=False) metric_col = [e for e in df.columns if "feedback" in e and not e.endswith(".comment")] + ["execution_time"] df_curr_m = df[metric_col].mean().to_frame().T df_curr_m.index = [f'{osp.basename(head_path)}-{n_exp}'] metric_f = osp.join(self.config.log_dir, "0_exp_metrics.csv") # start with 0 for first file in folder if osp.exists(metric_f): df_m = pd.read_csv(metric_f, index_col=0) df_m = pd.concat([df_m, df_curr_m]) else: df_m = df_curr_m df_m.to_csv(metric_f) self.config.save_config(f"{head_path}-{n_exp}.yaml") def format_result_df(self, df:pd.DataFrame): def map_fnc(out:List[BaseMessage]): return out[-1].content def extract_tool_out(out:List[BaseMessage]): rev = out[::-1] for msg in rev: if isinstance(msg, ToolMessage): return msg.content return None # outs = df["outputs.output"] # df["outputs.output"] = outs.apply(map_fnc) # df["tool_out"] = outs.apply(extract_tool_out) # Extract comments from raw results if enabled # if self.config.save_explanation: comments_map = {} # example_id -> {evaluator_key: comment} for res in self.result: example_id = str(res.get("example").id) if res.get("example") else "" eval_results = res.get("evaluation_results", {}).get("results", []) for eval_res in eval_results: key = eval_res.key if hasattr(eval_res, "key") else "" comment = eval_res.comment if hasattr(eval_res, "comment") else "" if comment and example_id: if example_id not in comments_map: comments_map[example_id] = {} comments_map[example_id][key] = comment # Add comment columns to DataFrame if comments_map and "example_id" in df.columns: evaluator_keys = set(k for v in comments_map.values() for k in v.keys()) for key in evaluator_keys: col_name = f"feedback.{key}.comment" df[col_name] = df["example_id"].apply( lambda eid: comments_map.get(str(eid), {}).get(key, "") ) return df