moved visualization to base
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
from typing import List, Callable
|
||||
from typing import List, Callable, TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import matplotlib.pyplot as plt
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
|
||||
class LangToolBase(ABC):
|
||||
@@ -10,14 +17,19 @@ class LangToolBase(ABC):
|
||||
|
||||
|
||||
class GraphBase(ABC):
|
||||
workflow: CompiledStateGraph
|
||||
|
||||
@abstractmethod
|
||||
def invoke(self, *nargs, **kwargs):
|
||||
pass
|
||||
|
||||
def show_graph(self):
|
||||
#NOTE: just a useful tool for debugging; has zero useful functionality
|
||||
|
||||
err_str = f"{type(self)} does not have workflow, this is unsupported"
|
||||
assert hasattr(self, "workflow"), err_str
|
||||
|
||||
class ToolNodeBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def tool_node_call(self, *nargs, **kwargs):
|
||||
pass
|
||||
logger.info("creating image")
|
||||
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user