support graph showing
This commit is contained in:
@@ -3,6 +3,9 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple
|
|||||||
import tyro
|
import tyro
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
@@ -130,7 +133,9 @@ class RoutingGraph(GraphBase):
|
|||||||
inp = {"messages":[
|
inp = {"messages":[
|
||||||
SystemMessage(
|
SystemMessage(
|
||||||
"You must use tool to complete the possible task"
|
"You must use tool to complete the possible task"
|
||||||
),self._get_human_msg(state)
|
),
|
||||||
|
# self._get_human_msg(state)
|
||||||
|
*state["inp"][0][1:]
|
||||||
]}, state["inp"][1]
|
]}, state["inp"][1]
|
||||||
|
|
||||||
out = self.tool_model.invoke(*inp)
|
out = self.tool_model.invoke(*inp)
|
||||||
@@ -161,3 +166,8 @@ class RoutingGraph(GraphBase):
|
|||||||
workflow = builder.compile()
|
workflow = builder.compile()
|
||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
||||||
|
def show_graph(self):
|
||||||
|
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.show()
|
||||||
@@ -14,8 +14,8 @@ from langchain.agents import create_agent
|
|||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||||
from lang_agent.graphs import AnnotatedGraph
|
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||||
from lang_agent.graphs.react import ReactGraph, ReactGraphConfig
|
from lang_agent.base import GraphBase
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +43,8 @@ class PipelineConfig(KeyConfig):
|
|||||||
"""what is my port"""
|
"""what is my port"""
|
||||||
|
|
||||||
# graph_config: ReactGraphConfig = field(default_factory=ReactGraphConfig)
|
# graph_config: ReactGraphConfig = field(default_factory=ReactGraphConfig)
|
||||||
graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
||||||
|
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -65,9 +66,14 @@ class Pipeline:
|
|||||||
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
|
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
|
||||||
self.config.graph_config.api_key = self.config.api_key
|
self.config.graph_config.api_key = self.config.api_key
|
||||||
|
|
||||||
self.graph:ReactGraph = self.config.graph_config.setup()
|
self.graph:GraphBase = self.config.graph_config.setup()
|
||||||
|
|
||||||
|
|
||||||
|
def show_graph(self):
|
||||||
|
if hasattr(self.graph, "show_graph"):
|
||||||
|
logger.info("showing graph")
|
||||||
|
self.graph.show_graph()
|
||||||
|
else:
|
||||||
|
logger.info(f"show graph not supported for {type(self.graph)}")
|
||||||
|
|
||||||
def invoke(self, *nargs, **kwargs):
|
def invoke(self, *nargs, **kwargs):
|
||||||
return self.graph.invoke(*nargs, **kwargs)
|
return self.graph.invoke(*nargs, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user