better prints

This commit is contained in:
2025-10-22 20:06:23 +08:00
parent 29fb7fd927
commit d5e8eb4781

View File

@@ -6,13 +6,14 @@ from loguru import logger
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import jax
from lang_agent.config import KeyConfig from lang_agent.config import KeyConfig
from lang_agent.tool_manager import ToolManager, ToolManagerConfig from lang_agent.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.graph import StateGraph, START, END from langgraph.graph import StateGraph, START, END
@@ -65,14 +66,21 @@ class RoutingGraph(GraphBase):
if as_stream: if as_stream:
# TODO this doesn't stream the entire process, we are blind # TODO this doesn't stream the entire process, we are blind
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs): for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
if "messages" in step: # if "messages" in step:
step["messages"]["messages"][-1].pretty_print() # step["messages"]["messages"][-1].pretty_print()
last_el = jax.tree.leaves(step)[-1]
if isinstance(last_el, str):
logger.info(last_el)
elif isinstance(last_el, BaseMessage):
last_el.pretty_print()
state = step state = step
else: else:
state = self.workflow.invoke({"inp": nargs}) state = self.workflow.invoke({"inp": nargs})
return state["messages"] # return state["messages"]
return jax.tree.leaves(state)[-1].content
def _build_modules(self): def _build_modules(self):
self.llm = init_chat_model(model=self.config.llm_name, self.llm = init_chat_model(model=self.config.llm_name,
@@ -135,7 +143,7 @@ class RoutingGraph(GraphBase):
"You must use tool to complete the possible task" "You must use tool to complete the possible task"
), ),
# self._get_human_msg(state) # self._get_human_msg(state)
*state["inp"][0][1:] *state["inp"][0]["messages"][1:]
]}, state["inp"][1] ]}, state["inp"][1]
out = self.tool_model.invoke(*inp) out = self.tool_model.invoke(*inp)