sync returns non-string when streaming

This commit is contained in:
2026-01-30 11:18:52 +08:00
parent 94dcc95881
commit f8791104cb

View File

@@ -37,14 +37,33 @@ class GraphBase(ABC):
def _stream_result(self, *nargs, **kwargs): def _stream_result(self, *nargs, **kwargs):
# def text_iterator():
# for chunk, metadata in self.workflow.stream({"inp": nargs},
# stream_mode="messages",
# subgraphs=True,
# **kwargs):
# if isinstance(metadata, tuple):
# chunk, metadata = metadata
# tags = metadata.get("tags")
# if not (tags in self.streamable_tags):
# continue
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
# yield chunk.content
def text_iterator(): def text_iterator():
for chunk, metadata in self.workflow.stream({"inp": nargs}, for _, mode, out in self.workflow.stream({"inp": nargs},
stream_mode="messages", stream_mode=["messages", "values"],
subgraphs=True, subgraphs=True,
**kwargs): **kwargs):
if isinstance(metadata, tuple): if mode == "values":
chunk, metadata = metadata val = out.get("messages")
if val is not None:
yield val
continue
chunk, metadata = out
tags = metadata.get("tags") tags = metadata.get("tags")
if not (tags in self.streamable_tags): if not (tags in self.streamable_tags):
continue continue
@@ -55,7 +74,8 @@ class GraphBase(ABC):
text_releaser = TextReleaser(*self.textreleaser_delay_keys) text_releaser = TextReleaser(*self.textreleaser_delay_keys)
logger.info("streaming output") logger.info("streaming output")
for chunk in text_releaser.release(text_iterator()): for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True) if isinstance(chunk, str):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk yield chunk
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class # NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
@@ -138,6 +158,10 @@ class GraphBase(ABC):
assert len(kwargs) == 0, "due to inp assumptions" assert len(kwargs) == 0, "due to inp assumptions"
def _get_inp_msgs(self, state:State):
msgs = state["inp"][0]["messages"]
return [e for e in msgs if not isinstance(e, SystemMessage)]
def _agent_call_template(self, system_prompt:str, def _agent_call_template(self, system_prompt:str,
model:CompiledStateGraph, model:CompiledStateGraph,
state:State, state:State,
@@ -149,7 +173,7 @@ class GraphBase(ABC):
messages = [ messages = [
SystemMessage(system_prompt), SystemMessage(system_prompt),
*state["inp"][0]["messages"][1:] *self._get_inp_msgs(state)
] ]
if human_msg is not None: if human_msg is not None: