226 lines
8.4 KiB
Python
226 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import List, Callable, Tuple, Dict, AsyncIterator, TYPE_CHECKING
|
|
from abc import ABC, abstractmethod
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import matplotlib.pyplot as plt
|
|
from loguru import logger
|
|
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
|
from langchain_core.messages.base import BaseMessageChunk
|
|
|
|
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
|
|
from lang_agent.utils import tree_leaves
|
|
|
|
if TYPE_CHECKING:
|
|
from lang_agent.graphs.graph_states import State
|
|
|
|
|
|
class LangToolBase(ABC):
|
|
"""
|
|
class to inherit if to create a new local tool
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_tool_fnc(self)->List[Callable]:
|
|
pass
|
|
|
|
|
|
class GraphBase(ABC):
|
|
workflow: CompiledStateGraph # the main workflow
|
|
streamable_tags: List[List[str]] # which llm to stream outputs; see routing.py for complex usage
|
|
textreleaser_delay_keys: List[str] = (None, None) # use to control when to start streaming; see routing.py for complex usage
|
|
|
|
def _stream_result(self, *nargs, **kwargs):
|
|
|
|
def text_iterator():
|
|
for _, mode, out in self.workflow.stream({"inp": nargs},
|
|
stream_mode=["messages", "values"],
|
|
subgraphs=True,
|
|
**kwargs):
|
|
if mode == "values":
|
|
val = out.get("messages")
|
|
if val is not None:
|
|
yield val
|
|
continue
|
|
|
|
chunk, metadata = out
|
|
tags = metadata.get("tags")
|
|
if not (tags in self.streamable_tags):
|
|
continue
|
|
|
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
|
yield chunk.content
|
|
|
|
text_releaser = TextReleaser(*self.textreleaser_delay_keys)
|
|
logger.info("streaming output")
|
|
for chunk in text_releaser.release(text_iterator()):
|
|
yield chunk
|
|
|
|
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
|
|
def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
self._validate_input(*nargs, **kwargs)
|
|
|
|
if as_stream:
|
|
# Stream messages from the workflow
|
|
print("\033[93m====================STREAM OUTPUT=============================\033[0m")
|
|
return self._stream_result(*nargs, **kwargs)
|
|
else:
|
|
state = self.workflow.invoke({"inp": nargs})
|
|
|
|
msg_list = tree_leaves(state)
|
|
|
|
for e in msg_list:
|
|
if isinstance(e, BaseMessage):
|
|
e.pretty_print()
|
|
|
|
if as_raw:
|
|
return msg_list
|
|
|
|
return msg_list[-1].content
|
|
|
|
# NOTE: DEFAULT IMPLEMENTATION; Overide to support your class
|
|
async def ainvoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs):
|
|
"""Async version of invoke using LangGraph's native async support."""
|
|
self._validate_input(*nargs, **kwargs)
|
|
|
|
if as_stream:
|
|
# Stream messages from the workflow asynchronously
|
|
print("\033[93m====================ASYNC STREAM OUTPUT=============================\033[0m")
|
|
return self._astream_result(*nargs, **kwargs)
|
|
else:
|
|
state = await self.workflow.ainvoke({"inp": nargs})
|
|
|
|
msg_list = tree_leaves(state)
|
|
|
|
for e in msg_list:
|
|
if isinstance(e, BaseMessage):
|
|
e.pretty_print()
|
|
|
|
if as_raw:
|
|
return msg_list
|
|
|
|
return msg_list[-1].content
|
|
|
|
async def _astream_result(self, *nargs, **kwargs) -> AsyncIterator[str]:
|
|
"""Async streaming using LangGraph's astream method."""
|
|
|
|
async def text_iterator():
|
|
async for _, mode, out in self.workflow.astream({"inp": nargs},
|
|
stream_mode=["messages", "values"],
|
|
subgraphs=True,
|
|
**kwargs):
|
|
if mode == "values":
|
|
val = out.get("messages")
|
|
if val is not None:
|
|
yield val
|
|
continue
|
|
|
|
chunk, metadata = out
|
|
tags = metadata.get("tags")
|
|
if not (tags in self.streamable_tags):
|
|
continue
|
|
|
|
if isinstance(chunk, (BaseMessageChunk, BaseMessage)) and getattr(chunk, "content", None):
|
|
yield chunk.content
|
|
|
|
text_releaser = AsyncTextReleaser(*self.textreleaser_delay_keys)
|
|
logger.info("streaming output")
|
|
async for chunk in text_releaser.release(text_iterator()):
|
|
yield chunk
|
|
|
|
def _validate_input(self, *nargs, **kwargs):
|
|
print("\033[93m====================INPUT HUMAN MESSAGES=============================\033[0m")
|
|
for e in nargs[0]["messages"]:
|
|
if isinstance(e, HumanMessage):
|
|
e.pretty_print()
|
|
print("\033[93m====================END INPUT HUMAN MESSAGES=============================\033[0m")
|
|
print(f"\033[93m model used: {self.config.llm_name}\033[0m")
|
|
|
|
assert len(kwargs) == 0, "due to inp assumptions"
|
|
|
|
def _get_inp_msgs(self, state:State):
|
|
msgs = state["inp"][0]["messages"]
|
|
return [e for e in msgs if not isinstance(e, SystemMessage)]
|
|
|
|
def _agent_call_template(self, system_prompt:str,
|
|
model:CompiledStateGraph,
|
|
state:State,
|
|
human_msg:str = None):
|
|
if state.get("messages") is not None:
|
|
inp = state["messages"], state["inp"][1]
|
|
else:
|
|
inp = state["inp"]
|
|
|
|
messages = [
|
|
SystemMessage(system_prompt),
|
|
*self._get_inp_msgs(state)
|
|
]
|
|
|
|
if human_msg is not None:
|
|
messages.append(HumanMessage(human_msg))
|
|
|
|
inp = ({"messages": messages}, state["inp"][1])
|
|
|
|
|
|
out = model.invoke(*inp)
|
|
return {"messages": out}
|
|
|
|
def show_graph(self, ret_img:bool=False):
|
|
#NOTE: just a useful tool for debugging; has zero useful functionality
|
|
|
|
err_str = f"{type(self)} does not have workflow, this is unsupported"
|
|
assert hasattr(self, "workflow"), err_str
|
|
|
|
logger.info("creating image")
|
|
img = Image.open(BytesIO(self.workflow.get_graph().draw_mermaid_png()))
|
|
|
|
if not ret_img:
|
|
plt.imshow(img)
|
|
plt.show()
|
|
else:
|
|
return img
|
|
|
|
def clear_memory(self):
|
|
# Clear the agent's (LangChain) memory if available
|
|
if hasattr(self, "memory") and self.memory is not None:
|
|
if isinstance(self.memory, MemorySaver):
|
|
for thread_id in self.memory.storage:
|
|
self.memory.delete_thread(thread_id)
|
|
|
|
async def aclear_memory(self):
|
|
# Clear the agent's (LangChain) memory if available (async version)
|
|
if hasattr(self, "memory") and self.memory is not None:
|
|
if isinstance(self.memory, MemorySaver):
|
|
for thread_id in self.memory.storage:
|
|
await self.memory.adelete_thread(thread_id)
|
|
|
|
|
|
|
|
class ToolNodeBase(GraphBase):
|
|
@abstractmethod
|
|
def get_streamable_tags(self)->List[List[str]]:
|
|
"""
|
|
returns names of llm model to listen to when streaming
|
|
NOTE: must be [['A1'], ['A2'] ...]
|
|
"""
|
|
return [["tool_llm"]]
|
|
|
|
def get_delay_keys(self)->Tuple[str, str]:
|
|
"""
|
|
returns 2 words, one for starting delayed yeilding, the other for ending delayed yielding,
|
|
they should be of format ('[key1]', '[key2]'); key1 is starting, key2 is ending
|
|
"""
|
|
return None, None
|
|
|
|
@abstractmethod
|
|
def invoke(self, inp)->Dict[str, List[BaseMessage]]:
|
|
pass
|
|
|
|
async def ainvoke(self, inp)->Dict[str, List[BaseMessage]]:
|
|
"""Async version of invoke. Subclasses should override for true async support."""
|
|
raise NotImplementedError("Subclass should implement ainvoke for async support") |