diff --git a/lang_agent/base.py b/lang_agent/base.py index 0b496a5..86fcc52 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -139,18 +139,28 @@ class GraphBase(ABC): def _agent_call_template(self, system_prompt:str, model:CompiledStateGraph, - state:State): + state:State, + human_msg:str = None): if state.get("messages") is not None: inp = state["messages"], state["inp"][1] else: inp = state["inp"] - inp = {"messages":[ - SystemMessage( - system_prompt - ), - *state["inp"][0]["messages"][1:] - ]}, state["inp"][1] + if human_msg is None: + inp = {"messages":[ + SystemMessage( + system_prompt + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] + else: + inp = {"messages":[ + SystemMessage( + system_prompt + ), + *state["inp"][0]["messages"][1:], + HumanMessage(human_msg) + ]}, state["inp"][1] out = model.invoke(*inp)