better prints

This commit is contained in:
2025-10-22 20:06:23 +08:00
parent 29fb7fd927
commit d5e8eb4781

View File

@@ -6,13 +6,14 @@ from loguru import logger
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import jax
from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent
from langgraph.graph import StateGraph, START, END
@@ -65,14 +66,21 @@ class RoutingGraph(GraphBase):
if as_stream:
# TODO this doesn't stream the entire process, we are blind
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
if "messages" in step:
step["messages"]["messages"][-1].pretty_print()
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
# if "messages" in step:
# step["messages"]["messages"][-1].pretty_print()
last_el = jax.tree.leaves(step)[-1]
if isinstance(last_el, str):
logger.info(last_el)
elif isinstance(last_el, BaseMessage):
last_el.pretty_print()
state = step
else:
state = self.workflow.invoke({"inp": nargs})
return state["messages"]
# return state["messages"]
return jax.tree.leaves(state)[-1].content
def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name,
@@ -135,7 +143,7 @@ class RoutingGraph(GraphBase):
"You must use tool to complete the possible task"
),
# self._get_human_msg(state)
*state["inp"][0][1:]
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = self.tool_model.invoke(*inp)