bug fixes

This commit is contained in:
2025-10-22 17:41:45 +08:00
parent 3e395e6b9c
commit 7fae2599fa

View File

@@ -2,6 +2,7 @@ from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro import tyro
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -43,6 +44,7 @@ class State(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]], inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
Dict[str, Dict[str, str|int]]] Dict[str, Dict[str, str|int]]]
output: str output: str
tool_output: str
decision:str decision:str
@@ -61,7 +63,8 @@ class RoutingGraph(GraphBase):
if as_stream: if as_stream:
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs): for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
step["messages"][-1].pretty_print() if "messages" in step:
step["messages"][-1].pretty_print()
state = step state = step
else: else:
state = self.workflow.invoke({"inp": nargs}) state = self.workflow.invoke({"inp": nargs})
@@ -85,7 +88,7 @@ class RoutingGraph(GraphBase):
decision:Route = self.router.invoke( decision:Route = self.router.invoke(
[ [
SystemMessage( SystemMessage(
content="Route to chat or order based on the need of the user" content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
), ),
self._get_human_msg(state) self._get_human_msg(state)
] ]
@@ -95,8 +98,10 @@ class RoutingGraph(GraphBase):
def _get_human_msg(self, state: State)->HumanMessage: def _get_human_msg(self, state: State)->HumanMessage:
"""
get user message of current invocation
"""
msgs = state["inp"][0]["messages"] msgs = state["inp"][0]["messages"]
assert len(msgs) == 2, "Expect 1 systemMessage, 1 HumanMessage"
candidate_hum_msg = msgs[1] candidate_hum_msg = msgs[1]
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message" assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
@@ -104,10 +109,11 @@ class RoutingGraph(GraphBase):
def _route_decision(self, state:State): def _route_decision(self, state:State):
if state.decision == "chat": logger.info(f"decision:{state["decision"]}")
return "_chat_model_call" if state["decision"] == "chat":
return "chat"
else: else:
return "_tool_model_call" return "tool"
def _chat_model_call(self, state:State): def _chat_model_call(self, state:State):
@@ -116,11 +122,11 @@ class RoutingGraph(GraphBase):
def _tool_model_call(self, state:State): def _tool_model_call(self, state:State):
inp = [ inp = {"messages":[
SystemMessage( SystemMessage(
"You must use tool to complete the possible task" "You must use tool to complete the possible task"
),self._get_human_msg(state) ),self._get_human_msg(state)
], state["inp"][1] ]}, state["inp"][1]
out = self.tool_model.invoke(*inp) out = self.tool_model.invoke(*inp)
return {"output": out["messages"][-1].content} return {"output": out["messages"][-1].content}
@@ -143,8 +149,9 @@ class RoutingGraph(GraphBase):
"tool": "tool_model_call" "tool": "tool_model_call"
} }
) )
builder.add_edge("chat_model_call", END)
builder.add_edge("tool_model_call", END) builder.add_edge("tool_model_call", END)
builder.add_edge("tool_model_call", "chat_model_call")
builder.add_edge("chat_model_call", END)
workflow = builder.compile() workflow = builder.compile()