sync returns non-string when streaming
This commit is contained in:
@@ -37,14 +37,33 @@ class GraphBase(ABC):
|
||||
|
||||
def _stream_result(self, *nargs, **kwargs):
|
||||
|
||||
# def text_iterator():
|
||||
# for chunk, metadata in self.workflow.stream({"inp": nargs},
|
||||
# stream_mode="messages",
|
||||
# subgraphs=True,
|
||||
# **kwargs):
|
||||
# if isinstance(metadata, tuple):
|
||||
# chunk, metadata = metadata
|
||||
|
||||
# tags = metadata.get("tags")
|
||||
# if not (tags in self.streamable_tags):
|
||||
# continue
|
||||
|
||||
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||
# yield chunk.content
|
||||
|
||||
def text_iterator():
|
||||
for chunk, metadata in self.workflow.stream({"inp": nargs},
|
||||
stream_mode="messages",
|
||||
for _, mode, out in self.workflow.stream({"inp": nargs},
|
||||
stream_mode=["messages", "values"],
|
||||
subgraphs=True,
|
||||
**kwargs):
|
||||
if isinstance(metadata, tuple):
|
||||
chunk, metadata = metadata
|
||||
if mode == "values":
|
||||
val = out.get("messages")
|
||||
if val is not None:
|
||||
yield val
|
||||
continue
|
||||
|
||||
chunk, metadata = out
|
||||
tags = metadata.get("tags")
|
||||
if not (tags in self.streamable_tags):
|
||||
continue
|
||||
@@ -55,7 +74,8 @@ class GraphBase(ABC):
|
||||
text_releaser = TextReleaser(*self.textreleaser_delay_keys)
|
||||
logger.info("streaming output")
|
||||
for chunk in text_releaser.release(text_iterator()):
|
||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||
if isinstance(chunk, str):
|
||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||
yield chunk
|
||||
|
||||
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
|
||||
@@ -138,6 +158,10 @@ class GraphBase(ABC):
|
||||
|
||||
assert len(kwargs) == 0, "due to inp assumptions"
|
||||
|
||||
def _get_inp_msgs(self, state:State):
|
||||
msgs = state["inp"][0]["messages"]
|
||||
return [e for e in msgs if not isinstance(e, SystemMessage)]
|
||||
|
||||
def _agent_call_template(self, system_prompt:str,
|
||||
model:CompiledStateGraph,
|
||||
state:State,
|
||||
@@ -149,7 +173,7 @@ class GraphBase(ABC):
|
||||
|
||||
messages = [
|
||||
SystemMessage(system_prompt),
|
||||
*state["inp"][0]["messages"][1:]
|
||||
*self._get_inp_msgs(state)
|
||||
]
|
||||
|
||||
if human_msg is not None:
|
||||
|
||||
Reference in New Issue
Block a user