from dataclasses import dataclass, field, is_dataclass from typing import Type, TypedDict, Literal, Dict, List, Tuple import tyro from pydantic import BaseModel, Field from loguru import logger from PIL import Image from io import BytesIO import matplotlib.pyplot as plt import jax import os.path as osp import commentjson from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.base import GraphBase from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.memory import MemorySaver @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class RoutingConfig(KeyConfig): _target: Type = field(default_factory=lambda: RoutingGraph) llm_name: str = "qwen-turbo" """name of llm""" llm_provider:str = "openai" """provider of the llm""" base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" sys_promp_json: str = None "path to json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided" tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) def __post_init__(self): super().__post_init__() if self.sys_promp_json is None: self.sys_promp_json = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts.json") logger.warning(f"config_f was not provided. Using default: {self.sys_promp_json}") assert osp.exists(self.sys_promp_json), f"config_f {self.sys_promp_json} does not exist." class Route(BaseModel): step: Literal["chat", "order"] = Field( None, description="The next step in the routing process" ) class State(TypedDict): inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str|int]]] messages: List[SystemMessage | HumanMessage] decision:str class RoutingGraph(GraphBase): def __init__(self, config: RoutingConfig): self.config = config self.chat_sys_msg = None self._build_modules() self.workflow = self._build_graph() def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs)->str: self._validate_input(*nargs, **kwargs) if as_stream: # TODO: this doesn't stream the entire process, we are blind for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs): last_el = jax.tree.leaves(step)[-1] if isinstance(last_el, str): logger.info(last_el) elif isinstance(last_el, BaseMessage): last_el.pretty_print() state = step else: state = self.workflow.invoke({"inp": nargs}) msg_list = jax.tree.leaves(state) if as_raw: return msg_list return msg_list[-1].content def _validate_input(self, *nargs, **kwargs): assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(kwargs) == 0, "due to inp assumptions" def _build_modules(self): self.llm = init_chat_model(model=self.config.llm_name, model_provider=self.config.llm_provider, api_key=self.config.api_key, base_url=self.config.base_url) self.memory = MemorySaver() self.router = self.llm.with_structured_output(Route) tool_manager:ToolManager = self.config.tool_manager_config.setup() self.chat_model = create_agent(self.llm, [], checkpointer=self.memory) self.tool_model = create_agent(self.llm, tool_manager.get_langchain_tools(), checkpointer=self.memory) with open(self.config.sys_promp_json , "r") as f: self.prompt_dict = commentjson.load(f) def _router_call(self, state:State): decision:Route = self.router.invoke( [ SystemMessage( # content="Return a JSON object with 'step'.the value should be one of 'chat' or 'order' based on the user input" content=self.prompt_dict.get("route_prompt") ), self._get_human_msg(state) ] ) return {"decision": decision.step} def _get_human_msg(self, state: State)->HumanMessage: """ get user message of current invocation """ msgs = state["inp"][0]["messages"] candidate_hum_msg = msgs[1] assert isinstance(candidate_hum_msg, HumanMessage), "not a human message" return candidate_hum_msg def _route_decision(self, state:State): logger.info(f"decision:{state["decision"]}") if state["decision"] == "chat": return "chat" else: return "tool" def _chat_model_call(self, state:State): if state.get("messages") is not None: inp = state["messages"], state["inp"][1] else: inp = state["inp"] if self.prompt_dict.get("chat_prompt") is not None: inp = {"messages":[ SystemMessage( # "You must use tool to complete the possible task" self.prompt_dict["chat_prompt"] ), *state["inp"][0]["messages"][1:] ]}, state["inp"][1] out = self.chat_model.invoke(*inp) return {"messages": out} def _tool_model_call(self, state:State): inp = {"messages":[ SystemMessage( # "You must use tool to complete the possible task" self.prompt_dict["tool_prompt"] ), *state["inp"][0]["messages"][1:] ]}, state["inp"][1] out = self.tool_model.invoke(*inp) return {"messages": out} def _build_graph(self): builder = StateGraph(State) # add nodes builder.add_node("chat_model_call", self._chat_model_call) builder.add_node("tool_model_call", self._tool_model_call) builder.add_node("router_call", self._router_call) # add edge connections builder.add_edge(START, "router_call") builder.add_conditional_edges( "router_call", self._route_decision, { "chat": "chat_model_call", "tool": "tool_model_call" } ) builder.add_edge("tool_model_call", END) # builder.add_edge("tool_model_call", "chat_model_call") builder.add_edge("chat_model_call", END) workflow = builder.compile() return workflow def show_graph(self): img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png())) plt.imshow(img) plt.show()