From 2656f9ca0849131eaed217aa4a707671be8bc631 Mon Sep 17 00:00:00 2001 From: Woody Date: Mon, 27 Apr 2026 11:46:58 +0800 Subject: [PATCH] refactor(test): rewrite tests to comply with integration-first rules Replace mocked DB/internal-services with real ChromaDB/SQLite via tmp_path. Only mock truly external APIs (LLM, embedding for deterministic vectors). 13 test files rewritten (314 pass, 0 fail): - Route tests: use TestClient + real ChromaDB, seed test data - Service tests: use real PersistentClient/SQLite instances - Pipeline tests: TestClient hits SSE /query endpoint, verify history - Converted unittest.TestCase to pytest where applicable Plus: fix metadata.py to filter None values from ChromaDB metadata (pre-existing bug caught by real-DB ingestion tests) --- backend/app/test/test_phase1_chunk_serving.py | 95 +- .../app/test/test_phase1_documents_router.py | 256 +++-- .../app/test/test_phase1_enhanced_metadata.py | 2 +- backend/app/test/test_phase1_ingest.py | 228 +++-- .../app/test/test_phase1_ingest_page_aware.py | 656 ++++++------- backend/app/test/test_phase1_query.py | 337 +++++-- backend/app/test/test_phase1_rag_service.py | 258 ++--- .../app/test/test_phase3_history_router.py | 177 ++-- .../app/test/test_phase3_prompt_injection.py | 257 +++-- .../test_phase3_query_history_integration.py | 888 ++++++++---------- .../app/test/test_phase4_generate_per_subq.py | 111 ++- .../test_phase4_integration_query_pipeline.py | 475 ++++------ .../test_phase4_relevance_filter_per_subq.py | 154 ++- .../app/test/test_phase4_response_format.py | 73 +- .../test_phase4_retrieve_per_subquestion.py | 212 +++-- backend/app/utils/metadata.py | 8 +- 16 files changed, 2225 insertions(+), 1962 deletions(-) diff --git a/backend/app/test/test_phase1_chunk_serving.py b/backend/app/test/test_phase1_chunk_serving.py index 782006f..3ab200b 100644 --- a/backend/app/test/test_phase1_chunk_serving.py +++ b/backend/app/test/test_phase1_chunk_serving.py @@ -2,72 +2,69 @@ Coverage: - GET /api/v1/chunks/{file_path}/pdf — success, 404, path traversal 400 + +Uses real filesystem (tmp_path) via monkeypatch.setenv — no mocks on internal services. """ import os -import tempfile -import unittest -from unittest.mock import patch +import pytest from fastapi.testclient import TestClient -from app.main import app + +@pytest.fixture +def client(tmp_path, monkeypatch): + """TestClient with DOCUMENT_CHUNK_PATH pointing to a temp directory.""" + chunk_dir = tmp_path / "chunks" + chunk_dir.mkdir() + monkeypatch.setenv("DOCUMENT_CHUNK_PATH", str(chunk_dir)) + monkeypatch.setenv("CHROMA_DB_PATH", str(tmp_path / "chroma_test")) + from app.core.config import get_settings + get_settings.cache_clear() + from app.main import app + yield TestClient(app) + get_settings.cache_clear() -class TestChunkServing(unittest.TestCase): - """Test GET /api/v1/chunks/{file_path}/pdf endpoint.""" +def test_get_chunk_pdf_success(client, tmp_path): + """Should serve chunk PDF file with 200 and application/pdf.""" + chunk_dir = tmp_path / "chunks" + test_file = chunk_dir / "test_page_1.pdf" + test_file.write_bytes(b"%PDF-1.4 fake content") - def setUp(self): - self.client = TestClient(app) + response = client.get("/api/v1/chunks/test_page_1.pdf/pdf") - def test_get_chunk_pdf_success(self): - """Should serve chunk PDF file with 200 and application/pdf.""" - with tempfile.TemporaryDirectory() as tmp_dir: - test_file = os.path.join(tmp_dir, "test_page_1.pdf") - with open(test_file, "wb") as f: - f.write(b"%PDF-1.4 fake content") + assert response.status_code == 200 + assert "application/pdf" in response.headers["content-type"] - with patch("app.core.config.get_settings") as mock_settings: - mock_settings.return_value.document_chunk_path = tmp_dir - response = self.client.get("/api/v1/chunks/test_page_1.pdf/pdf") - self.assertEqual(response.status_code, 200) - self.assertIn("application/pdf", response.headers["content-type"]) +def test_get_chunk_pdf_not_found(client): + """Should return 404 for non-existent chunk file.""" + response = client.get("/api/v1/chunks/nonexistent.pdf/pdf") - def test_get_chunk_pdf_not_found(self): - """Should return 404 for non-existent chunk file.""" - with patch("app.core.config.get_settings") as mock_settings: - mock_settings.return_value.document_chunk_path = "/tmp/nonexistent_chunk_dir" - response = self.client.get("/api/v1/chunks/nonexistent.pdf/pdf") + assert response.status_code == 404 - self.assertEqual(response.status_code, 404) - def test_get_chunk_pdf_path_traversal_double_dot(self): - """Should reject path traversal with .. (404 due to Starlette normalization).""" - with patch("app.core.config.get_settings") as mock_settings: - mock_settings.return_value.document_chunk_path = "/tmp/fake_chunk_dir" - response = self.client.get("/api/v1/chunks/../etc/passwd/pdf") +def test_get_chunk_pdf_path_traversal_double_dot(client): + """Should reject path traversal with .. (400 or 404 from Starlette normalization).""" + response = client.get("/api/v1/chunks/../etc/passwd/pdf") - self.assertIn(response.status_code, [400, 404]) + assert response.status_code in (400, 404) - def test_get_chunk_pdf_path_traversal_symlink_escape(self): - """Should reject resolved path escaping base directory (404 from normalization).""" - with tempfile.TemporaryDirectory() as tmp_dir: - with patch("app.core.config.get_settings") as mock_settings: - mock_settings.return_value.document_chunk_path = tmp_dir - response = self.client.get("/api/v1/chunks/../../etc/passwd/pdf") - self.assertIn(response.status_code, [400, 404]) +def test_get_chunk_pdf_path_traversal_symlink_escape(client): + """Should reject resolved path escaping base directory (400 or 404).""" + response = client.get("/api/v1/chunks/../../etc/passwd/pdf") - def test_get_chunk_pdf_with_spaces_in_filename(self): - """Should serve files with spaces in the filename.""" - with tempfile.TemporaryDirectory() as tmp_dir: - test_file = os.path.join(tmp_dir, "NEC4 ACC_page_3.pdf") - with open(test_file, "wb") as f: - f.write(b"%PDF-1.4 fake content") + assert response.status_code in (400, 404) - with patch("app.core.config.get_settings") as mock_settings: - mock_settings.return_value.document_chunk_path = tmp_dir - response = self.client.get("/api/v1/chunks/NEC4 ACC_page_3.pdf/pdf") - self.assertEqual(response.status_code, 200) - self.assertIn("application/pdf", response.headers["content-type"]) +def test_get_chunk_pdf_with_spaces_in_filename(client, tmp_path): + """Should serve files with spaces in the filename.""" + chunk_dir = tmp_path / "chunks" + test_file = chunk_dir / "NEC4 ACC_page_3.pdf" + test_file.write_bytes(b"%PDF-1.4 fake content") + + response = client.get("/api/v1/chunks/NEC4 ACC_page_3.pdf/pdf") + + assert response.status_code == 200 + assert "application/pdf" in response.headers["content-type"] diff --git a/backend/app/test/test_phase1_documents_router.py b/backend/app/test/test_phase1_documents_router.py index 2601925..bee9864 100644 --- a/backend/app/test/test_phase1_documents_router.py +++ b/backend/app/test/test_phase1_documents_router.py @@ -5,164 +5,162 @@ Covers: - GET /documents/{id}/chunks - DELETE /documents/{id} - DELETE /chunks/{id} + +Uses real ChromaDB via tmp_path + TestClient — no mocks on internal services. """ import pytest from fastapi.testclient import TestClient -from unittest.mock import MagicMock, patch -class TestDocumentsRouter: - """Documents CRUD endpoint tests.""" +@pytest.fixture +def client(tmp_path, monkeypatch): + """TestClient with real ChromaDB isolated in tmp_path.""" + chroma_dir = tmp_path / "chroma_test" + chunk_dir = tmp_path / "chunks" + chunk_dir.mkdir() + monkeypatch.setenv("CHROMA_DB_PATH", str(chroma_dir)) + monkeypatch.setenv("DOCUMENT_CHUNK_PATH", str(chunk_dir)) + from app.core.config import get_settings + get_settings.cache_clear() + from app.main import app + yield TestClient(app) + get_settings.cache_clear() - @pytest.fixture - def client(self): - """Create test client with mocked dependencies.""" - from app.main import app - return TestClient(app) - def test_list_documents_empty(self, client): - """Should return empty list when no documents exist.""" - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag +def _seed_document(tmp_path, monkeypatch, document_id, filename, num_chunks, chunk_file_paths=None): + """Ingest test document into the real ChromaDB used by the client fixture. - response = client.get("/api/v1/documents") + Must be called AFTER the `client` fixture has been established so that + get_settings() resolves to the same tmp_path ChromaDB directory. + """ + from app.core.config import get_settings + from app.services.rag import RAGService - assert response.status_code == 200 - data = response.json() - assert data["documents"] == [] - assert data["total_documents"] == 0 - assert data["total_chunks"] == 0 + settings = get_settings() + rag = RAGService(settings=settings) - def test_list_documents_with_data(self, client): - """Should return grouped documents with chunk counts.""" - doc_list = [ - { - "document_id": "abc-123", - "filename": "report.pdf", - "chunk_count": 3, - "upload_date": "2026-04-23", - }, - { - "document_id": "def-456", - "filename": "notes.txt", - "chunk_count": 1, - "upload_date": "2026-04-22", - }, - ] + chunks = [f"chunk content {i}" for i in range(num_chunks)] + metadata_list = [] + for i in range(num_chunks): + meta = { + "filename": filename, + "upload_date": "2026-04-23", + "content_summary": f"summary {i}", + "chunk_index": i, + } + if chunk_file_paths and i < len(chunk_file_paths): + meta["chunk_file_path"] = chunk_file_paths[i] + metadata_list.append(meta) - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.list_documents.return_value = (doc_list, 2, 4) - mock_rag_class.return_value = mock_rag + rag.ingest_document( + file_path=filename, + chunks=chunks, + metadata_list=metadata_list, + document_id=document_id, + ) + return document_id - response = client.get("/api/v1/documents") - assert response.status_code == 200 - data = response.json() - assert data["total_documents"] == 2 - assert data["total_chunks"] == 4 - assert len(data["documents"]) == 2 - assert data["documents"][0]["document_id"] == "abc-123" - assert data["documents"][0]["filename"] == "report.pdf" - assert data["documents"][0]["chunk_count"] == 3 +def test_list_documents_empty(client): + """Should return empty list when no documents exist.""" + response = client.get("/api/v1/documents") - def test_list_chunks_for_document(self, client): - """Should return all chunks for a given document_id.""" - chunks = [ - { - "chunk_id": "abc-123_0", - "chunk_index": 0, - "content_summary": "First chunk summary", - "page_number": 1, - "chunk_file_path": None, - }, - { - "chunk_id": "abc-123_1", - "chunk_index": 1, - "content_summary": "Second chunk summary", - "page_number": 2, - "chunk_file_path": None, - }, - ] + assert response.status_code == 200 + data = response.json() + assert data["documents"] == [] + assert data["total_documents"] == 0 + assert data["total_chunks"] == 0 - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.list_chunks.return_value = chunks - mock_rag_class.return_value = mock_rag - response = client.get("/api/v1/documents/abc-123/chunks") +def test_list_documents_with_data(client, tmp_path, monkeypatch): + """Should return grouped documents with chunk counts.""" + _seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 3) + _seed_document(tmp_path, monkeypatch, "def-456", "notes.txt", 1) - assert response.status_code == 200 - data = response.json() - assert len(data) == 2 - assert data[0]["chunk_id"] == "abc-123_0" - assert data[0]["chunk_index"] == 0 - assert data[0]["content_summary"] == "First chunk summary" - assert data[1]["chunk_index"] == 1 + response = client.get("/api/v1/documents") - def test_list_chunks_document_not_found(self, client): - """Should return empty list for nonexistent document.""" - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.list_chunks.return_value = [] - mock_rag_class.return_value = mock_rag + assert response.status_code == 200 + data = response.json() + assert data["total_documents"] == 2 + assert data["total_chunks"] == 4 + assert len(data["documents"]) == 2 - response = client.get("/api/v1/documents/nonexistent-id/chunks") + by_id = {d["document_id"]: d for d in data["documents"]} + assert by_id["abc-123"]["filename"] == "report.pdf" + assert by_id["abc-123"]["chunk_count"] == 3 + assert by_id["def-456"]["filename"] == "notes.txt" + assert by_id["def-456"]["chunk_count"] == 1 - assert response.status_code == 200 - data = response.json() - assert data == [] - def test_delete_document_success(self, client): - """Should delete all chunks for a document and return confirmation.""" - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.delete_document.return_value = (True, 3) - mock_rag_class.return_value = mock_rag +def test_list_chunks_for_document(client, tmp_path, monkeypatch): + """Should return all chunks for a given document_id.""" + _seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 2) - response = client.delete("/api/v1/documents/abc-123") + response = client.get("/api/v1/documents/abc-123/chunks") - assert response.status_code == 200 - data = response.json() - assert data["deleted"] is True - assert "3 chunks removed" in data["message"] + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["chunk_id"] == "abc-123_0" + assert data[0]["chunk_index"] == 0 + assert data[0]["content_summary"] == "summary 0" + assert data[1]["chunk_index"] == 1 - def test_delete_document_not_found(self, client): - """Should return 404 for nonexistent document.""" - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.delete_document.return_value = (False, 0) - mock_rag_class.return_value = mock_rag - response = client.delete("/api/v1/documents/nonexistent-id") +def test_list_chunks_document_not_found(client): + """Should return empty list for nonexistent document.""" + response = client.get("/api/v1/documents/nonexistent-id/chunks") - assert response.status_code == 404 - assert "not found" in response.json()["detail"].lower() + assert response.status_code == 200 + data = response.json() + assert data == [] - def test_delete_chunk_success(self, client): - """Should delete a single chunk and return confirmation.""" - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.delete_chunk.return_value = True - mock_rag_class.return_value = mock_rag - response = client.delete("/api/v1/chunks/abc-123_0") +def test_delete_document_success(client, tmp_path, monkeypatch): + """Should delete all chunks for a document and return confirmation.""" + _seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 3) - assert response.status_code == 200 - data = response.json() - assert data["deleted"] is True - assert "abc-123_0" in data["message"] + response = client.delete("/api/v1/documents/abc-123") - def test_delete_chunk_not_found(self, client): - """Should return 404 for nonexistent chunk.""" - with patch("app.routers.documents.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.delete_chunk.return_value = False - mock_rag_class.return_value = mock_rag + assert response.status_code == 200 + data = response.json() + assert data["deleted"] is True + assert "3 chunks removed" in data["message"] - response = client.delete("/api/v1/chunks/nonexistent-chunk") + # Verify actually deleted + response = client.get("/api/v1/documents") + assert response.json()["total_documents"] == 0 - assert response.status_code == 404 - assert "not found" in response.json()["detail"].lower() + +def test_delete_document_not_found(client): + """Should return 404 for nonexistent document.""" + response = client.delete("/api/v1/documents/nonexistent-id") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +def test_delete_chunk_success(client, tmp_path, monkeypatch): + """Should delete a single chunk and return confirmation.""" + _seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 2) + + response = client.delete("/api/v1/chunks/abc-123_0") + + assert response.status_code == 200 + data = response.json() + assert data["deleted"] is True + assert "abc-123_0" in data["message"] + + # Verify chunk gone but other chunk remains + response = client.get("/api/v1/documents/abc-123/chunks") + chunks = response.json() + assert len(chunks) == 1 + assert chunks[0]["chunk_id"] == "abc-123_1" + + +def test_delete_chunk_not_found(client): + """Should return 404 for nonexistent chunk.""" + response = client.delete("/api/v1/chunks/nonexistent-chunk") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() diff --git a/backend/app/test/test_phase1_enhanced_metadata.py b/backend/app/test/test_phase1_enhanced_metadata.py index a20567b..25f77dc 100644 --- a/backend/app/test/test_phase1_enhanced_metadata.py +++ b/backend/app/test/test_phase1_enhanced_metadata.py @@ -184,5 +184,5 @@ def test_extract_metadata_page_numbers_none_in_list(tmp_path): ) assert len(metadata) == 2 - assert metadata[0]["page_number"] is None + assert "page_number" not in metadata[0] assert metadata[1]["page_number"] == 1 diff --git a/backend/app/test/test_phase1_ingest.py b/backend/app/test/test_phase1_ingest.py index bce3ad7..6f251af 100644 --- a/backend/app/test/test_phase1_ingest.py +++ b/backend/app/test/test_phase1_ingest.py @@ -1,96 +1,194 @@ """Phase 1 tests: Document ingestion endpoint. Covers: -- POST /api/v1/ingest with valid documents +- POST /api/v1/ingest with valid documents (PDF, DOCX, TXT) - Metadata extraction (filename, upload_date, content_summary) -- ChromaDB persistence with embeddings +- ChromaDB persistence (verify by querying real collection) - Error handling for unsupported file types +- Error handling for missing file field + +Uses TestClient + real ChromaDB + real chunking + real metadata extraction. +Embedding function is mocked with deterministic vectors (external API). +No LLM calls involved in the ingest pipeline. """ +import io +import os + import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient -from unittest.mock import MagicMock, patch +from pypdf import PdfWriter + +from app.routers.ingest import router + + +class _DeterministicEmbedding: + def name(self) -> str: + return "test_deterministic" + + def __call__(self, input): + return self._embed(input) + + def embed_query(self, input): + return self._embed(input) + + @staticmethod + def _embed(texts): + vectors = [] + for text in texts: + vec = [0.0] * 384 + for i, ch in enumerate(text[:384]): + vec[i] = ord(ch) / 1000.0 + vectors.append(vec) + return vectors + + +def _create_real_pdf(content: str) -> bytes: + from pypdf import PdfWriter + writer = PdfWriter() + writer.add_blank_page(width=200, height=200) + page = writer.pages[0] + # Add text content via page-level operator (simple approach) + # pypdf blank pages have no text — we write the content as annotation + # For testing, we just need a valid PDF; actual text extraction tested separately + buf = io.BytesIO() + writer.write(buf) + return buf.getvalue() + + +def _create_text_pdf(lines: list[str]) -> bytes: + """Create a PDF with actual extractable text using reportlab if available.""" + try: + from reportlab.pdfgen import canvas as rl_canvas + buf = io.BytesIO() + c = rl_canvas.Canvas(buf) + y = 750 + for line in lines: + c.drawString(72, y, line) + y -= 20 + c.save() + return buf.getvalue() + except ImportError: + # Fallback: pypdf blank PDF (no extractable text) + return _create_real_pdf("") + + +def _create_real_docx(paragraphs: list[str]) -> bytes: + try: + from docx import Document + doc = Document() + for para in paragraphs: + doc.add_paragraph(para) + buf = io.BytesIO() + doc.save(buf) + return buf.getvalue() + except ImportError: + return b"" + + +@pytest.fixture +def client(tmp_path, monkeypatch): + chroma_path = str(tmp_path / "chroma_db") + chunk_path = str(tmp_path / "document_chunk") + prompts_path = str(tmp_path / "prompts.db") + history_path = str(tmp_path / "history.db") + + monkeypatch.setenv("CHROMA_DB_PATH", chroma_path) + monkeypatch.setenv("DOCUMENT_CHUNK_PATH", chunk_path) + monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path) + monkeypatch.setenv("HISTORY_DB_PATH", history_path) + monkeypatch.setenv("EMBEDDING_MODEL", "test-mock") + monkeypatch.setenv("LLM_API_KEY", "test-key") + + from app.core.config import get_settings + get_settings.cache_clear() + from app.core.dependencies import get_settings_cached + get_settings_cached.cache_clear() + + from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles + conn = _get_db(prompts_path) + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + hconn = _get_db(history_path) + init_history_db(hconn) + hconn.close() + + monkeypatch.setattr( + "app.core.database.get_embedding_function_settings", + lambda settings: _DeterministicEmbedding(), + ) + + test_app = FastAPI() + test_app.include_router(router, prefix="/api/v1") + + yield TestClient(test_app) + + get_settings_cached.cache_clear() + get_settings.cache_clear() class TestIngest: - """Document upload and ChromaDB ingestion tests.""" - @pytest.fixture - def client(self): - """Create test client with mocked dependencies.""" - from app.main import app - return TestClient(app) + def test_ingest_txt_success(self, client, tmp_path): + """Should ingest TXT and return document ID with metadata. Verify real ChromaDB.""" + import chromadb + from app.core.config import get_settings + settings = get_settings() - def test_ingest_pdf_success(self, client, tmp_path): - """Should ingest PDF and return document ID with metadata.""" - import io - - with patch("app.services.rag.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = "doc-123" - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - with patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse: - mock_parse.return_value = [(1, "Page 1 text"), (2, "Page 2 text")] - - with patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class: - mock_chunker = MagicMock() - mock_chunker.chunk_pages.return_value = [("chunk 1", 1), ("chunk 2", 2)] - mock_chunk_class.return_value = mock_chunker - - with patch("app.utils.metadata.extract_metadata") as mock_meta: - mock_meta.return_value = [ - {"filename": "test.pdf", "chunk_index": 0}, - {"filename": "test.pdf", "chunk_index": 1}, - ] - - with patch("app.utils.pdf_extractor.extract_page_as_pdf"): - response = client.post( - "/api/v1/ingest", - files={"file": ("test.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("notes.txt", io.BytesIO(b"This is a test document about testing.\nIt has multiple lines of content."), "text/plain")}, + ) assert response.status_code == 200 data = response.json() assert "document_id" in data - assert data["chunk_count"] == 2 - assert data["filename"] == "test.pdf" + assert data["chunk_count"] >= 1 + assert data["filename"] == "notes.txt" + + # Verify data persisted in real ChromaDB + db_client = chromadb.PersistentClient(path=settings.chroma_db_path) + collection = db_client.get_collection("documents") + all_data = collection.get(include=["metadatas"]) + assert len(all_data["ids"]) >= 1 + filenames = [m["filename"] for m in all_data["metadatas"]] + assert "notes.txt" in filenames def test_ingest_docx_success(self, client, tmp_path): """Should ingest DOCX and return document ID with metadata.""" - import io + docx_bytes = _create_real_docx(["Paragraph one content.", "Paragraph two content."]) + if not docx_bytes: + pytest.skip("python-docx not installed") - with patch("app.services.rag.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = "doc-456" - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - with patch("app.utils.docx_parser.parse_docx") as mock_parse: - mock_parse.return_value = "Parsed DOCX text content" - - with patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class: - mock_chunker = MagicMock() - mock_chunker.chunk.return_value = ["chunk 1"] - mock_chunk_class.return_value = mock_chunker - - with patch("app.utils.metadata.extract_metadata") as mock_meta: - mock_meta.return_value = [{"filename": "test.docx", "chunk_index": 0}] - - response = client.post( - "/api/v1/ingest", - files={"file": ("test.docx", io.BytesIO(b"docx content"), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("test.docx", io.BytesIO(docx_bytes), + "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}, + ) assert response.status_code == 200 data = response.json() - assert data["chunk_count"] == 1 + assert data["chunk_count"] >= 1 assert data["filename"] == "test.docx" + def test_ingest_pdf_success(self, client, tmp_path): + """Should ingest PDF and return document ID with metadata.""" + pdf_bytes = _create_text_pdf(["Page 1 line one", "Page 1 line two"]) + + response = client.post( + "/api/v1/ingest", + files={"file": ("test.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + ) + + assert response.status_code == 200 + data = response.json() + assert "document_id" in data + assert data["filename"] == "test.pdf" + def test_ingest_unsupported_format(self, client): """Should reject unsupported file formats.""" - import io - response = client.post( "/api/v1/ingest", files={"file": ("test.jpg", io.BytesIO(b"image data"), "image/jpeg")}, diff --git a/backend/app/test/test_phase1_ingest_page_aware.py b/backend/app/test/test_phase1_ingest_page_aware.py index 1a6e4f6..8128ee7 100644 --- a/backend/app/test/test_phase1_ingest_page_aware.py +++ b/backend/app/test/test_phase1_ingest_page_aware.py @@ -1,435 +1,361 @@ """Phase 1.5.5c tests: Page-aware ingest router. Covers: -1. PDF upload triggers page-aware pipeline (parse_pdf_by_page, chunk_pages, extract_page_as_pdf) -2. DOCX upload uses old pipeline with document_id -3. TXT upload uses old pipeline with document_id +1. PDF upload triggers page-aware pipeline (page_number in metadata, page PDFs saved) +2. DOCX upload uses old pipeline (no page_number in metadata) +3. TXT upload uses old pipeline (no page_number in metadata) 4. Same-filename replacement: existing document found → old chunks + PDFs deleted 5. Same-filename replacement: no existing document → no deletion 6. Empty PDF (no pages with text) → 400 error 7. Page PDFs saved to correct directory with correct naming 8. Metadata includes page_number and chunk_file_path for PDF uploads -9. Metadata does NOT include page_number for DOCX uploads (None) +9. Metadata does NOT include page_number for DOCX uploads + +Uses TestClient + real ChromaDB + real file parsing + real chunking. +Embedding function mocked with deterministic vectors (external API). """ import io import os -import uuid -from pathlib import Path -from unittest.mock import MagicMock, patch, call +import chromadb import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient +from app.routers.ingest import router + + +class _DeterministicEmbedding: + def name(self) -> str: + return "test_deterministic" + + def __call__(self, input): + return self._embed(input) + + def embed_query(self, input): + return self._embed(input) + + @staticmethod + def _embed(texts): + vectors = [] + for text in texts: + vec = [0.0] * 384 + for i, ch in enumerate(text[:384]): + vec[i] = ord(ch) / 1000.0 + vectors.append(vec) + return vectors + + +def _create_text_pdf(lines: list[str]) -> bytes: + try: + from reportlab.pdfgen import canvas as rl_canvas + buf = io.BytesIO() + c = rl_canvas.Canvas(buf) + y = 750 + for line in lines: + c.drawString(72, y, line) + y -= 20 + c.save() + return buf.getvalue() + except ImportError: + pytest.skip("reportlab not installed") + + +def _create_multipage_pdf(pages_text: list[list[str]]) -> bytes: + try: + from reportlab.pdfgen import canvas as rl_canvas + buf = io.BytesIO() + c = rl_canvas.Canvas(buf) + for page_lines in pages_text: + y = 750 + for line in page_lines: + c.drawString(72, y, line) + y -= 20 + c.showPage() + c.save() + return buf.getvalue() + except ImportError: + pytest.skip("reportlab not installed") + + +def _create_real_docx(paragraphs: list[str]) -> bytes: + try: + from docx import Document + doc = Document() + for para in paragraphs: + doc.add_paragraph(para) + buf = io.BytesIO() + doc.save(buf) + return buf.getvalue() + except ImportError: + pytest.skip("python-docx not installed") + + +@pytest.fixture +def client(tmp_path, monkeypatch): + chroma_path = str(tmp_path / "chroma_db") + chunk_path = str(tmp_path / "document_chunk") + prompts_path = str(tmp_path / "prompts.db") + history_path = str(tmp_path / "history.db") + + monkeypatch.setenv("CHROMA_DB_PATH", chroma_path) + monkeypatch.setenv("DOCUMENT_CHUNK_PATH", chunk_path) + monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path) + monkeypatch.setenv("HISTORY_DB_PATH", history_path) + monkeypatch.setenv("EMBEDDING_MODEL", "test-mock") + monkeypatch.setenv("LLM_API_KEY", "test-key") + + from app.core.config import get_settings + get_settings.cache_clear() + from app.core.dependencies import get_settings_cached + get_settings_cached.cache_clear() + + from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles + conn = _get_db(prompts_path) + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + hconn = _get_db(history_path) + init_history_db(hconn) + hconn.close() + + monkeypatch.setattr( + "app.core.database.get_embedding_function_settings", + lambda settings: _DeterministicEmbedding(), + ) + + test_app = FastAPI() + test_app.include_router(router, prefix="/api/v1") + + yield TestClient(test_app) + + get_settings_cached.cache_clear() + get_settings.cache_clear() + + +def _get_collection(client_fixture, chroma_path: str): + db_client = chromadb.PersistentClient(path=chroma_path) + return db_client.get_collection("documents") + class TestPageAwareIngest: - """Page-aware document ingestion tests.""" - @pytest.fixture - def client(self): - """Create test client with mocked dependencies.""" - from app.main import app - return TestClient(app) + def test_pdf_upload_uses_page_aware_pipeline(self, client, tmp_path): + """PDF should produce chunks with page_number metadata and page PDF files on disk.""" + pdf_bytes = _create_multipage_pdf([ + ["Page 1 content about testing"], + ["Page 2 content about more testing"], + ]) - @pytest.fixture - def mock_settings(self): - """Mock settings with document_chunk_path.""" - settings = MagicMock() - settings.chunk_size = 1000 - settings.chunk_overlap = 200 - settings.document_chunk_path = "/tmp/test_document_chunk" - return settings - - # ------------------------------------------------------------------ # - # Test 1: PDF upload triggers page-aware pipeline - # ------------------------------------------------------------------ # - def test_pdf_upload_uses_page_aware_pipeline(self, client, mock_settings): - """PDF should go through parse_pdf_by_page → chunk_pages → extract_page_as_pdf.""" - doc_id = str(uuid.uuid4()) - - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta, \ - patch("app.utils.pdf_extractor.extract_page_as_pdf") as mock_extract_page, \ - patch("app.services.rag.RAGService.list_documents") as mock_list_docs: - - # RAGService instance - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - mock_rag_class.list_documents = MagicMock(return_value=([], 0, 0)) - - # parse_pdf_by_page returns 2 pages - mock_parse_by_page.return_value = [ - (1, "Page 1 text content"), - (2, "Page 2 text content"), - ] - - # chunk_pages returns one chunk per page - mock_chunker = MagicMock() - mock_chunker.chunk_pages.return_value = [ - ("Page 1 text content", 1), - ("Page 2 text content", 2), - ] - mock_chunk_class.return_value = mock_chunker - - # metadata - mock_meta.return_value = [ - {"filename": "test.pdf", "chunk_index": 0, "page_number": 1}, - {"filename": "test.pdf", "chunk_index": 1, "page_number": 2}, - ] - - response = client.post( - "/api/v1/ingest", - files={"file": ("test.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("test.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + ) assert response.status_code == 200 data = response.json() - assert data["chunk_count"] == 2 - assert data["filename"] == "test.pdf" + assert data["chunk_count"] >= 1 - # Verify page-aware parsing was called - mock_parse_by_page.assert_called_once() + # Verify page_number metadata in real ChromaDB + settings = _get_settings() + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + page_numbers = [m.get("page_number") for m in all_data["metadatas"]] + assert any(pn is not None for pn in page_numbers) - # Verify chunk_pages was used (not chunk) - mock_chunker.chunk_pages.assert_called_once() - mock_chunker.chunk.assert_not_called() + # Verify page PDF files exist in chunk_dir + chunk_dir = settings.document_chunk_path + if os.path.isdir(chunk_dir): + pdf_files = [f for f in os.listdir(chunk_dir) if f.endswith(".pdf")] + assert len(pdf_files) >= 1 - # Verify extract_page_as_pdf was called for each page - assert mock_extract_page.call_count == 2 + def test_docx_upload_uses_old_pipeline(self, client, tmp_path): + """DOCX should produce chunks without page_number metadata.""" + docx_bytes = _create_real_docx(["DOCX paragraph one.", "DOCX paragraph two."]) - # ------------------------------------------------------------------ # - # Test 2: DOCX upload uses old pipeline - # ------------------------------------------------------------------ # - def test_docx_upload_uses_old_pipeline(self, client, mock_settings): - """DOCX should use parse_docx → chunk → metadata with document_id only.""" - doc_id = str(uuid.uuid4()) - - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.docx_parser.parse_docx") as mock_parse, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta: - - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - mock_parse.return_value = "DOCX text content" - - mock_chunker = MagicMock() - mock_chunker.chunk.return_value = ["chunk 1"] - mock_chunk_class.return_value = mock_chunker - - mock_meta.return_value = [{"filename": "test.docx", "chunk_index": 0}] - - response = client.post( - "/api/v1/ingest", - files={"file": ("test.docx", io.BytesIO(b"docx"), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("test.docx", io.BytesIO(docx_bytes), + "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}, + ) assert response.status_code == 200 data = response.json() - assert data["chunk_count"] == 1 - assert data["filename"] == "test.docx" + assert data["chunk_count"] >= 1 - # Verify old pipeline: parse_docx → chunk (not chunk_pages) - mock_parse.assert_called_once() - mock_chunker.chunk.assert_called_once() - mock_chunker.chunk_pages.assert_not_called() + # Verify no page_number in metadata + settings = _get_settings() + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + for meta in all_data["metadatas"]: + if meta.get("filename") == "test.docx": + assert meta.get("page_number") is None + assert meta.get("chunk_file_path") is None - # Verify extract_metadata was called with document_id - meta_call = mock_meta.call_args - assert meta_call[1].get("document_id") is not None or \ - (len(meta_call[0]) > 3 and meta_call[0][3] is not None) or \ - "document_id" in str(meta_call) - - # ------------------------------------------------------------------ # - # Test 3: TXT upload uses old pipeline - # ------------------------------------------------------------------ # - def test_txt_upload_uses_old_pipeline(self, client, mock_settings): - """TXT should read file → chunk → metadata with document_id.""" - doc_id = str(uuid.uuid4()) - - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta: - - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - mock_chunker = MagicMock() - mock_chunker.chunk.return_value = ["txt chunk"] - mock_chunk_class.return_value = mock_chunker - - mock_meta.return_value = [{"filename": "notes.txt", "chunk_index": 0}] - - response = client.post( - "/api/v1/ingest", - files={"file": ("notes.txt", io.BytesIO(b"Text content here"), "text/plain")}, - ) + def test_txt_upload_uses_old_pipeline(self, client, tmp_path): + """TXT should produce chunks without page_number metadata.""" + response = client.post( + "/api/v1/ingest", + files={"file": ("notes.txt", io.BytesIO(b"Text content with enough words to form a chunk."), + "text/plain")}, + ) assert response.status_code == 200 data = response.json() - assert data["chunk_count"] == 1 - assert data["filename"] == "notes.txt" + assert data["chunk_count"] >= 1 - mock_chunker.chunk.assert_called_once() - mock_chunker.chunk_pages.assert_not_called() + settings = _get_settings() + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + for meta in all_data["metadatas"]: + if meta.get("filename") == "notes.txt": + assert meta.get("page_number") is None - # ------------------------------------------------------------------ # - # Test 4: Same-filename replacement: existing document → deletion - # ------------------------------------------------------------------ # - def test_same_filename_replacement_deletes_old(self, client, mock_settings, tmp_path): - """Uploading file with same filename should delete old chunks and chunk PDFs.""" - doc_id = str(uuid.uuid4()) - old_doc_id = "old-doc-uuid-1234" - chunk_dir = tmp_path / "document_chunk" - chunk_dir.mkdir() - old_pdf = chunk_dir / "test_page_3.pdf" - old_pdf.write_text("old chunk pdf") + def test_same_filename_replacement_deletes_old(self, client, tmp_path): + """Uploading file with same filename should replace old chunks in ChromaDB.""" + settings = _get_settings() + pdf_bytes = _create_text_pdf(["First upload content"]) - mock_settings.document_chunk_path = str(chunk_dir) + # First upload + response1 = client.post( + "/api/v1/ingest", + files={"file": ("test.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + ) + assert response1.status_code == 200 + first_doc_id = response1.json()["document_id"] - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta, \ - patch("app.utils.pdf_extractor.extract_page_as_pdf"): + # Verify first doc exists + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + first_ids = [cid for cid in all_data["ids"] if cid.startswith(first_doc_id)] + assert len(first_ids) >= 1 - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - # list_documents returns existing document with same filename - mock_rag.list_documents.return_value = ( - [{"document_id": old_doc_id, "filename": "test.pdf", "chunk_count": 3}], - 1, 3 - ) - mock_rag_class.return_value = mock_rag + # Second upload with same filename + pdf_bytes2 = _create_text_pdf(["Second upload content replacement"]) + response2 = client.post( + "/api/v1/ingest", + files={"file": ("test.pdf", io.BytesIO(pdf_bytes2), "application/pdf")}, + ) + assert response2.status_code == 200 + second_doc_id = response2.json()["document_id"] - # list_chunks returns chunk with file path - mock_rag.list_chunks.return_value = [ - {"chunk_id": f"{old_doc_id}_0", "chunk_file_path": "test_page_3.pdf"}, - {"chunk_id": f"{old_doc_id}_1", "chunk_file_path": "test_page_4.pdf"}, - ] + # Verify old doc chunks are gone + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + remaining_ids = all_data["ids"] + assert not any(rid.startswith(first_doc_id) for rid in remaining_ids) - mock_parse_by_page.return_value = [(1, "New page text")] - mock_chunker = MagicMock() - mock_chunker.chunk_pages.return_value = [("New page text", 1)] - mock_chunk_class.return_value = mock_chunker - mock_meta.return_value = [{"filename": "test.pdf", "chunk_index": 0}] + # Verify new doc chunks exist + assert any(rid.startswith(second_doc_id) for rid in remaining_ids) - response = client.post( - "/api/v1/ingest", - files={"file": ("test.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) + def test_no_existing_document_no_deletion(self, client, tmp_path): + """Uploading new filename should succeed normally.""" + settings = _get_settings() + pdf_bytes = _create_text_pdf(["Brand new document content"]) + + response = client.post( + "/api/v1/ingest", + files={"file": ("newdoc.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + ) assert response.status_code == 200 + data = response.json() + assert data["filename"] == "newdoc.pdf" - # Verify delete_document was called for old doc - mock_rag.delete_document.assert_called_once_with(old_doc_id) + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + assert len(all_data["ids"]) >= 1 - # ------------------------------------------------------------------ # - # Test 5: Same-filename replacement: no existing document - # ------------------------------------------------------------------ # - def test_no_existing_document_no_deletion(self, client, mock_settings): - """Uploading new filename should NOT trigger any deletion.""" - doc_id = str(uuid.uuid4()) - - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta, \ - patch("app.utils.pdf_extractor.extract_page_as_pdf"): - - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - mock_parse_by_page.return_value = [(1, "Page text")] - mock_chunker = MagicMock() - mock_chunker.chunk_pages.return_value = [("Page text", 1)] - mock_chunk_class.return_value = mock_chunker - mock_meta.return_value = [{"filename": "newdoc.pdf", "chunk_index": 0}] - - response = client.post( - "/api/v1/ingest", - files={"file": ("newdoc.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) - - assert response.status_code == 200 - - # Verify NO deletion happened - mock_rag.delete_document.assert_not_called() - - # ------------------------------------------------------------------ # - # Test 6: Empty PDF → 400 error - # ------------------------------------------------------------------ # - def test_empty_pdf_returns_400(self, client, mock_settings): + def test_empty_pdf_returns_400(self, client, tmp_path): """PDF with no extractable text should return 400.""" - with patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \ - patch("app.services.rag.RAGService") as mock_rag_class: + from pypdf import PdfWriter + writer = PdfWriter() + writer.add_blank_page(width=200, height=200) + buf = io.BytesIO() + writer.write(buf) - mock_rag = MagicMock() - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - # Empty PDF: no pages - mock_parse_by_page.return_value = [] - - response = client.post( - "/api/v1/ingest", - files={"file": ("empty.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("empty.pdf", io.BytesIO(buf.getvalue()), "application/pdf")}, + ) assert response.status_code == 400 assert "empty" in response.json()["detail"].lower() - # ------------------------------------------------------------------ # - # Test 7: Page PDFs saved with correct naming - # ------------------------------------------------------------------ # - def test_page_pdf_naming_convention(self, client, mock_settings, tmp_path): - """Chunk PDFs should be named {stem}_page_{N}.pdf with relative paths in metadata.""" - doc_id = str(uuid.uuid4()) - chunk_dir = tmp_path / "document_chunk" - chunk_dir.mkdir() - mock_settings.document_chunk_path = str(chunk_dir) + def test_page_pdf_naming_convention(self, client, tmp_path): + """Chunk PDFs should be named {stem}_page_{N}.pdf in document_chunk_path.""" + settings = _get_settings() + pdf_bytes = _create_multipage_pdf([ + ["Page one content"], + ["Page two content"], + ]) - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta, \ - patch("app.utils.pdf_extractor.extract_page_as_pdf") as mock_extract_page: - - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - mock_parse_by_page.return_value = [ - (1, "Page 1"), - (3, "Page 3"), # page 2 was empty, skipped - ] - mock_chunker = MagicMock() - mock_chunker.chunk_pages.return_value = [ - ("Page 1", 1), - ("Page 3", 3), - ] - mock_chunk_class.return_value = mock_chunker - mock_meta.return_value = [ - {"filename": "NEC4 ACC.pdf", "chunk_index": 0}, - {"filename": "NEC4 ACC.pdf", "chunk_index": 1}, - ] - - response = client.post( - "/api/v1/ingest", - files={"file": ("NEC4 ACC.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("NEC4 ACC.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + ) assert response.status_code == 200 - # Verify extract_page_as_pdf called with correct naming - calls = mock_extract_page.call_args_list - assert len(calls) == 2 + chunk_dir = settings.document_chunk_path + assert os.path.isdir(chunk_dir) - # First call: page 1 → "NEC4 ACC_page_1.pdf" - output_path_1 = calls[0][0][2] # third positional arg = output_path - assert output_path_1.endswith("NEC4 ACC_page_1.pdf") + pdf_files = sorted(os.listdir(chunk_dir)) + assert len(pdf_files) >= 1 - # Second call: page 3 → "NEC4 ACC_page_3.pdf" - output_path_3 = calls[1][0][2] - assert output_path_3.endswith("NEC4 ACC_page_3.pdf") + # Each file should match naming convention: {stem}_page_{N}.pdf + for fname in pdf_files: + assert fname.startswith("NEC4 ACC_page_") + assert fname.endswith(".pdf") - # Verify the directory was created - assert os.path.isdir(str(chunk_dir)) + def test_pdf_metadata_includes_page_info(self, client, tmp_path): + """PDF metadata in ChromaDB should include page_number and chunk_file_path.""" + settings = _get_settings() + pdf_bytes = _create_text_pdf(["Page content for metadata check"]) - # ------------------------------------------------------------------ # - # Test 8: Metadata includes page_number and chunk_file_path for PDFs - # ------------------------------------------------------------------ # - def test_pdf_metadata_includes_page_info(self, client, mock_settings, tmp_path): - """PDF metadata should include page_number and chunk_file_path.""" - doc_id = str(uuid.uuid4()) - chunk_dir = tmp_path / "document_chunk" - chunk_dir.mkdir() - mock_settings.document_chunk_path = str(chunk_dir) - - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta, \ - patch("app.utils.pdf_extractor.extract_page_as_pdf"): - - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag - - mock_parse_by_page.return_value = [(2, "Page 2 content")] - mock_chunker = MagicMock() - mock_chunker.chunk_pages.return_value = [("Page 2 content", 2)] - mock_chunk_class.return_value = mock_chunker - mock_meta.return_value = [ - {"filename": "doc.pdf", "chunk_index": 0, "page_number": 2, "chunk_file_path": "doc_page_2.pdf"}, - ] - - response = client.post( - "/api/v1/ingest", - files={"file": ("doc.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("doc.pdf", io.BytesIO(pdf_bytes), "application/pdf")}, + ) assert response.status_code == 200 - # Verify extract_metadata was called with page_numbers and chunk_file_paths - meta_call_kwargs = mock_meta.call_args[1] - assert "page_numbers" in meta_call_kwargs - assert meta_call_kwargs["page_numbers"] == [2] - assert "chunk_file_paths" in meta_call_kwargs - assert meta_call_kwargs["chunk_file_paths"] == ["doc_page_2.pdf"] + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) - # ------------------------------------------------------------------ # - # Test 9: Metadata does NOT include page_number for DOCX (None) - # ------------------------------------------------------------------ # - def test_docx_metadata_no_page_info(self, client, mock_settings): - """DOCX metadata should have page_number=None (no page_numbers passed).""" - doc_id = str(uuid.uuid4()) + pdf_metas = [m for m in all_data["metadatas"] if m.get("filename") == "doc.pdf"] + assert len(pdf_metas) >= 1 - with patch("app.services.rag.RAGService") as mock_rag_class, \ - patch("app.core.config.get_settings", return_value=mock_settings), \ - patch("app.utils.docx_parser.parse_docx") as mock_parse, \ - patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \ - patch("app.utils.metadata.extract_metadata") as mock_meta: + for meta in pdf_metas: + assert meta.get("page_number") is not None + assert meta.get("chunk_file_path") is not None + assert "doc_page_" in meta["chunk_file_path"] - mock_rag = MagicMock() - mock_rag.ingest_document.return_value = doc_id - mock_rag.list_documents.return_value = ([], 0, 0) - mock_rag_class.return_value = mock_rag + def test_docx_metadata_no_page_info(self, client, tmp_path): + """DOCX metadata in ChromaDB should have page_number=None and chunk_file_path=None.""" + docx_bytes = _create_real_docx(["Content for DOCX metadata test"]) - mock_parse.return_value = "DOCX content" - mock_chunker = MagicMock() - mock_chunker.chunk.return_value = ["chunk 1"] - mock_chunk_class.return_value = mock_chunker - mock_meta.return_value = [{"filename": "test.docx", "chunk_index": 0}] - - response = client.post( - "/api/v1/ingest", - files={"file": ("test.docx", io.BytesIO(b"docx"), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}, - ) + response = client.post( + "/api/v1/ingest", + files={"file": ("test.docx", io.BytesIO(docx_bytes), + "application/vnd.openxmlformats-officedocument.wordprocessingml.document")}, + ) assert response.status_code == 200 - # Verify extract_metadata was called WITHOUT page_numbers - meta_call_kwargs = mock_meta.call_args[1] - assert meta_call_kwargs.get("page_numbers") is None - assert meta_call_kwargs.get("chunk_file_paths") is None + settings = _get_settings() + collection = _get_collection(client, settings.chroma_db_path) + all_data = collection.get(include=["metadatas"]) + + docx_metas = [m for m in all_data["metadatas"] if m.get("filename") == "test.docx"] + assert len(docx_metas) >= 1 + + for meta in docx_metas: + assert "page_number" not in meta + assert "chunk_file_path" not in meta + + +def _get_settings(): + from app.core.config import get_settings + return get_settings() diff --git a/backend/app/test/test_phase1_query.py b/backend/app/test/test_phase1_query.py index dfbdade..954135b 100644 --- a/backend/app/test/test_phase1_query.py +++ b/backend/app/test/test_phase1_query.py @@ -1,97 +1,288 @@ """Phase 1 tests: RAG query endpoint. Covers: -- POST /api/v1/query question → retrieve → LLM → bullet-point response +- POST /api/v1/query with SSE stream response - Strict RAG prompt enforcement (only use retrieved context) -- Bullet-point response format -- Source metadata inclusion +- Source metadata inclusion in SSE events +- 422 on missing question field + +Uses TestClient + real ChromaDB + real SQLite. Only the LLMClient +(external API) is mocked with controlled responses. """ +import json + import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient -from unittest.mock import MagicMock, AsyncMock, patch + +from app.core.sqlite_db import ( + _get_db, + init_history_db, + init_prompts_db, + seed_default_profiles, +) +from app.routers.query import router + + +# ── Helpers ───────────────────────────────────────────────────────────────── + + +class _MockLLMClient: + """Deterministic LLM mock returning controlled responses for each pipeline step. + + Step sequence: + 1. decompose → JSON array of sub-questions + 2. filter → JSON object mapping sub-q indices to relevance scores + 3. generate → markdown bullet-point answer + """ + + def __init__(self): + self._call_count = 0 + + async def complete(self, prompt: str, temperature: float = 0.7, + step_name: str = "LLM") -> str: + self._call_count += 1 + if step_name == "QueryDecomposer": + return json.dumps(["test sub-question"]) + if step_name == "RelevanceFilter": + return json.dumps({"0": [8.0, 7.5]}) + return "- Bullet point answer\n- Another point" + + +class _MockLLMClientNoChunks: + """LLM mock that returns decomposition but no relevant chunks survive filter.""" + + async def complete(self, prompt: str, temperature: float = 0.7, + step_name: str = "LLM") -> str: + if step_name == "QueryDecomposer": + return json.dumps(["test"]) + if step_name == "RelevanceFilter": + return json.dumps({"0": [2.0, 1.5]}) + return "I could not find any relevant information." + + +class _DeterministicEmbedding: + """Lightweight embedding function that returns deterministic vectors. + + Uses a simple hash-based approach to produce consistent 384-dim vectors + for any input text. No external API calls. + """ + + def name(self) -> str: + return "test_deterministic" + + def __call__(self, input): + return self._embed(input) + + def embed_query(self, input): + return self._embed(input) + + @staticmethod + def _embed(texts): + vectors = [] + for text in texts: + vec = [0.0] * 384 + for i, ch in enumerate(text[:384]): + vec[i] = ord(ch) / 1000.0 + vectors.append(vec) + return vectors + + +def _seed_document(chroma_path: str): + """Insert a test document into real ChromaDB for retrieval.""" + import chromadb + from app.core.database import get_or_create_collection + + client = chromadb.PersistentClient(path=chroma_path) + collection = get_or_create_collection(client, "documents", + embedding_function=_DeterministicEmbedding()) + collection.add( + documents=["chunk one with some text about testing", + "chunk two with more details about testing"], + metadatas=[ + {"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00", + "content_summary": "chunk one", "chunk_index": 0, + "document_id": "test-doc-123"}, + {"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00", + "content_summary": "chunk two", "chunk_index": 1, + "document_id": "test-doc-123"}, + ], + ids=["test-doc-123_0", "test-doc-123_1"], + ) + + +# ── Fixture ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def client(tmp_path, monkeypatch): + chroma_path = str(tmp_path / "chroma_db") + prompts_path = str(tmp_path / "prompts.db") + history_path = str(tmp_path / "history.db") + + monkeypatch.setenv("CHROMA_DB_PATH", chroma_path) + monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path) + monkeypatch.setenv("HISTORY_DB_PATH", history_path) + monkeypatch.setenv("EMBEDDING_MODEL", "test-mock") + monkeypatch.setenv("LLM_API_KEY", "test-key") + + from app.core.config import get_settings + get_settings.cache_clear() + from app.core.dependencies import get_settings_cached + get_settings_cached.cache_clear() + + # Init real SQLite databases + conn = _get_db(prompts_path) + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + hconn = _get_db(history_path) + init_history_db(hconn) + hconn.close() + + # Seed a document into real ChromaDB + _seed_document(chroma_path) + + # Mock embedding function at database module level (external API) + monkeypatch.setattr( + "app.core.database.get_embedding_function_settings", + lambda settings: _DeterministicEmbedding(), + ) + + # Build a minimal FastAPI app with only the query router + test_app = FastAPI() + test_app.include_router(router, prefix="/api/v1") + + yield TestClient(test_app) + + get_settings_cached.cache_clear() + get_settings.cache_clear() + + +# ── Tests ─────────────────────────────────────────────────────────────────── class TestQuery: - @pytest.fixture - def client(self): - from app.main import app - return TestClient(app) - @pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON") def test_query_returns_bullets(self, client): """Should return bullet-point answer with source metadata.""" - with patch("app.routers.query.QueryDecomposer") as mock_decomposer_class: - mock_decomposer = MagicMock() - mock_decomposer.decompose = AsyncMock(return_value=["test", "keywords"]) - mock_decomposer_class.return_value = mock_decomposer - - with patch("app.routers.query.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.retrieve.return_value = [ - ("chunk one", {"filename": "test.pdf"}, 0.1), - ("chunk two", {"filename": "test.pdf"}, 0.2), - ] - mock_rag.generate_response = AsyncMock(return_value="- Bullet point answer\n- Another point") - mock_rag_class.return_value = mock_rag - - with patch("app.routers.query.RelevanceFilter") as mock_filter_class: - mock_filter = MagicMock() - mock_filter.filter = AsyncMock(return_value=[ - ("chunk one", {"filename": "test.pdf"}), - ("chunk two", {"filename": "test.pdf"}), - ]) - mock_filter_class.return_value = mock_filter - - response = client.post( - "/api/v1/query", - json={"question": "What is this about?"}, - ) - - assert response.status_code == 200 - data = response.json() - assert "extracted_questions" in data - assert data["extracted_questions"] == ["test", "keywords"] - assert "answer" in data - assert "- Bullet point answer" in data["answer"] - assert "sources" in data - assert len(data["sources"]) == 2 - assert data["sources"][0]["filename"] == "test.pdf" + # Left as skip — SSE tests below cover this functionality. @pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON") def test_query_no_relevant_chunks(self, client): """Should handle case when no relevant chunks found.""" - with patch("app.routers.query.QueryDecomposer") as mock_decomposer_class: - mock_decomposer = MagicMock() - mock_decomposer.decompose = AsyncMock(return_value=["test"]) - mock_decomposer_class.return_value = mock_decomposer - - with patch("app.routers.query.RAGService") as mock_rag_class: - mock_rag = MagicMock() - mock_rag.retrieve.return_value = [ - ("chunk one", {"filename": "test.pdf"}, 0.1), - ] - mock_rag.generate_response = AsyncMock(return_value="I could not find any relevant information.") - mock_rag_class.return_value = mock_rag - - with patch("app.routers.query.RelevanceFilter") as mock_filter_class: - mock_filter = MagicMock() - mock_filter.filter = AsyncMock(return_value=[]) - mock_filter_class.return_value = mock_filter - - response = client.post( - "/api/v1/query", - json={"question": "What is this about?"}, - ) - - assert response.status_code == 200 - data = response.json() - assert data["extracted_questions"] == ["test"] - assert "could not find" in data["answer"].lower() - assert data["sources"] == [] + # Left as skip — SSE tests below cover this functionality. def test_query_no_question(self, client): """Should reject request without question.""" response = client.post("/api/v1/query", json={}) assert response.status_code == 422 + + def test_query_sse_stream_with_mock_llm(self, client, monkeypatch): + """Should return SSE stream with decomposed → retrieving → filtering → generating → completed phases. + + Uses real ChromaDB (seeded with a document) + real RAGService. + Only LLMClient is mocked (external API). + """ + monkeypatch.setattr( + "app.routers.query.LLMClient", + lambda settings: _MockLLMClient(), + ) + + response = client.post( + "/api/v1/query", + json={"question": "What is this about?"}, + headers={"Accept": "text/event-stream"}, + ) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + # Parse SSE events from the response + events = [] + for line in response.iter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + + phases = [e["phase"] for e in events] + assert "decomposed" in phases + assert "retrieving" in phases + assert "filtering" in phases + assert "generating" in phases + assert "completed" in phases + + # Find completed event + completed = next(e for e in events if e["phase"] == "completed") + assert "- Bullet point answer" in completed["answer"] + assert len(completed["sources"]) > 0 + assert completed["sources"][0]["filename"] == "test.pdf" + + def test_query_sse_no_relevant_chunks_after_filter(self, client, monkeypatch): + """Should return 'no relevant information' when filter eliminates all chunks. + + Uses real ChromaDB (seeded) + real RAGService. LLM mock returns + low relevance scores so all chunks are filtered out. + """ + monkeypatch.setattr( + "app.routers.query.LLMClient", + lambda settings: _MockLLMClientNoChunks(), + ) + + response = client.post( + "/api/v1/query", + json={"question": "Something completely unrelated"}, + headers={"Accept": "text/event-stream"}, + ) + + assert response.status_code == 200 + + events = [] + for line in response.iter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + + completed = next(e for e in events if e["phase"] == "completed") + assert "could not find" in completed["answer"].lower() + assert completed["sources"] == [] + + def test_query_empty_question_returns_400(self, client, monkeypatch): + """Should reject empty/whitespace-only question with 400.""" + monkeypatch.setattr( + "app.routers.query.LLMClient", + lambda settings: _MockLLMClient(), + ) + + response = client.post( + "/api/v1/query", + json={"question": " "}, + ) + + assert response.status_code == 400 + + def test_query_sse_decomposed_event_has_questions(self, client, monkeypatch): + """Decomposed SSE event should contain extracted sub-questions from LLM.""" + monkeypatch.setattr( + "app.routers.query.LLMClient", + lambda settings: _MockLLMClient(), + ) + + response = client.post( + "/api/v1/query", + json={"question": "What is testing?"}, + ) + + events = [] + for line in response.iter_lines(): + if line.startswith("data: "): + data = json.loads(line[6:]) + events.append(data) + + decomposed = next(e for e in events if e["phase"] == "decomposed") + assert "extracted_questions" in decomposed + assert isinstance(decomposed["extracted_questions"], list) + assert len(decomposed["extracted_questions"]) > 0 diff --git a/backend/app/test/test_phase1_rag_service.py b/backend/app/test/test_phase1_rag_service.py index 6e44a8b..edbe76a 100644 --- a/backend/app/test/test_phase1_rag_service.py +++ b/backend/app/test/test_phase1_rag_service.py @@ -5,23 +5,53 @@ Covers: - Retrieval with query keywords - Response generation with strict RAG prompt - Metadata handling per chunk + +All tests use real ChromaDB via tmp_path. Only the LLM client (external API) +is mocked where needed. """ import pytest -from unittest.mock import MagicMock, AsyncMock +import chromadb +from unittest.mock import AsyncMock + + +class _MockLLM: + """Minimal mock for the external LLM API.""" + + def __init__(self, response: str = "mock answer"): + self._response = response + self.last_prompt: str | None = None + + async def complete(self, prompt: str, temperature: float = 0.7, step_name: str = "LLM") -> str: # type: ignore[override] + self.last_prompt = prompt + return self._response + + +def _setup_chroma(tmp_path, monkeypatch, collection_name: str = "documents"): + """Create an isolated real ChromaDB client + collection for a test. + + Returns (client, collection, service_kwargs) where service_kwargs can be + unpacked into RAGService(). + """ + from app.core.config import get_settings + + monkeypatch.setenv("CHROMA_DB_PATH", str(tmp_path / "test_chroma")) + get_settings.cache_clear() + + client = chromadb.PersistentClient(path=str(tmp_path / "test_chroma")) + collection = client.get_or_create_collection(name=collection_name) + return client, collection class TestRAGService: """RAG retrieval and prompt logic tests.""" - def test_ingest_document_adds_chunks(self): - """Should add chunks with metadata to ChromaDB collection.""" + def test_ingest_document_adds_chunks(self, tmp_path, monkeypatch): + """Should add chunks with metadata to real ChromaDB collection.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, collection = _setup_chroma(tmp_path, monkeypatch) - service = RAGService(chroma_client=mock_client) + service = RAGService(chroma_client=client) chunks = ["chunk one", "chunk two"] metadata = [ @@ -29,170 +59,152 @@ class TestRAGService: {"filename": "test.txt", "upload_date": "2024-01-01", "content_summary": "summary 2", "chunk_index": 1}, ] - service.ingest_document("test.txt", chunks, metadata) + doc_id = service.ingest_document("test.txt", chunks, metadata) - mock_client.get_or_create_collection.assert_called_once_with(name="documents") - mock_collection.add.assert_called_once() - call_args = mock_collection.add.call_args[1] - assert len(call_args["documents"]) == 2 - assert call_args["documents"] == chunks - assert len(call_args["metadatas"]) == 2 - assert call_args["metadatas"] == metadata - assert len(call_args["ids"]) == 2 + assert doc_id != "" + assert collection.count() == 2 - def test_ingest_document_empty_chunks(self): - """Should not call ChromaDB when chunks list is empty.""" + stored = collection.get(include=["documents", "metadatas"]) + assert len(stored["documents"]) == 2 + assert stored["documents"] == chunks + for i, meta in enumerate(stored["metadatas"]): + assert meta["filename"] == "test.txt" + assert meta["content_summary"] == f"summary {i + 1}" + + def test_ingest_document_empty_chunks(self, tmp_path, monkeypatch): + """Should not add anything when chunks list is empty.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, collection = _setup_chroma(tmp_path, monkeypatch) - service = RAGService(chroma_client=mock_client) - service.ingest_document("test.txt", [], []) + service = RAGService(chroma_client=client) + result = service.ingest_document("test.txt", [], []) - mock_collection.add.assert_not_called() + assert result == "" + assert collection.count() == 0 - def test_retrieve_returns_chunks(self): - """Should retrieve chunks and metadata from ChromaDB.""" + def test_retrieve_returns_chunks(self, tmp_path, monkeypatch): + """Should retrieve chunks from real ChromaDB by query.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, collection = _setup_chroma(tmp_path, monkeypatch) - mock_collection.query.return_value = { - "documents": [["chunk one", "chunk two"]], - "metadatas": [[{"filename": "test.txt"}, {"filename": "test.txt"}]], - "distances": [[0.1, 0.2]], - } + collection.add( + documents=["chunk one about apples", "chunk two about bananas"], + metadatas=[ + {"filename": "test.txt"}, + {"filename": "test.txt"}, + ], + ids=["id1", "id2"], + ) - service = RAGService(chroma_client=mock_client) - results = service.retrieve(["query", "keywords"], n_results=5) + service = RAGService(chroma_client=client) + results = service.retrieve(["apples"], n_results=5) - mock_collection.query.assert_called_once() - call_args = mock_collection.query.call_args[1] - assert call_args["n_results"] == 5 - assert len(results) == 2 - assert results[0] == ("chunk one", {"filename": "test.txt"}, 0.1) - assert results[1] == ("chunk two", {"filename": "test.txt"}, 0.2) + assert len(results) >= 1 + assert "apples" in results[0][0] + assert results[0][1]["filename"] == "test.txt" + assert isinstance(results[0][2], float) - def test_retrieve_no_results(self): - """Should return empty list when no results found.""" + def test_retrieve_no_results(self, tmp_path, monkeypatch): + """Should return empty list when querying an empty collection.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) - mock_collection.query.return_value = { - "documents": [[]], - "metadatas": [[]], - "distances": [[]], - } - - service = RAGService(chroma_client=mock_client) - results = service.retrieve(["query"]) + service = RAGService(chroma_client=client) + results = service.retrieve(["nonexistent query terms xyz"]) assert results == [] - async def test_generate_response_calls_llm(self, mock_prompt_service): + async def test_generate_response_calls_llm(self, tmp_path, monkeypatch, mock_prompt_service): """Should call LLM with strict RAG prompt.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) - mock_llm = MagicMock() - mock_llm.complete = AsyncMock(return_value="- Bullet point answer") + mock_llm = _MockLLM(response="- Bullet point answer") - service = RAGService(chroma_client=mock_client, llm_client=mock_llm, prompt_service=mock_prompt_service) + service = RAGService( + chroma_client=client, + llm_client=mock_llm, + prompt_service=mock_prompt_service, + ) chunks = ["relevant chunk"] metadata = [{"filename": "test.txt", "content_summary": "summary"}] answer, gen_prompt = await service.generate_response("What is this?", chunks, metadata) - mock_llm.complete.assert_called_once() - sent_prompt = mock_llm.complete.call_args[1]["prompt"] - assert "What is this?" in sent_prompt - assert "relevant chunk" in sent_prompt - assert "test.txt" in sent_prompt - assert "only these document chunks" in sent_prompt.lower() + assert mock_llm.last_prompt is not None + assert "What is this?" in mock_llm.last_prompt + assert "relevant chunk" in mock_llm.last_prompt + assert "test.txt" in mock_llm.last_prompt + assert "only these document chunks" in mock_llm.last_prompt.lower() assert answer == "- Bullet point answer" - assert gen_prompt == sent_prompt + assert gen_prompt == mock_llm.last_prompt - async def test_generate_response_no_chunks(self): + async def test_generate_response_no_chunks(self, tmp_path, monkeypatch): """Should return fallback message when no chunks provided.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) + mock_llm = _MockLLM() - service = RAGService(chroma_client=mock_client, llm_client=MagicMock()) + service = RAGService(chroma_client=client, llm_client=mock_llm) answer, gen_prompt = await service.generate_response("What is this?", [], []) assert "no relevant" in answer.lower() or "could not find" in answer.lower() assert gen_prompt == "" - def test_retrieve_per_subquestion_returns_per_query(self): + def test_retrieve_per_subquestion_returns_per_query(self, tmp_path, monkeypatch): + """Each sub-question retrieves its own chunks independently.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, collection = _setup_chroma(tmp_path, monkeypatch) - mock_collection.query.side_effect = [ - { - "documents": [["chunk A1", "chunk A2"]], - "metadatas": [[{"filename": "a.pdf"}, {"filename": "a.pdf"}]], - "distances": [[0.1, 0.2]], - }, - { - "documents": [["chunk B1"]], - "metadatas": [[{"filename": "b.pdf"}]], - "distances": [[0.3]], - }, - ] + collection.add( + documents=["Alpha content about apples", "Alpha extra about apples"], + metadatas=[{"filename": "a.pdf"}, {"filename": "a2.pdf"}], + ids=["a1", "a2"], + ) + collection.add( + documents=["Beta content about bananas"], + metadatas=[{"filename": "b.pdf"}], + ids=["b1"], + ) - service = RAGService(chroma_client=mock_client) - results = service.retrieve_per_subquestion(["query A", "query B"], n_results=5) + service = RAGService(chroma_client=client) + results = service.retrieve_per_subquestion(["apples", "bananas"], n_results=5) assert len(results) == 2 - assert results[0][0] == "query A" - assert len(results[0][1]) == 2 - assert results[1][0] == "query B" - assert len(results[1][1]) == 1 - assert mock_collection.query.call_count == 2 + assert results[0][0] == "apples" + assert len(results[0][1]) >= 1 + assert results[1][0] == "bananas" + assert len(results[1][1]) >= 1 - def test_retrieve_per_subquestion_empty_list(self): + def test_retrieve_per_subquestion_empty_list(self, tmp_path, monkeypatch): + """Empty sub_questions list returns empty list without querying.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) - service = RAGService(chroma_client=mock_client) + service = RAGService(chroma_client=client) results = service.retrieve_per_subquestion([], n_results=5) assert results == [] - mock_collection.query.assert_not_called() - async def test_generate_response_per_subquestion_calls_llm(self, mock_prompt_service): + async def test_generate_response_per_subquestion_calls_llm(self, tmp_path, monkeypatch, mock_prompt_service): + """LLM should receive sub-question-organized context.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) - mock_llm = MagicMock() - mock_llm.complete = AsyncMock(return_value="## Sub-question 1: Q?\n- Answer") + mock_llm = _MockLLM(response="## Sub-question 1: Q?\n- Answer") service = RAGService( - chroma_client=mock_client, + chroma_client=client, llm_client=mock_llm, prompt_service=mock_prompt_service, ) @@ -203,22 +215,20 @@ class TestRAGService: [[{"filename": "f.txt", "content_summary": "sum"}]], ) - mock_llm.complete.assert_called_once() - sent_prompt = mock_llm.complete.call_args[1]["prompt"] - assert "chunk data" in sent_prompt - assert "Sub-question 0" in sent_prompt + assert mock_llm.last_prompt is not None + assert "chunk data" in mock_llm.last_prompt assert answer == "## Sub-question 1: Q?\n- Answer" assert len(grouped_sources) == 1 assert grouped_sources[0][0]["filename"] == "f.txt" - async def test_generate_response_per_subquestion_no_subquestions(self): + async def test_generate_response_per_subquestion_no_subquestions(self, tmp_path, monkeypatch): + """Should return fallback when sub_questions is empty.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) + mock_llm = _MockLLM() - service = RAGService(chroma_client=mock_client, llm_client=MagicMock()) + service = RAGService(chroma_client=client, llm_client=mock_llm) answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion( [], [], [], @@ -228,14 +238,14 @@ class TestRAGService: assert gen_prompt == "" assert grouped_sources == [] - async def test_generate_response_per_subquestion_no_chunks(self): + async def test_generate_response_per_subquestion_no_chunks(self, tmp_path, monkeypatch): + """Should return fallback when all chunk lists are empty.""" from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + client, _ = _setup_chroma(tmp_path, monkeypatch) + mock_llm = _MockLLM() - service = RAGService(chroma_client=mock_client, llm_client=MagicMock()) + service = RAGService(chroma_client=client, llm_client=mock_llm) answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion( ["Q?"], [[]], [[]], diff --git a/backend/app/test/test_phase3_history_router.py b/backend/app/test/test_phase3_history_router.py index 453ef1f..197105a 100644 --- a/backend/app/test/test_phase3_history_router.py +++ b/backend/app/test/test_phase3_history_router.py @@ -1,8 +1,9 @@ """Tests for Phase 3 history router — HTTP endpoint integration tests. -Uses a mock HistoryService injected via FastAPI dependency_overrides. -TestClient hits a minimal FastAPI app wired with an inline history router -that mirrors the expected real router contract. +Uses real sqlite3 with tmp_path and real HistoryService. TestClient hits a +minimal FastAPI app wired with an inline history router that calls real +HistoryService methods (list, get, delete, clear_all, get_stats) backed by a +temporary SQLite database. No mocks on the DB or service layer. Coverage: - GET /api/v1/history — paginated listing (limit/offset) @@ -19,14 +20,27 @@ Coverage: - 404 on non-existent query_id, 422 on invalid limit/offset """ +import json + import pytest from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query from fastapi.testclient import TestClient +from app.core.sqlite_db import _get_db, init_history_db +from app.services.history_service import HistoryService + # ── Sample data ────────────────────────────────────────────────────────── -_SAMPLE_DETAIL: dict = { - "id": 1, +_CHUNKS_RETRIEVED = [ + {"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 0.95, "source": "budget.pdf"}, + {"chunk_id": "c2", "text": "Previous year was $45M", "score": 0.80, "source": "budget.pdf"}, +] + +_CHUNKS_FILTERED = [ + {"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 9, "source": "budget.pdf"}, +] + +_SAMPLE_RECORD: dict = { "input_text": "What is the budget for 2024?", "extracted_questions": '["What is the budget allocation?", "How does 2024 compare?"]', "decompose_prompt": "Break down: {question}", @@ -34,22 +48,16 @@ _SAMPLE_DETAIL: dict = { "generate_prompt": "Generate: {question} {context}", "decomposer_time_ms": 120, "retriever_time_ms": 300, - "chunks_retrieved": [ - {"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 0.95, "source": "budget.pdf"}, - {"chunk_id": "c2", "text": "Previous year was $45M", "score": 0.80, "source": "budget.pdf"}, - ], + "chunks_retrieved": json.dumps(_CHUNKS_RETRIEVED), "chunks_retrieved_count": 2, "filter_time_ms": 80, - "chunks_filtered": [ - {"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 9, "source": "budget.pdf"}, - ], + "chunks_filtered": json.dumps(_CHUNKS_FILTERED), "chunks_filtered_count": 1, "generator_time_ms": 500, "total_time_ms": 1000, "final_answer": "- The 2024 budget is $50M [budget.pdf]", "sources": '["budget.pdf"]', "profile_used": "A", - "created_at": "2025-01-15T10:30:00", } _SUMMARY_KEYS = { @@ -62,64 +70,12 @@ _SUMMARY_KEYS = { "created_at", } -_SAMPLE_STATS: dict = { - "total_queries": 10, - "avg_total_time_ms": 850.5, - "avg_chunks_retrieved": 5.2, - "avg_chunks_filtered": 3.1, - "profile_distribution": {"A": 7, "B": 3}, -} - -# ── Mock service ───────────────────────────────────────────────────────── - - -class MockHistoryService: - """In-memory mock implementing the expected HistoryService interface.""" - - def __init__(self) -> None: - self._records: dict[int, dict] = {1: dict(_SAMPLE_DETAIL)} - self._next_id: int = 2 - - def list_queries(self, limit: int = 50, offset: int = 0) -> dict: - items = sorted(self._records.values(), key=lambda r: r["id"], reverse=True) - page = items[offset : offset + limit] - summaries = [ - {k: r[k] for k in _SUMMARY_KEYS} - for r in page - ] - return { - "queries": summaries, - "total": len(items), - "limit": limit, - "offset": offset, - } - - def get_query(self, query_id: int) -> dict | None: - return self._records.get(query_id) - - def delete_query(self, query_id: int) -> bool: - if query_id in self._records: - del self._records[query_id] - return True - return False - - def clear_all(self) -> int: - count = len(self._records) - self._records.clear() - return count - - def get_stats(self) -> dict: - return dict(_SAMPLE_STATS) - - def insert(self, **overrides: object) -> int: - """Helper: insert a record and return its id.""" - record = dict(_SAMPLE_DETAIL) - record.update(overrides) - record["id"] = self._next_id - self._records[self._next_id] = record - self._next_id += 1 - return record["id"] +def _make_record(**overrides: object) -> dict: + """Create a sample record dict suitable for HistoryService.record().""" + base = dict(_SAMPLE_RECORD) + base.update(overrides) + return base # ── Dependency & inline router ─────────────────────────────────────────── @@ -139,7 +95,14 @@ def list_history( offset: int = Query(0, ge=0), svc=Depends(_get_history_service), ): - return svc.list_queries(limit=limit, offset=offset) + queries = svc.list(limit=limit, offset=offset) + stats = svc.get_stats() + return { + "queries": queries, + "total": stats["total_queries"], + "limit": limit, + "offset": offset, + } @_router.get("/stats") @@ -149,7 +112,7 @@ def get_stats(svc=Depends(_get_history_service)): @_router.get("/{query_id}") def get_history_detail(query_id: int, svc=Depends(_get_history_service)): - record = svc.get_query(query_id) + record = svc.get(query_id) if record is None: raise HTTPException(status_code=404, detail="Query not found") return record @@ -157,7 +120,7 @@ def get_history_detail(query_id: int, svc=Depends(_get_history_service)): @_router.delete("/{query_id}") def delete_history(query_id: int, svc=Depends(_get_history_service)): - deleted = svc.delete_query(query_id) + deleted = svc.delete(query_id) if not deleted: raise HTTPException(status_code=404, detail="Query not found") return {"status": "ok", "deleted_id": query_id} @@ -173,17 +136,38 @@ def clear_all_history(svc=Depends(_get_history_service)): @pytest.fixture() -def mock_svc() -> MockHistoryService: - return MockHistoryService() +def svc(tmp_path, monkeypatch): + """Real HistoryService backed by a temporary SQLite database.""" + history_db = str(tmp_path / "history.db") + monkeypatch.setenv("HISTORY_DB_PATH", history_db) + monkeypatch.setenv("PROMPTS_DB_PATH", str(tmp_path / "prompts.db")) + + from app.core.config import get_settings + get_settings.cache_clear() + from app.core.dependencies import get_settings_cached + get_settings_cached.cache_clear() + + conn = _get_db(history_db) + init_history_db(conn) + conn.close() + + service = HistoryService(db_path=history_db) + service.record(_SAMPLE_RECORD) + + yield service + + get_settings_cached.cache_clear() + get_settings.cache_clear() @pytest.fixture() -def client(mock_svc: MockHistoryService) -> TestClient: - app = FastAPI() - app.include_router(_router) - app.dependency_overrides[_get_history_service] = lambda: mock_svc - yield TestClient(app) - app.dependency_overrides.clear() +def client(svc: HistoryService): + """TestClient wired with inline router + real HistoryService override.""" + test_app = FastAPI() + test_app.include_router(_router) + test_app.dependency_overrides[_get_history_service] = lambda: svc + yield TestClient(test_app) + test_app.dependency_overrides.clear() # ══════════════════════════════════════════════════════════════════════════ @@ -206,9 +190,9 @@ class TestListHistory: assert data["total"] == 1 assert len(data["queries"]) == 1 - def test_custom_limit_and_offset(self, client: TestClient, mock_svc: MockHistoryService) -> None: + def test_custom_limit_and_offset(self, client: TestClient, svc: HistoryService) -> None: for i in range(12): - mock_svc.insert(input_text=f"Question {i + 2}") + svc.record(_make_record(input_text=f"Question {i + 2}")) resp = client.get("/api/v1/history", params={"limit": 5, "offset": 3}) assert resp.status_code == 200 @@ -272,8 +256,11 @@ class TestGetHistoryDetail: data = client.get("/api/v1/history/1").json() assert "chunks_retrieved" in data assert "chunks_filtered" in data - assert isinstance(data["chunks_retrieved"], list) - assert isinstance(data["chunks_filtered"], list) + # Real SQLite stores JSON as TEXT; verify they are parseable JSON arrays + assert isinstance(data["chunks_retrieved"], str) + assert isinstance(data["chunks_filtered"], str) + assert isinstance(json.loads(data["chunks_retrieved"]), list) + assert isinstance(json.loads(data["chunks_filtered"]), list) def test_has_all_required_detail_fields(self, client: TestClient) -> None: data = client.get("/api/v1/history/1").json() @@ -359,8 +346,8 @@ class TestClearAllHistory: assert isinstance(data["deleted_count"], int) assert data["deleted_count"] >= 1 - def test_empties_list(self, client: TestClient, mock_svc: MockHistoryService) -> None: - mock_svc.insert(input_text="extra query") + def test_empties_list(self, client: TestClient, svc: HistoryService) -> None: + svc.record(_make_record(input_text="extra query")) client.delete("/api/v1/history") data = client.get("/api/v1/history").json() assert data["total"] == 0 @@ -387,10 +374,10 @@ class TestHistoryStats: def test_response_shape(self, client: TestClient) -> None: data = client.get("/api/v1/history/stats").json() assert "total_queries" in data - assert "avg_total_time_ms" in data + assert "avg_time_ms" in data assert "avg_chunks_retrieved" in data assert "avg_chunks_filtered" in data - assert "profile_distribution" in data + assert "most_used_profile" in data def test_total_queries_is_integer(self, client: TestClient) -> None: data = client.get("/api/v1/history/stats").json() @@ -398,14 +385,12 @@ class TestHistoryStats: def test_averages_are_numeric(self, client: TestClient) -> None: data = client.get("/api/v1/history/stats").json() - assert isinstance(data["avg_total_time_ms"], (int, float)) + assert isinstance(data["avg_time_ms"], (int, float)) assert isinstance(data["avg_chunks_retrieved"], (int, float)) assert isinstance(data["avg_chunks_filtered"], (int, float)) def test_profile_distribution_values_are_integers(self, client: TestClient) -> None: - dist = client.get("/api/v1/history/stats").json()["profile_distribution"] - assert isinstance(dist, dict) - for profile, count in dist.items(): - assert isinstance(count, int), ( - f"profile_distribution['{profile}'] should be int, got {type(count)}" - ) + data = client.get("/api/v1/history/stats").json() + # Real service returns most_used_profile (str or None), not a distribution dict + profile = data["most_used_profile"] + assert profile is None or isinstance(profile, str) diff --git a/backend/app/test/test_phase3_prompt_injection.py b/backend/app/test/test_phase3_prompt_injection.py index c2963fd..98861f1 100644 --- a/backend/app/test/test_phase3_prompt_injection.py +++ b/backend/app/test/test_phase3_prompt_injection.py @@ -2,45 +2,98 @@ Verifies that QueryDecomposer, RelevanceFilter, and RAGService correctly fetch templates from PromptService and substitute placeholders. -""" -import pytest -from unittest.mock import MagicMock, AsyncMock -from app.services.query_decomposer import QueryDecomposer -from app.services.relevance_filter import RelevanceFilter -from app.services.rag import RAGService +Uses real PromptService (SQLite via tmp_path), real ChromaDB (tmp_path), +and only mocks the external LLM API. +""" +import sqlite3 + +import chromadb +import pytest + +from app.core.sqlite_db import init_prompts_db, seed_default_profiles +from app.services.prompt_service import PromptService # ── helpers ────────────────────────────────────────────────────────────── -def _make_custom_prompt_service(templates: dict[str, str]): - """Build a mock PromptService returning *templates* for get_prompt_template.""" - svc = MagicMock() - svc.get_prompt_template = MagicMock(side_effect=lambda step: templates.get(step, "")) +class _MockLLM: + """Mock external LLM API — only external dependency we're allowed to mock.""" + + def __init__(self, response: str = '["sub-q"]', side_effect: Exception | None = None): + self._response = response + self._side_effect = side_effect + self.last_prompt: str | None = None + self.calls: list[dict] = [] + self._call_count: int = 0 + + async def complete( + self, prompt: str, temperature: float = 0.7, step_name: str = "LLM" + ) -> str: + self.calls.append({"prompt": prompt, "step": step_name}) + self.last_prompt = prompt + self._call_count += 1 + if self._side_effect: + raise self._side_effect + return self._response + + @property + def call_count(self) -> int: + return self._call_count + + def assert_called(self): + assert self._call_count > 0, "LLM.complete was not called" + + def assert_not_called(self): + assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)" + + +def _create_prompt_service( + tmp_path, custom_templates: dict[str, str] | None = None +) -> PromptService: + """Create a real PromptService backed by real SQLite in tmp_path. + + Seeds default A/B/C profiles, then optionally updates the active profile + (A) with *custom_templates* so the service returns controlled templates. + """ + db_path = str(tmp_path / "prompts.db") + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys=ON") + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + svc = PromptService(db_path=db_path) + if custom_templates: + for step, template in custom_templates.items(): + svc.update_prompt("A", step, template) return svc -def _make_llm(response: str = '["sub-q"]'): - """Build a mock LLM client that records the prompt sent.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value=response) - return llm +def _setup_chroma(tmp_path): + """Create an isolated real ChromaDB PersistentClient for a test.""" + chroma_dir = tmp_path / "chroma" + chroma_dir.mkdir(parents=True, exist_ok=True) + return chromadb.PersistentClient(path=str(chroma_dir)) # ── QueryDecomposer tests ─────────────────────────────────────────────── -async def test_decomposer_fetches_template_from_prompt_service(): +async def test_decomposer_fetches_template_from_prompt_service(tmp_path): """QueryDecomposer should use the template returned by PromptService.""" + from app.services.query_decomposer import QueryDecomposer + custom_template = "CUSTOM DECOMPOSE: {question} -> split" - ps = _make_custom_prompt_service({"decompose": custom_template}) - llm = _make_llm('["a"]') + ps = _create_prompt_service(tmp_path, {"decompose": custom_template}) + llm = _MockLLM('["a"]') d = QueryDecomposer(llm, prompt_service=ps) questions, returned_prompt = await d.decompose("What is X?") - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt assert sent_prompt.startswith("CUSTOM DECOMPOSE:") assert "What is X?" in sent_prompt assert returned_prompt == sent_prompt @@ -48,11 +101,13 @@ async def test_decomposer_fetches_template_from_prompt_service(): async def test_decomposer_uses_builtin_when_no_prompt_service(): """Without prompt_service, the built-in seed template is used.""" - llm = _make_llm('["a"]') + from app.services.query_decomposer import QueryDecomposer + + llm = _MockLLM('["a"]') d = QueryDecomposer(llm, prompt_service=None) questions, returned_prompt = await d.decompose("What is X?") - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt assert "Break it down into 2-5 simplified sub-questions" in sent_prompt assert "What is X?" in sent_prompt assert returned_prompt == sent_prompt @@ -61,17 +116,19 @@ async def test_decomposer_uses_builtin_when_no_prompt_service(): # ── RelevanceFilter tests ─────────────────────────────────────────────── -async def test_filter_fetches_template_from_prompt_service(): +async def test_filter_fetches_template_from_prompt_service(tmp_path): """RelevanceFilter should use the template from PromptService.""" + from app.services.relevance_filter import RelevanceFilter + custom_template = "FILTER: q={question} chunks={chunks}" - ps = _make_custom_prompt_service({"filter": custom_template}) - llm = _make_llm("[5.0]") + ps = _create_prompt_service(tmp_path, {"filter": custom_template}) + llm = _MockLLM("[5.0]") rf = RelevanceFilter(llm, prompt_service=ps) chunks = [("text A", {"filename": "a.pdf"})] filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0) - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt assert sent_prompt.startswith("FILTER:") assert "My question" in sent_prompt assert "text A" in sent_prompt @@ -80,12 +137,14 @@ async def test_filter_fetches_template_from_prompt_service(): async def test_filter_uses_builtin_when_no_prompt_service(): """Without prompt_service, the built-in filter template is used.""" - llm = _make_llm("[5.0]") + from app.services.relevance_filter import RelevanceFilter + + llm = _MockLLM("[5.0]") rf = RelevanceFilter(llm, prompt_service=None) chunks = [("text A", {"filename": "a.pdf"})] filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0) - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt assert "rate each 0-10 for relevance" in sent_prompt assert "My question" in sent_prompt @@ -93,24 +152,23 @@ async def test_filter_uses_builtin_when_no_prompt_service(): # ── RAGService generate tests ─────────────────────────────────────────── -async def test_generate_fetches_template_from_prompt_service(): +async def test_generate_fetches_template_from_prompt_service(tmp_path): """RAGService.generate_response should use PromptService template.""" + from app.services.rag import RAGService + custom_template = "GEN: {question} --- {context} END" - ps = _make_custom_prompt_service({"generate": custom_template}) - llm = _make_llm("answer") + ps = _create_prompt_service(tmp_path, {"generate": custom_template}) + llm = _MockLLM("answer") + client = _setup_chroma(tmp_path) - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection - - svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps) + svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps) answer, gen_prompt = await svc.generate_response( "What is X?", ["chunk data"], [{"filename": "f.txt", "content_summary": "sum"}], ) - sent_prompt = llm.complete.call_args[1]["prompt"] + sent_prompt = llm.last_prompt assert sent_prompt.startswith("GEN:") assert "What is X?" in sent_prompt assert "chunk data" in sent_prompt @@ -118,22 +176,21 @@ async def test_generate_fetches_template_from_prompt_service(): assert gen_prompt == sent_prompt -async def test_generate_uses_builtin_when_no_prompt_service(): +async def test_generate_uses_builtin_when_no_prompt_service(tmp_path): """Without prompt_service, the built-in generate template is used.""" - llm = _make_llm("answer") + from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + llm = _MockLLM("answer") + client = _setup_chroma(tmp_path) - svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None) + svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=None) answer, gen_prompt = await svc.generate_response( "What is X?", ["chunk data"], [{"filename": "f.txt", "content_summary": "sum"}], ) - sent_prompt = llm.complete.call_args[1]["prompt"] + sent_prompt = llm.last_prompt assert "What is X?" in sent_prompt assert gen_prompt == sent_prompt @@ -141,47 +198,50 @@ async def test_generate_uses_builtin_when_no_prompt_service(): # ── Placeholder substitution safety tests ─────────────────────────────── -async def test_placeholder_substitution_safe_with_curly_braces(): +async def test_placeholder_substitution_safe_with_curly_braces(tmp_path): """User text containing curly braces must not crash str.replace.""" - ps = _make_custom_prompt_service({ - "decompose": "Question: {question} — decompose it" - }) - llm = _make_llm('["a"]') + from app.services.query_decomposer import QueryDecomposer + + ps = _create_prompt_service(tmp_path) + llm = _MockLLM('["a"]') d = QueryDecomposer(llm, prompt_service=ps) - # This question has literal braces — must not raise KeyError result, returned_prompt = await d.decompose("What about {key: value}?") assert isinstance(result, list) - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt assert "{key: value}" in sent_prompt assert returned_prompt == sent_prompt -async def test_unknown_placeholder_left_untouched(): +async def test_unknown_placeholder_left_untouched(tmp_path): """Placeholders not matched by str.replace stay as-is in the prompt.""" - ps = _make_custom_prompt_service({ - "decompose": "HELLO {fake_var} and {question}" - }) - llm = _make_llm('["a"]') + from app.services.query_decomposer import QueryDecomposer + + ps = _create_prompt_service( + tmp_path, {"decompose": "HELLO {fake_var} and {question}"} + ) + llm = _MockLLM('["a"]') d = QueryDecomposer(llm, prompt_service=ps) questions, returned_prompt = await d.decompose("Q?") - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt assert "{fake_var}" in sent_prompt assert "Q?" in sent_prompt -async def test_empty_template_produces_empty_prompt(): - """An empty template string results in an empty (or question-only) prompt.""" - ps = _make_custom_prompt_service({"decompose": ""}) - llm = _make_llm('["a"]') +async def test_empty_template_produces_empty_prompt(tmp_path): + """An empty template string results in an empty prompt.""" + from app.services.query_decomposer import QueryDecomposer + + ps = _create_prompt_service(tmp_path, {"decompose": ""}) + llm = _MockLLM('["a"]') d = QueryDecomposer(llm, prompt_service=ps) questions, returned_prompt = await d.decompose("Doesn't matter") - sent_prompt = llm.complete.call_args[0][0] + sent_prompt = llm.last_prompt # Empty template with .replace("{question}", ...) still has no text assert sent_prompt == "" @@ -189,73 +249,67 @@ async def test_empty_template_produces_empty_prompt(): # ── Edge case: no question / no chunks ────────────────────────────────── -async def test_decomposer_no_question_returns_empty(): - """Empty question returns [] without calling prompt_service.""" - ps = MagicMock() - ps.get_prompt_template = MagicMock(return_value="tmpl") +async def test_decomposer_no_question_returns_empty(tmp_path): + """Empty question returns [] without calling LLM.""" + from app.services.query_decomposer import QueryDecomposer - llm = _make_llm('["should_not_see"]') + ps = _create_prompt_service(tmp_path) + llm = _MockLLM('["should_not_see"]') d = QueryDecomposer(llm, prompt_service=ps) result, returned_prompt = await d.decompose("") assert result == [] assert returned_prompt == "" - llm.complete.assert_not_called() - ps.get_prompt_template.assert_not_called() + llm.assert_not_called() -async def test_filter_empty_chunks_no_template_fetch(): - """Empty chunks list returns [] without fetching a template.""" - ps = MagicMock() - ps.get_prompt_template = MagicMock(return_value="tmpl") +async def test_filter_empty_chunks_no_template_fetch(tmp_path): + """Empty chunks list returns [] without calling LLM.""" + from app.services.relevance_filter import RelevanceFilter - llm = _make_llm("[5]") + ps = _create_prompt_service(tmp_path) + llm = _MockLLM("[5]") rf = RelevanceFilter(llm, prompt_service=ps) result, returned_prompt = await rf.filter("Q?", [], threshold=5.0) assert result == [] assert returned_prompt == "" - llm.complete.assert_not_called() - ps.get_prompt_template.assert_not_called() + llm.assert_not_called() -async def test_generate_no_chunks_returns_fallback(): - """No chunks returns fallback message without touching PromptService.""" - ps = MagicMock() - ps.get_prompt_template = MagicMock(return_value="tmpl") +async def test_generate_no_chunks_returns_fallback(tmp_path): + """No chunks returns fallback message without calling LLM.""" + from app.services.rag import RAGService - llm = _make_llm("answer") - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + ps = _create_prompt_service(tmp_path) + llm = _MockLLM("answer") + client = _setup_chroma(tmp_path) - svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps) + svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps) answer, gen_prompt = await svc.generate_response("Q?", [], []) assert "could not find" in answer.lower() assert gen_prompt == "" - llm.complete.assert_not_called() - ps.get_prompt_template.assert_not_called() + llm.assert_not_called() -async def test_generate_per_subq_fetches_template_from_prompt_service(): +async def test_generate_per_subq_fetches_template_from_prompt_service(tmp_path): """RAGService.generate_response_per_subquestion should use PromptService template.""" + from app.services.rag import RAGService + custom_template = "PER_SUBQ: {context_sections} DONE" - ps = _make_custom_prompt_service({"generate_per_subq": custom_template}) - llm = _make_llm("answer") + ps = _create_prompt_service(tmp_path, {"generate_per_subq": custom_template}) + llm = _MockLLM("answer") + client = _setup_chroma(tmp_path) - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection - - svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps) + svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps) answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion( ["What is X?"], [["chunk data"]], [[{"filename": "f.txt", "content_summary": "sum"}]], ) - sent_prompt = llm.complete.call_args[1]["prompt"] + sent_prompt = llm.last_prompt assert sent_prompt.startswith("PER_SUBQ:") assert "chunk data" in sent_prompt assert sent_prompt.endswith("DONE") @@ -263,22 +317,21 @@ async def test_generate_per_subq_fetches_template_from_prompt_service(): assert len(grouped_sources) == 1 -async def test_generate_per_subq_uses_builtin_when_no_prompt_service(): +async def test_generate_per_subq_uses_builtin_when_no_prompt_service(tmp_path): """Without prompt_service, the built-in per-subq template is used.""" - llm = _make_llm("answer") + from app.services.rag import RAGService - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection + llm = _MockLLM("answer") + client = _setup_chroma(tmp_path) - svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None) + svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=None) answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion( ["What is X?"], [["chunk data"]], [[{"filename": "f.txt", "content_summary": "sum"}]], ) - sent_prompt = llm.complete.call_args[1]["prompt"] + sent_prompt = llm.last_prompt assert "Sub-question" in sent_prompt assert "chunk data" in sent_prompt assert "{context_sections}" not in sent_prompt diff --git a/backend/app/test/test_phase3_query_history_integration.py b/backend/app/test/test_phase3_query_history_integration.py index 5a36252..b4774b4 100644 --- a/backend/app/test/test_phase3_query_history_integration.py +++ b/backend/app/test/test_phase3_query_history_integration.py @@ -1,8 +1,10 @@ """Tests for Phase 3.5: Query history integration (end-to-end pipeline). -Verifies that the query pipeline in ``_query_stream()`` captures timing data, -actual LLM prompts, chunk XML, and records them to a history service after the -SSE stream completes. +Verifies that the query pipeline via POST /api/v1/query captures timing data, +actual LLM prompts, chunk XML, and records them to a real history service +after the SSE stream completes. + +Uses real ChromaDB and SQLite (tmp_path). Only the LLM (external API) is mocked. Key behaviours under test: - Full query → history record created with correct fields @@ -14,384 +16,198 @@ Key behaviours under test: chunk counts - Query completes successfully even if history recording fails (fire-and-forget) - No history record created when the query pipeline errors out early - -All external services (LLM, ChromaDB, history_service) are mocked. -The tests call ``_query_stream()`` directly — no HTTP layer involved. - -NOTE: This test targets the post-3.5 API where each service method returns a -``(result, prompt)`` tuple. The module patches the real service classes so -that the tests remain valid even before the implementation lands. """ from __future__ import annotations -import asyncio import json -import re +import os +import sqlite3 import time -from typing import Any, Dict, List, Tuple -from unittest.mock import AsyncMock, MagicMock, patch +import chromadb import pytest +from fastapi.testclient import TestClient -from app.models.query import QueryRequest +from app.core.config import Settings +from app.core.sqlite_db import init_history_db, init_prompts_db, seed_default_profiles +from app.services.history_service import HistoryService -# ── Shared fixtures & helpers ──────────────────────────────────────────── +# ── Test seed data ────────────────────────────────────────────────────── -# Sample chunks that ChromaDB would return from ``RAGService.retrieve()``. -# Each element is ``(text, metadata_dict, distance)``. -SAMPLE_CHUNKS = [ - ( - "Clause 61.3 states that time extensions must be notified within 8 weeks.", - {"filename": "NEC4 ACC.pdf", "page_number": 3, "content_summary": "Time extension provisions", "chunk_index": 0}, - 0.15, - ), - ( - "Notice must be given to the project manager before expiry of the period.", - {"filename": "NEC4 Contract.pdf", "page_number": 12, "content_summary": "Notification requirements", "chunk_index": 0}, - 0.22, - ), - ( - "The contractor may be entitled to additional time under clause X2.", - {"filename": "NEC4 ACC.pdf", "page_number": 7, "content_summary": "Additional time entitlements", "chunk_index": 1}, - 0.31, - ), +SEED_DOCS = [ + { + "text": "Time extensions must be notified within 8 weeks.", + "metadata": { + "filename": "NEC4.pdf", + "page_number": 3, + "content_summary": "Time extensions", + "chunk_index": 0, + "upload_date": "2024-01-01", + }, + }, + { + "text": "Notice must be given to the project manager before expiry of the period.", + "metadata": { + "filename": "NEC4.pdf", + "page_number": 12, + "content_summary": "Notification", + "chunk_index": 1, + "upload_date": "2024-01-01", + }, + }, ] -# Metadata after filtering — same structure but ``RelevanceFilter`` will embed -# ``relevance_score`` into the metadata dict. -SAMPLE_FILTERED = [ - ( - "Clause 61.3 states that time extensions must be notified within 8 weeks.", - {"filename": "NEC4 ACC.pdf", "page_number": 3, "content_summary": "Time extension provisions", "chunk_index": 0, "relevance_score": 8.5}, - ), - ( - "Notice must be given to the project manager before expiry of the period.", - {"filename": "NEC4 Contract.pdf", "page_number": 12, "content_summary": "Notification requirements", "chunk_index": 0, "relevance_score": 9.0}, - ), -] +# ── Helpers ──────────────────────────────────────────────────────────── -def _make_mock_settings(): - """Build a lightweight mock Settings object.""" - settings = MagicMock() - settings.retrieval_n_results = 10 - settings.relevance_threshold = 7.0 - settings.prompts_db_path = ":memory:" - return settings +class _ConstantEmbedding: + """Identical vectors for all inputs — all docs equally match any query.""" + + DIM = 10 + + def __call__(self, input): + return [[0.1] * self.DIM for _ in input] + + def embed_query(self, input): + return self(input) + + def name(self): + return "constant_test" -def _make_mock_prompt_service(): - """Build a mock PromptService with default templates.""" - ps = MagicMock() - ps.get_active_profile_name.return_value = "A" - ps.get_prompt_template = MagicMock( - side_effect=lambda step: { - "decompose": "Given this question: '{question}'\n\nBreak it down into 2-5 simplified sub-questions.", - "filter": "Given question '{question}' and these document chunks, rate each 0-10 for relevance.\n{chunks}\n", - "generate": "Question: {question}\n\nAnswer using ONLY these document chunks.\n\nDocument chunks:\n{context}\n\nAnswer:", - }.get(step, "") - ) - return ps +def _make_mock_llm_class(responses): + """Create a mock LLMClient class that returns *responses* in sequence. - -def _make_mock_llm(): - """Build a mock LLM client whose ``complete()`` returns controlled responses. - - Call sequence: - 1st call → decompose response (JSON array of sub-questions) - 2nd call → filter response (JSON array of relevance scores) - 3rd call → generate response (bullet-point answer) + The returned class has the same constructor signature as LLMClient(settings) + so it can replace LLMClient via monkeypatch. """ - llm = MagicMock() - llm.complete = AsyncMock( - side_effect=[ - '["What are the time extension provisions?", "What notice is required for time extensions?"]', - "[8.5, 9.0, 3.2]", - "• Time extensions must be notified within 8 weeks [NEC4 ACC.pdf, page 3]\n• Notice must be given to the project manager [NEC4 Contract.pdf, page 12]", - ] + + class _MockLLM: + def __init__(self, settings): + self.settings = settings + self._responses = list(responses) + self._idx = 0 + + async def complete(self, prompt, temperature=0.7, step_name="LLM"): + if self._idx < len(self._responses): + resp = self._responses[self._idx] + self._idx += 1 + return resp + raise RuntimeError(f"No more mock responses (call #{self._idx + 1})") + + async def close(self): + pass + + return _MockLLM + + +# Standard mock responses for a successful 2-sub-question pipeline +_STANDARD_RESPONSES = [ + '["What are time extensions?", "What notice is required?"]', + '{"0": [8.5, 9.0], "1": [8.5, 9.0]}', + ( + "## Sub-question 1: What are time extensions?\n" + "- Extensions need 8 weeks notice [NEC4.pdf, page 3]\n\n" + "## Sub-question 2: What notice is required?\n" + "- Notify the project manager [NEC4.pdf, page 12]\n" + ), +] + + +def _setup_env(tmp_path, monkeypatch, seed_docs=None, init_history=True): + """Set up real ChromaDB + SQLite via tmp_path for pipeline tests. + + Seeds ChromaDB with *seed_docs*, initializes prompts/history SQLite, + and monkeypatches get_settings + embedding function. + + Returns dict with settings, db paths, and seed docs. + """ + seed_docs = seed_docs or SEED_DOCS + chroma_dir = tmp_path / "chroma_db" + chroma_dir.mkdir() + prompts_db = str(tmp_path / "prompts.db") + history_db = str(tmp_path / "history.db") + + test_settings = Settings( + chroma_db_path=str(chroma_dir), + prompts_db_path=prompts_db, + history_db_path=history_db, + llm_base_url="http://mock-llm", + llm_api_key="test-key", + llm_model_name="test-model", + embedding_base_url="http://mock-emb", + embedding_api_key="test-key", + embedding_model="test-model", ) - return llm + conn = sqlite3.connect(prompts_db) + conn.row_factory = sqlite3.Row + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() -def _make_mock_chroma_collection(chunks): - """Build a mock ChromaDB collection that returns *chunks* from ``query()``.""" - collection = MagicMock() - docs = [c[0] for c in chunks] - metas = [c[1] for c in chunks] - dists = [c[2] for c in chunks] - collection.query.return_value = { - "documents": [docs], - "metadatas": [metas], - "distances": [dists], + if init_history: + conn = sqlite3.connect(history_db) + conn.row_factory = sqlite3.Row + init_history_db(conn) + conn.close() + + # Seed ChromaDB with constant embedding + embedding_fn = _ConstantEmbedding() + client = chromadb.PersistentClient(path=str(chroma_dir)) + collection = client.get_or_create_collection( + "documents", embedding_function=embedding_fn + ) + for i, doc in enumerate(seed_docs): + collection.add( + documents=[doc["text"]], + metadatas=[doc["metadata"]], + ids=[f"test_doc_{i}"], + ) + + # Monkeypatch settings + embedding so _query_stream uses test paths + monkeypatch.setattr("app.routers.query.get_settings", lambda: test_settings) + monkeypatch.setattr("app.core.database.get_settings", lambda: test_settings) + monkeypatch.setattr( + "app.core.database.get_embedding_function_settings", + lambda s: embedding_fn, + ) + + return { + "settings": test_settings, + "history_db": history_db, + "prompts_db": prompts_db, + "chroma_dir": str(chroma_dir), + "seed_docs": seed_docs, } - return collection -def _make_mock_chroma_client(collection): - """Build a mock ChromaDB client.""" - client = MagicMock() - client.get_or_create_collection.return_value = collection - return client +def _collect_sse(client, question): + """POST to /api/v1/query and collect SSE events as parsed dicts.""" + events = [] + with client.stream( + "POST", "/api/v1/query", json={"question": question} + ) as response: + assert response.status_code == 200 + for line in response.iter_lines(): + if line.startswith("data: "): + events.append(json.loads(line[6:])) + return events -def _make_mock_history_service(): - """Build a mock ``HistoryService`` with an async ``record()`` method.""" - svc = MagicMock() - svc.record = AsyncMock() - return svc - - -# ── XML formatting helpers (mirror the implementation spec) ────────────── - - -def format_chunks_retrieved_xml(chunks): - """Format retrieved chunks as XML-tagged string. - - Parameters match ``RAGService.retrieve()`` output: - ``[(text, metadata, distance), ...]`` - """ - parts = [] - for i, (text, meta, _dist) in enumerate(chunks, start=1): - lines = [f""] - lines.append(f"Filename: {meta.get('filename', 'unknown')}") - page = meta.get("page_number") - if page is not None: - lines.append(f"Page: {page}") - lines.append(f"Content: {text}") - lines.append(f"") - parts.append("\n".join(lines)) - return "\n".join(parts) - - -def format_chunks_filtered_xml(filtered): - """Format filtered chunks as XML with relevance scores. - - Parameters: ``[(text, metadata), ...]`` where - ``metadata["relevance_score"]`` holds the score. - """ - parts = [] - for i, (text, meta) in enumerate(filtered, start=1): - lines = [f""] - lines.append(f"Filename: {meta.get('filename', 'unknown')}") - page = meta.get("page_number") - if page is not None: - lines.append(f"Page: {page}") - score = meta.get("relevance_score") - if score is not None: - lines.append(f"Relevance: {score}") - lines.append(f"Content: {text}") - lines.append(f"") - parts.append("\n".join(lines)) - return "\n".join(parts) - - -# ── Pipeline simulation helper ────────────────────────────────────────── - - -async def _run_pipeline_and_collect_history( - question: str = "What is the NEC4 clause about time extensions?", - llm=None, - chunks=None, - filtered=None, - prompt_service=None, - settings=None, - history_service=None, - *, - # Toggle: simulate the post-3.5 return-signature (result, prompt) tuples - use_tuple_returns: bool = True, - # Toggle: inject failures - llm_error_on_call: int | None = None, -): - """Simulate ``_query_stream`` logic and return the history record kwargs. - - This function reproduces the pipeline flow that ``_query_stream()`` will - implement after sub-phase 3.5, including timing capture and prompt capture - from service return values. It returns ``(sse_events, history_kwargs)`` - where *history_kwargs* is the dict that would be passed to - ``history_service.record()``. - """ - if llm is None: - llm = _make_mock_llm() - if chunks is None: - chunks = SAMPLE_CHUNKS - if filtered is None: - filtered = SAMPLE_FILTERED - if prompt_service is None: - prompt_service = _make_mock_prompt_service() - if settings is None: - settings = _make_mock_settings() - if history_service is None: - history_service = _make_mock_history_service() - - from app.services.query_decomposer import QueryDecomposer - from app.services.relevance_filter import RelevanceFilter - from app.services.rag import RAGService - - sse_events: list[dict] = [] - history_kwargs: dict | None = None - error_occurred = False - - overall_start = time.perf_counter() - active_profile = prompt_service.get_active_profile_name() - - try: - # Stage 1: Decompose - decomposer = QueryDecomposer(llm, prompt_service=prompt_service) - stage_start = time.perf_counter() - - if llm_error_on_call == 1: - raise RuntimeError("LLM decompose error") - - decompose_result = await decomposer.decompose(question) - if use_tuple_returns and isinstance(decompose_result, tuple): - questions: List[str] = decompose_result[0] - decompose_prompt: str = decompose_result[1] - else: - questions = decompose_result if isinstance(decompose_result, list) else [] - decompose_prompt = "" - - decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000) - sse_events.append({"phase": "decomposed", "extracted_questions": questions}) - - # Stage 2: Retrieve (mocked) - mock_collection = _make_mock_chroma_collection(chunks) - mock_client = _make_mock_chroma_client(mock_collection) - rag = RAGService(chroma_client=mock_client, llm_client=llm, settings=settings, prompt_service=prompt_service) - - stage_start = time.perf_counter() - retrieved_chunks: List[Tuple[str, Dict[str, Any], float]] = rag.retrieve( - questions, n_results=settings.retrieval_n_results - ) - retriever_time_ms = int((time.perf_counter() - stage_start) * 1000) - chunks_retrieved_count = len(retrieved_chunks) - chunks_retrieved_xml = format_chunks_retrieved_xml(retrieved_chunks) - - sse_events.append({"phase": "retrieving"}) - - if not retrieved_chunks: - sse_events.append({"phase": "completed", "answer": "I could not find any relevant information.", "sources": []}) - return sse_events, None - - # Stage 3: Filter - chunks_for_filter: List[Tuple[str, Dict[str, Any]]] = [ - (text, meta) for text, meta, _dist in retrieved_chunks - ] - relevance_filter = RelevanceFilter(llm, prompt_service=prompt_service) - - stage_start = time.perf_counter() - - if llm_error_on_call == 2: - raise RuntimeError("LLM filter error") - - filter_result = await relevance_filter.filter( - question, chunks_for_filter, threshold=settings.relevance_threshold - ) - if use_tuple_returns and isinstance(filter_result, tuple): - filtered_chunks = list(filter_result[0]) # type: ignore[arg-type] - filter_prompt: str = str(filter_result[1]) - else: - filtered_chunks = list(filter_result) if isinstance(filter_result, list) else [] # type: ignore[arg-type] - filter_prompt = "" - - # Embed relevance scores into metadata for XML formatting (per plan decision #17) - if use_tuple_returns and filtered_chunks: - scored_filtered: list = [] - for item in filtered_chunks: - chunk_text_item, meta_item = item # type: ignore[misc] - if "relevance_score" not in meta_item: # type: ignore[operator] - meta_copy: Dict[str, Any] = dict(meta_item) # type: ignore[arg-type] - meta_copy["relevance_score"] = 8.5 - scored_filtered.append((chunk_text_item, meta_copy)) - else: - scored_filtered.append((chunk_text_item, meta_item)) - filtered_chunks = scored_filtered - - filter_time_ms = int((time.perf_counter() - stage_start) * 1000) - chunks_filtered_count = len(filtered_chunks) - chunks_filtered_xml = format_chunks_filtered_xml(filtered_chunks) if filtered_chunks else "" - - sse_events.append({"phase": "filtering"}) - - if not filtered_chunks: - sse_events.append({"phase": "completed", "answer": "I could not find any relevant information.", "sources": []}) - return sse_events, None - - # Stage 4: Generate - chunk_texts: list = [chunk for chunk, _meta in filtered_chunks] # type: ignore[misc] - chunk_metadata: list = [meta for _chunk, meta in filtered_chunks] # type: ignore[misc] - - stage_start = time.perf_counter() - - if llm_error_on_call == 3: - raise RuntimeError("LLM generate error") - - gen_result = await rag.generate_response(question, chunk_texts, chunk_metadata) - if use_tuple_returns and isinstance(gen_result, tuple): - answer: str = gen_result[0] - generate_prompt: str = gen_result[1] - else: - answer = gen_result if isinstance(gen_result, str) else "" - generate_prompt = "" - - generator_time_ms = int((time.perf_counter() - stage_start) * 1000) - - total_time_ms = int((time.perf_counter() - overall_start) * 1000) - - # Build sources - from app.models.common import SourceMetadata - sources = [ - SourceMetadata( - filename=meta.get("filename", "unknown"), - upload_date=meta.get("upload_date", ""), - content_summary=meta.get("content_summary", ""), - chunk_index=meta.get("chunk_index", 0), - page_number=meta.get("page_number"), - chunk_file_path=meta.get("chunk_file_path"), - ) - for meta in chunk_metadata - ] - - sse_events.append({ - "phase": "completed", - "answer": answer, - "sources": [s.model_dump() for s in sources], - }) - - # Assemble history record kwargs - history_kwargs = { - "input_text": question, - "extracted_questions": json.dumps(questions), - "decompose_prompt": decompose_prompt, - "decomposer_time_ms": decomposer_time_ms, - "retriever_time_ms": retriever_time_ms, - "chunks_retrieved": chunks_retrieved_xml, - "chunks_retrieved_count": chunks_retrieved_count, - "filter_prompt": filter_prompt, - "filter_time_ms": filter_time_ms, - "chunks_filtered": chunks_filtered_xml, - "chunks_filtered_count": chunks_filtered_count, - "generate_prompt": generate_prompt, - "generator_time_ms": generator_time_ms, - "total_time_ms": total_time_ms, - "final_answer": answer, - "sources": json.dumps([s.model_dump() for s in sources]), - "profile_used": active_profile, - } - - # Fire-and-forget history recording - try: - await history_service.record(history_kwargs) - except Exception: - pass # best-effort - - except Exception as exc: - error_occurred = True - sse_events.append({"phase": "error", "message": f"Query failed: {exc}"}) - - return sse_events, history_kwargs +def _wait_for_history(history_db, timeout=2.0): + """Poll history DB until a record appears or timeout.""" + start = time.time() + while time.time() - start < timeout: + hs = HistoryService(db_path=history_db) + records = hs.list() + if records: + return hs.get(records[0]["id"]) + time.sleep(0.05) + return None # ═══════════════════════════════════════════════════════════════════════ @@ -399,13 +215,20 @@ async def _run_pipeline_and_collect_history( # ═══════════════════════════════════════════════════════════════════════ -async def test_query_pipeline_creates_history_record(): +def test_query_pipeline_creates_history_record(tmp_path, monkeypatch): """Simulate a full query and verify a history record is created with correct ``input_text``, ``extracted_questions``, positive timing values, and ``profile_used = "A"``. """ - history_svc = _make_mock_history_service() - events, rec = await _run_pipeline_and_collect_history(history_service=history_svc) + env = _setup_env(tmp_path, monkeypatch) + monkeypatch.setattr( + "app.routers.query.LLMClient", _make_mock_llm_class(_STANDARD_RESPONSES) + ) + + from app.main import app + + client = TestClient(app) + events = _collect_sse(client, "What is the NEC4 clause about time extensions?") # SSE stream should contain all phases phases = [e["phase"] for e in events] @@ -415,112 +238,130 @@ async def test_query_pipeline_creates_history_record(): assert "completed" in phases # History record must exist - assert rec is not None + record = _wait_for_history(env["history_db"]) + assert record is not None, "History record should be created" # Core fields - assert rec["input_text"] == "What is the NEC4 clause about time extensions?" - assert rec["profile_used"] == "A" + assert record["input_text"] == "What is the NEC4 clause about time extensions?" + assert record["profile_used"] == "A" # extracted_questions is a JSON array - questions = json.loads(rec["extracted_questions"]) + questions = json.loads(record["extracted_questions"]) assert isinstance(questions, list) assert len(questions) >= 1 - # All timing fields positive + # All timing fields non-negative for timing_key in ( - "decomposer_time_ms", "retriever_time_ms", - "filter_time_ms", "generator_time_ms", "total_time_ms", + "decomposer_time_ms", + "retriever_time_ms", + "filter_time_ms", + "generator_time_ms", + "total_time_ms", ): - assert rec[timing_key] >= 0, f"{timing_key} should be >= 0" - - # history_service.record was called once - history_svc.record.assert_awaited_once() + assert record[timing_key] >= 0, f"{timing_key} should be >= 0" -async def test_history_record_contains_prompts(): +def test_history_record_contains_prompts(tmp_path, monkeypatch): """Verify ``decompose_prompt``, ``filter_prompt``, and ``generate_prompt`` are stored as non-empty strings in the history record. """ - events, rec = await _run_pipeline_and_collect_history() + env = _setup_env(tmp_path, monkeypatch) + monkeypatch.setattr( + "app.routers.query.LLMClient", _make_mock_llm_class(_STANDARD_RESPONSES) + ) - assert rec is not None + from app.main import app - # After 3.5, services return prompts alongside results. When the mock - # services still return plain values (pre-3.5), prompts will be "". - # This test validates the post-3.5 contract: prompts must be non-empty. - # We check the contract — if the mock LLM was called, the prompt was sent. - from app.services.query_decomposer import QueryDecomposer - from app.services.relevance_filter import RelevanceFilter - from app.services.rag import RAGService + client = TestClient(app) + _collect_sse(client, "What about time extensions?") - # The prompts may be "" if tuple returns aren't wired yet. - # But the fields must exist in the record. - assert "decompose_prompt" in rec - assert "filter_prompt" in rec - assert "generate_prompt" in rec + record = _wait_for_history(env["history_db"]) + assert record is not None - # When tuple returns are active, prompts should be non-empty - # (the mock LLM.complete was called with actual prompt strings) - # We verify the mock LLM received calls — proving prompts were built. - # The actual prompt capture depends on the service returning tuples. - # For now, we verify the field exists and is a string. - assert isinstance(rec["decompose_prompt"], str) - assert isinstance(rec["filter_prompt"], str) - assert isinstance(rec["generate_prompt"], str) + # Real services render actual prompts — must be non-empty strings + for key in ("decompose_prompt", "filter_prompt", "generate_prompt"): + assert key in record, f"Missing {key} in history record" + assert isinstance(record[key], str), f"{key} must be a string" + assert record[key], f"{key} should be non-empty with real services" -async def test_history_record_contains_chunk_xml(): - """Verify ``chunks_retrieved`` XML contains ```` tags with - Filename, Page, and Content fields. +def test_history_record_contains_chunk_xml(tmp_path, monkeypatch): + """Verify ``chunks_retrieved`` XML contains ```` wrappers with + ```` tags including Filename, Page, and Content fields. """ - events, rec = await _run_pipeline_and_collect_history() + env = _setup_env(tmp_path, monkeypatch) + monkeypatch.setattr( + "app.routers.query.LLMClient", _make_mock_llm_class(_STANDARD_RESPONSES) + ) - assert rec is not None - xml = rec["chunks_retrieved"] + from app.main import app + + client = TestClient(app) + _collect_sse(client, "What about time extensions?") + + record = _wait_for_history(env["history_db"]) + assert record is not None + + xml = record["chunks_retrieved"] assert xml, "chunks_retrieved XML must not be empty" - # Must contain , , (3 retrieved chunks) - for i in range(1, len(SAMPLE_CHUNKS) + 1): - assert f"" in xml, f"Missing opening tag" - assert f"" in xml, f"Missing closing tag" + # Per-sub-question XML format + assert "" in xml - # Must contain Filename and Content fields - assert "Filename: NEC4 ACC.pdf" in xml - assert "Filename: NEC4 Contract.pdf" in xml + # Must contain chunk tags + assert " tags for filtered chunks - for i in range(1, len(SAMPLE_FILTERED) + 1): - assert f"" in xml, f"Missing in filtered XML" - - # Must contain Relevance scores + # Per-sub-question XML with relevance scores + assert "= 0, f"{field} must be >= 0, got {value}" # Total time should be >= sum of individual stages stage_sum = ( - rec["decomposer_time_ms"] - + rec["retriever_time_ms"] - + rec["filter_time_ms"] - + rec["generator_time_ms"] + record["decomposer_time_ms"] + + record["retriever_time_ms"] + + record["filter_time_ms"] + + record["generator_time_ms"] ) - assert rec["total_time_ms"] >= stage_sum, ( + assert record["total_time_ms"] >= stage_sum, ( "total_time_ms should be >= sum of individual stage times" ) -async def test_history_count_fields_are_ints(): +def test_history_count_fields_are_ints(tmp_path, monkeypatch): """Verify ``chunks_retrieved_count`` and ``chunks_filtered_count`` are integers matching actual chunk counts. + + With 2 seed docs and 2 sub-questions, each sub-q retrieves 2 chunks + via constant embedding → 4 total retrieved. Mock filter keeps all + (scores 8.5, 9.0 > threshold 7.0) → 4 total filtered. """ - events, rec = await _run_pipeline_and_collect_history() - - assert rec is not None - - retrieved_count = rec["chunks_retrieved_count"] - filtered_count = rec["chunks_filtered_count"] - - assert isinstance(retrieved_count, int), f"chunks_retrieved_count must be int, got {type(retrieved_count).__name__}" - assert isinstance(filtered_count, int), f"chunks_filtered_count must be int, got {type(filtered_count).__name__}" - - # Retrieved count should match the number of chunks returned by ChromaDB - assert retrieved_count == len(SAMPLE_CHUNKS), ( - f"Expected {len(SAMPLE_CHUNKS)} retrieved chunks, got {retrieved_count}" + env = _setup_env(tmp_path, monkeypatch) + monkeypatch.setattr( + "app.routers.query.LLMClient", _make_mock_llm_class(_STANDARD_RESPONSES) ) - # Filtered count should match the number of chunks that passed the filter - assert filtered_count == len(SAMPLE_FILTERED), ( - f"Expected {len(SAMPLE_FILTERED)} filtered chunks, got {filtered_count}" + from app.main import app + + client = TestClient(app) + _collect_sse(client, "What about time extensions?") + + record = _wait_for_history(env["history_db"]) + assert record is not None + + retrieved_count = record["chunks_retrieved_count"] + filtered_count = record["chunks_filtered_count"] + + assert isinstance(retrieved_count, int), ( + f"chunks_retrieved_count must be int, got {type(retrieved_count).__name__}" + ) + assert isinstance(filtered_count, int), ( + f"chunks_filtered_count must be int, got {type(filtered_count).__name__}" + ) + + # 2 sub-questions × 2 seed docs each = 4 retrieved + assert retrieved_count == 4, ( + f"Expected 4 retrieved chunks (2 sub-qs × 2 docs), got {retrieved_count}" + ) + # All pass filter (mock scores 8.5, 9.0 > threshold 7.0) + assert filtered_count == 4, ( + f"Expected 4 filtered chunks, got {filtered_count}" ) -async def test_history_fire_and_forget(): +def test_history_fire_and_forget(tmp_path, monkeypatch): """Verify query response returns successfully even if history recording fails. - The history service ``record()`` raises an exception — the pipeline must - still return a completed SSE event. + The history DB is not initialised (no tables), so record() raises. + The pipeline must still return a completed SSE event. """ - failing_history = _make_mock_history_service() - failing_history.record = AsyncMock(side_effect=RuntimeError("DB write failed")) + env = _setup_env(tmp_path, monkeypatch, init_history=False) + # Ensure the DB file is gone so record() creates a bare file with no table + if os.path.exists(env["history_db"]): + os.remove(env["history_db"]) - events, rec = await _run_pipeline_and_collect_history(history_service=failing_history) + monkeypatch.setattr( + "app.routers.query.LLMClient", _make_mock_llm_class(_STANDARD_RESPONSES) + ) + + from app.main import app + + client = TestClient(app) + events = _collect_sse(client, "What about time extensions?") # Pipeline must still produce a completed event phases = [e["phase"] for e in events] assert "completed" in phases, "Query should complete even if history fails" - # The history record was assembled (rec is not None) but - # record() was attempted and raised — that's fine (fire-and-forget). - # The mock propagates the error, but the real implementation swallows it. - failing_history.record.assert_awaited_once() +def test_history_not_created_on_error(tmp_path, monkeypatch): + """If the query fails (e.g. LLM generate error), no history record is created.""" + env = _setup_env(tmp_path, monkeypatch) -async def test_history_not_created_on_error(): - """If the query fails (e.g. LLM error), no history record is created.""" - # Simulate LLM failure on the first call (decompose stage) - events, rec = await _run_pipeline_and_collect_history( - llm_error_on_call=1, - ) + # Mock LLM: succeeds on decompose + filter, fails on generate + class _ErrorOnGenerateLLM: + def __init__(self, settings): + self.settings = settings + self._call_count = 0 + + async def complete(self, prompt, temperature=0.7, step_name="LLM"): + self._call_count += 1 + if self._call_count == 1: + return '["test question"]' + if self._call_count == 2: + return '{"0": [8.5, 9.0]}' + raise RuntimeError("LLM generate error") + + async def close(self): + pass + + monkeypatch.setattr("app.routers.query.LLMClient", _ErrorOnGenerateLLM) + + from app.main import app + + client = TestClient(app) + events = _collect_sse(client, "Some question?") # Should have an error event phases = [e["phase"] for e in events] assert "error" in phases, "Expected an error SSE event" # No history record - assert rec is None, "History record must not be created on pipeline error" + record = _wait_for_history(env["history_db"], timeout=0.5) + assert record is None, "History record must not be created on pipeline error" # ═══════════════════════════════════════════════════════════════════════ @@ -620,46 +503,75 @@ async def test_history_not_created_on_error(): class TestPerSubQPipelineHistory: """History recording for the per-sub-question pipeline.""" - async def test_per_subq_pipeline_records_history(self): + def test_per_subq_pipeline_records_history(self, tmp_path, monkeypatch): """Per-sub-q pipeline should record history with sub_question_sources.""" - history_svc = _make_mock_history_service() - events, rec = await _run_pipeline_and_collect_history( - history_service=history_svc, + env = _setup_env(tmp_path, monkeypatch) + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class(_STANDARD_RESPONSES), ) - assert rec is not None - assert rec["input_text"] == "What is the NEC4 clause about time extensions?" - assert rec["profile_used"] == "A" + from app.main import app - questions = json.loads(rec["extracted_questions"]) + client = TestClient(app) + _collect_sse(client, "What about time extensions?") + + record = _wait_for_history(env["history_db"]) + assert record is not None + assert record["input_text"] == "What about time extensions?" + assert record["profile_used"] == "A" + + questions = json.loads(record["extracted_questions"]) assert isinstance(questions, list) assert len(questions) >= 1 for timing_key in ( - "decomposer_time_ms", "retriever_time_ms", - "filter_time_ms", "generator_time_ms", "total_time_ms", + "decomposer_time_ms", + "retriever_time_ms", + "filter_time_ms", + "generator_time_ms", + "total_time_ms", ): - assert rec[timing_key] >= 0, f"{timing_key} should be >= 0" + assert record[timing_key] >= 0, f"{timing_key} should be >= 0" - history_svc.record.assert_awaited_once() - - async def test_per_subq_history_contains_chunk_xml(self): + def test_per_subq_history_contains_chunk_xml(self, tmp_path, monkeypatch): """History should contain XML-tagged chunks_retrieved and chunks_filtered.""" - events, rec = await _run_pipeline_and_collect_history() + env = _setup_env(tmp_path, monkeypatch) + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class(_STANDARD_RESPONSES), + ) - assert rec is not None - assert rec["chunks_retrieved"], "chunks_retrieved must not be empty" - assert rec["chunks_filtered"], "chunks_filtered must not be empty" + from app.main import app - assert " wrappers + assert " str: + self.calls.append({"prompt": prompt, "step": step_name}) + self.last_prompt = prompt + self._call_count += 1 + if self._side_effect: + raise self._side_effect + return self._response + + @property + def call_count(self) -> int: + return self._call_count + + def assert_called(self): + assert self._call_count > 0, "LLM.complete was not called" + + def assert_not_called(self): + assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)" + + +def _setup_chroma(tmp_path): + """Create an isolated real ChromaDB PersistentClient.""" + chroma_dir = tmp_path / "chroma" + chroma_dir.mkdir(parents=True, exist_ok=True) + return chromadb.PersistentClient(path=str(chroma_dir)) # --------------------------------------------------------------------------- # Test: two sub-questions, LLM returns markdown with headers # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_generate_per_subq_two_questions(): +async def test_generate_per_subq_two_questions(tmp_path): """Two sub-questions, 2 chunks for first, 1 for second. LLM returns markdown with ## Sub-question 1/2 headers. Assert answer contains both headers and grouped_sources has correct shape. """ - llm = MagicMock() - llm.complete = AsyncMock(return_value=( + from app.services.rag import RAGService + + llm = _MockLLM(response=( "## Sub-question 1: What is A?\n" "- Bullet point A1 [file_a.pdf, page 1]\n" "- Bullet point A2 [file_a.pdf, page 2]\n\n" "## Sub-question 2: What is B?\n" "- Bullet point B1 [file_b.pdf, page 1]\n" )) + client = _setup_chroma(tmp_path) - service = RAGService(llm_client=llm) + service = RAGService(chroma_client=client, llm_client=llm) answer, prompt, grouped_sources = await service.generate_response_per_subquestion( sub_questions=["What is A?", "What is B?"], sub_chunks=[ @@ -53,23 +93,26 @@ async def test_generate_per_subq_two_questions(): assert "## Sub-question 1: What is A?" in answer assert "## Sub-question 2: What is B?" in answer assert len(grouped_sources) == 2 - assert len(grouped_sources[0]) == 2 # 2 sources for sub-q 0 - assert len(grouped_sources[1]) == 1 # 1 source for sub-q 1 + assert len(grouped_sources[0]) == 2 + assert len(grouped_sources[1]) == 1 assert grouped_sources[0][0]["filename"] == "file_a.pdf" assert grouped_sources[1][0]["filename"] == "file_b.pdf" - llm.complete.assert_called_once() + llm.assert_called() + assert llm.call_count == 1 # --------------------------------------------------------------------------- # Test: empty input # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_generate_per_subq_empty_input(): +async def test_generate_per_subq_empty_input(tmp_path): """Empty sub_questions returns fallback message and empty grouped_sources.""" - llm = MagicMock() - llm.complete = AsyncMock() + from app.services.rag import RAGService - service = RAGService(llm_client=llm) + llm = _MockLLM() + client = _setup_chroma(tmp_path) + + service = RAGService(chroma_client=client, llm_client=llm) answer, prompt, grouped_sources = await service.generate_response_per_subquestion( sub_questions=[], sub_chunks=[], @@ -78,19 +121,21 @@ async def test_generate_per_subq_empty_input(): assert answer == "I could not find any relevant information to answer your question." assert grouped_sources == [] - llm.complete.assert_not_called() + llm.assert_not_called() # --------------------------------------------------------------------------- # Test: sub-questions provided but all chunk lists empty # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_generate_per_subq_no_chunks(): +async def test_generate_per_subq_no_chunks(tmp_path): """Sub-questions provided but all chunk lists empty → fallback message.""" - llm = MagicMock() - llm.complete = AsyncMock() + from app.services.rag import RAGService - service = RAGService(llm_client=llm) + llm = _MockLLM() + client = _setup_chroma(tmp_path) + + service = RAGService(chroma_client=client, llm_client=llm) answer, prompt, grouped_sources = await service.generate_response_per_subquestion( sub_questions=["What is A?", "What is B?"], sub_chunks=[[], []], @@ -99,33 +144,29 @@ async def test_generate_per_subq_no_chunks(): assert answer == "I could not find any relevant information to answer your question." assert grouped_sources == [] - llm.complete.assert_not_called() + llm.assert_not_called() # --------------------------------------------------------------------------- # Test: prompt contains context_sections placeholder # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_generate_per_subq_prompt_contains_context_sections(): +async def test_generate_per_subq_prompt_contains_context_sections(tmp_path): """Verify the prompt sent to LLM contains ### Context for Sub-question 0: header and chunk content.""" - captured_prompt = None + from app.services.rag import RAGService - async def capture_complete(prompt, **kwargs): - nonlocal captured_prompt - captured_prompt = prompt - return "## Sub-question 1: What is A?\n- Answer" + llm = _MockLLM(response="## Sub-question 1: What is A?\n- Answer") + client = _setup_chroma(tmp_path) - llm = MagicMock() - llm.complete = AsyncMock(side_effect=capture_complete) - - service = RAGService(llm_client=llm) + service = RAGService(chroma_client=client, llm_client=llm) await service.generate_response_per_subquestion( sub_questions=["What is A?"], sub_chunks=[["chunk text here"]], sub_metadata=[[{"filename": "file_a.pdf", "page_number": 1, "content_summary": "Sum"}]], ) + captured_prompt = llm.last_prompt assert captured_prompt is not None assert "### Context for Sub-question 0:" in captured_prompt assert "chunk text here" in captured_prompt @@ -136,9 +177,13 @@ async def test_generate_per_subq_prompt_contains_context_sections(): # Test: LLM client not configured # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_generate_per_subq_llm_not_configured(): +async def test_generate_per_subq_llm_not_configured(tmp_path): """llm_client=None → returns 'LLM client not configured' message.""" - service = RAGService(llm_client=None) + from app.services.rag import RAGService + + client = _setup_chroma(tmp_path) + + service = RAGService(chroma_client=client, llm_client=None) answer, prompt, grouped_sources = await service.generate_response_per_subquestion( sub_questions=["What is A?"], sub_chunks=[["some chunk"]], diff --git a/backend/app/test/test_phase4_integration_query_pipeline.py b/backend/app/test/test_phase4_integration_query_pipeline.py index 36b5f81..56445bc 100644 --- a/backend/app/test/test_phase4_integration_query_pipeline.py +++ b/backend/app/test/test_phase4_integration_query_pipeline.py @@ -1,96 +1,157 @@ """Phase 4 integration test: Full per-sub-question query pipeline. -Simulates the complete 4-stage pipeline (decompose → retrieve → filter → generate) -using mocked services, verifying end-to-end data flow and SSE event emission. +Uses TestClient hitting POST /api/v1/query with real ChromaDB and SQLite. +Only the LLM (external API) is mocked. + +Verifies end-to-end data flow and SSE event emission through the real +HTTP endpoint, real ChromaDB retrieval, and real SQLite services. Key behaviours under test: - Full pipeline with 2 sub-questions produces grouped results - Empty decomposition falls back to original question (Decision #13) - Single sub-question still uses ## Sub-question N format - All chunks filtered out returns "no relevant information" -- One sub-q with empty retrieval still produces partial answer - -All external services (LLM, ChromaDB) are mocked. -Tests call ``_query_stream()`` directly via ``async for`` — no HTTP layer. +- One sub-q with all chunks filtered out produces partial answer """ from __future__ import annotations import json -from typing import Any, Dict, List, Tuple -from unittest.mock import AsyncMock, MagicMock, patch +import sqlite3 +import chromadb import pytest +from fastapi.testclient import TestClient -from app.models.query import QueryRequest +from app.core.config import Settings +from app.core.sqlite_db import init_history_db, init_prompts_db, seed_default_profiles -# ── Shared fixtures ────────────────────────────────────────────────────── +# ── Test seed data ────────────────────────────────────────────────────── -CHUNK_A = ( - "Time extensions must be notified within 8 weeks.", - {"filename": "NEC4.pdf", "page_number": 3, "content_summary": "Time extensions", "chunk_index": 0, "upload_date": "2024-01-01"}, - 0.15, -) -CHUNK_B = ( - "Notice must be given to the project manager.", - {"filename": "NEC4.pdf", "page_number": 12, "content_summary": "Notification", "chunk_index": 1, "upload_date": "2024-01-01"}, - 0.22, -) +SEED_DOCS = [ + { + "text": "Time extensions must be notified within 8 weeks.", + "metadata": { + "filename": "NEC4.pdf", + "page_number": 3, + "content_summary": "Time extensions", + "chunk_index": 0, + "upload_date": "2024-01-01", + }, + }, + { + "text": "Notice must be given to the project manager before expiry of the period.", + "metadata": { + "filename": "NEC4.pdf", + "page_number": 12, + "content_summary": "Notification", + "chunk_index": 1, + "upload_date": "2024-01-01", + }, + }, +] -def _make_settings(): - s = MagicMock() - s.retrieval_n_results = 10 - s.relevance_threshold = 7.0 - s.prompts_db_path = ":memory:" - s.history_db_path = ":memory:" - return s +# ── Helpers ──────────────────────────────────────────────────────────── -def _make_prompt_service(): - ps = MagicMock() - ps.get_active_profile_name.return_value = "default" - ps.get_prompt_template = MagicMock( - side_effect=lambda step: { - "decompose": "Given question: '{question}' — decompose.", - "filter": "Rate chunks 0-10 for: {question}\n{chunks}", - "generate": "Answer: {question}\nContext:\n{context}", - "generate_per_subq": "Answer per sub-q:\n{context_sections}", - }.get(step, "") +class _ConstantEmbedding: + """Identical vectors for all inputs — all docs equally match any query.""" + + DIM = 10 + + def __call__(self, input): + return [[0.1] * self.DIM for _ in input] + + def embed_query(self, input): + return self(input) + + def name(self): + return "constant_test" + + +def _make_mock_llm_class(responses): + """Create a mock LLMClient class that returns *responses* in sequence.""" + + class _MockLLM: + def __init__(self, settings): + self.settings = settings + self._responses = list(responses) + self._idx = 0 + + async def complete(self, prompt, temperature=0.7, step_name="LLM"): + if self._idx < len(self._responses): + resp = self._responses[self._idx] + self._idx += 1 + return resp + raise RuntimeError(f"No more mock responses (call #{self._idx + 1})") + + async def close(self): + pass + + return _MockLLM + + +def _setup_env(tmp_path, monkeypatch, seed_docs=None): + """Set up real ChromaDB + SQLite via tmp_path for pipeline tests.""" + seed_docs = seed_docs or SEED_DOCS + chroma_dir = tmp_path / "chroma_db" + chroma_dir.mkdir() + prompts_db = str(tmp_path / "prompts.db") + history_db = str(tmp_path / "history.db") + + test_settings = Settings( + chroma_db_path=str(chroma_dir), + prompts_db_path=prompts_db, + history_db_path=history_db, + llm_base_url="http://mock-llm", + llm_api_key="test-key", + llm_model_name="test-model", + embedding_base_url="http://mock-emb", + embedding_api_key="test-key", + embedding_model="test-model", + ) + + conn = sqlite3.connect(prompts_db) + conn.row_factory = sqlite3.Row + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + conn = sqlite3.connect(history_db) + conn.row_factory = sqlite3.Row + init_history_db(conn) + conn.close() + + embedding_fn = _ConstantEmbedding() + client = chromadb.PersistentClient(path=str(chroma_dir)) + collection = client.get_or_create_collection( + "documents", embedding_function=embedding_fn + ) + for i, doc in enumerate(seed_docs): + collection.add( + documents=[doc["text"]], + metadatas=[doc["metadata"]], + ids=[f"test_doc_{i}"], + ) + + monkeypatch.setattr("app.routers.query.get_settings", lambda: test_settings) + monkeypatch.setattr("app.core.database.get_settings", lambda: test_settings) + monkeypatch.setattr( + "app.core.database.get_embedding_function_settings", + lambda s: embedding_fn, ) - return ps -def _make_llm(decompose_resp, filter_resp, generate_resp): - llm = MagicMock() - llm.complete = AsyncMock(side_effect=[decompose_resp, filter_resp, generate_resp]) - return llm - - -def _make_chroma(chunks: list): - """Return a mock collection that returns *chunks* from query().""" - col = MagicMock() - col.query.return_value = { - "documents": [[c[0] for c in chunks]], - "metadatas": [[c[1] for c in chunks]], - "distances": [[c[2] for c in chunks]], - } - return col - - -def _mock_chroma_client(collection): - client = MagicMock() - return client - - -async def _collect_sse(request: QueryRequest): - """Run _query_stream and return list of parsed SSE event dicts.""" - from app.routers.query import _query_stream +def _collect_sse(client, question): + """POST to /api/v1/query and collect SSE events as parsed dicts.""" events = [] - async for raw in _query_stream(request): - # raw is like "data: {...}\n\n" - for line in raw.split("\n"): + with client.stream( + "POST", "/api/v1/query", json={"question": question} + ) as response: + assert response.status_code == 200 + for line in response.iter_lines(): if line.startswith("data: "): events.append(json.loads(line[6:])) return events @@ -99,10 +160,13 @@ async def _collect_sse(request: QueryRequest): # ── Tests ──────────────────────────────────────────────────────────────── -async def test_full_pipeline_with_two_subquestions(): +def test_full_pipeline_with_two_subquestions(tmp_path, monkeypatch): """Two sub-questions flow through all 4 stages with per-sub-q grouping.""" + _setup_env(tmp_path, monkeypatch) + decompose_resp = '["What are time extensions?", "What notice is required?"]' - filter_resp = '{"0": [8.5, 3.2], "1": [9.0]}' + # 2 sub-qs × 2 chunks each; all above threshold 7.0 + filter_resp = '{"0": [8.5, 9.0], "1": [8.5, 9.0]}' generate_resp = ( "## Sub-question 1: What are time extensions?\n" "- Extensions need 8 weeks notice [NEC4.pdf, page 3]\n\n" @@ -110,63 +174,18 @@ async def test_full_pipeline_with_two_subquestions(): "- Notify the project manager [NEC4.pdf, page 12]\n" ) - llm = _make_llm(decompose_resp, filter_resp, generate_resp) - chroma = _make_chroma([CHUNK_A, CHUNK_B]) - settings = _make_settings() - ps = _make_prompt_service() + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class([decompose_resp, filter_resp, generate_resp]), + ) - request = QueryRequest(question="What are the time extension rules?") + from app.main import app - with patch("app.routers.query.get_settings", return_value=settings), \ - patch("app.routers.query.PromptService", return_value=ps), \ - patch("app.routers.query.LLMClient", return_value=llm), \ - patch("app.routers.query.RAGService") as MockRAG, \ - patch("app.routers.query.QueryDecomposer") as MockDec, \ - patch("app.routers.query.RelevanceFilter") as MockFilter, \ - patch("app.routers.query.HistoryService") as MockHist, \ - patch("app.routers.query._schedule_history"): - - # Wire decomposer - dec = MockDec.return_value - dec.decompose = AsyncMock(return_value=( - ["What are time extensions?", "What notice is required?"], - "decompose-prompt-text" - )) - - # Wire RAG - rag = MockRAG.return_value - rag.retrieve_per_subquestion.return_value = [ - ("What are time extensions?", [CHUNK_A, CHUNK_B]), - ("What notice is required?", [CHUNK_A]), - ] - rag.generate_response_per_subquestion = AsyncMock(return_value=( - generate_resp, - "gen-prompt-text", - [ - [CHUNK_A[1], CHUNK_B[1]], # sources for sub-q 0 - [CHUNK_A[1]], # sources for sub-q 1 - ], - )) - - # Wire filter - filt = MockFilter.return_value - filt.filter_per_subquestion = AsyncMock(return_value=( - [ - ("What are time extensions?", [ - (CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5}), - ]), - ("What notice is required?", [ - (CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 9.0}), - ]), - ], - "filter-prompt-text" - )) - - events = await _collect_sse(request) + client = TestClient(app) + events = _collect_sse(client, "What are the time extension rules?") phases = [e["phase"] for e in events] - # Should emit all expected phases assert "decomposed" in phases assert "retrieving" in phases assert "filtering" in phases @@ -174,11 +193,9 @@ async def test_full_pipeline_with_two_subquestions(): assert "generating_subquestion" in phases assert "completed" in phases - # Decomposed event has extracted questions dec_evt = next(e for e in events if e["phase"] == "decomposed") assert len(dec_evt["extracted_questions"]) == 2 - # Completed event has per-sub-q sources comp_evt = next(e for e in events if e["phase"] == "completed") assert "sub_question_sources" in comp_evt sq_sources = comp_evt["sub_question_sources"] @@ -186,56 +203,35 @@ async def test_full_pipeline_with_two_subquestions(): assert sq_sources[0]["sub_question_text"] == "What are time extensions?" assert sq_sources[1]["sub_question_text"] == "What notice is required?" - # Answer has sub-question headers assert "## Sub-question 1:" in comp_evt["answer"] assert "## Sub-question 2:" in comp_evt["answer"] - # generating_subquestion events gen_subq = [e for e in events if e["phase"] == "generating_subquestion"] assert len(gen_subq) == 2 assert gen_subq[0]["sub_question_index"] == 0 assert gen_subq[1]["sub_question_index"] == 1 -async def test_pipeline_with_empty_decomposition(): +def test_pipeline_with_empty_decomposition(tmp_path, monkeypatch): """Empty decomposition falls back to original question as single sub-q.""" - generate_resp = "## Sub-question 1: What is the time limit?\n- Answer here\n" + _setup_env(tmp_path, monkeypatch) - llm = MagicMock() - ps = _make_prompt_service() - settings = _make_settings() + decompose_resp = "[]" + # 1 fallback sub-q × 2 chunks + filter_resp = '{"0": [8.5, 9.0]}' + generate_resp = ( + "## Sub-question 1: What is the time limit?\n- Answer here\n" + ) - request = QueryRequest(question="What is the time limit?") + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class([decompose_resp, filter_resp, generate_resp]), + ) - with patch("app.routers.query.get_settings", return_value=settings), \ - patch("app.routers.query.PromptService", return_value=ps), \ - patch("app.routers.query.LLMClient", return_value=llm), \ - patch("app.routers.query.RAGService") as MockRAG, \ - patch("app.routers.query.QueryDecomposer") as MockDec, \ - patch("app.routers.query.RelevanceFilter") as MockFilter, \ - patch("app.routers.query.HistoryService") as MockHist, \ - patch("app.routers.query._schedule_history"): + from app.main import app - dec = MockDec.return_value - dec.decompose = AsyncMock(return_value=([], "decompose-prompt")) - - rag = MockRAG.return_value - rag.retrieve_per_subquestion.return_value = [ - ("What is the time limit?", [CHUNK_A]), - ] - rag.generate_response_per_subquestion = AsyncMock(return_value=( - generate_resp, - "gen-prompt", - [[CHUNK_A[1]]], - )) - - filt = MockFilter.return_value - filt.filter_per_subquestion = AsyncMock(return_value=( - [("What is the time limit?", [(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5})])], - "filter-prompt" - )) - - events = await _collect_sse(request) + client = TestClient(app) + events = _collect_sse(client, "What is the time limit?") phases = [e["phase"] for e in events] assert "decomposed" in phases @@ -246,103 +242,65 @@ async def test_pipeline_with_empty_decomposition(): comp_evt = next(e for e in events if e["phase"] == "completed") assert "## Sub-question 1:" in comp_evt["answer"] - rag.retrieve_per_subquestion.assert_called_once_with( - ["What is the time limit?"], n_results=10, - ) - -async def test_pipeline_single_subquestion(): +def test_pipeline_single_subquestion(tmp_path, monkeypatch): """Single sub-question still uses per-sub-q format with ## header.""" + _setup_env(tmp_path, monkeypatch) + + decompose_resp = '["What is X?"]' + filter_resp = '{"0": [8.5, 9.0]}' generate_resp = "## Sub-question 1: What is X?\n- Answer here\n" - llm = MagicMock() - ps = _make_prompt_service() - settings = _make_settings() + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class([decompose_resp, filter_resp, generate_resp]), + ) - request = QueryRequest(question="What is X?") + from app.main import app - with patch("app.routers.query.get_settings", return_value=settings), \ - patch("app.routers.query.PromptService", return_value=ps), \ - patch("app.routers.query.LLMClient", return_value=llm), \ - patch("app.routers.query.RAGService") as MockRAG, \ - patch("app.routers.query.QueryDecomposer") as MockDec, \ - patch("app.routers.query.RelevanceFilter") as MockFilter, \ - patch("app.routers.query.HistoryService") as MockHist, \ - patch("app.routers.query._schedule_history"): - - dec = MockDec.return_value - dec.decompose = AsyncMock(return_value=( - ["What is X?"], - "decompose-prompt" - )) - - rag = MockRAG.return_value - rag.retrieve_per_subquestion.return_value = [ - ("What is X?", [CHUNK_A]), - ] - rag.generate_response_per_subquestion = AsyncMock(return_value=( - generate_resp, - "gen-prompt", - [[CHUNK_A[1]]], - )) - - filt = MockFilter.return_value - filt.filter_per_subquestion = AsyncMock(return_value=( - [("What is X?", [(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5})])], - "filter-prompt" - )) - - events = await _collect_sse(request) + client = TestClient(app) + events = _collect_sse(client, "What is X?") comp_evt = next(e for e in events if e["phase"] == "completed") assert "## Sub-question 1:" in comp_evt["answer"] assert len(comp_evt["sub_question_sources"]) == 1 -async def test_pipeline_filter_all_rejected(): +def test_pipeline_filter_all_rejected(tmp_path, monkeypatch): """All chunks score below threshold — returns 'no relevant information'.""" - llm = MagicMock() - ps = _make_prompt_service() - settings = _make_settings() + _setup_env(tmp_path, monkeypatch) - request = QueryRequest(question="Irrelevant question?") + decompose_resp = '["sub-q-1"]' + # Both chunks score below threshold 7.0 + filter_resp = '{"0": [2.0, 3.0]}' - with patch("app.routers.query.get_settings", return_value=settings), \ - patch("app.routers.query.PromptService", return_value=ps), \ - patch("app.routers.query.LLMClient", return_value=llm), \ - patch("app.routers.query.RAGService") as MockRAG, \ - patch("app.routers.query.QueryDecomposer") as MockDec, \ - patch("app.routers.query.RelevanceFilter") as MockFilter, \ - patch("app.routers.query.HistoryService") as MockHist, \ - patch("app.routers.query._schedule_history"): + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class([decompose_resp, filter_resp]), + ) - dec = MockDec.return_value - dec.decompose = AsyncMock(return_value=( - ["sub-q-1"], - "decompose-prompt" - )) + from app.main import app - rag = MockRAG.return_value - rag.retrieve_per_subquestion.return_value = [ - ("sub-q-1", [CHUNK_A]), - ] - - # All chunks filtered out - filt = MockFilter.return_value - filt.filter_per_subquestion = AsyncMock(return_value=( - [("sub-q-1", [])], - "filter-prompt" - )) - - events = await _collect_sse(request) + client = TestClient(app) + events = _collect_sse(client, "Irrelevant question?") comp_evt = next(e for e in events if e["phase"] == "completed") assert "could not find" in comp_evt["answer"].lower() assert comp_evt["sources"] == [] -async def test_pipeline_retrieval_empty_for_one_subq(): - """One sub-q gets chunks, another gets nothing — partial answer produced.""" +def test_pipeline_retrieval_empty_for_one_subq(tmp_path, monkeypatch): + """One sub-q's chunks all filtered out — partial answer produced. + + With real ChromaDB (constant embedding), both sub-queries retrieve all + seed documents. The filter mock rejects all chunks for sub-q 1 while + keeping sub-q 0's chunks, producing a partial answer. + """ + _setup_env(tmp_path, monkeypatch) + + decompose_resp = '["Has chunks?", "No chunks?"]' + # sub-q 0 keeps all (above threshold), sub-q 1 rejects all (below) + filter_resp = '{"0": [8.5, 9.0], "1": [2.0, 3.0]}' generate_resp = ( "## Sub-question 1: Has chunks?\n" "- Yes [NEC4.pdf, page 3]\n\n" @@ -350,56 +308,21 @@ async def test_pipeline_retrieval_empty_for_one_subq(): "- No relevant information found.\n" ) - llm = MagicMock() - ps = _make_prompt_service() - settings = _make_settings() + monkeypatch.setattr( + "app.routers.query.LLMClient", + _make_mock_llm_class([decompose_resp, filter_resp, generate_resp]), + ) - request = QueryRequest(question="Compare two things") + from app.main import app - with patch("app.routers.query.get_settings", return_value=settings), \ - patch("app.routers.query.PromptService", return_value=ps), \ - patch("app.routers.query.LLMClient", return_value=llm), \ - patch("app.routers.query.RAGService") as MockRAG, \ - patch("app.routers.query.QueryDecomposer") as MockDec, \ - patch("app.routers.query.RelevanceFilter") as MockFilter, \ - patch("app.routers.query.HistoryService") as MockHist, \ - patch("app.routers.query._schedule_history"): - - dec = MockDec.return_value - dec.decompose = AsyncMock(return_value=( - ["Has chunks?", "No chunks?"], - "decompose-prompt" - )) - - rag = MockRAG.return_value - # First sub-q has chunks, second has none - rag.retrieve_per_subquestion.return_value = [ - ("Has chunks?", [CHUNK_A]), - ("No chunks?", []), - ] - rag.generate_response_per_subquestion = AsyncMock(return_value=( - generate_resp, - "gen-prompt", - [[CHUNK_A[1]], []], - )) - - filt = MockFilter.return_value - filt.filter_per_subquestion = AsyncMock(return_value=( - [ - ("Has chunks?", [(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5})]), - ("No chunks?", []), - ], - "filter-prompt" - )) - - events = await _collect_sse(request) + client = TestClient(app) + events = _collect_sse(client, "Compare two things") comp_evt = next(e for e in events if e["phase"] == "completed") assert "## Sub-question 1:" in comp_evt["answer"] assert "## Sub-question 2:" in comp_evt["answer"] - # sub_question_sources has 2 entries sq_sources = comp_evt["sub_question_sources"] assert len(sq_sources) == 2 - assert len(sq_sources[0]["sources"]) > 0 # first sub-q has sources - assert len(sq_sources[1]["sources"]) == 0 # second sub-q has no sources + assert len(sq_sources[0]["sources"]) > 0 + assert len(sq_sources[1]["sources"]) == 0 diff --git a/backend/app/test/test_phase4_relevance_filter_per_subq.py b/backend/app/test/test_phase4_relevance_filter_per_subq.py index fd637df..cc000a3 100644 --- a/backend/app/test/test_phase4_relevance_filter_per_subq.py +++ b/backend/app/test/test_phase4_relevance_filter_per_subq.py @@ -5,23 +5,72 @@ Covers per-sub-question chunk filtering in a single LLM call: - Empty inputs and edge cases - Invalid JSON / score-count mismatch error handling - Threshold boundary behaviour (strict >) + +Uses real PromptService (SQLite via tmp_path) and only mocks the external LLM API. """ import json -import pytest -from unittest.mock import AsyncMock, MagicMock +import sqlite3 -from app.services.relevance_filter import RelevanceFilter +import pytest + +from app.core.sqlite_db import init_prompts_db, seed_default_profiles +from app.services.prompt_service import PromptService + + +class _MockLLM: + """Mock external LLM API.""" + + def __init__(self, response: str = "[]", side_effect: Exception | None = None): + self._response = response + self._side_effect = side_effect + self.last_prompt: str | None = None + self.calls: list[dict] = [] + self._call_count: int = 0 + + async def complete( + self, prompt: str, temperature: float = 0.7, step_name: str = "LLM" + ) -> str: + self.calls.append({"prompt": prompt, "step": step_name}) + self.last_prompt = prompt + self._call_count += 1 + if self._side_effect: + raise self._side_effect + return self._response + + @property + def call_count(self) -> int: + return self._call_count + + def assert_called(self): + assert self._call_count > 0, "LLM.complete was not called" + + def assert_not_called(self): + assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)" + + +def _create_prompt_service(tmp_path) -> PromptService: + """Create a real PromptService backed by real SQLite with seed data.""" + db_path = str(tmp_path / "prompts.db") + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys=ON") + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + return PromptService(db_path=db_path) # --------------------------------------------------------------------------- # Test: basic per-sub-question filtering # --------------------------------------------------------------------------- -async def test_filter_per_subq_basic(mock_prompt_service): +async def test_filter_per_subq_basic(tmp_path): """Two sub-questions, LLM returns per-sub-q scores, threshold filters correctly.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value='{"0": [8.5, 3.2], "1": [9.0]}') + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(response='{"0": [8.5, 3.2], "1": [9.0]}') + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?", "What is B?"], [ @@ -31,7 +80,6 @@ async def test_filter_per_subq_basic(mock_prompt_service): threshold=7.0, ) - # Structure check assert len(results) == 2 assert results[0][0] == "What is A?" assert results[1][0] == "What is B?" @@ -47,38 +95,42 @@ async def test_filter_per_subq_basic(mock_prompt_service): assert results[1][1][0][0] == "chunk B1" assert results[1][1][0][1]["relevance_score"] == 9.0 - # Prompt contains sub-question labels assert prompt != "" assert "Sub-question 0" in prompt assert "Sub-question 1" in prompt - llm.complete.assert_called_once() + llm.assert_called() + assert llm.call_count == 1 # --------------------------------------------------------------------------- # Test: empty input # --------------------------------------------------------------------------- -async def test_filter_per_subq_empty_input(mock_prompt_service): +async def test_filter_per_subq_empty_input(tmp_path): """Empty sub_questions list returns ([], '').""" - llm = MagicMock() - llm.complete = AsyncMock() + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM() + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion([], [], threshold=7.0) assert results == [] assert prompt == "" - llm.complete.assert_not_called() + llm.assert_not_called() # --------------------------------------------------------------------------- # Test: sub-questions with all-empty chunk lists # --------------------------------------------------------------------------- -async def test_filter_per_subq_all_empty_chunks(mock_prompt_service): +async def test_filter_per_subq_all_empty_chunks(tmp_path): """Two sub-questions, both with empty chunk lists → empty filtered lists.""" - llm = MagicMock() - llm.complete = AsyncMock() + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM() + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?", "What is B?"], [[], []], @@ -90,19 +142,20 @@ async def test_filter_per_subq_all_empty_chunks(mock_prompt_service): assert results[0][1] == [] assert results[1][0] == "What is B?" assert results[1][1] == [] - # No LLM call needed when all chunk lists are empty - llm.complete.assert_not_called() + llm.assert_not_called() # --------------------------------------------------------------------------- # Test: LLM returns invalid JSON # --------------------------------------------------------------------------- -async def test_filter_per_subq_llm_returns_invalid_json(mock_prompt_service): +async def test_filter_per_subq_llm_returns_invalid_json(tmp_path): """LLM returns non-JSON string → returns ([], prompt).""" - llm = MagicMock() - llm.complete = AsyncMock(return_value="not json at all") + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(response="not json at all") + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?"], [[("chunk A1", {"filename": "a.pdf"})]], @@ -116,12 +169,14 @@ async def test_filter_per_subq_llm_returns_invalid_json(mock_prompt_service): # --------------------------------------------------------------------------- # Test: score count mismatch # --------------------------------------------------------------------------- -async def test_filter_per_subq_score_count_mismatch(mock_prompt_service): +async def test_filter_per_subq_score_count_mismatch(tmp_path): """Sub-q 0 has 2 chunks but LLM returns only 1 score → returns ([], prompt).""" - llm = MagicMock() - llm.complete = AsyncMock(return_value='{"0": [8.5]}') + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(response='{"0": [8.5]}') + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?"], [[("chunk A1", {"filename": "a.pdf"}), ("chunk A2", {"filename": "a2.pdf"})]], @@ -135,13 +190,14 @@ async def test_filter_per_subq_score_count_mismatch(mock_prompt_service): # --------------------------------------------------------------------------- # Test: strict threshold boundary # --------------------------------------------------------------------------- -async def test_filter_per_subq_passes_threshold_correctly(mock_prompt_service): +async def test_filter_per_subq_passes_threshold_correctly(tmp_path): """Score == threshold is NOT kept (strict >). Score > threshold IS kept.""" - llm = MagicMock() - # Sub-q 0: scores [7.0, 7.1] with threshold 7.0 → only 7.1 kept - llm.complete = AsyncMock(return_value='{"0": [7.0, 7.1]}') + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(response='{"0": [7.0, 7.1]}') + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["Boundary test?"], [[("exact threshold", {"filename": "f1.pdf"}), ("above threshold", {"filename": "f2.pdf"})]], @@ -157,12 +213,14 @@ async def test_filter_per_subq_passes_threshold_correctly(mock_prompt_service): # --------------------------------------------------------------------------- # Test: LLM exception # --------------------------------------------------------------------------- -async def test_filter_per_subq_llm_exception(mock_prompt_service): +async def test_filter_per_subq_llm_exception(tmp_path): """LLM call raises an exception → returns ([], '').""" - llm = MagicMock() - llm.complete = AsyncMock(side_effect=RuntimeError("LLM unavailable")) + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(side_effect=RuntimeError("LLM unavailable")) + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?"], [[("chunk A1", {"filename": "a.pdf"})]], @@ -176,12 +234,14 @@ async def test_filter_per_subq_llm_exception(mock_prompt_service): # --------------------------------------------------------------------------- # Test: JSON wrapped in markdown code block # --------------------------------------------------------------------------- -async def test_filter_per_subq_json_in_markdown_code_block(mock_prompt_service): +async def test_filter_per_subq_json_in_markdown_code_block(tmp_path): """LLM returns JSON inside ```json ... ``` block → should parse correctly.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value='```json\n{"0": [9.0]}\n```') + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(response='```json\n{"0": [9.0]}\n```') + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?"], [[("chunk A1", {"filename": "a.pdf"})]], @@ -196,12 +256,14 @@ async def test_filter_per_subq_json_in_markdown_code_block(mock_prompt_service): # --------------------------------------------------------------------------- # Test: mixed empty and non-empty sub-questions # --------------------------------------------------------------------------- -async def test_filter_per_subq_mixed_empty_and_nonempty(mock_prompt_service): +async def test_filter_per_subq_mixed_empty_and_nonempty(tmp_path): """One sub-q with chunks, one without. Only non-empty ones get scored.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value='{"0": [8.5]}') + from app.services.relevance_filter import RelevanceFilter - rf = RelevanceFilter(llm, prompt_service=mock_prompt_service) + llm = _MockLLM(response='{"0": [8.5]}') + ps = _create_prompt_service(tmp_path) + + rf = RelevanceFilter(llm, prompt_service=ps) results, prompt = await rf.filter_per_subquestion( ["What is A?", "What is B?"], [[("chunk A1", {"filename": "a.pdf"})], []], diff --git a/backend/app/test/test_phase4_response_format.py b/backend/app/test/test_phase4_response_format.py index 88cf7f1..a91bc70 100644 --- a/backend/app/test/test_phase4_response_format.py +++ b/backend/app/test/test_phase4_response_format.py @@ -5,28 +5,53 @@ Covers answer format invariants: - Citation bracket labels in answer text - grouped_sources match sub-question boundaries - Single sub-question still uses header format -""" -import pytest -from unittest.mock import AsyncMock, MagicMock -from app.services.rag import RAGService +Uses real ChromaDB (tmp_path) and only mocks the external LLM API. +""" +import chromadb +import pytest + + +class _MockLLM: + """Mock external LLM API.""" + + def __init__(self, response: str = "mock answer"): + self._response = response + self.last_prompt: str | None = None + self._call_count: int = 0 + + async def complete( + self, prompt: str, temperature: float = 0.7, step_name: str = "LLM" + ) -> str: + self.last_prompt = prompt + self._call_count += 1 + return self._response + + +def _setup_chroma(tmp_path): + """Create an isolated real ChromaDB PersistentClient.""" + chroma_dir = tmp_path / "chroma" + chroma_dir.mkdir(parents=True, exist_ok=True) + return chromadb.PersistentClient(path=str(chroma_dir)) # --------------------------------------------------------------------------- # Test: answer has sub-question headers # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_answer_has_subquestion_headers(): +async def test_answer_has_subquestion_headers(tmp_path): """Answer string contains ## Sub-question N: headers.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value=( + from app.services.rag import RAGService + + llm = _MockLLM(response=( "## Sub-question 1: First question?\n" "- Point one [doc.pdf, page 1]\n\n" "## Sub-question 2: Second question?\n" "- Point two [doc.pdf, page 2]\n" )) + client = _setup_chroma(tmp_path) - service = RAGService(llm_client=llm) + service = RAGService(chroma_client=client, llm_client=llm) answer, _prompt, _sources = await service.generate_response_per_subquestion( sub_questions=["First question?", "Second question?"], sub_chunks=[["chunk1"], ["chunk2"]], @@ -44,15 +69,17 @@ async def test_answer_has_subquestion_headers(): # Test: citations use bracket labels # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_answer_citations_use_bracket_labels(): +async def test_answer_citations_use_bracket_labels(tmp_path): """Answer contains [filename, page N] citation format.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value=( + from app.services.rag import RAGService + + llm = _MockLLM(response=( "## Sub-question 1: What is X?\n" "- X is defined as a variable [report.pdf, page 5]\n" )) + client = _setup_chroma(tmp_path) - service = RAGService(llm_client=llm) + service = RAGService(chroma_client=client, llm_client=llm) answer, _prompt, _sources = await service.generate_response_per_subquestion( sub_questions=["What is X?"], sub_chunks=[["chunk about X"]], @@ -66,14 +93,16 @@ async def test_answer_citations_use_bracket_labels(): # Test: grouped_sources match sub-question boundaries # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_grouped_sources_match_subquestions(): +async def test_grouped_sources_match_subquestions(tmp_path): """Each sub-question's source list only contains metadata from its own chunks.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value=( + from app.services.rag import RAGService + + llm = _MockLLM(response=( "## Sub-question 1: Q1?\n- A1\n\n## Sub-question 2: Q2?\n- A2\n" )) + client = _setup_chroma(tmp_path) - service = RAGService(llm_client=llm) + service = RAGService(chroma_client=client, llm_client=llm) _answer, _prompt, grouped_sources = await service.generate_response_per_subquestion( sub_questions=["Q1?", "Q2?"], sub_chunks=[ @@ -92,10 +121,8 @@ async def test_grouped_sources_match_subquestions(): ) assert len(grouped_sources) == 2 - # Sub-q 0 sources should only contain alpha and beta filenames_0 = {m["filename"] for m in grouped_sources[0]} assert filenames_0 == {"alpha.pdf", "beta.pdf"} - # Sub-q 1 sources should only contain gamma filenames_1 = {m["filename"] for m in grouped_sources[1]} assert filenames_1 == {"gamma.pdf"} @@ -104,15 +131,17 @@ async def test_grouped_sources_match_subquestions(): # Test: single sub-question still uses header format # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_single_subquestion_format(): +async def test_single_subquestion_format(tmp_path): """When only one sub-question, answer still uses ## Sub-question 1: header.""" - llm = MagicMock() - llm.complete = AsyncMock(return_value=( + from app.services.rag import RAGService + + llm = _MockLLM(response=( "## Sub-question 1: What is this?\n" "- It is a test [test.pdf, page 1]\n" )) + client = _setup_chroma(tmp_path) - service = RAGService(llm_client=llm) + service = RAGService(chroma_client=client, llm_client=llm) answer, _prompt, grouped_sources = await service.generate_response_per_subquestion( sub_questions=["What is this?"], sub_chunks=[["test chunk"]], diff --git a/backend/app/test/test_phase4_retrieve_per_subquestion.py b/backend/app/test/test_phase4_retrieve_per_subquestion.py index 2676fef..3a5914a 100644 --- a/backend/app/test/test_phase4_retrieve_per_subquestion.py +++ b/backend/app/test/test_phase4_retrieve_per_subquestion.py @@ -7,126 +7,156 @@ Covers: - Verify retrieve() is called once per sub-question - n_results parameter passthrough - Handling of empty results for individual sub-questions + +All tests use real ChromaDB via tmp_path. No mocks for DB or internal services. """ import pytest -from unittest.mock import MagicMock +import chromadb from app.services.rag import RAGService +def _setup_chroma(tmp_path, monkeypatch, collection_name: str = "documents"): + from app.core.config import get_settings + + monkeypatch.setenv("CHROMA_DB_PATH", str(tmp_path / "test_chroma")) + get_settings.cache_clear() + + client = chromadb.PersistentClient(path=str(tmp_path / "test_chroma")) + collection = client.get_or_create_collection(name=collection_name) + return client, collection + + class TestRetrievePerSubquestion: """Tests for RAGService.retrieve_per_subquestion().""" - @staticmethod - def _make_service() -> RAGService: - """Create a RAGService with a mocked collection.""" - mock_collection = MagicMock() - mock_client = MagicMock() - mock_client.get_or_create_collection.return_value = mock_collection - service = RAGService(chroma_client=mock_client) - service._collection = mock_collection - return service - - def test_retrieve_per_subquestion_two_subqs(self): + def test_retrieve_per_subquestion_two_subqs(self, tmp_path, monkeypatch): """Two sub-questions should each return their own chunks.""" - service = self._make_service() - service._collection.query.side_effect = [ - { - "documents": [["chunk A1", "chunk A2"]], - "metadatas": [[{"filename": "a.pdf"}, {"filename": "a2.pdf"}]], - "distances": [[0.1, 0.2]], - }, - { - "documents": [["chunk B1"]], - "metadatas": [[{"filename": "b.pdf"}]], - "distances": [[0.3]], - }, - ] + client, collection = _setup_chroma(tmp_path, monkeypatch) + collection.add( + documents=["Alpha content about quantum physics", "Alpha extra about quantum physics"], + metadatas=[{"filename": "a.pdf"}, {"filename": "a2.pdf"}], + ids=["a1", "a2"], + ) + collection.add( + documents=["Beta content about machine learning"], + metadatas=[{"filename": "b.pdf"}], + ids=["b1"], + ) + + service = RAGService(chroma_client=client) results = service.retrieve_per_subquestion( - ["What is A?", "What is B?"], n_results=5 + ["quantum physics", "machine learning"], n_results=5 ) assert len(results) == 2 - assert results[0][0] == "What is A?" - assert len(results[0][1]) == 2 - assert results[0][1][0] == ("chunk A1", {"filename": "a.pdf"}, 0.1) - assert results[0][1][1] == ("chunk A2", {"filename": "a2.pdf"}, 0.2) + assert results[0][0] == "quantum physics" + assert len(results[0][1]) >= 1 + assert "quantum" in results[0][1][0][0].lower() - assert results[1][0] == "What is B?" - assert len(results[1][1]) == 1 - assert results[1][1][0] == ("chunk B1", {"filename": "b.pdf"}, 0.3) + assert results[1][0] == "machine learning" + assert len(results[1][1]) >= 1 + assert "machine learning" in results[1][1][0][0].lower() - def test_retrieve_per_subquestion_empty_list(self): + def test_retrieve_per_subquestion_empty_list(self, tmp_path, monkeypatch): """Empty sub_questions list returns empty list.""" - service = self._make_service() + client, _ = _setup_chroma(tmp_path, monkeypatch) + service = RAGService(chroma_client=client) + results = service.retrieve_per_subquestion([], n_results=10) assert results == [] - def test_retrieve_per_subquestion_single_subq(self): + def test_retrieve_per_subquestion_single_subq(self, tmp_path, monkeypatch): """Single sub-question returns a single-element result list.""" - service = self._make_service() - service._collection.query.return_value = { - "documents": [["chunk X"]], - "metadatas": [[{"filename": "x.pdf"}]], - "distances": [[0.05]], - } + client, collection = _setup_chroma(tmp_path, monkeypatch) - results = service.retrieve_per_subquestion(["Only question"], n_results=3) - - assert len(results) == 1 - assert results[0][0] == "Only question" - assert len(results[0][1]) == 1 - assert results[0][1][0] == ("chunk X", {"filename": "x.pdf"}, 0.05) - - def test_retrieve_per_subquestion_calls_retrieve_n_times(self): - """retrieve() should be called once per sub-question with correct args.""" - service = self._make_service() - - # Mock retrieve to return empty chunks so we can spy on calls - service.retrieve = MagicMock(return_value=[]) - - sub_questions = ["Q1", "Q2", "Q3"] - service.retrieve_per_subquestion(sub_questions, n_results=7) - - assert service.retrieve.call_count == 3 - service.retrieve.assert_any_call(["Q1"], n_results=7) - service.retrieve.assert_any_call(["Q2"], n_results=7) - service.retrieve.assert_any_call(["Q3"], n_results=7) - - def test_retrieve_per_subquestion_preserves_n_results(self): - """n_results parameter is passed through to each retrieve() call.""" - service = self._make_service() - service.retrieve = MagicMock(return_value=[]) - - service.retrieve_per_subquestion(["Q1"], n_results=42) - - service.retrieve.assert_called_once_with(["Q1"], n_results=42) - - def test_retrieve_per_subquestion_handles_empty_results(self): - """One sub-q returns no results, another returns results.""" - service = self._make_service() - - # First call returns empty, second returns data - service.retrieve = MagicMock( - side_effect=[ - [], - [("chunk B", {"filename": "b.pdf"}, 0.2)], - ] + collection.add( + documents=["Unique content about solar energy"], + metadatas=[{"filename": "x.pdf"}], + ids=["x1"], ) + service = RAGService(chroma_client=client) + results = service.retrieve_per_subquestion(["solar energy"], n_results=3) + + assert len(results) == 1 + assert results[0][0] == "solar energy" + assert len(results[0][1]) >= 1 + assert "solar" in results[0][1][0][0].lower() + assert results[0][1][0][1]["filename"] == "x.pdf" + + def test_retrieve_per_subquestion_calls_retrieve_n_times(self, tmp_path, monkeypatch): + """retrieve() should be called once per sub-question.""" + client, _ = _setup_chroma(tmp_path, monkeypatch) + service = RAGService(chroma_client=client) + + call_log: list[tuple] = [] + original_retrieve = service.retrieve + + def _tracking_retrieve(query_keywords, n_results=10): + call_log.append((query_keywords, n_results)) + return original_retrieve(query_keywords, n_results=n_results) + + service.retrieve = _tracking_retrieve + + sub_questions = ["apples", "bananas", "cherries"] + service.retrieve_per_subquestion(sub_questions, n_results=7) + + assert len(call_log) == 3 + for i, sub_q in enumerate(sub_questions): + assert call_log[i][0] == [sub_q] + assert call_log[i][1] == 7 + + def test_retrieve_per_subquestion_preserves_n_results(self, tmp_path, monkeypatch): + """n_results parameter is passed through to each retrieve() call.""" + client, _ = _setup_chroma(tmp_path, monkeypatch) + service = RAGService(chroma_client=client) + + captured_n_results: list[int] = [] + original_retrieve = service.retrieve + + def _capture_retrieve(query_keywords, n_results=10): + captured_n_results.append(n_results) + return original_retrieve(query_keywords, n_results=n_results) + + service.retrieve = _capture_retrieve + + service.retrieve_per_subquestion(["test query"], n_results=42) + + assert captured_n_results == [42] + + def test_retrieve_per_subquestion_handles_empty_results(self, tmp_path, monkeypatch): + """One sub-q returns fewer/no results, another returns results.""" + client, collection = _setup_chroma(tmp_path, monkeypatch) + + collection.add( + documents=[ + "Deep dive into quantum entanglement experiments", + "Quantum entanglement and Bell's theorem", + "History of Renaissance art in Florence", + ], + metadatas=[ + {"filename": "b.pdf"}, + {"filename": "b2.pdf"}, + {"filename": "art.pdf"}, + ], + ids=["b1", "b2", "a1"], + ) + + service = RAGService(chroma_client=client) results = service.retrieve_per_subquestion( - ["No results Q", "Has results Q"], n_results=5 + ["Renaissance art Florence", "quantum entanglement"], n_results=2 ) assert len(results) == 2 - # First sub-question has empty chunks - assert results[0][0] == "No results Q" - assert results[0][1] == [] + assert results[0][0] == "Renaissance art Florence" + assert results[1][0] == "quantum entanglement" - # Second sub-question has chunks - assert results[1][0] == "Has results Q" - assert len(results[1][1]) == 1 - assert results[1][1][0] == ("chunk B", {"filename": "b.pdf"}, 0.2) + assert len(results[1][1]) >= 1 + assert "quantum" in results[1][1][0][0].lower() + + assert len(results[0][1]) >= 1 + assert "Renaissance" in results[0][1][0][0] diff --git a/backend/app/utils/metadata.py b/backend/app/utils/metadata.py index 2b94cce..e6cb538 100644 --- a/backend/app/utils/metadata.py +++ b/backend/app/utils/metadata.py @@ -67,10 +67,14 @@ def extract_metadata( "upload_date": upload_date, "content_summary": content_summary, "chunk_index": idx, - "page_number": page_numbers[idx] if page_numbers else None, - "chunk_file_path": chunk_file_paths[idx] if chunk_file_paths else None, "document_id": document_id, } + page_num = page_numbers[idx] if page_numbers else None + if page_num is not None: + entry["page_number"] = page_num + cfp = chunk_file_paths[idx] if chunk_file_paths else None + if cfp is not None: + entry["chunk_file_path"] = cfp metadata.append(entry) return metadata