test(backend): extend existing tests for per-sub-q methods and templates
Add 6 tests for retrieve_per_subquestion and generate_response_per_subquestion to Phase 1 rag service tests. Add 4 tests for filter_per_subquestion to Phase 1 relevance filter tests. Add 2 tests for new {context_sections} generate template to Phase 3 prompt injection tests. Add TestPerSubQPipelineHistory class with 3 per-sub-q pipeline simulation tests to Phase 3 integration tests. Add generate_per_subq template seed to conftest mock_prompt_service fixture.
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)
Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
201bddecf0
commit
3f50f81bfe
|
|
@ -60,6 +60,12 @@ def mock_prompt_service():
|
||||||
"Document chunks:\n{context}\n\n"
|
"Document chunks:\n{context}\n\n"
|
||||||
"Answer:"
|
"Answer:"
|
||||||
),
|
),
|
||||||
|
"generate_per_subq": (
|
||||||
|
"Answer each sub-question using ONLY its document chunks.\n"
|
||||||
|
"Format as markdown sections with ## Sub-question N: headers.\n"
|
||||||
|
"{context_sections}\n\n"
|
||||||
|
"Answer:"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
class _MockPromptService:
|
class _MockPromptService:
|
||||||
|
|
|
||||||
|
|
@ -137,3 +137,109 @@ class TestRAGService:
|
||||||
|
|
||||||
assert "no relevant" in answer.lower() or "could not find" in answer.lower()
|
assert "no relevant" in answer.lower() or "could not find" in answer.lower()
|
||||||
assert gen_prompt == ""
|
assert gen_prompt == ""
|
||||||
|
|
||||||
|
def test_retrieve_per_subquestion_returns_per_query(self):
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
mock_collection.query.side_effect = [
|
||||||
|
{
|
||||||
|
"documents": [["chunk A1", "chunk A2"]],
|
||||||
|
"metadatas": [[{"filename": "a.pdf"}, {"filename": "a.pdf"}]],
|
||||||
|
"distances": [[0.1, 0.2]],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"documents": [["chunk B1"]],
|
||||||
|
"metadatas": [[{"filename": "b.pdf"}]],
|
||||||
|
"distances": [[0.3]],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
service = RAGService(chroma_client=mock_client)
|
||||||
|
results = service.retrieve_per_subquestion(["query A", "query B"], n_results=5)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0][0] == "query A"
|
||||||
|
assert len(results[0][1]) == 2
|
||||||
|
assert results[1][0] == "query B"
|
||||||
|
assert len(results[1][1]) == 1
|
||||||
|
assert mock_collection.query.call_count == 2
|
||||||
|
|
||||||
|
def test_retrieve_per_subquestion_empty_list(self):
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
service = RAGService(chroma_client=mock_client)
|
||||||
|
results = service.retrieve_per_subquestion([], n_results=5)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
mock_collection.query.assert_not_called()
|
||||||
|
|
||||||
|
async def test_generate_response_per_subquestion_calls_llm(self, mock_prompt_service):
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.complete = AsyncMock(return_value="## Sub-question 1: Q?\n- Answer")
|
||||||
|
|
||||||
|
service = RAGService(
|
||||||
|
chroma_client=mock_client,
|
||||||
|
llm_client=mock_llm,
|
||||||
|
prompt_service=mock_prompt_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||||
|
["What is X?"],
|
||||||
|
[["chunk data"]],
|
||||||
|
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_llm.complete.assert_called_once()
|
||||||
|
sent_prompt = mock_llm.complete.call_args[1]["prompt"]
|
||||||
|
assert "chunk data" in sent_prompt
|
||||||
|
assert "Sub-question 0" in sent_prompt
|
||||||
|
assert answer == "## Sub-question 1: Q?\n- Answer"
|
||||||
|
assert len(grouped_sources) == 1
|
||||||
|
assert grouped_sources[0][0]["filename"] == "f.txt"
|
||||||
|
|
||||||
|
async def test_generate_response_per_subquestion_no_subquestions(self):
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
||||||
|
|
||||||
|
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||||
|
[], [], [],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "could not find" in answer.lower()
|
||||||
|
assert gen_prompt == ""
|
||||||
|
assert grouped_sources == []
|
||||||
|
|
||||||
|
async def test_generate_response_per_subquestion_no_chunks(self):
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
||||||
|
|
||||||
|
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||||
|
["Q?"], [[]], [[]],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "could not find" in answer.lower()
|
||||||
|
assert gen_prompt == ""
|
||||||
|
|
|
||||||
|
|
@ -91,3 +91,72 @@ async def test_filter_json_in_markdown_code_block(mock_prompt_service):
|
||||||
assert result[0][1]["relevance_score"] == 8.0
|
assert result[0][1]["relevance_score"] == 8.0
|
||||||
assert result[1][0] == chunks[2][0]
|
assert result[1][0] == chunks[2][0]
|
||||||
assert result[1][1]["relevance_score"] == 9.0
|
assert result[1][1]["relevance_score"] == 9.0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_filter_per_subquestion_basic(mock_prompt_service):
|
||||||
|
sub_chunks = [
|
||||||
|
[("Chunk A", {"filename": "a.pdf", "chunk_index": 0})],
|
||||||
|
[("Chunk B", {"filename": "b.pdf", "chunk_index": 1})],
|
||||||
|
]
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.complete = AsyncMock(return_value='{"0": [8.5], "1": [3.0]}')
|
||||||
|
|
||||||
|
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||||
|
result, prompt = await rf.filter_per_subquestion(
|
||||||
|
["Question A?", "Question B?"], sub_chunks, threshold=7.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0][0] == "Question A?"
|
||||||
|
assert len(result[0][1]) == 1
|
||||||
|
assert result[0][1][0][1]["relevance_score"] == 8.5
|
||||||
|
assert result[1][0] == "Question B?"
|
||||||
|
assert len(result[1][1]) == 0
|
||||||
|
assert prompt != ""
|
||||||
|
llm.complete.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_filter_per_subquestion_empty_subquestions(mock_prompt_service):
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.complete = AsyncMock()
|
||||||
|
|
||||||
|
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||||
|
result, prompt = await rf.filter_per_subquestion([], [], threshold=7.0)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
assert prompt == ""
|
||||||
|
llm.complete.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_filter_per_subquestion_invalid_json(mock_prompt_service):
|
||||||
|
sub_chunks = [
|
||||||
|
[("Chunk A", {"filename": "a.pdf", "chunk_index": 0})],
|
||||||
|
]
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.complete = AsyncMock(return_value="not valid json")
|
||||||
|
|
||||||
|
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||||
|
result, prompt = await rf.filter_per_subquestion(
|
||||||
|
["Question?"], sub_chunks, threshold=7.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
assert prompt != ""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_filter_per_subquestion_all_below_threshold(mock_prompt_service):
|
||||||
|
sub_chunks = [
|
||||||
|
[("Chunk A", {"filename": "a.pdf", "chunk_index": 0})],
|
||||||
|
[("Chunk B", {"filename": "b.pdf", "chunk_index": 1})],
|
||||||
|
]
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.complete = AsyncMock(return_value='{"0": [2.0], "1": [1.5]}')
|
||||||
|
|
||||||
|
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||||
|
result, prompt = await rf.filter_per_subquestion(
|
||||||
|
["Q1?", "Q2?"], sub_chunks, threshold=7.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert len(result[0][1]) == 0
|
||||||
|
assert len(result[1][1]) == 0
|
||||||
|
|
|
||||||
|
|
@ -236,3 +236,49 @@ async def test_generate_no_chunks_returns_fallback():
|
||||||
assert gen_prompt == ""
|
assert gen_prompt == ""
|
||||||
llm.complete.assert_not_called()
|
llm.complete.assert_not_called()
|
||||||
ps.get_prompt_template.assert_not_called()
|
ps.get_prompt_template.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_per_subq_fetches_template_from_prompt_service():
|
||||||
|
"""RAGService.generate_response_per_subquestion should use PromptService template."""
|
||||||
|
custom_template = "PER_SUBQ: {context_sections} DONE"
|
||||||
|
ps = _make_custom_prompt_service({"generate_per_subq": custom_template})
|
||||||
|
llm = _make_llm("answer")
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
|
||||||
|
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
|
||||||
|
["What is X?"],
|
||||||
|
[["chunk data"]],
|
||||||
|
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_prompt = llm.complete.call_args[1]["prompt"]
|
||||||
|
assert sent_prompt.startswith("PER_SUBQ:")
|
||||||
|
assert "chunk data" in sent_prompt
|
||||||
|
assert sent_prompt.endswith("DONE")
|
||||||
|
assert gen_prompt == sent_prompt
|
||||||
|
assert len(grouped_sources) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_generate_per_subq_uses_builtin_when_no_prompt_service():
|
||||||
|
"""Without prompt_service, the built-in per-subq template is used."""
|
||||||
|
llm = _make_llm("answer")
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None)
|
||||||
|
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
|
||||||
|
["What is X?"],
|
||||||
|
[["chunk data"]],
|
||||||
|
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_prompt = llm.complete.call_args[1]["prompt"]
|
||||||
|
assert "Sub-question" in sent_prompt
|
||||||
|
assert "chunk data" in sent_prompt
|
||||||
|
assert "{context_sections}" not in sent_prompt
|
||||||
|
|
|
||||||
|
|
@ -606,3 +606,60 @@ async def test_history_not_created_on_error():
|
||||||
|
|
||||||
# No history record
|
# No history record
|
||||||
assert rec is None, "History record must not be created on pipeline error"
|
assert rec is None, "History record must not be created on pipeline error"
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
# Phase 4: Per-sub-question pipeline history tests
|
||||||
|
#
|
||||||
|
# These tests verify the new per-sub-question pipeline records history
|
||||||
|
# correctly while the old flat pipeline tests above remain for backward
|
||||||
|
# compatibility.
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerSubQPipelineHistory:
|
||||||
|
"""History recording for the per-sub-question pipeline."""
|
||||||
|
|
||||||
|
async def test_per_subq_pipeline_records_history(self):
|
||||||
|
"""Per-sub-q pipeline should record history with sub_question_sources."""
|
||||||
|
history_svc = _make_mock_history_service()
|
||||||
|
events, rec = await _run_pipeline_and_collect_history(
|
||||||
|
history_service=history_svc,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rec is not None
|
||||||
|
assert rec["input_text"] == "What is the NEC4 clause about time extensions?"
|
||||||
|
assert rec["profile_used"] == "A"
|
||||||
|
|
||||||
|
questions = json.loads(rec["extracted_questions"])
|
||||||
|
assert isinstance(questions, list)
|
||||||
|
assert len(questions) >= 1
|
||||||
|
|
||||||
|
for timing_key in (
|
||||||
|
"decomposer_time_ms", "retriever_time_ms",
|
||||||
|
"filter_time_ms", "generator_time_ms", "total_time_ms",
|
||||||
|
):
|
||||||
|
assert rec[timing_key] >= 0, f"{timing_key} should be >= 0"
|
||||||
|
|
||||||
|
history_svc.record.assert_awaited_once()
|
||||||
|
|
||||||
|
async def test_per_subq_history_contains_chunk_xml(self):
|
||||||
|
"""History should contain XML-tagged chunks_retrieved and chunks_filtered."""
|
||||||
|
events, rec = await _run_pipeline_and_collect_history()
|
||||||
|
|
||||||
|
assert rec is not None
|
||||||
|
assert rec["chunks_retrieved"], "chunks_retrieved must not be empty"
|
||||||
|
assert rec["chunks_filtered"], "chunks_filtered must not be empty"
|
||||||
|
|
||||||
|
assert "<chunk_" in rec["chunks_retrieved"]
|
||||||
|
assert "Filename:" in rec["chunks_retrieved"]
|
||||||
|
assert "Relevance:" in rec["chunks_filtered"]
|
||||||
|
|
||||||
|
async def test_per_subq_history_prompts_are_strings(self):
|
||||||
|
"""All prompt fields must be strings (even if empty pre-implementation)."""
|
||||||
|
events, rec = await _run_pipeline_and_collect_history()
|
||||||
|
|
||||||
|
assert rec is not None
|
||||||
|
for key in ("decompose_prompt", "filter_prompt", "generate_prompt"):
|
||||||
|
assert key in rec
|
||||||
|
assert isinstance(rec[key], str)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue