update base async to also return
This commit is contained in:
@@ -37,21 +37,6 @@ class GraphBase(ABC):
|
|||||||
|
|
||||||
def _stream_result(self, *nargs, **kwargs):
|
def _stream_result(self, *nargs, **kwargs):
|
||||||
|
|
||||||
# def text_iterator():
|
|
||||||
# for chunk, metadata in self.workflow.stream({"inp": nargs},
|
|
||||||
# stream_mode="messages",
|
|
||||||
# subgraphs=True,
|
|
||||||
# **kwargs):
|
|
||||||
# if isinstance(metadata, tuple):
|
|
||||||
# chunk, metadata = metadata
|
|
||||||
|
|
||||||
# tags = metadata.get("tags")
|
|
||||||
# if not (tags in self.streamable_tags):
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
|
||||||
# yield chunk.content
|
|
||||||
|
|
||||||
def text_iterator():
|
def text_iterator():
|
||||||
for _, mode, out in self.workflow.stream({"inp": nargs},
|
for _, mode, out in self.workflow.stream({"inp": nargs},
|
||||||
stream_mode=["messages", "values"],
|
stream_mode=["messages", "values"],
|
||||||
@@ -127,15 +112,17 @@ class GraphBase(ABC):
|
|||||||
"""Async streaming using LangGraph's astream method."""
|
"""Async streaming using LangGraph's astream method."""
|
||||||
|
|
||||||
async def text_iterator():
|
async def text_iterator():
|
||||||
async for chunk, metadata in self.workflow.astream(
|
async for _, mode, out in self.workflow.astream({"inp": nargs},
|
||||||
{"inp": nargs},
|
stream_mode=["messages", "values"],
|
||||||
stream_mode="messages",
|
|
||||||
subgraphs=True,
|
subgraphs=True,
|
||||||
**kwargs
|
**kwargs):
|
||||||
):
|
if mode == "values":
|
||||||
if isinstance(metadata, tuple):
|
val = out.get("messages")
|
||||||
chunk, metadata = metadata
|
if val is not None:
|
||||||
|
yield val
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk, metadata = out
|
||||||
tags = metadata.get("tags")
|
tags = metadata.get("tags")
|
||||||
if not (tags in self.streamable_tags):
|
if not (tags in self.streamable_tags):
|
||||||
continue
|
continue
|
||||||
@@ -144,7 +131,9 @@ class GraphBase(ABC):
|
|||||||
yield chunk.content
|
yield chunk.content
|
||||||
|
|
||||||
text_releaser = AsyncTextReleaser(*self.textreleaser_delay_keys)
|
text_releaser = AsyncTextReleaser(*self.textreleaser_delay_keys)
|
||||||
|
logger.info("streaming output")
|
||||||
async for chunk in text_releaser.release(text_iterator()):
|
async for chunk in text_releaser.release(text_iterator()):
|
||||||
|
if isinstance(chunk, str):
|
||||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user