get raw
This commit is contained in:
@@ -61,14 +61,12 @@ class RoutingGraph(GraphBase):
|
|||||||
self.workflow = self._build_graph()
|
self.workflow = self._build_graph()
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, *nargs, as_stream:bool=False, **kwargs):
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
||||||
assert len(kwargs) == 0, "due to inp assumptions"
|
self._validate_input(*nargs, **kwargs)
|
||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
# TODO: this doesn't stream the entire process, we are blind
|
# TODO: this doesn't stream the entire process, we are blind
|
||||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
|
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
|
||||||
# if "messages" in step:
|
|
||||||
# step["messages"]["messages"][-1].pretty_print()
|
|
||||||
last_el = jax.tree.leaves(step)[-1]
|
last_el = jax.tree.leaves(step)[-1]
|
||||||
if isinstance(last_el, str):
|
if isinstance(last_el, str):
|
||||||
logger.info(last_el)
|
logger.info(last_el)
|
||||||
@@ -79,8 +77,15 @@ class RoutingGraph(GraphBase):
|
|||||||
else:
|
else:
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
# return state["messages"]
|
msg_list = jax.tree.leaves(state)
|
||||||
return jax.tree.leaves(state)[-1].content
|
if as_raw:
|
||||||
|
return msg_list
|
||||||
|
|
||||||
|
return msg_list[-1].content
|
||||||
|
|
||||||
|
def _validate_input(self, *nargs, **kwargs):
|
||||||
|
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||||
|
assert len(kwargs) == 0, "due to inp assumptions"
|
||||||
|
|
||||||
def _build_modules(self):
|
def _build_modules(self):
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user