sync returns non-string when streaming
This commit is contained in:
@@ -37,14 +37,33 @@ class GraphBase(ABC):
|
|||||||
|
|
||||||
def _stream_result(self, *nargs, **kwargs):
|
def _stream_result(self, *nargs, **kwargs):
|
||||||
|
|
||||||
|
# def text_iterator():
|
||||||
|
# for chunk, metadata in self.workflow.stream({"inp": nargs},
|
||||||
|
# stream_mode="messages",
|
||||||
|
# subgraphs=True,
|
||||||
|
# **kwargs):
|
||||||
|
# if isinstance(metadata, tuple):
|
||||||
|
# chunk, metadata = metadata
|
||||||
|
|
||||||
|
# tags = metadata.get("tags")
|
||||||
|
# if not (tags in self.streamable_tags):
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
||||||
|
# yield chunk.content
|
||||||
|
|
||||||
def text_iterator():
|
def text_iterator():
|
||||||
for chunk, metadata in self.workflow.stream({"inp": nargs},
|
for _, mode, out in self.workflow.stream({"inp": nargs},
|
||||||
stream_mode="messages",
|
stream_mode=["messages", "values"],
|
||||||
subgraphs=True,
|
subgraphs=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if isinstance(metadata, tuple):
|
if mode == "values":
|
||||||
chunk, metadata = metadata
|
val = out.get("messages")
|
||||||
|
if val is not None:
|
||||||
|
yield val
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk, metadata = out
|
||||||
tags = metadata.get("tags")
|
tags = metadata.get("tags")
|
||||||
if not (tags in self.streamable_tags):
|
if not (tags in self.streamable_tags):
|
||||||
continue
|
continue
|
||||||
@@ -55,6 +74,7 @@ class GraphBase(ABC):
|
|||||||
text_releaser = TextReleaser(*self.textreleaser_delay_keys)
|
text_releaser = TextReleaser(*self.textreleaser_delay_keys)
|
||||||
logger.info("streaming output")
|
logger.info("streaming output")
|
||||||
for chunk in text_releaser.release(text_iterator()):
|
for chunk in text_releaser.release(text_iterator()):
|
||||||
|
if isinstance(chunk, str):
|
||||||
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
@@ -138,6 +158,10 @@ class GraphBase(ABC):
|
|||||||
|
|
||||||
assert len(kwargs) == 0, "due to inp assumptions"
|
assert len(kwargs) == 0, "due to inp assumptions"
|
||||||
|
|
||||||
|
def _get_inp_msgs(self, state:State):
|
||||||
|
msgs = state["inp"][0]["messages"]
|
||||||
|
return [e for e in msgs if not isinstance(e, SystemMessage)]
|
||||||
|
|
||||||
def _agent_call_template(self, system_prompt:str,
|
def _agent_call_template(self, system_prompt:str,
|
||||||
model:CompiledStateGraph,
|
model:CompiledStateGraph,
|
||||||
state:State,
|
state:State,
|
||||||
@@ -149,7 +173,7 @@ class GraphBase(ABC):
|
|||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(system_prompt),
|
SystemMessage(system_prompt),
|
||||||
*state["inp"][0]["messages"][1:]
|
*self._get_inp_msgs(state)
|
||||||
]
|
]
|
||||||
|
|
||||||
if human_msg is not None:
|
if human_msg is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user