legco_ai_assistant/backend/app/test/test_phase3_prompt_injectio...

285 lines
11 KiB
Python

"""Tests for Phase 3.4: Prompt template injection into LLM services.
Verifies that QueryDecomposer, RelevanceFilter, and RAGService
correctly fetch templates from PromptService and substitute placeholders.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock
from app.services.query_decomposer import QueryDecomposer
from app.services.relevance_filter import RelevanceFilter
from app.services.rag import RAGService
# ── helpers ──────────────────────────────────────────────────────────────
def _make_custom_prompt_service(templates: dict[str, str]):
"""Build a mock PromptService returning *templates* for get_prompt_template."""
svc = MagicMock()
svc.get_prompt_template = MagicMock(side_effect=lambda step: templates.get(step, ""))
return svc
def _make_llm(response: str = '["sub-q"]'):
"""Build a mock LLM client that records the prompt sent."""
llm = MagicMock()
llm.complete = AsyncMock(return_value=response)
return llm
# ── QueryDecomposer tests ───────────────────────────────────────────────
async def test_decomposer_fetches_template_from_prompt_service():
"""QueryDecomposer should use the template returned by PromptService."""
custom_template = "CUSTOM DECOMPOSE: {question} -> split"
ps = _make_custom_prompt_service({"decompose": custom_template})
llm = _make_llm('["a"]')
d = QueryDecomposer(llm, prompt_service=ps)
questions, returned_prompt = await d.decompose("What is X?")
sent_prompt = llm.complete.call_args[0][0]
assert sent_prompt.startswith("CUSTOM DECOMPOSE:")
assert "What is X?" in sent_prompt
assert returned_prompt == sent_prompt
async def test_decomposer_uses_builtin_when_no_prompt_service():
"""Without prompt_service, the built-in seed template is used."""
llm = _make_llm('["a"]')
d = QueryDecomposer(llm, prompt_service=None)
questions, returned_prompt = await d.decompose("What is X?")
sent_prompt = llm.complete.call_args[0][0]
assert "Break it down into 2-5 simplified sub-questions" in sent_prompt
assert "What is X?" in sent_prompt
assert returned_prompt == sent_prompt
# ── RelevanceFilter tests ───────────────────────────────────────────────
async def test_filter_fetches_template_from_prompt_service():
"""RelevanceFilter should use the template from PromptService."""
custom_template = "FILTER: q={question} chunks={chunks}"
ps = _make_custom_prompt_service({"filter": custom_template})
llm = _make_llm("[5.0]")
rf = RelevanceFilter(llm, prompt_service=ps)
chunks = [("text A", {"filename": "a.pdf"})]
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
sent_prompt = llm.complete.call_args[0][0]
assert sent_prompt.startswith("FILTER:")
assert "My question" in sent_prompt
assert "text A" in sent_prompt
assert returned_prompt == sent_prompt
async def test_filter_uses_builtin_when_no_prompt_service():
"""Without prompt_service, the built-in filter template is used."""
llm = _make_llm("[5.0]")
rf = RelevanceFilter(llm, prompt_service=None)
chunks = [("text A", {"filename": "a.pdf"})]
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
sent_prompt = llm.complete.call_args[0][0]
assert "rate each 0-10 for relevance" in sent_prompt
assert "My question" in sent_prompt
# ── RAGService generate tests ───────────────────────────────────────────
async def test_generate_fetches_template_from_prompt_service():
"""RAGService.generate_response should use PromptService template."""
custom_template = "GEN: {question} --- {context} END"
ps = _make_custom_prompt_service({"generate": custom_template})
llm = _make_llm("answer")
mock_collection = MagicMock()
mock_client = MagicMock()
mock_client.get_or_create_collection.return_value = mock_collection
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
answer, gen_prompt = await svc.generate_response(
"What is X?",
["chunk data"],
[{"filename": "f.txt", "content_summary": "sum"}],
)
sent_prompt = llm.complete.call_args[1]["prompt"]
assert sent_prompt.startswith("GEN:")
assert "What is X?" in sent_prompt
assert "chunk data" in sent_prompt
assert sent_prompt.endswith("END")
assert gen_prompt == sent_prompt
async def test_generate_uses_builtin_when_no_prompt_service():
"""Without prompt_service, the built-in generate template is used."""
llm = _make_llm("answer")
mock_collection = MagicMock()
mock_client = MagicMock()
mock_client.get_or_create_collection.return_value = mock_collection
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None)
answer, gen_prompt = await svc.generate_response(
"What is X?",
["chunk data"],
[{"filename": "f.txt", "content_summary": "sum"}],
)
sent_prompt = llm.complete.call_args[1]["prompt"]
assert "What is X?" in sent_prompt
assert gen_prompt == sent_prompt
# ── Placeholder substitution safety tests ───────────────────────────────
async def test_placeholder_substitution_safe_with_curly_braces():
"""User text containing curly braces must not crash str.replace."""
ps = _make_custom_prompt_service({
"decompose": "Question: {question} — decompose it"
})
llm = _make_llm('["a"]')
d = QueryDecomposer(llm, prompt_service=ps)
# This question has literal braces — must not raise KeyError
result, returned_prompt = await d.decompose("What about {key: value}?")
assert isinstance(result, list)
sent_prompt = llm.complete.call_args[0][0]
assert "{key: value}" in sent_prompt
assert returned_prompt == sent_prompt
async def test_unknown_placeholder_left_untouched():
"""Placeholders not matched by str.replace stay as-is in the prompt."""
ps = _make_custom_prompt_service({
"decompose": "HELLO {fake_var} and {question}"
})
llm = _make_llm('["a"]')
d = QueryDecomposer(llm, prompt_service=ps)
questions, returned_prompt = await d.decompose("Q?")
sent_prompt = llm.complete.call_args[0][0]
assert "{fake_var}" in sent_prompt
assert "Q?" in sent_prompt
async def test_empty_template_produces_empty_prompt():
"""An empty template string results in an empty (or question-only) prompt."""
ps = _make_custom_prompt_service({"decompose": ""})
llm = _make_llm('["a"]')
d = QueryDecomposer(llm, prompt_service=ps)
questions, returned_prompt = await d.decompose("Doesn't matter")
sent_prompt = llm.complete.call_args[0][0]
# Empty template with .replace("{question}", ...) still has no text
assert sent_prompt == ""
# ── Edge case: no question / no chunks ──────────────────────────────────
async def test_decomposer_no_question_returns_empty():
"""Empty question returns [] without calling prompt_service."""
ps = MagicMock()
ps.get_prompt_template = MagicMock(return_value="tmpl")
llm = _make_llm('["should_not_see"]')
d = QueryDecomposer(llm, prompt_service=ps)
result, returned_prompt = await d.decompose("")
assert result == []
assert returned_prompt == ""
llm.complete.assert_not_called()
ps.get_prompt_template.assert_not_called()
async def test_filter_empty_chunks_no_template_fetch():
"""Empty chunks list returns [] without fetching a template."""
ps = MagicMock()
ps.get_prompt_template = MagicMock(return_value="tmpl")
llm = _make_llm("[5]")
rf = RelevanceFilter(llm, prompt_service=ps)
result, returned_prompt = await rf.filter("Q?", [], threshold=5.0)
assert result == []
assert returned_prompt == ""
llm.complete.assert_not_called()
ps.get_prompt_template.assert_not_called()
async def test_generate_no_chunks_returns_fallback():
"""No chunks returns fallback message without touching PromptService."""
ps = MagicMock()
ps.get_prompt_template = MagicMock(return_value="tmpl")
llm = _make_llm("answer")
mock_collection = MagicMock()
mock_client = MagicMock()
mock_client.get_or_create_collection.return_value = mock_collection
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
answer, gen_prompt = await svc.generate_response("Q?", [], [])
assert "could not find" in answer.lower()
assert gen_prompt == ""
llm.complete.assert_not_called()
ps.get_prompt_template.assert_not_called()
async def test_generate_per_subq_fetches_template_from_prompt_service():
"""RAGService.generate_response_per_subquestion should use PromptService template."""
custom_template = "PER_SUBQ: {context_sections} DONE"
ps = _make_custom_prompt_service({"generate_per_subq": custom_template})
llm = _make_llm("answer")
mock_collection = MagicMock()
mock_client = MagicMock()
mock_client.get_or_create_collection.return_value = mock_collection
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
["What is X?"],
[["chunk data"]],
[[{"filename": "f.txt", "content_summary": "sum"}]],
)
sent_prompt = llm.complete.call_args[1]["prompt"]
assert sent_prompt.startswith("PER_SUBQ:")
assert "chunk data" in sent_prompt
assert sent_prompt.endswith("DONE")
assert gen_prompt == sent_prompt
assert len(grouped_sources) == 1
async def test_generate_per_subq_uses_builtin_when_no_prompt_service():
"""Without prompt_service, the built-in per-subq template is used."""
llm = _make_llm("answer")
mock_collection = MagicMock()
mock_client = MagicMock()
mock_client.get_or_create_collection.return_value = mock_collection
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None)
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
["What is X?"],
[["chunk data"]],
[[{"filename": "f.txt", "content_summary": "sum"}]],
)
sent_prompt = llm.complete.call_args[1]["prompt"]
assert "Sub-question" in sent_prompt
assert "chunk data" in sent_prompt
assert "{context_sections}" not in sent_prompt