Files
lang-agent/lang_agent/graphs/routing.py
2025-10-22 17:41:45 +08:00

158 lines
4.9 KiB
Python

from dataclasses import dataclass, field, is_dataclass
from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro
from pydantic import BaseModel, Field
from loguru import logger
from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.agents import create_agent
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
@tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass
class RoutingConfig(KeyConfig):
_target: Type = field(default_factory=lambda: RoutingGraph)
llm_name: str = "qwen-turbo"
"""name of llm"""
llm_provider:str = "openai"
"""provider of the llm"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
class Route(BaseModel):
step: Literal["chat", "order"] = Field(
None, description="The next step in the routing process"
)
class State(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
Dict[str, Dict[str, str|int]]]
output: str
tool_output: str
decision:str
class RoutingGraph(GraphBase):
def __init__(self, config: RoutingConfig):
self.config = config
self.chat_sys_msg = None
self._build_modules()
self.workflow = self._build_graph()
def invoke(self, *nargs, as_stream:bool=False, **kwargs):
assert len(kwargs) == 0, "due to inp assumptions"
if as_stream:
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
if "messages" in step:
step["messages"][-1].pretty_print()
state = step
else:
state = self.workflow.invoke({"inp": nargs})
return state["output"]
def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name,
model_provider=self.config.llm_provider,
api_key=self.config.api_key,
base_url=self.config.base_url)
self.memory = MemorySaver()
self.router = self.llm.with_structured_output(Route)
tool_manager:ToolManager = self.config.tool_manager_config.setup()
self.chat_model = create_agent(self.llm, [], checkpointer=self.memory)
self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory)
def _router_call(self, state:State):
decision:Route = self.router.invoke(
[
SystemMessage(
content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input"
),
self._get_human_msg(state)
]
)
return {"decision": decision.step}
def _get_human_msg(self, state: State)->HumanMessage:
"""
get user message of current invocation
"""
msgs = state["inp"][0]["messages"]
candidate_hum_msg = msgs[1]
assert isinstance(candidate_hum_msg, HumanMessage), "not a human message"
return candidate_hum_msg
def _route_decision(self, state:State):
logger.info(f"decision:{state["decision"]}")
if state["decision"] == "chat":
return "chat"
else:
return "tool"
def _chat_model_call(self, state:State):
out = self.chat_model.invoke(*state["inp"])
return {"output":out["messages"][-1].content}
def _tool_model_call(self, state:State):
inp = {"messages":[
SystemMessage(
"You must use tool to complete the possible task"
),self._get_human_msg(state)
]}, state["inp"][1]
out = self.tool_model.invoke(*inp)
return {"output": out["messages"][-1].content}
def _build_graph(self):
builder = StateGraph(State)
# add nodes
builder.add_node("chat_model_call", self._chat_model_call)
builder.add_node("tool_model_call", self._tool_model_call)
builder.add_node("router_call", self._router_call)
# add edge connections
builder.add_edge(START, "router_call")
builder.add_conditional_edges(
"router_call",
self._route_decision,
{
"chat": "chat_model_call",
"tool": "tool_model_call"
}
)
builder.add_edge("tool_model_call", END)
builder.add_edge("tool_model_call", "chat_model_call")
builder.add_edge("chat_model_call", END)
workflow = builder.compile()
return workflow