moved visualization to base

This commit is contained in:
2025-11-21 14:22:01 +08:00
parent 643f3a9dc3
commit 9022540b79
2 changed files with 18 additions and 15 deletions

View File

@@ -1,5 +1,12 @@
from typing import List, Callable
from typing import List, Callable, TYPE_CHECKING
from abc import ABC, abstractmethod
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from loguru import logger
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
class LangToolBase(ABC):
@@ -10,14 +17,19 @@ class LangToolBase(ABC):
class GraphBase(ABC):
workflow: CompiledStateGraph
@abstractmethod
def invoke(self, *nargs, **kwargs):
pass
def show_graph(self):
#NOTE: just a useful tool for debugging; has zero useful functionality
err_str = f"{type(self)} does not have workflow, this is unsupported"
assert hasattr(self, "workflow"), err_str
class ToolNodeBase(ABC):
@abstractmethod
def tool_node_call(self, *nargs, **kwargs):
pass
logger.info("creating image")
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
plt.imshow(img)
plt.show()

View File

@@ -3,9 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any
import tyro
from pydantic import BaseModel, Field
from loguru import logger
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import jax
import os.path as osp
import commentjson
@@ -254,12 +251,6 @@ class RoutingGraph(GraphBase):
return workflow
def show_graph(self):
logger.info("creating image")
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
plt.imshow(img)
plt.show()
if __name__ == "__main__":
from dotenv import load_dotenv
from langchain.messages import SystemMessage, HumanMessage