diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 93f4e11..e303879 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -61,14 +61,12 @@ class RoutingGraph(GraphBase): self.workflow = self._build_graph() - def invoke(self, *nargs, as_stream:bool=False, **kwargs): - assert len(kwargs) == 0, "due to inp assumptions" + def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): + self._validate_input(*nargs, **kwargs) if as_stream: # TODO: this doesn't stream the entire process, we are blind for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs): - # if "messages" in step: - # step["messages"]["messages"][-1].pretty_print() last_el = jax.tree.leaves(step)[-1] if isinstance(last_el, str): logger.info(last_el) @@ -78,9 +76,16 @@ class RoutingGraph(GraphBase): state = step else: state = self.workflow.invoke({"inp": nargs}) - - # return state["messages"] - return jax.tree.leaves(state)[-1].content + + msg_list = jax.tree.leaves(state) + if as_raw: + return msg_list + + return msg_list[-1].content + + def _validate_input(self, *nargs, **kwargs): + assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" + assert len(kwargs) == 0, "due to inp assumptions" def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name,