Files
lang-agent/lang_agent/base.py
2026-01-05 22:41:40 +08:00

172 lines
6.3 KiB
Python

from typing import List, Callable, Tuple, Dict, AsyncIterator
from abc import ABC, abstractmethod
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from loguru import logger
import jax
from langgraph.graph.state import CompiledStateGraph
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages.base import BaseMessageChunk
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
class LangToolBase(ABC):
"""
class to inherit if to create a new local tool
"""
@abstractmethod
def get_tool_fnc(self)->List[Callable]:
pass
class GraphBase(ABC):
workflow: CompiledStateGraph
streamable_tags: List[List[str]]
def _stream_result(self, *nargs, **kwargs):
def text_iterator():
for chunk, metadata in self.workflow.stream({"inp": nargs},
stream_mode="messages",
subgraphs=True,
**kwargs):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in self.streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = TextReleaser(*self.tool_node.get_delay_keys())
logger.info("streaming output")
for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow
print("\033[93m====================STREAM OUTPUT=============================\033[0m")
return self._stream_result(*nargs, **kwargs)
else:
state = self.workflow.invoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
"""Async version of invoke using LangGraph's native async support."""
self._validate_input(*nargs, **kwargs)
if as_stream:
# Stream messages from the workflow asynchronously
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
return self._astream_result(*nargs, **kwargs)
else:
state = await self.workflow.ainvoke({"inp": nargs})
msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw:
return msg_list
return msg_list[-1].content
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
"""Async streaming using LangGraph's astream method."""
async def text_iterator():
async for chunk, metadata in self.workflow.astream(
{"inp": nargs},
stream_mode="messages",
subgraphs=True,
**kwargs
):
if isinstance(metadata, tuple):
chunk, metadata = metadata
tags = metadata.get("tags")
if not (tags in self.streamable_tags):
continue
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
yield chunk.content
text_releaser = AsyncTextReleaser(*self.tool_node.get_delay_keys())
async for chunk in text_releaser.release(text_iterator()):
print(f"\033[92m{chunk}\033[0m", end="", flush=True)
yield chunk
def _validate_input(self, *nargs, **kwargs):
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
for e in nargs[0]["messages"]:
if isinstance(e, HumanMessage):
e.pretty_print()
print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m")
print(f"\033[93m model used: {self.config.llm_name}\033[0m")
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def show_graph(self, ret_img:bool=False):
#NOTE: just a useful tool for debugging; has zero useful functionality
err_str = f"{type(self)} does not have workflow, this is unsupported"
assert hasattr(self, "workflow"), err_str
logger.info("creating image")
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
if not ret_img:
plt.imshow(img)
plt.show()
else:
return img
class ToolNodeBase(GraphBase):
@abstractmethod
def get_streamable_tags(self)->List[List[str]]:
"""
returns names of llm model to listen to when streaming
NOTE: must be [['A1'], ['A2'] ...]
"""
return [["tool_llm"]]
def get_delay_keys(self)->Tuple[str, str]:
"""
returns 2 words, one for starting delayed yeilding, the other for ending delayed yielding,
they should be of format ('[key1]', '[key2]'); key1 is starting, key2 is ending
"""
return None, None
@abstractmethod
def invoke(self, inp)->Dict[str, List[BaseMessage]]:
pass
async def ainvoke(self, inp)->Dict[str, List[BaseMessage]]:
"""Async version of invoke. Subclasses should override for true async support."""
raise NotImplementedError("Subclass should implement ainvoke for async support")