pass outputs as message
This commit is contained in:
@@ -43,8 +43,7 @@ class Route(BaseModel):
|
||||
class State(TypedDict):
|
||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
||||
Dict[str, Dict[str, str|int]]]
|
||||
output: str
|
||||
tool_output: str
|
||||
messages: List[SystemMessage | HumanMessage]
|
||||
decision:str
|
||||
|
||||
|
||||
@@ -62,14 +61,15 @@ class RoutingGraph(GraphBase):
|
||||
assert len(kwargs) == 0, "due to inp assumptions"
|
||||
|
||||
if as_stream:
|
||||
# TODO: this doesn't stream the entire process, we are blind
|
||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
||||
if "messages" in step:
|
||||
step["messages"][-1].pretty_print()
|
||||
step["messages"]["messages"][-1].pretty_print()
|
||||
state = step
|
||||
else:
|
||||
state = self.workflow.invoke({"inp": nargs})
|
||||
|
||||
return state["output"]
|
||||
return state["messages"]
|
||||
|
||||
def _build_modules(self):
|
||||
self.llm = init_chat_model(model=self.config.llm_name,
|
||||
@@ -117,8 +117,13 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
|
||||
def _chat_model_call(self, state:State):
|
||||
out = self.chat_model.invoke(*state["inp"])
|
||||
return {"output":out["messages"][-1].content}
|
||||
if state.get("messages") is not None:
|
||||
inp = state["messages"], state["inp"][1]
|
||||
else:
|
||||
inp = state["inp"]
|
||||
|
||||
out = self.chat_model.invoke(*inp)
|
||||
return {"messages": out}
|
||||
|
||||
|
||||
def _tool_model_call(self, state:State):
|
||||
@@ -129,7 +134,7 @@ class RoutingGraph(GraphBase):
|
||||
]}, state["inp"][1]
|
||||
|
||||
out = self.tool_model.invoke(*inp)
|
||||
return {"output": out["messages"][-1].content}
|
||||
return {"messages": out}
|
||||
|
||||
def _build_graph(self):
|
||||
builder = StateGraph(State)
|
||||
@@ -150,7 +155,7 @@ class RoutingGraph(GraphBase):
|
||||
}
|
||||
)
|
||||
builder.add_edge("tool_model_call", END)
|
||||
builder.add_edge("tool_model_call", "chat_model_call")
|
||||
# builder.add_edge("tool_model_call", "chat_model_call")
|
||||
builder.add_edge("chat_model_call", END)
|
||||
|
||||
workflow = builder.compile()
|
||||
|
||||
Reference in New Issue
Block a user