239 lines
8.9 KiB
Python
239 lines
8.9 KiB
Python
"""Tests for Phase 3.4: Prompt template injection into LLM services.
|
|
|
|
Verifies that QueryDecomposer, RelevanceFilter, and RAGService
|
|
correctly fetch templates from PromptService and substitute placeholders.
|
|
"""
|
|
import pytest
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
from app.services.rag import RAGService
|
|
|
|
|
|
# ── helpers ──────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _make_custom_prompt_service(templates: dict[str, str]):
|
|
"""Build a mock PromptService returning *templates* for get_prompt_template."""
|
|
svc = MagicMock()
|
|
svc.get_prompt_template = MagicMock(side_effect=lambda step: templates.get(step, ""))
|
|
return svc
|
|
|
|
|
|
def _make_llm(response: str = '["sub-q"]'):
|
|
"""Build a mock LLM client that records the prompt sent."""
|
|
llm = MagicMock()
|
|
llm.complete = AsyncMock(return_value=response)
|
|
return llm
|
|
|
|
|
|
# ── QueryDecomposer tests ───────────────────────────────────────────────
|
|
|
|
|
|
async def test_decomposer_fetches_template_from_prompt_service():
|
|
"""QueryDecomposer should use the template returned by PromptService."""
|
|
custom_template = "CUSTOM DECOMPOSE: {question} -> split"
|
|
ps = _make_custom_prompt_service({"decompose": custom_template})
|
|
llm = _make_llm('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
questions, returned_prompt = await d.decompose("What is X?")
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
assert sent_prompt.startswith("CUSTOM DECOMPOSE:")
|
|
assert "What is X?" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
async def test_decomposer_uses_builtin_when_no_prompt_service():
|
|
"""Without prompt_service, the built-in seed template is used."""
|
|
llm = _make_llm('["a"]')
|
|
d = QueryDecomposer(llm, prompt_service=None)
|
|
questions, returned_prompt = await d.decompose("What is X?")
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
assert "Break it down into 2-5 simplified sub-questions" in sent_prompt
|
|
assert "What is X?" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
# ── RelevanceFilter tests ───────────────────────────────────────────────
|
|
|
|
|
|
async def test_filter_fetches_template_from_prompt_service():
|
|
"""RelevanceFilter should use the template from PromptService."""
|
|
custom_template = "FILTER: q={question} chunks={chunks}"
|
|
ps = _make_custom_prompt_service({"filter": custom_template})
|
|
llm = _make_llm("[5.0]")
|
|
|
|
rf = RelevanceFilter(llm, prompt_service=ps)
|
|
chunks = [("text A", {"filename": "a.pdf"})]
|
|
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
assert sent_prompt.startswith("FILTER:")
|
|
assert "My question" in sent_prompt
|
|
assert "text A" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
async def test_filter_uses_builtin_when_no_prompt_service():
|
|
"""Without prompt_service, the built-in filter template is used."""
|
|
llm = _make_llm("[5.0]")
|
|
rf = RelevanceFilter(llm, prompt_service=None)
|
|
chunks = [("text A", {"filename": "a.pdf"})]
|
|
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
assert "rate each 0-10 for relevance" in sent_prompt
|
|
assert "My question" in sent_prompt
|
|
|
|
|
|
# ── RAGService generate tests ───────────────────────────────────────────
|
|
|
|
|
|
async def test_generate_fetches_template_from_prompt_service():
|
|
"""RAGService.generate_response should use PromptService template."""
|
|
custom_template = "GEN: {question} --- {context} END"
|
|
ps = _make_custom_prompt_service({"generate": custom_template})
|
|
llm = _make_llm("answer")
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
|
|
answer, gen_prompt = await svc.generate_response(
|
|
"What is X?",
|
|
["chunk data"],
|
|
[{"filename": "f.txt", "content_summary": "sum"}],
|
|
)
|
|
|
|
sent_prompt = llm.complete.call_args[1]["prompt"]
|
|
assert sent_prompt.startswith("GEN:")
|
|
assert "What is X?" in sent_prompt
|
|
assert "chunk data" in sent_prompt
|
|
assert sent_prompt.endswith("END")
|
|
assert gen_prompt == sent_prompt
|
|
|
|
|
|
async def test_generate_uses_builtin_when_no_prompt_service():
|
|
"""Without prompt_service, the built-in generate template is used."""
|
|
llm = _make_llm("answer")
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None)
|
|
answer, gen_prompt = await svc.generate_response(
|
|
"What is X?",
|
|
["chunk data"],
|
|
[{"filename": "f.txt", "content_summary": "sum"}],
|
|
)
|
|
|
|
sent_prompt = llm.complete.call_args[1]["prompt"]
|
|
assert "What is X?" in sent_prompt
|
|
assert gen_prompt == sent_prompt
|
|
|
|
|
|
# ── Placeholder substitution safety tests ───────────────────────────────
|
|
|
|
|
|
async def test_placeholder_substitution_safe_with_curly_braces():
|
|
"""User text containing curly braces must not crash str.replace."""
|
|
ps = _make_custom_prompt_service({
|
|
"decompose": "Question: {question} — decompose it"
|
|
})
|
|
llm = _make_llm('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
# This question has literal braces — must not raise KeyError
|
|
result, returned_prompt = await d.decompose("What about {key: value}?")
|
|
assert isinstance(result, list)
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
assert "{key: value}" in sent_prompt
|
|
assert returned_prompt == sent_prompt
|
|
|
|
|
|
async def test_unknown_placeholder_left_untouched():
|
|
"""Placeholders not matched by str.replace stay as-is in the prompt."""
|
|
ps = _make_custom_prompt_service({
|
|
"decompose": "HELLO {fake_var} and {question}"
|
|
})
|
|
llm = _make_llm('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
questions, returned_prompt = await d.decompose("Q?")
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
assert "{fake_var}" in sent_prompt
|
|
assert "Q?" in sent_prompt
|
|
|
|
|
|
async def test_empty_template_produces_empty_prompt():
|
|
"""An empty template string results in an empty (or question-only) prompt."""
|
|
ps = _make_custom_prompt_service({"decompose": ""})
|
|
llm = _make_llm('["a"]')
|
|
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
questions, returned_prompt = await d.decompose("Doesn't matter")
|
|
|
|
sent_prompt = llm.complete.call_args[0][0]
|
|
# Empty template with .replace("{question}", ...) still has no text
|
|
assert sent_prompt == ""
|
|
|
|
|
|
# ── Edge case: no question / no chunks ──────────────────────────────────
|
|
|
|
|
|
async def test_decomposer_no_question_returns_empty():
|
|
"""Empty question returns [] without calling prompt_service."""
|
|
ps = MagicMock()
|
|
ps.get_prompt_template = MagicMock(return_value="tmpl")
|
|
|
|
llm = _make_llm('["should_not_see"]')
|
|
d = QueryDecomposer(llm, prompt_service=ps)
|
|
result, returned_prompt = await d.decompose("")
|
|
|
|
assert result == []
|
|
assert returned_prompt == ""
|
|
llm.complete.assert_not_called()
|
|
ps.get_prompt_template.assert_not_called()
|
|
|
|
|
|
async def test_filter_empty_chunks_no_template_fetch():
|
|
"""Empty chunks list returns [] without fetching a template."""
|
|
ps = MagicMock()
|
|
ps.get_prompt_template = MagicMock(return_value="tmpl")
|
|
|
|
llm = _make_llm("[5]")
|
|
rf = RelevanceFilter(llm, prompt_service=ps)
|
|
result, returned_prompt = await rf.filter("Q?", [], threshold=5.0)
|
|
|
|
assert result == []
|
|
assert returned_prompt == ""
|
|
llm.complete.assert_not_called()
|
|
ps.get_prompt_template.assert_not_called()
|
|
|
|
|
|
async def test_generate_no_chunks_returns_fallback():
|
|
"""No chunks returns fallback message without touching PromptService."""
|
|
ps = MagicMock()
|
|
ps.get_prompt_template = MagicMock(return_value="tmpl")
|
|
|
|
llm = _make_llm("answer")
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
|
|
answer, gen_prompt = await svc.generate_response("Q?", [], [])
|
|
|
|
assert "could not find" in answer.lower()
|
|
assert gen_prompt == ""
|
|
llm.complete.assert_not_called()
|
|
ps.get_prompt_template.assert_not_called()
|