inherit toolnodebase
This commit is contained in:
@@ -6,11 +6,10 @@ import os.path as osp
|
|||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import InstantiateConfig, KeyConfig
|
||||||
from lang_agent.tool_manager import ToolManager
|
from lang_agent.tool_manager import ToolManager
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import ToolNodeBase
|
||||||
from lang_agent.graphs.graph_states import State, ChattyToolState
|
from lang_agent.graphs.graph_states import State, ChattyToolState
|
||||||
from lang_agent.utils import make_llm
|
from lang_agent.utils import make_llm
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
@@ -26,7 +25,7 @@ class ToolNodeConfig(InstantiateConfig):
|
|||||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(GraphBase):
|
class ToolNode(ToolNodeBase):
|
||||||
def __init__(self, config: ToolNodeConfig,
|
def __init__(self, config: ToolNodeConfig,
|
||||||
tool_manager:ToolManager,
|
tool_manager:ToolManager,
|
||||||
memory:MemorySaver):
|
memory:MemorySaver):
|
||||||
@@ -71,7 +70,7 @@ class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
|||||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||||
|
|
||||||
|
|
||||||
class ChattyToolNode(GraphBase):
|
class ChattyToolNode(ToolNodeBase):
|
||||||
def __init__(self, config:ChattyToolNodeConfig,
|
def __init__(self, config:ChattyToolNodeConfig,
|
||||||
tool_manager:ToolManager,
|
tool_manager:ToolManager,
|
||||||
memory:MemorySaver):
|
memory:MemorySaver):
|
||||||
@@ -108,6 +107,9 @@ class ChattyToolNode(GraphBase):
|
|||||||
with open(self.config.tool_prompt_f, "r") as f:
|
with open(self.config.tool_prompt_f, "r") as f:
|
||||||
self.tool_sys_prompt = f.read()
|
self.tool_sys_prompt = f.read()
|
||||||
|
|
||||||
|
def get_streamable_tags(self):
|
||||||
|
return [["chatty_llm"], ["reit_llm"]]
|
||||||
|
|
||||||
def invoke(self, state:State):
|
def invoke(self, state:State):
|
||||||
self.tool_done = False
|
self.tool_done = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user