bug fixes
This commit is contained in:
@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, is_dataclass
|
|||||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
||||||
import tyro
|
import tyro
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
@@ -43,6 +44,7 @@ class State(TypedDict):
|
|||||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
||||||
Dict[str, Dict[str, str|int]]]
|
Dict[str, Dict[str, str|int]]]
|
||||||
output: str
|
output: str
|
||||||
|
tool_output: str
|
||||||
decision:str
|
decision:str
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +63,7 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
||||||
|
if "messages" in step:
|
||||||
step["messages"][-1].pretty_print()
|
step["messages"][-1].pretty_print()
|
||||||
state = step
|
state = step
|
||||||
else:
|
else:
|
||||||
@@ -85,7 +88,7 @@ class RoutingGraph(GraphBase):
|
|||||||
decision:Route = self.router.invoke(
|
decision:Route = self.router.invoke(
|
||||||
[
|
[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
content="Route to chat or order based on the need of the user"
|
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
|
||||||
),
|
),
|
||||||
self._get_human_msg(state)
|
self._get_human_msg(state)
|
||||||
]
|
]
|
||||||
@@ -95,8 +98,10 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _get_human_msg(self, state: State)->HumanMessage:
|
def _get_human_msg(self, state: State)->HumanMessage:
|
||||||
|
"""
|
||||||
|
get user message of current invocation
|
||||||
|
"""
|
||||||
msgs = state["inp"][0]["messages"]
|
msgs = state["inp"][0]["messages"]
|
||||||
assert len(msgs) == 2, "Expect 1 systemMessage, 1 HumanMessage"
|
|
||||||
candidate_hum_msg = msgs[1]
|
candidate_hum_msg = msgs[1]
|
||||||
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
|
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
|
||||||
|
|
||||||
@@ -104,10 +109,11 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _route_decision(self, state:State):
|
def _route_decision(self, state:State):
|
||||||
if state.decision == "chat":
|
logger.info(f"decision:{state["decision"]}")
|
||||||
return "_chat_model_call"
|
if state["decision"] == "chat":
|
||||||
|
return "chat"
|
||||||
else:
|
else:
|
||||||
return "_tool_model_call"
|
return "tool"
|
||||||
|
|
||||||
|
|
||||||
def _chat_model_call(self, state:State):
|
def _chat_model_call(self, state:State):
|
||||||
@@ -116,11 +122,11 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
def _tool_model_call(self, state:State):
|
def _tool_model_call(self, state:State):
|
||||||
inp = [
|
inp = {"messages":[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
"You must use tool to complete the possible task"
|
"You must use tool to complete the possible task"
|
||||||
),self._get_human_msg(state)
|
),self._get_human_msg(state)
|
||||||
], state["inp"][1]
|
]}, state["inp"][1]
|
||||||
|
|
||||||
out = self.tool_model.invoke(*inp)
|
out = self.tool_model.invoke(*inp)
|
||||||
return {"output": out["messages"][-1].content}
|
return {"output": out["messages"][-1].content}
|
||||||
@@ -143,8 +149,9 @@ class RoutingGraph(GraphBase):
|
|||||||
"tool": "tool_model_call"
|
"tool": "tool_model_call"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
builder.add_edge("chat_model_call", END)
|
|
||||||
builder.add_edge("tool_model_call", END)
|
builder.add_edge("tool_model_call", END)
|
||||||
|
builder.add_edge("tool_model_call", "chat_model_call")
|
||||||
|
builder.add_edge("chat_model_call", END)
|
||||||
|
|
||||||
workflow = builder.compile()
|
workflow = builder.compile()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user