validate tool use
This commit is contained in:
@@ -6,6 +6,7 @@ from lang_agent.config import KeyConfig
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_core.messages import BaseMessage, ToolMessage
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
@@ -21,12 +22,12 @@ class Validator:
|
||||
|
||||
# NOTE: Need to register function here
|
||||
self.dict_corr_map = {
|
||||
"Toxic Queries" : [self.Toxic_Queries_correct]
|
||||
"dev_langagent" : [self.default_correct, self.val_tool_use]
|
||||
}
|
||||
|
||||
# NOTE: Need to register function here
|
||||
self.dict_inp_map = {
|
||||
"Toxic Queries" : self.Toxic_Queries_inp_parse
|
||||
"dev_langagent" : self.default_inp_parse
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +40,7 @@ class Validator:
|
||||
)
|
||||
|
||||
# NOTE: for every dataset; need one of these
|
||||
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool:
|
||||
def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
|
||||
instructions = (
|
||||
"Given an actual answer and an expected answer, determine whether"
|
||||
" the actual answer contains all of the information in the"
|
||||
@@ -48,7 +49,7 @@ class Validator:
|
||||
" otherwise. Do not include anything else in your response."
|
||||
)
|
||||
actual_answer = outputs["output"][-1].content
|
||||
expected_answer = reference_outputs["label"]
|
||||
expected_answer = reference_outputs["answer"]
|
||||
|
||||
user_msg = (
|
||||
f"ACTUAL ANSWER: {actual_answer}"
|
||||
@@ -64,17 +65,42 @@ class Validator:
|
||||
|
||||
return response.content.upper() == "CORRECT"
|
||||
|
||||
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
|
||||
tool_uses:List[str] = reference_outputs.get("tool_use")
|
||||
if tool_uses is None:
|
||||
return True
|
||||
|
||||
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
|
||||
|
||||
# check if all tools are used
|
||||
tool_used = []
|
||||
for ref_tool in tool_uses:
|
||||
st_cond = False
|
||||
ref_tool = ref_tool.lower()
|
||||
for msg in tool_msgs:
|
||||
st_cond = ref_tool in msg.name.lower()
|
||||
if st_cond:
|
||||
break
|
||||
tool_used.append(st_cond)
|
||||
|
||||
for cond in tool_used:
|
||||
if not cond:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
# NOTE: for every dataset; need one of these
|
||||
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline):
|
||||
def default_inp_parse(self, inp, pipeline:Pipeline):
|
||||
inp = inp["text"]
|
||||
return pipeline.chat(inp, as_raw=True)
|
||||
|
||||
|
||||
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
||||
return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct])
|
||||
return self.dict_corr_map.get(dataset_name, [self.default_correct])
|
||||
|
||||
|
||||
def get_inp_fnc(self,dataset_name:str)->Callable:
|
||||
# return self.dict_inp_map[dataset_name]
|
||||
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse)
|
||||
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)
|
||||
Reference in New Issue
Block a user