from dataclasses import dataclass, field from typing import Type, Literal import tyro from loguru import logger import functools from lang_agent.config import InstantiateConfig from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.eval.validator import ValidatorConfig, Validator from langsmith import Client from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableLambda @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class EvaluatorConfig(InstantiateConfig): _target: Type = field(default_factory=lambda:Evaluator) experiment_prefix:str = "simple test" """name of experiment""" experiment_desc:str = "testing if this works or not" """describe the experiment""" dataset_name:Literal["Toxic Queries"] = "dev_langagent" """name of the dataset to evaluate""" pipe_config: PipelineConfig = field(default_factory=PipelineConfig) validator_config: ValidatorConfig = field(default_factory=ValidatorConfig) class Evaluator: def __init__(self, config: EvaluatorConfig): self.config = config self.populate_modules() def populate_modules(self): logger.info("preparing to run experiment") self.pipeline:Pipeline = self.config.pipe_config.setup() self.cli = Client() self.validator:Validator = self.config.validator_config.setup() self.dataset = self.cli.read_dataset(dataset_name=self.config.dataset_name) def evaluate(self): logger.info("running experiment") inp_fnc = self.validator.get_inp_fnc(self.config.dataset_name) runnable = functools.partial(inp_fnc, pipeline=self.pipeline) self.result = self.cli.evaluate( runnable, data=self.dataset.name, evaluators=self.validator.get_val_fnc(self.config.dataset_name), experiment_prefix=self.config.experiment_prefix, description=self.config.experiment_desc, max_concurrency=4 )