return default if not specified
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, Literal
|
||||
from typing import Type, Callable, List
|
||||
import tyro
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
@@ -71,9 +71,10 @@ class Validator:
|
||||
return pipeline.chat(inp, as_raw=True)
|
||||
|
||||
|
||||
def get_val_fnc(self, dataset_name:str):
|
||||
return self.dict_corr_map[dataset_name]
|
||||
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
||||
return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct])
|
||||
|
||||
|
||||
def get_inp_fnc(self,dataset_name:str):
|
||||
return self.dict_inp_map[dataset_name]
|
||||
def get_inp_fnc(self,dataset_name:str)->Callable:
|
||||
# return self.dict_inp_map[dataset_name]
|
||||
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse)
|
||||
Reference in New Issue
Block a user