diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 7b24b8c..8c5a55d 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field, is_dataclass from typing import Type, TypedDict, Literal, Dict, List, Tuple import tyro from pydantic import BaseModel, Field +from loguru import logger from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -43,6 +44,7 @@ class State(TypedDict): inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str|int]]] output: str + tool_output: str decision:str @@ -61,7 +63,8 @@ class RoutingGraph(GraphBase): if as_stream: for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs): - step["messages"][-1].pretty_print() + if "messages" in step: + step["messages"][-1].pretty_print() state = step else: state = self.workflow.invoke({"inp": nargs}) @@ -85,7 +88,7 @@ class RoutingGraph(GraphBase): decision:Route = self.router.invoke( [ SystemMessage( - content="Route to chat or order based on the need of the user" + content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" ), self._get_human_msg(state) ] @@ -95,8 +98,10 @@ class RoutingGraph(GraphBase): def _get_human_msg(self, state: State)->HumanMessage: + """ + get user message of current invocation + """ msgs = state["inp"][0]["messages"] - assert len(msgs) == 2, "Expect 1 systemMessage, 1 HumanMessage" candidate_hum_msg = msgs[1] assert isinstance(candidate_hum_msg, HumanMessage), "not a human message" @@ -104,10 +109,11 @@ class RoutingGraph(GraphBase): def _route_decision(self, state:State): - if state.decision == "chat": - return "_chat_model_call" + logger.info(f"decision:{state["decision"]}") + if state["decision"] == "chat": + return "chat" else: - return "_tool_model_call" + return "tool" def _chat_model_call(self, state:State): @@ -116,11 +122,11 @@ class RoutingGraph(GraphBase): def _tool_model_call(self, state:State): - inp = [ + inp = {"messages":[ SystemMessage( "You must use tool to complete the possible task" ),self._get_human_msg(state) - ], state["inp"][1] + ]}, state["inp"][1] out = self.tool_model.invoke(*inp) return {"output": out["messages"][-1].content} @@ -143,8 +149,9 @@ class RoutingGraph(GraphBase): "tool": "tool_model_call" } ) - builder.add_edge("chat_model_call", END) builder.add_edge("tool_model_call", END) + builder.add_edge("tool_model_call", "chat_model_call") + builder.add_edge("chat_model_call", END) workflow = builder.compile()