moved visualization to base
This commit is contained in:
@@ -1,5 +1,12 @@
|
|||||||
from typing import List, Callable
|
from typing import List, Callable, TYPE_CHECKING
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
|
|
||||||
class LangToolBase(ABC):
|
class LangToolBase(ABC):
|
||||||
@@ -10,14 +17,19 @@ class LangToolBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class GraphBase(ABC):
|
class GraphBase(ABC):
|
||||||
|
workflow: CompiledStateGraph
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def invoke(self, *nargs, **kwargs):
|
def invoke(self, *nargs, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def show_graph(self):
|
||||||
|
#NOTE: just a useful tool for debugging; has zero useful functionality
|
||||||
|
|
||||||
class ToolNodeBase(ABC):
|
err_str = f"{type(self)} does not have workflow, this is unsupported"
|
||||||
|
assert hasattr(self, "workflow"), err_str
|
||||||
|
|
||||||
@abstractmethod
|
logger.info("creating image")
|
||||||
def tool_node_call(self, *nargs, **kwargs):
|
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
|
||||||
pass
|
plt.imshow(img)
|
||||||
|
plt.show()
|
||||||
@@ -3,9 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any
|
|||||||
import tyro
|
import tyro
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from PIL import Image
|
|
||||||
from io import BytesIO
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import jax
|
import jax
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import commentjson
|
import commentjson
|
||||||
@@ -254,12 +251,6 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
def show_graph(self):
|
|
||||||
logger.info("creating image")
|
|
||||||
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
|
|
||||||
plt.imshow(img)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain.messages import SystemMessage, HumanMessage
|
from langchain.messages import SystemMessage, HumanMessage
|
||||||
|
|||||||
Reference in New Issue
Block a user