338 lines
12 KiB
Python
338 lines
12 KiB
Python
"""Tests for Phase 3.4: Prompt template injection into LLM services.
|
|
|
|
Verifies that QueryDecomposer, RelevanceFilter, and RAGService
|
|
correctly fetch templates from PromptService and substitute placeholders.
|
|
|
|
Uses real PromptService (SQLite via tmp_path), real ChromaDB (tmp_path),
|
|
and only mocks the external LLM API.
|
|
"""
|
|
import sqlite3
|
|
|
|
import chromadb
|
|
import pytest
|
|
|
|
from app.core.sqlite_db import init_prompts_db, seed_default_profiles
|
|
from app.services.prompt_service import PromptService
|
|
|
|
|
|
# ── helpers ──────────────────────────────────────────────────────────────
|
|
|
|
|
|
class _MockLLM:
|
|
"""Mock external LLM API — only external dependency we're allowed to mock."""
|
|
|
|
def __init__(self, response: str = '["sub-q"]', side_effect: Exception | None = None):
|
|
self._response = response
|
|
self._side_effect = side_effect
|
|
self.last_prompt: str | None = None
|
|
self.calls: list[dict] = []
|
|
self._call_count: int = 0
|
|
|
|
async def complete(
|
|
self, prompt: str, temperature: float = 0.7, step_name: str = "LLM"
|
|
) -> str:
|
|
self.calls.append({"prompt": prompt, "step": step_name})
|
|
self.last_prompt = prompt
|
|
self._call_count += 1
|
|
if self._side_effect:
|
|
raise self._side_effect
|
|
return self._response
|
|
|
|
@property
|
|
def call_count(self) -> int:
|
|
return self._call_count
|
|
|
|
def assert_called(self):
|
|
assert self._call_count > 0, "LLM.complete was not called"
|
|
|
|
def assert_not_called(self):
|
|
assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)"
|
|
|
|
|
|
def _create_prompt_service(
|
|
tmp_path, custom_templates: dict[str, str] | None = None
|
|
) -> PromptService:
|
|
"""Create a real PromptService backed by real SQLite in tmp_path.
|
|
|
|
Seeds default A/B/C profiles, then optionally updates the active profile
|
|
(A) with *custom_templates* so the service returns controlled templates.
|
|
"""
|
|
db_path = str(tmp_path / "prompts.db")
|
|
conn = sqlite3.connect(db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA foreign_keys=ON")
|
|
init_prompts_db(conn)
|
|
seed_default_profiles(conn)
|
|
conn.close()
|
|
|
|
svc = PromptService(db_path=db_path)
|
|
if custom_templates:
|
|
for step, template in custom_templates.items():
|
|
svc.update_prompt("A", step, template)
|
|
return svc
|
|
|
|
|
|
def _setup_chroma(tmp_path):
|
|
"""Create an isolated real ChromaDB PersistentClient for a test."""
|
|
chroma_dir = tmp_path / "chroma"
|
|
chroma_dir.mkdir(parents=True, exist_ok=True)
|
|
return chromadb.PersistentClient(path=str(chroma_dir))
|
|
|
|
|
|
# ── QueryDecomposer tests ───────────────────────────────────────────────
|
|
|
|
|
|
async def test_decomposer_fetches_template_from_prompt_service(tmp_path):
|
|
"""QueryDecomposer should use the template returned by PromptService."""
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
|
|
custom_template = "CUSTOM DECOMPOSE: {question} -> split"
|
|
ps = _create_prompt_service(tmp_path, {"decompose": custom_template})
|
|
llm = _MockLLM('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
questions, returned_prompt = await d.decompose("What is X?")
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert sent_prompt.startswith("CUSTOM DECOMPOSE:")
|
|
assert "What is X?" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
async def test_decomposer_uses_builtin_when_no_prompt_service():
|
|
"""Without prompt_service, the built-in seed template is used."""
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
|
|
llm = _MockLLM('["a"]')
|
|
d = QueryDecomposer(llm, prompt_service=None)
|
|
questions, returned_prompt = await d.decompose("What is X?")
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert "Break it down into 2-5 simplified sub-questions" in sent_prompt
|
|
assert "What is X?" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
# ── RelevanceFilter tests ───────────────────────────────────────────────
|
|
|
|
|
|
async def test_filter_fetches_template_from_prompt_service(tmp_path):
|
|
"""RelevanceFilter should use the template from PromptService."""
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
|
|
custom_template = "FILTER: q={question} chunks={chunks}"
|
|
ps = _create_prompt_service(tmp_path, {"filter": custom_template})
|
|
llm = _MockLLM("[5.0]")
|
|
|
|
rf = RelevanceFilter(llm, prompt_service=ps)
|
|
chunks = [("text A", {"filename": "a.pdf"})]
|
|
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert sent_prompt.startswith("FILTER:")
|
|
assert "My question" in sent_prompt
|
|
assert "text A" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
async def test_filter_uses_builtin_when_no_prompt_service():
|
|
"""Without prompt_service, the built-in filter template is used."""
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
|
|
llm = _MockLLM("[5.0]")
|
|
rf = RelevanceFilter(llm, prompt_service=None)
|
|
chunks = [("text A", {"filename": "a.pdf"})]
|
|
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert "rate each 0-10 for relevance" in sent_prompt
|
|
assert "My question" in sent_prompt
|
|
|
|
|
|
# ── RAGService generate tests ───────────────────────────────────────────
|
|
|
|
|
|
async def test_generate_fetches_template_from_prompt_service(tmp_path):
|
|
"""RAGService.generate_response should use PromptService template."""
|
|
from app.services.rag import RAGService
|
|
|
|
custom_template = "GEN: {question} --- {context} END"
|
|
ps = _create_prompt_service(tmp_path, {"generate": custom_template})
|
|
llm = _MockLLM("answer")
|
|
client = _setup_chroma(tmp_path)
|
|
|
|
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps)
|
|
answer, gen_prompt = await svc.generate_response(
|
|
"What is X?",
|
|
["chunk data"],
|
|
[{"filename": "f.txt", "content_summary": "sum"}],
|
|
)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert sent_prompt.startswith("GEN:")
|
|
assert "What is X?" in sent_prompt
|
|
assert "chunk data" in sent_prompt
|
|
assert sent_prompt.endswith("END")
|
|
assert gen_prompt == sent_prompt
|
|
|
|
|
|
async def test_generate_uses_builtin_when_no_prompt_service(tmp_path):
|
|
"""Without prompt_service, the built-in generate template is used."""
|
|
from app.services.rag import RAGService
|
|
|
|
llm = _MockLLM("answer")
|
|
client = _setup_chroma(tmp_path)
|
|
|
|
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=None)
|
|
answer, gen_prompt = await svc.generate_response(
|
|
"What is X?",
|
|
["chunk data"],
|
|
[{"filename": "f.txt", "content_summary": "sum"}],
|
|
)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert "What is X?" in sent_prompt
|
|
assert gen_prompt == sent_prompt
|
|
|
|
|
|
# ── Placeholder substitution safety tests ───────────────────────────────
|
|
|
|
|
|
async def test_placeholder_substitution_safe_with_curly_braces(tmp_path):
|
|
"""User text containing curly braces must not crash str.replace."""
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
|
|
ps = _create_prompt_service(tmp_path)
|
|
llm = _MockLLM('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
result, returned_prompt = await d.decompose("What about {key: value}?")
|
|
assert isinstance(result, list)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert "{key: value}" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
async def test_unknown_placeholder_left_untouched(tmp_path):
|
|
"""Placeholders not matched by str.replace stay as-is in the prompt."""
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
|
|
ps = _create_prompt_service(
|
|
tmp_path, {"decompose": "HELLO {fake_var} and {question}"}
|
|
)
|
|
llm = _MockLLM('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
questions, returned_prompt = await d.decompose("Q?")
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert "{fake_var}" in sent_prompt
|
|
assert "Q?" in sent_prompt
|
|
|
|
|
|
async def test_empty_template_produces_empty_prompt(tmp_path):
|
|
"""An empty template string results in an empty prompt."""
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
|
|
ps = _create_prompt_service(tmp_path, {"decompose": ""})
|
|
llm = _MockLLM('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
questions, returned_prompt = await d.decompose("Doesn't matter")
|
|
|
|
sent_prompt = llm.last_prompt
|
|
# Empty template with .replace("{question}", ...) still has no text
|
|
assert sent_prompt == ""
|
|
|
|
|
|
# ── Edge case: no question / no chunks ──────────────────────────────────
|
|
|
|
|
|
async def test_decomposer_no_question_returns_empty(tmp_path):
|
|
"""Empty question returns [] without calling LLM."""
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
|
|
ps = _create_prompt_service(tmp_path)
|
|
llm = _MockLLM('["should_not_see"]')
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
result, returned_prompt = await d.decompose("")
|
|
|
|
assert result == []
|
|
assert returned_prompt == ""
|
|
llm.assert_not_called()
|
|
|
|
|
|
async def test_filter_empty_chunks_no_template_fetch(tmp_path):
|
|
"""Empty chunks list returns [] without calling LLM."""
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
|
|
ps = _create_prompt_service(tmp_path)
|
|
llm = _MockLLM("[5]")
|
|
rf = RelevanceFilter(llm, prompt_service=ps)
|
|
result, returned_prompt = await rf.filter("Q?", [], threshold=5.0)
|
|
|
|
assert result == []
|
|
assert returned_prompt == ""
|
|
llm.assert_not_called()
|
|
|
|
|
|
async def test_generate_no_chunks_returns_fallback(tmp_path):
|
|
"""No chunks returns fallback message without calling LLM."""
|
|
from app.services.rag import RAGService
|
|
|
|
ps = _create_prompt_service(tmp_path)
|
|
llm = _MockLLM("answer")
|
|
client = _setup_chroma(tmp_path)
|
|
|
|
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps)
|
|
answer, gen_prompt = await svc.generate_response("Q?", [], [])
|
|
|
|
assert "could not find" in answer.lower()
|
|
assert gen_prompt == ""
|
|
llm.assert_not_called()
|
|
|
|
|
|
async def test_generate_per_subq_fetches_template_from_prompt_service(tmp_path):
|
|
"""RAGService.generate_response_per_subquestion should use PromptService template."""
|
|
from app.services.rag import RAGService
|
|
|
|
custom_template = "PER_SUBQ: {context_sections} DONE"
|
|
ps = _create_prompt_service(tmp_path, {"generate_per_subq": custom_template})
|
|
llm = _MockLLM("answer")
|
|
client = _setup_chroma(tmp_path)
|
|
|
|
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps)
|
|
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
|
|
["What is X?"],
|
|
[["chunk data"]],
|
|
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
|
)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert sent_prompt.startswith("PER_SUBQ:")
|
|
assert "chunk data" in sent_prompt
|
|
assert sent_prompt.endswith("DONE")
|
|
assert gen_prompt == sent_prompt
|
|
assert len(grouped_sources) == 1
|
|
|
|
|
|
async def test_generate_per_subq_uses_builtin_when_no_prompt_service(tmp_path):
|
|
"""Without prompt_service, the built-in per-subq template is used."""
|
|
from app.services.rag import RAGService
|
|
|
|
llm = _MockLLM("answer")
|
|
client = _setup_chroma(tmp_path)
|
|
|
|
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=None)
|
|
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
|
|
["What is X?"],
|
|
[["chunk data"]],
|
|
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
|
)
|
|
|
|
sent_prompt = llm.last_prompt
|
|
assert "Sub-question" in sent_prompt
|
|
assert "chunk data" in sent_prompt
|
|
assert "{context_sections}" not in sent_prompt
|