validate tool use
This commit is contained in:
@@ -6,6 +6,7 @@ from lang_agent.config import KeyConfig
|
|||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
from langchain_core.messages import BaseMessage, ToolMessage
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -21,12 +22,12 @@ class Validator:
|
|||||||
|
|
||||||
# NOTE: Need to register function here
|
# NOTE: Need to register function here
|
||||||
self.dict_corr_map = {
|
self.dict_corr_map = {
|
||||||
"Toxic Queries" : [self.Toxic_Queries_correct]
|
"dev_langagent" : [self.default_correct, self.val_tool_use]
|
||||||
}
|
}
|
||||||
|
|
||||||
# NOTE: Need to register function here
|
# NOTE: Need to register function here
|
||||||
self.dict_inp_map = {
|
self.dict_inp_map = {
|
||||||
"Toxic Queries" : self.Toxic_Queries_inp_parse
|
"dev_langagent" : self.default_inp_parse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ class Validator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: for every dataset; need one of these
|
# NOTE: for every dataset; need one of these
|
||||||
def Toxic_Queries_correct(self, inputs: dict, outputs: list, reference_outputs: dict) -> bool:
|
def default_correct(self, inputs: dict, outputs: dict, reference_outputs: dict) -> bool:
|
||||||
instructions = (
|
instructions = (
|
||||||
"Given an actual answer and an expected answer, determine whether"
|
"Given an actual answer and an expected answer, determine whether"
|
||||||
" the actual answer contains all of the information in the"
|
" the actual answer contains all of the information in the"
|
||||||
@@ -48,7 +49,7 @@ class Validator:
|
|||||||
" otherwise. Do not include anything else in your response."
|
" otherwise. Do not include anything else in your response."
|
||||||
)
|
)
|
||||||
actual_answer = outputs["output"][-1].content
|
actual_answer = outputs["output"][-1].content
|
||||||
expected_answer = reference_outputs["label"]
|
expected_answer = reference_outputs["answer"]
|
||||||
|
|
||||||
user_msg = (
|
user_msg = (
|
||||||
f"ACTUAL ANSWER: {actual_answer}"
|
f"ACTUAL ANSWER: {actual_answer}"
|
||||||
@@ -64,17 +65,42 @@ class Validator:
|
|||||||
|
|
||||||
return response.content.upper() == "CORRECT"
|
return response.content.upper() == "CORRECT"
|
||||||
|
|
||||||
|
def val_tool_use(self, inputs:dict, outputs:dict, reference_outputs:dict)->bool:
|
||||||
|
tool_uses:List[str] = reference_outputs.get("tool_use")
|
||||||
|
if tool_uses is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
tool_msgs = [e for e in outputs["output"] if isinstance(e, ToolMessage)]
|
||||||
|
|
||||||
|
# check if all tools are used
|
||||||
|
tool_used = []
|
||||||
|
for ref_tool in tool_uses:
|
||||||
|
st_cond = False
|
||||||
|
ref_tool = ref_tool.lower()
|
||||||
|
for msg in tool_msgs:
|
||||||
|
st_cond = ref_tool in msg.name.lower()
|
||||||
|
if st_cond:
|
||||||
|
break
|
||||||
|
tool_used.append(st_cond)
|
||||||
|
|
||||||
|
for cond in tool_used:
|
||||||
|
if not cond:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: for every dataset; need one of these
|
# NOTE: for every dataset; need one of these
|
||||||
def Toxic_Queries_inp_parse(self, inp, pipeline:Pipeline):
|
def default_inp_parse(self, inp, pipeline:Pipeline):
|
||||||
inp = inp["text"]
|
inp = inp["text"]
|
||||||
return pipeline.chat(inp, as_raw=True)
|
return pipeline.chat(inp, as_raw=True)
|
||||||
|
|
||||||
|
|
||||||
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
def get_val_fnc(self, dataset_name:str)->List[Callable]:
|
||||||
return self.dict_corr_map.get(dataset_name, [self.Toxic_Queries_correct])
|
return self.dict_corr_map.get(dataset_name, [self.default_correct])
|
||||||
|
|
||||||
|
|
||||||
def get_inp_fnc(self,dataset_name:str)->Callable:
|
def get_inp_fnc(self,dataset_name:str)->Callable:
|
||||||
# return self.dict_inp_map[dataset_name]
|
# return self.dict_inp_map[dataset_name]
|
||||||
return self.dict_corr_map.get(dataset_name, self.Toxic_Queries_inp_parse)
|
return self.dict_inp_map.get(dataset_name, self.default_inp_parse)
|
||||||
Reference in New Issue
Block a user