diff --git a/lang_agent/config/db_config_manager.py b/lang_agent/config/db_config_manager.py index bb5ce7c..e40102c 100644 --- a/lang_agent/config/db_config_manager.py +++ b/lang_agent/config/db_config_manager.py @@ -17,26 +17,48 @@ class DBConfigManager: cur.execute("DELETE FROM prompt_sets WHERE pipeline_id = %s AND id = %s", (pipeline_id, prompt_set_id)) conn.commit() - def list_prompt_sets(self, pipeline_id: Optional[str] = None) -> List[Dict[str, object]]: + def list_prompt_sets( + self, pipeline_id: Optional[str] = None, graph_id: Optional[str] = None + ) -> List[Dict[str, object]]: """ List prompt_set metadata for UI listing. """ with psycopg.connect(self.conn_str) as conn: with conn.cursor(row_factory=dict_row) as cur: - if pipeline_id: + if pipeline_id and graph_id: cur.execute( """ - SELECT id, pipeline_id, name, description, is_active, created_at, updated_at, list + SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list + FROM prompt_sets + WHERE pipeline_id = %s AND graph_id = %s + ORDER BY updated_at DESC, created_at DESC + """, + (pipeline_id, graph_id), + ) + elif pipeline_id: + cur.execute( + """ + SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list FROM prompt_sets WHERE pipeline_id = %s ORDER BY updated_at DESC, created_at DESC """, (pipeline_id,), ) + elif graph_id: + cur.execute( + """ + SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list + FROM prompt_sets + WHERE graph_id = %s + ORDER BY updated_at DESC, created_at DESC + """, + (graph_id,), + ) else: cur.execute( """ - SELECT id, pipeline_id, name, description, is_active, created_at, updated_at, list + SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list FROM prompt_sets ORDER BY updated_at DESC, created_at DESC """ @@ -47,6 +69,7 @@ class DBConfigManager: { "prompt_set_id": str(row["id"]), "pipeline_id": row["pipeline_id"], + "graph_id": row.get("graph_id"), "name": row["name"], "description": row["description"] or "", "is_active": bool(row["is_active"]), @@ -65,7 +88,7 @@ class DBConfigManager: with conn.cursor(row_factory=dict_row) as cur: cur.execute( """ - SELECT id, pipeline_id, name, description, is_active, created_at, updated_at, list + SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list FROM prompt_sets WHERE id = %s AND pipeline_id = %s """, @@ -79,6 +102,7 @@ class DBConfigManager: return { "prompt_set_id": str(row["id"]), "pipeline_id": row["pipeline_id"], + "graph_id": row.get("graph_id"), "name": row["name"], "description": row["description"] or "", "is_active": bool(row["is_active"]), @@ -108,6 +132,7 @@ class DBConfigManager: resolved_set_id, tool_csv = self._resolve_prompt_set( conn, pipeline_id=pipeline_id, + graph_id=None, prompt_set_id=prompt_set_id, create_if_missing=False, ) @@ -131,6 +156,7 @@ class DBConfigManager: def set_config( self, pipeline_id: str, + graph_id: Optional[str], prompt_set_id: Optional[str], tool_list: Optional[Sequence[str]], prompt_dict: Optional[Mapping[str, str]], @@ -150,6 +176,9 @@ class DBConfigManager: """ if not pipeline_id: raise ValueError("pipeline_id is required") + normalized_graph_id = self._normalize_graph_id(graph_id) + if prompt_set_id is None and not normalized_graph_id: + raise ValueError("graph_id is required when creating a new prompt set") normalized_prompt_dict = self._normalize_prompt_dict(prompt_dict) tool_csv = self._join_tool_list(tool_list) @@ -158,6 +187,7 @@ class DBConfigManager: resolved_set_id, _ = self._resolve_prompt_set( conn, pipeline_id=pipeline_id, + graph_id=normalized_graph_id, prompt_set_id=prompt_set_id, create_if_missing=prompt_set_id is None, ) @@ -170,10 +200,10 @@ class DBConfigManager: cur.execute( """ UPDATE prompt_sets - SET list = %s, updated_at = now() + SET list = %s, graph_id = COALESCE(%s, graph_id), updated_at = now() WHERE id = %s """, - (tool_csv, resolved_set_id), + (tool_csv, normalized_graph_id, resolved_set_id), ) keys = list(normalized_prompt_dict.keys()) @@ -218,6 +248,7 @@ class DBConfigManager: self, conn: psycopg.Connection, pipeline_id: str, + graph_id: Optional[str], prompt_set_id: Optional[str], create_if_missing: bool, ) -> Tuple[Optional[str], str]: @@ -255,12 +286,13 @@ class DBConfigManager: cur.execute( """ - INSERT INTO prompt_sets (pipeline_id, name, description, is_active, list) - VALUES (%s, %s, %s, %s, %s) + INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list) + VALUES (%s, %s, %s, %s, %s, %s) RETURNING id, list """, ( pipeline_id, + graph_id, "default", "Auto-created by DBConfigManager", True, @@ -302,4 +334,10 @@ class DBConfigManager: if not norm_key: continue out[norm_key] = value if isinstance(value, str) else str(value) - return out \ No newline at end of file + return out + + def _normalize_graph_id(self, graph_id: Optional[str]) -> Optional[str]: + if graph_id is None: + return None + value = str(graph_id).strip() + return value or None \ No newline at end of file