return default if not specified

This commit is contained in:
2025-10-27 15:21:12 +08:00
parent b3f35d5319
commit 880a573c42

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Type, Literal from typing import Type, Callable, List
import tyro import tyro
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
@@ -71,9 +71,10 @@ class Validator:
return pipeline.chat(inp, as_raw=True) return pipeline.chat(inp, as_raw=True)
def get_val_fnc(self, dataset_name:str): def get_val_fnc(self, dataset_name:str)->List[Callable]:
return self.dict_corr_map[dataset_name] return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct])
def get_inp_fnc(self,dataset_name:str): def get_inp_fnc(self,dataset_name:str)->Callable:
return self.dict_inp_map[dataset_name] # return self.dict_inp_map[dataset_name]
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse)