use prompt store

This commit is contained in:
2026-02-10 10:54:58 +08:00
parent ede7199dfc
commit cb5b3afd05
6 changed files with 130 additions and 38 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Type, TypedDict, Literal, Dict, List
from typing import Type, TypedDict, Literal, Dict, List, Optional
import tyro
from pydantic import BaseModel, Field
from loguru import logger
@@ -9,6 +9,7 @@ from langchain.chat_models import init_chat_model
from lang_agent.config import LLMKeyConfig
from lang_agent.base import GraphBase
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.graphs.graph_states import State
from langchain.agents import create_agent
@@ -50,6 +51,12 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
class DualConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda:Dual)
pipeline_id: Optional[str] = None
"""If set, load prompts from database (with hardcoded fallback)"""
prompt_set_id: Optional[str] = None
"""If set, load from this specific prompt set instead of the active one"""
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
from langchain.tools import tool
@@ -96,14 +103,23 @@ class Dual(GraphBase):
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
self.prompt_store = build_prompt_store(
pipeline_id=self.config.pipeline_id,
prompt_set_id=self.config.prompt_set_id,
hardcoded={
"sys_prompt": SYS_PROMPT,
"tool_sys_prompt": TOOL_SYS_PROMPT,
},
)
self.streamable_tags = [["dual_chat_llm"]]
def _chat_call(self, state:State):
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state)
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
def _tool_call(self, state:State):
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state)
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
return {}
def _join(self, state:State):