This commit is contained in:
2025-10-23 20:25:43 +08:00
parent 8823e9039d
commit 4f14017364

View File

@@ -61,14 +61,12 @@ class RoutingGraph(GraphBase):
self.workflow = self._build_graph() self.workflow = self._build_graph()
def invoke(self, *nargs, as_stream:bool=False, **kwargs): def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
assert len(kwargs) == 0, "due to inp assumptions" self._validate_input(*nargs, **kwargs)
if as_stream: if as_stream:
# TODO this doesn't stream the entire process, we are blind # TODO this doesn't stream the entire process, we are blind
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs): for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
# if "messages" in step:
# step["messages"]["messages"][-1].pretty_print()
last_el = jax.tree.leaves(step)[-1] last_el = jax.tree.leaves(step)[-1]
if isinstance(last_el, str): if isinstance(last_el, str):
logger.info(last_el) logger.info(last_el)
@@ -79,8 +77,15 @@ class RoutingGraph(GraphBase):
else: else:
state = self.workflow.invoke({"inp": nargs}) state = self.workflow.invoke({"inp": nargs})
# return state["messages"] msg_list = jax.tree.leaves(state)
return jax.tree.leaves(state)[-1].content if as_raw:
return msg_list
return msg_list[-1].content
def _validate_input(self, *nargs, **kwargs):
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def _build_modules(self): def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,