use prompt store
This commit is contained in:
@@ -6,7 +6,7 @@ Vision-enabled routing graph that:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type, TypedDict, List, Dict, Any, Tuple
|
||||
from typing import Type, TypedDict, List, Dict, Any, Tuple, Optional
|
||||
import tyro
|
||||
import base64
|
||||
import json
|
||||
@@ -14,6 +14,7 @@ from loguru import logger
|
||||
|
||||
from lang_agent.config import LLMKeyConfig
|
||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||
from lang_agent.components.prompt_store import build_prompt_store
|
||||
from lang_agent.base import GraphBase, ToolNodeBase
|
||||
from lang_agent.components.client_tool_manager import ClientToolManagerConfig
|
||||
|
||||
@@ -104,6 +105,12 @@ class VisionRoutingConfig(LLMKeyConfig):
|
||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
"""base url for API"""
|
||||
|
||||
pipeline_id: Optional[str] = None
|
||||
"""If set, load prompts from database (with hardcoded fallback)"""
|
||||
|
||||
prompt_set_id: Optional[str] = None
|
||||
"""If set, load from this specific prompt set instead of the active one"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
|
||||
|
||||
|
||||
@@ -167,6 +174,17 @@ class VisionRoutingGraph(GraphBase):
|
||||
# Create tool node for executing tools
|
||||
self.tool_node = ToolNode(self.camera_tools)
|
||||
|
||||
# Build prompt store: DB (if pipeline_id) > hardcoded defaults
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
hardcoded={
|
||||
"camera_decision_prompt": CAMERA_DECISION_PROMPT,
|
||||
"vision_description_prompt": VISION_DESCRIPTION_PROMPT,
|
||||
"conversation_prompt": CONVERSATION_PROMPT,
|
||||
},
|
||||
)
|
||||
|
||||
def _get_human_msg(self, state: VisionRoutingState) -> HumanMessage:
|
||||
"""Get user message from current invocation"""
|
||||
msgs = state["inp"][0]["messages"]
|
||||
@@ -180,7 +198,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=CAMERA_DECISION_PROMPT),
|
||||
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
|
||||
human_msg
|
||||
]
|
||||
|
||||
@@ -279,7 +297,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=VISION_DESCRIPTION_PROMPT),
|
||||
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
|
||||
vision_message
|
||||
]
|
||||
|
||||
@@ -292,7 +310,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=CONVERSATION_PROMPT),
|
||||
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
|
||||
human_msg
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user