bug fixes
This commit is contained in:
@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from lang_agent.config import KeyConfig
|
||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||
@@ -43,6 +44,7 @@ class State(TypedDict):
|
||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
||||
Dict[str, Dict[str, str|int]]]
|
||||
output: str
|
||||
tool_output: str
|
||||
decision:str
|
||||
|
||||
|
||||
@@ -61,7 +63,8 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
if as_stream:
|
||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
||||
step["messages"][-1].pretty_print()
|
||||
if "messages" in step:
|
||||
step["messages"][-1].pretty_print()
|
||||
state = step
|
||||
else:
|
||||
state = self.workflow.invoke({"inp": nargs})
|
||||
@@ -85,7 +88,7 @@ class RoutingGraph(GraphBase):
|
||||
decision:Route = self.router.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="Route to chat or order based on the need of the user"
|
||||
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
||||
),
|
||||
self._get_human_msg(state)
|
||||
]
|
||||
@@ -95,8 +98,10 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
|
||||
def _get_human_msg(self, state: State)->HumanMessage:
|
||||
"""
|
||||
get user message of current invocation
|
||||
"""
|
||||
msgs = state["inp"][0]["messages"]
|
||||
assert len(msgs) == 2, "Expect 1 systemMessage, 1 HumanMessage"
|
||||
candidate_hum_msg = msgs[1]
|
||||
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
|
||||
|
||||
@@ -104,10 +109,11 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
|
||||
def _route_decision(self, state:State):
|
||||
if state.decision == "chat":
|
||||
return "_chat_model_call"
|
||||
logger.info(f"decision:{state["decision"]}")
|
||||
if state["decision"] == "chat":
|
||||
return "chat"
|
||||
else:
|
||||
return "_tool_model_call"
|
||||
return "tool"
|
||||
|
||||
|
||||
def _chat_model_call(self, state:State):
|
||||
@@ -116,11 +122,11 @@ class RoutingGraph(GraphBase):
|
||||
|
||||
|
||||
def _tool_model_call(self, state:State):
|
||||
inp = [
|
||||
inp = {"messages":[
|
||||
SystemMessage(
|
||||
"You must use tool to complete the possible task"
|
||||
),self._get_human_msg(state)
|
||||
], state["inp"][1]
|
||||
]}, state["inp"][1]
|
||||
|
||||
out = self.tool_model.invoke(*inp)
|
||||
return {"output": out["messages"][-1].content}
|
||||
@@ -143,8 +149,9 @@ class RoutingGraph(GraphBase):
|
||||
"tool": "tool_model_call"
|
||||
}
|
||||
)
|
||||
builder.add_edge("chat_model_call", END)
|
||||
builder.add_edge("tool_model_call", END)
|
||||
builder.add_edge("tool_model_call", "chat_model_call")
|
||||
builder.add_edge("chat_model_call", END)
|
||||
|
||||
workflow = builder.compile()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user