From 9022540b79c6c8dae51c2e9245ae12effc7986ef Mon Sep 17 00:00:00 2001 From: goulustis Date: Fri, 21 Nov 2025 14:22:01 +0800 Subject: [PATCH] moved visualization to base --- lang_agent/base.py | 24 ++++++++++++++++++------ lang_agent/graphs/routing.py | 9 --------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/lang_agent/base.py b/lang_agent/base.py index 3c17c00..31e3096 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -1,5 +1,12 @@ -from typing import List, Callable +from typing import List, Callable, TYPE_CHECKING from abc import ABC, abstractmethod +from PIL import Image +from io import BytesIO +import matplotlib.pyplot as plt +from loguru import logger + +if TYPE_CHECKING: + from langgraph.graph.state import CompiledStateGraph class LangToolBase(ABC): @@ -10,14 +17,19 @@ class LangToolBase(ABC): class GraphBase(ABC): + workflow: CompiledStateGraph @abstractmethod def invoke(self, *nargs, **kwargs): pass + def show_graph(self): + #NOTE: just a useful tool for debugging; has zero useful functionality + + err_str = f"{type(self)} does not have workflow, this is unsupported" + assert hasattr(self, "workflow"), err_str -class ToolNodeBase(ABC): - - @abstractmethod - def tool_node_call(self, *nargs, **kwargs): - pass \ No newline at end of file + logger.info("creating image") + img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png())) + plt.imshow(img) + plt.show() \ No newline at end of file diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index 29ee60a..2dc791a 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -3,9 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any import tyro from pydantic import BaseModel, Field from loguru import logger -from PIL import Image -from io import BytesIO -import matplotlib.pyplot as plt import jax import os.path as osp import commentjson @@ -254,12 +251,6 @@ class RoutingGraph(GraphBase): return workflow - def show_graph(self): - logger.info("creating image") - img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png())) - plt.imshow(img) - plt.show() - if __name__ == "__main__": from dotenv import load_dotenv from langchain.messages import SystemMessage, HumanMessage