pass outputs as message

This commit is contained in:
2025-10-22 18:01:20 +08:00
parent 7fae2599fa
commit 91093fada9

View File

@@ -43,8 +43,7 @@ class Route(BaseModel):
class State(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
Dict[str, Dict[str, str|int]]]
output: str
tool_output: str
messages: List[SystemMessage | HumanMessage]
decision:str
@@ -62,14 +61,15 @@ class RoutingGraph(GraphBase):
assert len(kwargs) == 0, "due to inp assumptions"
if as_stream:
# TODO this doesn't stream the entire process, we are blind
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
if "messages" in step:
step["messages"][-1].pretty_print()
step["messages"]["messages"][-1].pretty_print()
state = step
else:
state = self.workflow.invoke({"inp": nargs})
return state["output"]
return state["messages"]
def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name,
@@ -117,8 +117,13 @@ class RoutingGraph(GraphBase):
def _chat_model_call(self, state:State):
out = self.chat_model.invoke(*state["inp"])
return {"output":out["messages"][-1].content}
if state.get("messages") is not None:
inp = state["messages"], state["inp"][1]
else:
inp = state["inp"]
out = self.chat_model.invoke(*inp)
return {"messages": out}
def _tool_model_call(self, state:State):
@@ -129,7 +134,7 @@ class RoutingGraph(GraphBase):
]}, state["inp"][1]
out = self.tool_model.invoke(*inp)
return {"output": out["messages"][-1].content}
return {"messages": out}
def _build_graph(self):
builder = StateGraph(State)
@@ -150,7 +155,7 @@ class RoutingGraph(GraphBase):
}
)
builder.add_edge("tool_model_call", END)
builder.add_edge("tool_model_call", "chat_model_call")
# builder.add_edge("tool_model_call", "chat_model_call")
builder.add_edge("chat_model_call", END)
workflow = builder.compile()