diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index c88730e..c040e63 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass, field, is_dataclass -from typing import Type, List, Callable, Any, AsyncIterator +from dataclasses import dataclass, field +from typing import Type import tyro +import os.path as osp from lang_agent.config import KeyConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig @@ -24,12 +25,18 @@ class ReactGraphConfig(KeyConfig): llm_provider:str = "openai" """provider of the llm""" + sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt") + """path to system prompt""" + base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1" """base url; could be used to overwrite the baseurl in llm provider""" tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) - + def __post_init__(self): + super().__post_init__() + err_msg = f"{self.sys_prompt_f} does not exist" + assert osp.exists(self.sys_prompt_f), err_msg class ReactGraph(GraphBase): @@ -50,6 +57,21 @@ class ReactGraph(GraphBase): self.memory = MemorySaver() tools = self.tool_manager.get_langchain_tools() self.agent = create_agent(self.llm, tools, checkpointer=self.memory) + + with open(self.config.sys_prompt_f, "r") as f: + self.sys_prompt = f.read() + + def _get_human_msg(self, *nargs): + msgs = nargs[0]["messages"] + + candidate_hum_msg = None + for msg in msgs: + if isinstance(msg, HumanMessage): + candidate_hum_msg = msg + + assert isinstance(candidate_hum_msg, HumanMessage), "not a human message" + + return candidate_hum_msg def invoke(self, *nargs, as_stream:bool=False, as_raw:bool=False, **kwargs): """