246 lines
9.0 KiB
Python
246 lines
9.0 KiB
Python
"""Phase 1 tests: RAG service logic.
|
|
|
|
Covers:
|
|
- ChromaDB document ingestion with metadata
|
|
- Retrieval with query keywords
|
|
- Response generation with strict RAG prompt
|
|
- Metadata handling per chunk
|
|
"""
|
|
import pytest
|
|
from unittest.mock import MagicMock, AsyncMock
|
|
|
|
|
|
class TestRAGService:
|
|
"""RAG retrieval and prompt logic tests."""
|
|
|
|
def test_ingest_document_adds_chunks(self):
|
|
"""Should add chunks with metadata to ChromaDB collection."""
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
service = RAGService(chroma_client=mock_client)
|
|
|
|
chunks = ["chunk one", "chunk two"]
|
|
metadata = [
|
|
{"filename": "test.txt", "upload_date": "2024-01-01", "content_summary": "summary 1", "chunk_index": 0},
|
|
{"filename": "test.txt", "upload_date": "2024-01-01", "content_summary": "summary 2", "chunk_index": 1},
|
|
]
|
|
|
|
service.ingest_document("test.txt", chunks, metadata)
|
|
|
|
mock_client.get_or_create_collection.assert_called_once_with(name="documents")
|
|
mock_collection.add.assert_called_once()
|
|
call_args = mock_collection.add.call_args[1]
|
|
assert len(call_args["documents"]) == 2
|
|
assert call_args["documents"] == chunks
|
|
assert len(call_args["metadatas"]) == 2
|
|
assert call_args["metadatas"] == metadata
|
|
assert len(call_args["ids"]) == 2
|
|
|
|
def test_ingest_document_empty_chunks(self):
|
|
"""Should not call ChromaDB when chunks list is empty."""
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
service = RAGService(chroma_client=mock_client)
|
|
service.ingest_document("test.txt", [], [])
|
|
|
|
mock_collection.add.assert_not_called()
|
|
|
|
def test_retrieve_returns_chunks(self):
|
|
"""Should retrieve chunks and metadata from ChromaDB."""
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
mock_collection.query.return_value = {
|
|
"documents": [["chunk one", "chunk two"]],
|
|
"metadatas": [[{"filename": "test.txt"}, {"filename": "test.txt"}]],
|
|
"distances": [[0.1, 0.2]],
|
|
}
|
|
|
|
service = RAGService(chroma_client=mock_client)
|
|
results = service.retrieve(["query", "keywords"], n_results=5)
|
|
|
|
mock_collection.query.assert_called_once()
|
|
call_args = mock_collection.query.call_args[1]
|
|
assert call_args["n_results"] == 5
|
|
assert len(results) == 2
|
|
assert results[0] == ("chunk one", {"filename": "test.txt"}, 0.1)
|
|
assert results[1] == ("chunk two", {"filename": "test.txt"}, 0.2)
|
|
|
|
def test_retrieve_no_results(self):
|
|
"""Should return empty list when no results found."""
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
mock_collection.query.return_value = {
|
|
"documents": [[]],
|
|
"metadatas": [[]],
|
|
"distances": [[]],
|
|
}
|
|
|
|
service = RAGService(chroma_client=mock_client)
|
|
results = service.retrieve(["query"])
|
|
|
|
assert results == []
|
|
|
|
async def test_generate_response_calls_llm(self, mock_prompt_service):
|
|
"""Should call LLM with strict RAG prompt."""
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm.complete = AsyncMock(return_value="- Bullet point answer")
|
|
|
|
service = RAGService(chroma_client=mock_client, llm_client=mock_llm, prompt_service=mock_prompt_service)
|
|
|
|
chunks = ["relevant chunk"]
|
|
metadata = [{"filename": "test.txt", "content_summary": "summary"}]
|
|
|
|
answer, gen_prompt = await service.generate_response("What is this?", chunks, metadata)
|
|
|
|
mock_llm.complete.assert_called_once()
|
|
sent_prompt = mock_llm.complete.call_args[1]["prompt"]
|
|
assert "What is this?" in sent_prompt
|
|
assert "relevant chunk" in sent_prompt
|
|
assert "test.txt" in sent_prompt
|
|
assert "only these document chunks" in sent_prompt.lower()
|
|
assert answer == "- Bullet point answer"
|
|
assert gen_prompt == sent_prompt
|
|
|
|
async def test_generate_response_no_chunks(self):
|
|
"""Should return fallback message when no chunks provided."""
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
|
|
|
answer, gen_prompt = await service.generate_response("What is this?", [], [])
|
|
|
|
assert "no relevant" in answer.lower() or "could not find" in answer.lower()
|
|
assert gen_prompt == ""
|
|
|
|
def test_retrieve_per_subquestion_returns_per_query(self):
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
mock_collection.query.side_effect = [
|
|
{
|
|
"documents": [["chunk A1", "chunk A2"]],
|
|
"metadatas": [[{"filename": "a.pdf"}, {"filename": "a.pdf"}]],
|
|
"distances": [[0.1, 0.2]],
|
|
},
|
|
{
|
|
"documents": [["chunk B1"]],
|
|
"metadatas": [[{"filename": "b.pdf"}]],
|
|
"distances": [[0.3]],
|
|
},
|
|
]
|
|
|
|
service = RAGService(chroma_client=mock_client)
|
|
results = service.retrieve_per_subquestion(["query A", "query B"], n_results=5)
|
|
|
|
assert len(results) == 2
|
|
assert results[0][0] == "query A"
|
|
assert len(results[0][1]) == 2
|
|
assert results[1][0] == "query B"
|
|
assert len(results[1][1]) == 1
|
|
assert mock_collection.query.call_count == 2
|
|
|
|
def test_retrieve_per_subquestion_empty_list(self):
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
service = RAGService(chroma_client=mock_client)
|
|
results = service.retrieve_per_subquestion([], n_results=5)
|
|
|
|
assert results == []
|
|
mock_collection.query.assert_not_called()
|
|
|
|
async def test_generate_response_per_subquestion_calls_llm(self, mock_prompt_service):
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm.complete = AsyncMock(return_value="## Sub-question 1: Q?\n- Answer")
|
|
|
|
service = RAGService(
|
|
chroma_client=mock_client,
|
|
llm_client=mock_llm,
|
|
prompt_service=mock_prompt_service,
|
|
)
|
|
|
|
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
|
["What is X?"],
|
|
[["chunk data"]],
|
|
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
|
)
|
|
|
|
mock_llm.complete.assert_called_once()
|
|
sent_prompt = mock_llm.complete.call_args[1]["prompt"]
|
|
assert "chunk data" in sent_prompt
|
|
assert "Sub-question 0" in sent_prompt
|
|
assert answer == "## Sub-question 1: Q?\n- Answer"
|
|
assert len(grouped_sources) == 1
|
|
assert grouped_sources[0][0]["filename"] == "f.txt"
|
|
|
|
async def test_generate_response_per_subquestion_no_subquestions(self):
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
|
|
|
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
|
[], [], [],
|
|
)
|
|
|
|
assert "could not find" in answer.lower()
|
|
assert gen_prompt == ""
|
|
assert grouped_sources == []
|
|
|
|
async def test_generate_response_per_subquestion_no_chunks(self):
|
|
from app.services.rag import RAGService
|
|
|
|
mock_collection = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.get_or_create_collection.return_value = mock_collection
|
|
|
|
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
|
|
|
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
|
["Q?"], [[]], [[]],
|
|
)
|
|
|
|
assert "could not find" in answer.lower()
|
|
assert gen_prompt == ""
|