use prompt store
This commit is contained in:
248
lang_agent/components/prompt_store.py
Normal file
248
lang_agent/components/prompt_store.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from typing import Dict, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import os.path as osp
|
||||
import glob
|
||||
import commentjson
|
||||
import psycopg
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class PromptStoreBase(ABC):
|
||||
"""Interface for getting prompts by key."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> str:
|
||||
"""Get a prompt by key. Raises KeyError if not found."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self) -> Dict[str, str]:
|
||||
"""Get all available prompts as {key: content}."""
|
||||
...
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
try:
|
||||
self.get(key)
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
|
||||
class FilePromptStore(PromptStoreBase):
|
||||
"""
|
||||
Loads prompts from files — preserves existing behavior exactly.
|
||||
|
||||
Supports:
|
||||
- A directory of .txt files (key = filename without extension)
|
||||
- A single .json file (keys from JSON object)
|
||||
- A single .txt file (stored under a provided default_key)
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, default_key: str = "sys_prompt"):
|
||||
self._prompts: Dict[str, str] = {}
|
||||
self._load(path, default_key)
|
||||
|
||||
def _load(self, path: str, default_key: str):
|
||||
if not path or not osp.exists(path):
|
||||
logger.warning(f"Prompt path does not exist: {path}")
|
||||
return
|
||||
|
||||
if osp.isdir(path):
|
||||
# Directory of .txt files — same as RoutingGraph._load_sys_prompts()
|
||||
sys_fs = glob.glob(osp.join(path, "*.txt"))
|
||||
sys_fs = sorted([e for e in sys_fs if "optional" not in e])
|
||||
for sys_f in sys_fs:
|
||||
key = osp.basename(sys_f).split(".")[0]
|
||||
with open(sys_f, "r") as f:
|
||||
self._prompts[key] = f.read()
|
||||
|
||||
elif path.endswith(".json"):
|
||||
# JSON file — same as RoutingGraph._load_sys_prompts()
|
||||
with open(path, "r") as f:
|
||||
self._prompts = commentjson.load(f)
|
||||
|
||||
elif path.endswith(".txt"):
|
||||
# Single text file — same as ReactGraph / ToolNode
|
||||
with open(path, "r") as f:
|
||||
self._prompts[default_key] = f.read()
|
||||
else:
|
||||
raise ValueError(f"Unsupported prompt path format: {path}")
|
||||
|
||||
for k in self._prompts:
|
||||
logger.info(f"FilePromptStore loaded: '{k}'")
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
if key not in self._prompts:
|
||||
raise KeyError(f"Prompt '{key}' not found in file store")
|
||||
return self._prompts[key]
|
||||
|
||||
def get_all(self) -> Dict[str, str]:
|
||||
return dict(self._prompts)
|
||||
|
||||
|
||||
class DBPromptStore(PromptStoreBase):
|
||||
"""
|
||||
Loads prompts from PostgreSQL via prompt_sets + prompt_templates tables.
|
||||
|
||||
Schema:
|
||||
prompt_sets (id, pipeline_id, name, is_active, ...)
|
||||
prompt_templates (id, prompt_set_id FK, prompt_key, content, ...)
|
||||
|
||||
By default loads from the active prompt set for the given pipeline_id.
|
||||
A specific prompt_set_id can be provided to target a non-active set
|
||||
(useful for previewing or A/B testing).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_id: str,
|
||||
prompt_set_id: str = None,
|
||||
conn_str: str = None,
|
||||
):
|
||||
self.pipeline_id = pipeline_id
|
||||
self.prompt_set_id = prompt_set_id
|
||||
self.conn_str = conn_str or os.environ.get("CONN_STR")
|
||||
if not self.conn_str:
|
||||
raise ValueError("CONN_STR not set for DBPromptStore")
|
||||
self._cache: Optional[Dict[str, str]] = None # lazy loaded
|
||||
|
||||
def _load(self):
|
||||
"""Load all prompts for the active (or specified) prompt set from DB."""
|
||||
if self._cache is not None:
|
||||
return
|
||||
self._cache = {}
|
||||
try:
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
if self.prompt_set_id:
|
||||
# Load from a specific prompt set
|
||||
cur.execute(
|
||||
"SELECT prompt_key, content FROM prompt_templates "
|
||||
"WHERE prompt_set_id = %s",
|
||||
(self.prompt_set_id,),
|
||||
)
|
||||
else:
|
||||
# Load from the active prompt set for this pipeline
|
||||
cur.execute(
|
||||
"SELECT pt.prompt_key, pt.content "
|
||||
"FROM prompt_templates pt "
|
||||
"JOIN prompt_sets ps ON pt.prompt_set_id = ps.id "
|
||||
"WHERE ps.pipeline_id = %s AND ps.is_active = true",
|
||||
(self.pipeline_id,),
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
self._cache[row[0]] = row[1]
|
||||
source = f"set '{self.prompt_set_id}'" if self.prompt_set_id else "active set"
|
||||
logger.info(
|
||||
f"DBPromptStore loaded {len(self._cache)} prompts for pipeline "
|
||||
f"'{self.pipeline_id}' ({source})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"DBPromptStore failed to load: {e}")
|
||||
self._cache = {}
|
||||
|
||||
def invalidate_cache(self):
|
||||
"""Force reload on next access (call after prompt update via API)."""
|
||||
self._cache = None
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
self._load()
|
||||
if key not in self._cache:
|
||||
raise KeyError(
|
||||
f"Prompt '{key}' not in DB for pipeline '{self.pipeline_id}'"
|
||||
)
|
||||
return self._cache[key]
|
||||
|
||||
def get_all(self) -> Dict[str, str]:
|
||||
self._load()
|
||||
return dict(self._cache)
|
||||
|
||||
|
||||
class FallbackPromptStore(PromptStoreBase):
|
||||
"""
|
||||
Tries primary store (DB) first, falls back to secondary (files).
|
||||
This is the main store graphs should use.
|
||||
"""
|
||||
|
||||
def __init__(self, primary: PromptStoreBase, fallback: PromptStoreBase):
|
||||
self.primary = primary
|
||||
self.fallback = fallback
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
try:
|
||||
val = self.primary.get(key)
|
||||
logger.debug(f"Prompt '{key}' resolved from primary store")
|
||||
return val
|
||||
except KeyError:
|
||||
logger.debug(f"Prompt '{key}' not in primary, trying fallback")
|
||||
return self.fallback.get(key)
|
||||
|
||||
def get_all(self) -> Dict[str, str]:
|
||||
merged = self.fallback.get_all()
|
||||
merged.update(self.primary.get_all()) # primary overrides fallback
|
||||
return merged
|
||||
|
||||
|
||||
class HardcodedPromptStore(PromptStoreBase):
|
||||
"""For graphs that currently use module-level constants."""
|
||||
|
||||
def __init__(self, prompts: Dict[str, str]):
|
||||
self._prompts = prompts
|
||||
|
||||
def get(self, key: str) -> str:
|
||||
if key not in self._prompts:
|
||||
raise KeyError(f"Prompt '{key}' not in hardcoded store")
|
||||
return self._prompts[key]
|
||||
|
||||
def get_all(self) -> Dict[str, str]:
|
||||
return dict(self._prompts)
|
||||
|
||||
|
||||
def build_prompt_store(
|
||||
pipeline_id: Optional[str] = None,
|
||||
prompt_set_id: Optional[str] = None,
|
||||
file_path: Optional[str] = None,
|
||||
default_key: str = "sys_prompt",
|
||||
hardcoded: Optional[Dict[str, str]] = None,
|
||||
) -> PromptStoreBase:
|
||||
"""
|
||||
Factory function — builds the right prompt store based on what's provided.
|
||||
|
||||
Priority: DB (if pipeline_id) > Files (if file_path) > Hardcoded
|
||||
|
||||
When pipeline_id is None (default), DB layer is skipped entirely and
|
||||
existing file-based / hardcoded behavior is preserved.
|
||||
|
||||
Args:
|
||||
pipeline_id: Loads from the active prompt_set for this pipeline.
|
||||
prompt_set_id: If provided, loads from this specific prompt set
|
||||
instead of the active one (useful for preview / A/B).
|
||||
file_path: Path to file or directory for file-based fallback.
|
||||
default_key: Key name when file_path points to a single .txt file.
|
||||
hardcoded: Dict of prompt_key → content as last-resort defaults.
|
||||
"""
|
||||
stores = []
|
||||
|
||||
if prompt_set_id:
|
||||
try:
|
||||
stores.append(DBPromptStore(pipeline_id, prompt_set_id=prompt_set_id))
|
||||
except ValueError:
|
||||
logger.warning("CONN_STR not set, skipping DB prompt store")
|
||||
|
||||
if file_path and osp.exists(file_path):
|
||||
stores.append(FilePromptStore(file_path, default_key))
|
||||
|
||||
if hardcoded:
|
||||
stores.append(HardcodedPromptStore(hardcoded))
|
||||
|
||||
if not stores:
|
||||
raise ValueError("No prompt source available")
|
||||
|
||||
# Chain them: first store is highest priority
|
||||
result = stores[-1]
|
||||
for store in reversed(stores[:-1]):
|
||||
result = FallbackPromptStore(primary=store, fallback=result)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user