diff --git a/lang_agent/eval/evaluator.py b/lang_agent/eval/evaluator.py index 9994d46..4e0d9f7 100644 --- a/lang_agent/eval/evaluator.py +++ b/lang_agent/eval/evaluator.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from typing import Type, Literal import tyro from loguru import logger +import functools from lang_agent.config import InstantiateConfig from lang_agent.pipeline import Pipeline, PipelineConfig @@ -9,6 +10,9 @@ from lang_agent.eval.validator import ValidatorConfig, Validator from langsmith import Client +from langchain_core.messages import HumanMessage +from langchain_core.runnables import RunnableLambda + @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class EvaluatorConfig(InstantiateConfig): @@ -44,12 +48,18 @@ class Evaluator: def evaluate(self): logger.info("running experiment") + + + inp_fnc = self.validator.get_inp_fnc(self.config.dataset_name) + runnable = functools.partial(inp_fnc, pipeline=self.pipeline) + self.result = self.cli.evaluate( - self.pipeline.chat, + runnable, data=self.dataset.name, evaluators=[self.validator.get_val_fnc(self.config.dataset_name)], experiment_prefix=self.config.experiment_prefix, - description=self.config.experiment_desc + description=self.config.experiment_desc, + max_concurrency=4 )