better prints
This commit is contained in:
@@ -6,13 +6,14 @@ from loguru import logger
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import jax
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
@@ -65,14 +66,21 @@ class RoutingGraph(GraphBase):
|
|||||||
|
|
||||||
if as_stream:
|
if as_stream:
|
||||||
# TODO: this doesn't stream the entire process, we are blind
|
# TODO: this doesn't stream the entire process, we are blind
|
||||||
for step in self.workflow.stream({"inp": nargs}, stream_mode="values", **kwargs):
|
for step in self.workflow.stream({"inp": nargs}, stream_mode="updates", **kwargs):
|
||||||
if "messages" in step:
|
# if "messages" in step:
|
||||||
step["messages"]["messages"][-1].pretty_print()
|
# step["messages"]["messages"][-1].pretty_print()
|
||||||
|
last_el = jax.tree.leaves(step)[-1]
|
||||||
|
if isinstance(last_el, str):
|
||||||
|
logger.info(last_el)
|
||||||
|
elif isinstance(last_el, BaseMessage):
|
||||||
|
last_el.pretty_print()
|
||||||
|
|
||||||
state = step
|
state = step
|
||||||
else:
|
else:
|
||||||
state = self.workflow.invoke({"inp": nargs})
|
state = self.workflow.invoke({"inp": nargs})
|
||||||
|
|
||||||
return state["messages"]
|
# return state["messages"]
|
||||||
|
return jax.tree.leaves(state)[-1].content
|
||||||
|
|
||||||
def _build_modules(self):
|
def _build_modules(self):
|
||||||
self.llm = init_chat_model(model=self.config.llm_name,
|
self.llm = init_chat_model(model=self.config.llm_name,
|
||||||
@@ -135,7 +143,7 @@ class RoutingGraph(GraphBase):
|
|||||||
"You must use tool to complete the possible task"
|
"You must use tool to complete the possible task"
|
||||||
),
|
),
|
||||||
# self._get_human_msg(state)
|
# self._get_human_msg(state)
|
||||||
*state["inp"][0][1:]
|
*state["inp"][0]["messages"][1:]
|
||||||
]}, state["inp"][1]
|
]}, state["inp"][1]
|
||||||
|
|
||||||
out = self.tool_model.invoke(*inp)
|
out = self.tool_model.invoke(*inp)
|
||||||
|
|||||||
Reference in New Issue
Block a user