diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index b93d6ce..330c22f 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -3,6 +3,9 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple import tyro from pydantic import BaseModel, Field from loguru import logger +from PIL import Image +from io import BytesIO +import matplotlib.pyplot as plt from lang_agent.config import KeyConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig @@ -130,7 +133,9 @@ class RoutingGraph(GraphBase): inp = {"messages":[ SystemMessage( "You must use tool to complete the possible task" - ),self._get_human_msg(state) + ), + # self._get_human_msg(state) + *state["inp"][0][1:] ]}, state["inp"][1] out = self.tool_model.invoke(*inp) @@ -160,4 +165,9 @@ class RoutingGraph(GraphBase): workflow = builder.compile() - return workflow \ No newline at end of file + return workflow + + def show_graph(self): + img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png())) + plt.imshow(img) + plt.show() \ No newline at end of file diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 533b6a2..82d9cbf 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -14,8 +14,8 @@ from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver from lang_agent.config import InstantiateConfig, KeyConfig -from lang_agent.graphs import AnnotatedGraph -from lang_agent.graphs.react import ReactGraph, ReactGraphConfig +from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig +from lang_agent.base import GraphBase @@ -43,7 +43,8 @@ class PipelineConfig(KeyConfig): """what is my port""" # graph_config: ReactGraphConfig = field(default_factory=ReactGraphConfig) - graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) + # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) + graph_config: AnnotatedGraph = field(default_factory=RoutingConfig) @@ -65,9 +66,14 @@ class Pipeline: self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url self.config.graph_config.api_key = self.config.api_key - self.graph:ReactGraph = self.config.graph_config.setup() + self.graph:GraphBase = self.config.graph_config.setup() - + def show_graph(self): + if hasattr(self.graph, "show_graph"): + logger.info("showing graph") + self.graph.show_graph() + else: + logger.info(f"show graph not supported for {type(self.graph)}") def invoke(self, *nargs, **kwargs): return self.graph.invoke(*nargs, **kwargs)