From 4c106a597a6393feee633050c6c5cb74971062b1 Mon Sep 17 00:00:00 2001 From: goulustis Date: Tue, 6 Jan 2026 00:22:46 +0800 Subject: [PATCH] type checking --- lang_agent/base.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/lang_agent/base.py b/lang_agent/base.py index 101a425..9c6f97f 100644 --- a/lang_agent/base.py +++ b/lang_agent/base.py @@ -1,4 +1,6 @@ -from typing import List, Callable, Tuple, Dict, AsyncIterator +from __future__ import annotations + +from typing import List, Callable, Tuple, Dict, AsyncIterator, TYPE_CHECKING from abc import ABC, abstractmethod from PIL import Image from io import BytesIO @@ -7,11 +9,14 @@ from loguru import logger import jax from langgraph.graph.state import CompiledStateGraph -from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.messages.base import BaseMessageChunk from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser +if TYPE_CHECKING: + from lang_agent.graphs.graph_states import State + class LangToolBase(ABC): """ @@ -130,6 +135,25 @@ class GraphBase(ABC): assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(kwargs) == 0, "due to inp assumptions" + + def _agent_call_template(self, system_prompt:str, + model:CompiledStateGraph, + state:State): + if state.get("messages") is not None: + inp = state["messages"], state["inp"][1] + else: + inp = state["inp"] + + inp = {"messages":[ + SystemMessage( + system_prompt + ), + *state["inp"][0]["messages"][1:] + ]}, state["inp"][1] + + + out = model.invoke(*inp) + return {"messages": out} def show_graph(self, ret_img:bool=False): #NOTE: just a useful tool for debugging; has zero useful functionality