pass outputs as message

This commit is contained in:
2025-10-22 18:01:20 +08:00
parent 7fae2599fa
commit 91093fada9

View File

@@ -43,8 +43,7 @@ class Route(BaseModel):
class State(TypedDict): class State(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]], inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
Dict[str, Dict[str, str|int]]] Dict[str, Dict[str, str|int]]]
output: str messages: List[SystemMessage | HumanMessage]
tool_output: str
decision:str decision:str
@@ -62,14 +61,15 @@ class RoutingGraph(GraphBase):
assert len(kwargs) == 0, "due to inp assumptions" assert len(kwargs) == 0, "due to inp assumptions"
if as_stream: if as_stream:
# TODO this doesn't stream the entire process, we are blind
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs): for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
if "messages" in step: if "messages" in step:
step["messages"][-1].pretty_print() step["messages"]["messages"][-1].pretty_print()
state = step state = step
else: else:
state = self.workflow.invoke({"inp": nargs}) state = self.workflow.invoke({"inp": nargs})
return state["output"] return state["messages"]
def _build_modules(self): def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
@@ -117,8 +117,13 @@ class RoutingGraph(GraphBase):
def _chat_model_call(self, state:State): def _chat_model_call(self, state:State):
out = self.chat_model.invoke(*state["inp"]) if state.get("messages") is not None:
return {"output":out["messages"][-1].content} inp = state["messages"], state["inp"][1]
else:
inp = state["inp"]
out = self.chat_model.invoke(*inp)
return {"messages": out}
def _tool_model_call(self, state:State): def _tool_model_call(self, state:State):
@@ -129,7 +134,7 @@ class RoutingGraph(GraphBase):
]}, state["inp"][1] ]}, state["inp"][1]
out = self.tool_model.invoke(*inp) out = self.tool_model.invoke(*inp)
return {"output": out["messages"][-1].content} return {"messages": out}
def _build_graph(self): def _build_graph(self):
builder = StateGraph(State) builder = StateGraph(State)
@@ -150,7 +155,7 @@ class RoutingGraph(GraphBase):
} }
) )
builder.add_edge("tool_model_call", END) builder.add_edge("tool_model_call", END)
builder.add_edge("tool_model_call", "chat_model_call") # builder.add_edge("tool_model_call", "chat_model_call")
builder.add_edge("chat_model_call", END) builder.add_edge("chat_model_call", END)
workflow = builder.compile() workflow = builder.compile()