support graph showing

This commit is contained in:
2025-10-22 18:19:24 +08:00
parent a26e42fdf5
commit 21ca1be80c
2 changed files with 23 additions and 7 deletions

View File

@@ -3,6 +3,9 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple
import tyro import tyro
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from loguru import logger from loguru import logger
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
@@ -130,7 +133,9 @@ class RoutingGraph(GraphBase):
inp = {"messages":[ inp = {"messages":[
SystemMessage( SystemMessage(
"You must use tool to complete the possible task" "You must use tool to complete the possible task"
),self._get_human_msg(state) ),
# self._get_human_msg(state)
*state["inp"][0][1:]
]}, state["inp"][1] ]}, state["inp"][1]
out = self.tool_model.invoke(*inp) out = self.tool_model.invoke(*inp)
@@ -161,3 +166,8 @@ class RoutingGraph(GraphBase):
workflow = builder.compile() workflow = builder.compile()
return workflow return workflow
def show_graph(self):
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
plt.imshow(img)
plt.show()

View File

@@ -14,8 +14,8 @@ from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import InstantiateConfig, KeyConfig
from lang_agent.graphs import AnnotatedGraph from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.graphs.react import ReactGraph, ReactGraphConfig from lang_agent.base import GraphBase
@@ -43,7 +43,8 @@ class PipelineConfig(KeyConfig):
"""what is my port""" """what is my port"""
# graph_config: ReactGraphConfig = field(default_factory=ReactGraphConfig) # graph_config: ReactGraphConfig = field(default_factory=ReactGraphConfig)
graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
@@ -65,9 +66,14 @@ class Pipeline:
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
self.config.graph_config.api_key = self.config.api_key self.config.graph_config.api_key = self.config.api_key
self.graph:ReactGraph = self.config.graph_config.setup() self.graph:GraphBase = self.config.graph_config.setup()
def show_graph(self):
if hasattr(self.graph, "show_graph"):
logger.info("showing graph")
self.graph.show_graph()
else:
logger.info(f"show graph not supported for {type(self.graph)}")
def invoke(self, *nargs, **kwargs): def invoke(self, *nargs, **kwargs):
return self.graph.invoke(*nargs, **kwargs) return self.graph.invoke(*nargs, **kwargs)