legco_ai_assistant/backend/app/services/prompt_service.py

177 lines
6.8 KiB
Python

"""Prompt profile management service.
Reads and writes prompt templates in the prompts SQLite database.
Uses sync sqlite3 — all operations are instant local reads/writes.
"""
import logging
import sqlite3
from app.core.sqlite_db import _SEED_TEMPLATES
logger = logging.getLogger(__name__)
_VALID_NAMES = {"A", "B", "C"}
_VALID_STEPS = {
"decompose",
"filter",
"generate",
"generate_per_subq",
"filter_intro",
"filter_section",
"filter_outro",
}
def _connect(db_path: str) -> sqlite3.Connection:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
return conn
class PromptService:
"""CRUD operations for prompt profiles and templates.
Each method opens its own connection so the service is safe to
instantiate once per request without holding open file handles.
"""
def __init__(self, db_path: str) -> None:
self._db_path = db_path
# ── helpers ────────────────────────────────────────────────────────────
def _validate_name(self, name: str) -> None:
if name not in _VALID_NAMES:
raise ValueError(f"Invalid profile name '{name}'. Must be one of A, B, C.")
def _validate_step(self, step: str) -> None:
if step not in _VALID_STEPS:
raise ValueError(f"Invalid step '{step}'. Must be one of decompose, filter, generate.")
# ── read operations ────────────────────────────────────────────────────
def get_active_profile_name(self) -> str:
"""Return the name of the currently active profile."""
with _connect(self._db_path) as conn:
row = conn.execute(
"SELECT name FROM system_prompt_profiles WHERE is_active=1"
).fetchone()
if row is None:
raise RuntimeError("No active prompt profile found.")
return row["name"]
def get_prompt_template(self, step: str) -> str:
"""Return the prompt template for *step* of the active profile."""
self._validate_step(step)
with _connect(self._db_path) as conn:
row = conn.execute(
"""
SELECT sp.prompt_template
FROM system_prompts sp
JOIN system_prompt_profiles spp ON sp.profile_id = spp.id
WHERE spp.is_active=1 AND sp.step_name=?
""",
(step,),
).fetchone()
if row is None:
raise RuntimeError(f"No template found for step '{step}'.")
return row["prompt_template"]
def list_profiles(self) -> list[dict]:
"""Return all profiles with their active status."""
with _connect(self._db_path) as conn:
rows = conn.execute(
"SELECT name, is_active FROM system_prompt_profiles ORDER BY name"
).fetchall()
return [{"name": r["name"], "is_active": bool(r["is_active"])} for r in rows]
def get_profile_prompts(self, name: str) -> dict:
"""Return all three prompt templates for the given profile."""
self._validate_name(name)
with _connect(self._db_path) as conn:
rows = conn.execute(
"""
SELECT sp.step_name, sp.prompt_template
FROM system_prompts sp
JOIN system_prompt_profiles spp ON sp.profile_id = spp.id
WHERE spp.name=?
ORDER BY sp.step_name
""",
(name,),
).fetchall()
return {r["step_name"]: r["prompt_template"] for r in rows}
# ── write operations ───────────────────────────────────────────────────
def activate_profile(self, name: str) -> None:
"""Set *name* as the active profile (deactivates all others)."""
self._validate_name(name)
with _connect(self._db_path) as conn:
conn.execute("UPDATE system_prompt_profiles SET is_active=0")
conn.execute(
"UPDATE system_prompt_profiles SET is_active=1 WHERE name=?",
(name,),
)
conn.commit()
logger.info("Activated prompt profile '%s'.", name)
def update_prompt(self, name: str, step: str, template: str) -> None:
"""Update a single prompt template for the given profile."""
self._validate_name(name)
self._validate_step(step)
with _connect(self._db_path) as conn:
conn.execute(
"""
UPDATE system_prompts
SET prompt_template=?, updated_at=datetime('now')
WHERE profile_id=(SELECT id FROM system_prompt_profiles WHERE name=?)
AND step_name=?
""",
(template, name, step),
)
conn.commit()
logger.info("Updated prompt: profile='%s' step='%s'.", name, step)
def update_all_prompts(self, name: str, prompts: dict[str, str]) -> None:
"""Batch-update all three prompt templates for the given profile."""
self._validate_name(name)
for step in prompts:
self._validate_step(step)
with _connect(self._db_path) as conn:
for step, template in prompts.items():
conn.execute(
"""
UPDATE system_prompts
SET prompt_template=?, updated_at=datetime('now')
WHERE profile_id=(SELECT id FROM system_prompt_profiles WHERE name=?)
AND step_name=?
""",
(template, name, step),
)
conn.commit()
logger.info("Batch-updated all prompts for profile '%s'.", name)
def reset_to_defaults(self, name: str, step: str | None = None) -> None:
"""Reset prompt template(s) to the built-in seed defaults.
If *step* is ``None``, all three templates are reset.
"""
self._validate_name(name)
steps = _VALID_STEPS if step is None else {step}
for s in steps:
self._validate_step(s)
with _connect(self._db_path) as conn:
for s in steps:
conn.execute(
"""
UPDATE system_prompts
SET prompt_template=?, updated_at=datetime('now')
WHERE profile_id=(SELECT id FROM system_prompt_profiles WHERE name=?)
AND step_name=?
""",
(_SEED_TEMPLATES[s], name, s),
)
conn.commit()
logger.info("Reset prompts for profile '%s': steps=%s.", name, steps)