type checking
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
from typing import List, Callable, Tuple, Dict, AsyncIterator
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import List, Callable, Tuple, Dict, AsyncIterator, TYPE_CHECKING
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -7,11 +9,14 @@ from loguru import logger
|
|||||||
import jax
|
import jax
|
||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
from langchain_core.messages import BaseMessage, HumanMessage
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.messages.base import BaseMessageChunk
|
from langchain_core.messages.base import BaseMessageChunk
|
||||||
|
|
||||||
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
|
from lang_agent.components.text_releaser import TextReleaser, AsyncTextReleaser
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lang_agent.graphs.graph_states import State
|
||||||
|
|
||||||
|
|
||||||
class LangToolBase(ABC):
|
class LangToolBase(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -130,6 +135,25 @@ class GraphBase(ABC):
|
|||||||
|
|
||||||
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
|
||||||
assert len(kwargs) == 0, "due to inp assumptions"
|
assert len(kwargs) == 0, "due to inp assumptions"
|
||||||
|
|
||||||
|
def _agent_call_template(self, system_prompt:str,
|
||||||
|
model:CompiledStateGraph,
|
||||||
|
state:State):
|
||||||
|
if state.get("messages") is not None:
|
||||||
|
inp = state["messages"], state["inp"][1]
|
||||||
|
else:
|
||||||
|
inp = state["inp"]
|
||||||
|
|
||||||
|
inp = {"messages":[
|
||||||
|
SystemMessage(
|
||||||
|
system_prompt
|
||||||
|
),
|
||||||
|
*state["inp"][0]["messages"][1:]
|
||||||
|
]}, state["inp"][1]
|
||||||
|
|
||||||
|
|
||||||
|
out = model.invoke(*inp)
|
||||||
|
return {"messages": out}
|
||||||
|
|
||||||
def show_graph(self, ret_img:bool=False):
|
def show_graph(self, ret_img:bool=False):
|
||||||
#NOTE: just a useful tool for debugging; has zero useful functionality
|
#NOTE: just a useful tool for debugging; has zero useful functionality
|
||||||
|
|||||||
Reference in New Issue
Block a user