type checking

This commit is contained in:
2026-01-06 00:22:46 +08:00
parent f1b8da8209
commit 4c106a597a

View File

@@ -1,4 +1,6 @@
from typing import List, Callable, Tuple, Dict, AsyncIterator
from __future__ import annotations
from typing import List, Callable, Tuple, Dict, AsyncIterator, TYPE_CHECKING
from abc import ABC, abstractmethod
from PIL import Image
from io import BytesIO
@@ -7,11 +9,14 @@ from loguru import logger
import jax
from langgraph.graph.state import CompiledStateGraph
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessageChunk
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
if TYPE_CHECKING:
from lang_agent.graphs.graph_states import State
class LangToolBase(ABC):
"""
@@ -130,6 +135,25 @@ class GraphBase(ABC):
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions"
def _agent_call_template(self, system_prompt:str,
model:CompiledStateGraph,
state:State):
if state.get("messages") is not None:
inp = state["messages"], state["inp"][1]
else:
inp = state["inp"]
inp = {"messages":[
SystemMessage(
system_prompt
),
*state["inp"][0]["messages"][1:]
]}, state["inp"][1]
out = model.invoke(*inp)
return {"messages": out}
def show_graph(self, ret_img:bool=False):
#NOTE: just a useful tool for debugging; has zero useful functionality