bug fixes

This commit is contained in:
2025-10-22 17:41:45 +08:00
parent 3e395e6b9c
commit 7fae2599fa

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro
from pydantic import BaseModel, Field
from loguru import logger
from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -43,6 +44,7 @@ class State(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
Dict[str, Dict[str, str|int]]]
output: str
tool_output: str
decision:str
@@ -61,7 +63,8 @@ class RoutingGraph(GraphBase):
if as_stream:
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
step["messages"][-1].pretty_print()
if "messages" in step:
step["messages"][-1].pretty_print()
state = step
else:
state = self.workflow.invoke({"inp": nargs})
@@ -85,7 +88,7 @@ class RoutingGraph(GraphBase):
decision:Route = self.router.invoke(
[
SystemMessage(
content="Route to chat or order based on the need of the user"
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
),
self._get_human_msg(state)
]
@@ -95,8 +98,10 @@ class RoutingGraph(GraphBase):
def _get_human_msg(self, state: State)->HumanMessage:
"""
get user message of current invocation
"""
msgs = state["inp"][0]["messages"]
assert len(msgs) == 2, "Expect 1 systemMessage, 1 HumanMessage"
candidate_hum_msg = msgs[1]
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
@@ -104,10 +109,11 @@ class RoutingGraph(GraphBase):
def _route_decision(self, state:State):
if state.decision == "chat":
return "_chat_model_call"
logger.info(f"decision:{state["decision"]}")
if state["decision"] == "chat":
return "chat"
else:
return "_tool_model_call"
return "tool"
def _chat_model_call(self, state:State):
@@ -116,11 +122,11 @@ class RoutingGraph(GraphBase):
def _tool_model_call(self, state:State):
inp = [
inp = {"messages":[
SystemMessage(
"You must use tool to complete the possible task"
),self._get_human_msg(state)
], state["inp"][1]
]}, state["inp"][1]
out = self.tool_model.invoke(*inp)
return {"output": out["messages"][-1].content}
@@ -143,8 +149,9 @@ class RoutingGraph(GraphBase):
"tool": "tool_model_call"
}
)
builder.add_edge("chat_model_call", END)
builder.add_edge("tool_model_call", END)
builder.add_edge("tool_model_call", "chat_model_call")
builder.add_edge("chat_model_call", END)
workflow = builder.compile()