Files
lang-agent/lang_agent/config/db_config_manager.py
2026-02-11 16:22:37 +08:00

235 lines
7.9 KiB
Python

import os
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
import psycopg
from psycopg.rows import dict_row
class DBConfigManager:
def __init__(self):
self.conn_str = os.environ.get("CONN_STR")
if self.conn_str is None:
raise ValueError("CONN_STR is not set")
def remove_config(self, pipeline_id: str, prompt_set_id:str):
with psycopg.connect(self.conn_str) as conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM prompt_sets WHERE pipeline_id = %s AND id = %s", (pipeline_id, prompt_set_id))
conn.commit()
def get_config(
self, pipeline_id: str, prompt_set_id: Optional[str] = None
) -> Tuple[Dict[str, str], List[str]]:
"""
Read prompt + tool configuration from DB.
Returns:
({prompt_key: content}, [tool_key, ...])
Resolution order:
- If prompt_set_id is provided, read that set.
- Otherwise, read the active set for pipeline_id.
- If no matching set exists, return ({}, []).
"""
if not pipeline_id:
raise ValueError("pipeline_id is required")
with psycopg.connect(self.conn_str) as conn:
resolved_set_id, tool_csv = self._resolve_prompt_set(
conn,
pipeline_id=pipeline_id,
prompt_set_id=prompt_set_id,
create_if_missing=False,
)
if resolved_set_id is None:
return {}, []
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(
"""
SELECT prompt_key, content
FROM prompt_templates
WHERE prompt_set_id = %s
""",
(resolved_set_id,),
)
rows = cur.fetchall()
prompt_dict: Dict[str, str] = {row["prompt_key"]: row["content"] for row in rows}
return prompt_dict, self._parse_tool_list(tool_csv)
def set_config(
self,
pipeline_id: str,
prompt_set_id: Optional[str],
tool_list: Optional[Sequence[str]],
prompt_dict: Optional[Mapping[str, str]],
) -> str:
"""
Persist prompt + tool configuration.
Behavior:
- If prompt_set_id is provided, update that set (must belong to pipeline_id).
- If prompt_set_id is None, update the active set for pipeline_id;
create one if missing.
- prompt_templates for the set are synchronized to prompt_dict
(keys not present in prompt_dict are removed).
Returns:
The target prompt_set_id used for the write.
"""
if not pipeline_id:
raise ValueError("pipeline_id is required")
normalized_prompt_dict = self._normalize_prompt_dict(prompt_dict)
tool_csv = self._join_tool_list(tool_list)
with psycopg.connect(self.conn_str) as conn:
resolved_set_id, _ = self._resolve_prompt_set(
conn,
pipeline_id=pipeline_id,
prompt_set_id=prompt_set_id,
create_if_missing=prompt_set_id is None,
)
if resolved_set_id is None:
raise ValueError(
f"prompt_set_id '{prompt_set_id}' not found for pipeline '{pipeline_id}'"
)
with conn.cursor() as cur:
cur.execute(
"""
UPDATE prompt_sets
SET list = %s, updated_at = now()
WHERE id = %s
""",
(tool_csv, resolved_set_id),
)
keys = list(normalized_prompt_dict.keys())
if keys:
cur.execute(
"""
DELETE FROM prompt_templates
WHERE prompt_set_id = %s
AND NOT (prompt_key = ANY(%s))
""",
(resolved_set_id, keys),
)
else:
cur.execute(
"""
DELETE FROM prompt_templates
WHERE prompt_set_id = %s
""",
(resolved_set_id,),
)
if normalized_prompt_dict:
cur.executemany(
"""
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
VALUES (%s, %s, %s)
ON CONFLICT (prompt_set_id, prompt_key)
DO UPDATE SET
content = EXCLUDED.content,
updated_at = now()
""",
[
(resolved_set_id, prompt_key, content)
for prompt_key, content in normalized_prompt_dict.items()
],
)
conn.commit()
return str(resolved_set_id)
def _resolve_prompt_set(
self,
conn: psycopg.Connection,
pipeline_id: str,
prompt_set_id: Optional[str],
create_if_missing: bool,
) -> Tuple[Optional[str], str]:
"""
Resolve target prompt_set and return (id, list_csv).
"""
with conn.cursor(row_factory=dict_row) as cur:
if prompt_set_id:
cur.execute(
"""
SELECT id, list
FROM prompt_sets
WHERE id = %s AND pipeline_id = %s
""",
(prompt_set_id, pipeline_id),
)
else:
cur.execute(
"""
SELECT id, list
FROM prompt_sets
WHERE pipeline_id = %s AND is_active = true
ORDER BY updated_at DESC, created_at DESC
LIMIT 1
""",
(pipeline_id,),
)
row = cur.fetchone()
if row is not None:
return str(row["id"]), row.get("list") or ""
if not create_if_missing:
return None, ""
cur.execute(
"""
INSERT INTO prompt_sets (pipeline_id, name, description, is_active, list)
VALUES (%s, %s, %s, %s, %s)
RETURNING id, list
""",
(
pipeline_id,
"default",
"Auto-created by DBConfigManager",
True,
"",
),
)
created = cur.fetchone()
return str(created["id"]), created.get("list") or ""
def _join_tool_list(self, tool_list: Optional[Sequence[str]]) -> str:
if not tool_list:
return ""
cleaned: List[str] = []
seen = set()
for tool in tool_list:
if tool is None:
continue
key = str(tool).strip()
if not key or key in seen:
continue
seen.add(key)
cleaned.append(key)
return ",".join(cleaned)
def _parse_tool_list(self, tool_csv: Optional[str]) -> List[str]:
if not tool_csv:
return []
return [k.strip() for k in tool_csv.split(",") if k.strip()]
def _normalize_prompt_dict(
self, prompt_dict: Optional[Mapping[str, str]]
) -> Dict[str, str]:
if not prompt_dict:
return {}
out: Dict[str, str] = {}
for key, value in prompt_dict.items():
norm_key = str(key).strip()
if not norm_key:
continue
out[norm_key] = value if isinstance(value, str) else str(value)
return out