return default if not specified
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Type, Literal
|
from typing import Type, Callable, List
|
||||||
import tyro
|
import tyro
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
@@ -71,9 +71,10 @@ class Validator:
|
|||||||
return pipeline.chat(inp, as_raw=True)
|
return pipeline.chat(inp, as_raw=True)
|
||||||
|
|
||||||
|
|
||||||
def get_val_fnc(self, dataset_name:str):
|
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
||||||
return self.dict_corr_map[dataset_name]
|
return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct])
|
||||||
|
|
||||||
|
|
||||||
def get_inp_fnc(self,dataset_name:str):
|
def get_inp_fnc(self,dataset_name:str)->Callable:
|
||||||
return self.dict_inp_map[dataset_name]
|
# return self.dict_inp_map[dataset_name]
|
||||||
|
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse)
|
||||||
Reference in New Issue
Block a user