moved visualization to base

This commit is contained in:
2025-11-21 14:22:01 +08:00
parent 643f3a9dc3
commit 9022540b79
2 changed files with 18 additions and 15 deletions

View File

@@ -1,5 +1,12 @@
from typing import List, Callable from typing import List, Callable, TYPE_CHECKING
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from loguru import logger
if TYPE_CHECKING:
from langgraph.graph.state import CompiledStateGraph
class LangToolBase(ABC): class LangToolBase(ABC):
@@ -10,14 +17,19 @@ class LangToolBase(ABC):
class GraphBase(ABC): class GraphBase(ABC):
workflow: CompiledStateGraph
@abstractmethod @abstractmethod
def invoke(self, *nargs, **kwargs): def invoke(self, *nargs, **kwargs):
pass pass
def show_graph(self):
#NOTE: just a useful tool for debugging; has zero useful functionality
err_str = f"{type(self)} does not have workflow, this is unsupported"
assert hasattr(self, "workflow"), err_str
class ToolNodeBase(ABC): logger.info("creating image")
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
@abstractmethod plt.imshow(img)
def tool_node_call(self, *nargs, **kwargs): plt.show()
pass

View File

@@ -3,9 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any
import tyro import tyro
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import jax import jax
import os.path as osp import os.path as osp
import commentjson import commentjson
@@ -254,12 +251,6 @@ class RoutingGraph(GraphBase):
return workflow return workflow
def show_graph(self):
logger.info("creating image")
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
plt.imshow(img)
plt.show()
if __name__ == "__main__": if __name__ == "__main__":
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain.messages import SystemMessage, HumanMessage from langchain.messages import SystemMessage, HumanMessage