249 lines
8.4 KiB
Python
249 lines
8.4 KiB
Python
from typing import Dict, Optional
|
|
from abc import ABC, abstractmethod
|
|
import os
|
|
import os.path as osp
|
|
import glob
|
|
import commentjson
|
|
import psycopg
|
|
from loguru import logger
|
|
|
|
|
|
class PromptStoreBase(ABC):
|
|
"""Interface for getting prompts by key."""
|
|
|
|
@abstractmethod
|
|
def get(self, key: str) -> str:
|
|
"""Get a prompt by key. Raises KeyError if not found."""
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_all(self) -> Dict[str, str]:
|
|
"""Get all available prompts as {key: content}."""
|
|
...
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
try:
|
|
self.get(key)
|
|
return True
|
|
except KeyError:
|
|
return False
|
|
|
|
|
|
class FilePromptStore(PromptStoreBase):
|
|
"""
|
|
Loads prompts from files — preserves existing behavior exactly.
|
|
|
|
Supports:
|
|
- A directory of .txt files (key = filename without extension)
|
|
- A single .json file (keys from JSON object)
|
|
- A single .txt file (stored under a provided default_key)
|
|
"""
|
|
|
|
def __init__(self, path: str, default_key: str = "sys_prompt"):
|
|
self._prompts: Dict[str, str] = {}
|
|
self._load(path, default_key)
|
|
|
|
def _load(self, path: str, default_key: str):
|
|
if not path or not osp.exists(path):
|
|
logger.warning(f"Prompt path does not exist: {path}")
|
|
return
|
|
|
|
if osp.isdir(path):
|
|
# Directory of .txt files — same as RoutingGraph._load_sys_prompts()
|
|
sys_fs = glob.glob(osp.join(path, "*.txt"))
|
|
sys_fs = sorted([e for e in sys_fs if "optional" not in e])
|
|
for sys_f in sys_fs:
|
|
key = osp.basename(sys_f).split(".")[0]
|
|
with open(sys_f, "r") as f:
|
|
self._prompts[key] = f.read()
|
|
|
|
elif path.endswith(".json"):
|
|
# JSON file — same as RoutingGraph._load_sys_prompts()
|
|
with open(path, "r") as f:
|
|
self._prompts = commentjson.load(f)
|
|
|
|
elif path.endswith(".txt"):
|
|
# Single text file — same as ReactGraph / ToolNode
|
|
with open(path, "r") as f:
|
|
self._prompts[default_key] = f.read()
|
|
else:
|
|
raise ValueError(f"Unsupported prompt path format: {path}")
|
|
|
|
for k in self._prompts:
|
|
logger.info(f"FilePromptStore loaded: '{k}'")
|
|
|
|
def get(self, key: str) -> str:
|
|
if key not in self._prompts:
|
|
raise KeyError(f"Prompt '{key}' not found in file store")
|
|
return self._prompts[key]
|
|
|
|
def get_all(self) -> Dict[str, str]:
|
|
return dict(self._prompts)
|
|
|
|
|
|
class DBPromptStore(PromptStoreBase):
|
|
"""
|
|
Loads prompts from PostgreSQL via prompt_sets + prompt_templates tables.
|
|
|
|
Schema:
|
|
prompt_sets (id, pipeline_id, name, is_active, ...)
|
|
prompt_templates (id, prompt_set_id FK, prompt_key, content, ...)
|
|
|
|
By default loads from the active prompt set for the given pipeline_id.
|
|
A specific prompt_set_id can be provided to target a non-active set
|
|
(useful for previewing or A/B testing).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
pipeline_id: str,
|
|
prompt_set_id: str = None,
|
|
conn_str: str = None,
|
|
):
|
|
self.pipeline_id = pipeline_id
|
|
self.prompt_set_id = prompt_set_id
|
|
self.conn_str = conn_str or os.environ.get("CONN_STR")
|
|
if not self.conn_str:
|
|
raise ValueError("CONN_STR not set for DBPromptStore")
|
|
self._cache: Optional[Dict[str, str]] = None # lazy loaded
|
|
|
|
def _load(self):
|
|
"""Load all prompts for the active (or specified) prompt set from DB."""
|
|
if self._cache is not None:
|
|
return
|
|
self._cache = {}
|
|
try:
|
|
with psycopg.connect(self.conn_str) as conn:
|
|
with conn.cursor() as cur:
|
|
if self.prompt_set_id:
|
|
# Load from a specific prompt set
|
|
cur.execute(
|
|
"SELECT prompt_key, content FROM prompt_templates "
|
|
"WHERE prompt_set_id = %s",
|
|
(self.prompt_set_id,),
|
|
)
|
|
else:
|
|
# Load from the active prompt set for this pipeline
|
|
cur.execute(
|
|
"SELECT pt.prompt_key, pt.content "
|
|
"FROM prompt_templates pt "
|
|
"JOIN prompt_sets ps ON pt.prompt_set_id = ps.id "
|
|
"WHERE ps.pipeline_id = %s AND ps.is_active = true",
|
|
(self.pipeline_id,),
|
|
)
|
|
for row in cur.fetchall():
|
|
self._cache[row[0]] = row[1]
|
|
source = f"set '{self.prompt_set_id}'" if self.prompt_set_id else "active set"
|
|
logger.info(
|
|
f"DBPromptStore loaded {len(self._cache)} prompts for pipeline "
|
|
f"'{self.pipeline_id}' ({source})"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"DBPromptStore failed to load: {e}")
|
|
self._cache = {}
|
|
|
|
def invalidate_cache(self):
|
|
"""Force reload on next access (call after prompt update via API)."""
|
|
self._cache = None
|
|
|
|
def get(self, key: str) -> str:
|
|
self._load()
|
|
if key not in self._cache:
|
|
raise KeyError(
|
|
f"Prompt '{key}' not in DB for pipeline '{self.pipeline_id}'"
|
|
)
|
|
return self._cache[key]
|
|
|
|
def get_all(self) -> Dict[str, str]:
|
|
self._load()
|
|
return dict(self._cache)
|
|
|
|
|
|
class FallbackPromptStore(PromptStoreBase):
|
|
"""
|
|
Tries primary store (DB) first, falls back to secondary (files).
|
|
This is the main store graphs should use.
|
|
"""
|
|
|
|
def __init__(self, primary: PromptStoreBase, fallback: PromptStoreBase):
|
|
self.primary = primary
|
|
self.fallback = fallback
|
|
|
|
def get(self, key: str) -> str:
|
|
try:
|
|
val = self.primary.get(key)
|
|
logger.debug(f"Prompt '{key}' resolved from primary store")
|
|
return val
|
|
except KeyError:
|
|
logger.debug(f"Prompt '{key}' not in primary, trying fallback")
|
|
return self.fallback.get(key)
|
|
|
|
def get_all(self) -> Dict[str, str]:
|
|
merged = self.fallback.get_all()
|
|
merged.update(self.primary.get_all()) # primary overrides fallback
|
|
return merged
|
|
|
|
|
|
class HardcodedPromptStore(PromptStoreBase):
|
|
"""For graphs that currently use module-level constants."""
|
|
|
|
def __init__(self, prompts: Dict[str, str]):
|
|
self._prompts = prompts
|
|
|
|
def get(self, key: str) -> str:
|
|
if key not in self._prompts:
|
|
raise KeyError(f"Prompt '{key}' not in hardcoded store")
|
|
return self._prompts[key]
|
|
|
|
def get_all(self) -> Dict[str, str]:
|
|
return dict(self._prompts)
|
|
|
|
|
|
def build_prompt_store(
|
|
pipeline_id: Optional[str] = None,
|
|
prompt_set_id: Optional[str] = None,
|
|
file_path: Optional[str] = None,
|
|
default_key: str = "sys_prompt",
|
|
hardcoded: Optional[Dict[str, str]] = None,
|
|
) -> PromptStoreBase:
|
|
"""
|
|
Factory function — builds the right prompt store based on what's provided.
|
|
|
|
Priority: DB (if pipeline_id) > Files (if file_path) > Hardcoded
|
|
|
|
When pipeline_id is None (default), DB layer is skipped entirely and
|
|
existing file-based / hardcoded behavior is preserved.
|
|
|
|
Args:
|
|
pipeline_id: Loads from the active prompt_set for this pipeline.
|
|
prompt_set_id: If provided, loads from this specific prompt set
|
|
instead of the active one (useful for preview / A/B).
|
|
file_path: Path to file or directory for file-based fallback.
|
|
default_key: Key name when file_path points to a single .txt file.
|
|
hardcoded: Dict of prompt_key → content as last-resort defaults.
|
|
"""
|
|
stores = []
|
|
|
|
if prompt_set_id:
|
|
try:
|
|
stores.append(DBPromptStore(pipeline_id, prompt_set_id=prompt_set_id))
|
|
except ValueError:
|
|
logger.warning("CONN_STR not set, skipping DB prompt store")
|
|
|
|
if file_path and osp.exists(file_path):
|
|
stores.append(FilePromptStore(file_path, default_key))
|
|
|
|
if hardcoded:
|
|
stores.append(HardcodedPromptStore(hardcoded))
|
|
|
|
if not stores:
|
|
raise ValueError("No prompt source available")
|
|
|
|
# Chain them: first store is highest priority
|
|
result = stores[-1]
|
|
for store in reversed(stores[:-1]):
|
|
result = FallbackPromptStore(primary=store, fallback=result)
|
|
|
|
return result
|
|
|