pass outputs as message
This commit is contained in:
@@ -43,8 +43,7 @@ class Route(BaseModel):
|
|||||||
class State(TypedDict):
|
class State(TypedDict):
|
||||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
||||||
Dict[str, Dict[str, str|int]]]
|
Dict[str, Dict[str, str|int]]]
|
||||||
output: str
|
messages: List[SystemMessage | HumanMessage]
|
||||||
tool_output: str
|
|
||||||
decision:str
|
decision:str
|
||||||
|
|
||||||
|
|
||||||
@@ -62,14 +61,15 @@ class RoutingGraph(GraphBase):
|
|||||||
assert len(kwargs) == 0, "due to inp assumptions"
|
assert len(kwargs) == 0, "due to inp assumptions"
|
||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
|
# TODO: this doesn't stream the entire process, we are blind
|
||||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
||||||
if "messages" in step:
|
if "messages" in step:
|
||||||
step["messages"][-1].pretty_print()
|
step["messages"]["messages"][-1].pretty_print()
|
||||||
state = step
|
state = step
|
||||||
else:
|
else:
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
return state["output"]
|
return state["messages"]
|
||||||
|
|
||||||
def _build_modules(self):
|
def _build_modules(self):
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
@@ -117,8 +117,13 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _chat_model_call(self, state:State):
|
def _chat_model_call(self, state:State):
|
||||||
out = self.chat_model.invoke(*state["inp"])
|
if state.get("messages") is not None:
|
||||||
return {"output":out["messages"][-1].content}
|
inp = state["messages"], state["inp"][1]
|
||||||
|
else:
|
||||||
|
inp = state["inp"]
|
||||||
|
|
||||||
|
out = self.chat_model.invoke(*inp)
|
||||||
|
return {"messages": out}
|
||||||
|
|
||||||
|
|
||||||
def _tool_model_call(self, state:State):
|
def _tool_model_call(self, state:State):
|
||||||
@@ -129,7 +134,7 @@ class RoutingGraph(GraphBase):
|
|||||||
]}, state["inp"][1]
|
]}, state["inp"][1]
|
||||||
|
|
||||||
out = self.tool_model.invoke(*inp)
|
out = self.tool_model.invoke(*inp)
|
||||||
return {"output": out["messages"][-1].content}
|
return {"messages": out}
|
||||||
|
|
||||||
def _build_graph(self):
|
def _build_graph(self):
|
||||||
builder = StateGraph(State)
|
builder = StateGraph(State)
|
||||||
@@ -150,7 +155,7 @@ class RoutingGraph(GraphBase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
builder.add_edge("tool_model_call", END)
|
builder.add_edge("tool_model_call", END)
|
||||||
builder.add_edge("tool_model_call", "chat_model_call")
|
# builder.add_edge("tool_model_call", "chat_model_call")
|
||||||
builder.add_edge("chat_model_call", END)
|
builder.add_edge("chat_model_call", END)
|
||||||
|
|
||||||
workflow = builder.compile()
|
workflow = builder.compile()
|
||||||
|
|||||||
Reference in New Issue
Block a user