use prompt store
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, TypedDict, Literal, Dict, List
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Optional
|
||||
import tyro
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
@@ -9,6 +9,7 @@ from langchain.chat_models import init_chat_model
|
||||
from lang_agent.config import LLMKeyConfig
|
||||
from lang_agent.base import GraphBase
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.graphs.graph_states import State
|
||||
|
||||
from langchain.agents import create_agent
|
||||
@@ -50,6 +51,12 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
|
||||
class DualConfig(LLMKeyConfig):
|
||||
_target: Type = field(default_factory=lambda:Dual)
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with hardcoded fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
from langchain.tools import tool
|
||||
@@ -96,14 +103,23 @@ class Dual(GraphBase):
|
||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
|
||||
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
|
||||
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
hardcoded={
|
||||
"sys_prompt": SYS_PROMPT,
|
||||
"tool_sys_prompt": TOOL_SYS_PROMPT,
|
||||
},
|
||||
)
|
||||
|
||||
self.streamable_tags = [["dual_chat_llm"]]
|
||||
|
||||
|
||||
def _chat_call(self, state:State):
|
||||
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state)
|
||||
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
|
||||
|
||||
def _tool_call(self, state:State):
|
||||
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state)
|
||||
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
|
||||
return {}
|
||||
|
||||
def _join(self, state:State):
|
||||
|
||||
Reference in New Issue
Block a user