refactor(test): rewrite tests to comply with integration-first rules
Replace mocked DB/internal-services with real ChromaDB/SQLite via tmp_path. Only mock truly external APIs (LLM, embedding for deterministic vectors). 13 test files rewritten (314 pass, 0 fail): - Route tests: use TestClient + real ChromaDB, seed test data - Service tests: use real PersistentClient/SQLite instances - Pipeline tests: TestClient hits SSE /query endpoint, verify history - Converted unittest.TestCase to pytest where applicable Plus: fix metadata.py to filter None values from ChromaDB metadata (pre-existing bug caught by real-DB ingestion tests)
This commit is contained in:
parent
3b868a0133
commit
2656f9ca08
|
|
@ -2,72 +2,69 @@
|
|||
|
||||
Coverage:
|
||||
- GET /api/v1/chunks/{file_path}/pdf — success, 404, path traversal 400
|
||||
|
||||
Uses real filesystem (tmp_path) via monkeypatch.setenv — no mocks on internal services.
|
||||
"""
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
"""TestClient with DOCUMENT_CHUNK_PATH pointing to a temp directory."""
|
||||
chunk_dir = tmp_path / "chunks"
|
||||
chunk_dir.mkdir()
|
||||
monkeypatch.setenv("DOCUMENT_CHUNK_PATH", str(chunk_dir))
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", str(tmp_path / "chroma_test"))
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
from app.main import app
|
||||
yield TestClient(app)
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
class TestChunkServing(unittest.TestCase):
|
||||
"""Test GET /api/v1/chunks/{file_path}/pdf endpoint."""
|
||||
def test_get_chunk_pdf_success(client, tmp_path):
|
||||
"""Should serve chunk PDF file with 200 and application/pdf."""
|
||||
chunk_dir = tmp_path / "chunks"
|
||||
test_file = chunk_dir / "test_page_1.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
def setUp(self):
|
||||
self.client = TestClient(app)
|
||||
response = client.get("/api/v1/chunks/test_page_1.pdf/pdf")
|
||||
|
||||
def test_get_chunk_pdf_success(self):
|
||||
"""Should serve chunk PDF file with 200 and application/pdf."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_file = os.path.join(tmp_dir, "test_page_1.pdf")
|
||||
with open(test_file, "wb") as f:
|
||||
f.write(b"%PDF-1.4 fake content")
|
||||
assert response.status_code == 200
|
||||
assert "application/pdf" in response.headers["content-type"]
|
||||
|
||||
with patch("app.core.config.get_settings") as mock_settings:
|
||||
mock_settings.return_value.document_chunk_path = tmp_dir
|
||||
response = self.client.get("/api/v1/chunks/test_page_1.pdf/pdf")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn("application/pdf", response.headers["content-type"])
|
||||
def test_get_chunk_pdf_not_found(client):
|
||||
"""Should return 404 for non-existent chunk file."""
|
||||
response = client.get("/api/v1/chunks/nonexistent.pdf/pdf")
|
||||
|
||||
def test_get_chunk_pdf_not_found(self):
|
||||
"""Should return 404 for non-existent chunk file."""
|
||||
with patch("app.core.config.get_settings") as mock_settings:
|
||||
mock_settings.return_value.document_chunk_path = "/tmp/nonexistent_chunk_dir"
|
||||
response = self.client.get("/api/v1/chunks/nonexistent.pdf/pdf")
|
||||
assert response.status_code == 404
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_get_chunk_pdf_path_traversal_double_dot(self):
|
||||
"""Should reject path traversal with .. (404 due to Starlette normalization)."""
|
||||
with patch("app.core.config.get_settings") as mock_settings:
|
||||
mock_settings.return_value.document_chunk_path = "/tmp/fake_chunk_dir"
|
||||
response = self.client.get("/api/v1/chunks/../etc/passwd/pdf")
|
||||
def test_get_chunk_pdf_path_traversal_double_dot(client):
|
||||
"""Should reject path traversal with .. (400 or 404 from Starlette normalization)."""
|
||||
response = client.get("/api/v1/chunks/../etc/passwd/pdf")
|
||||
|
||||
self.assertIn(response.status_code, [400, 404])
|
||||
assert response.status_code in (400, 404)
|
||||
|
||||
def test_get_chunk_pdf_path_traversal_symlink_escape(self):
|
||||
"""Should reject resolved path escaping base directory (404 from normalization)."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with patch("app.core.config.get_settings") as mock_settings:
|
||||
mock_settings.return_value.document_chunk_path = tmp_dir
|
||||
response = self.client.get("/api/v1/chunks/../../etc/passwd/pdf")
|
||||
|
||||
self.assertIn(response.status_code, [400, 404])
|
||||
def test_get_chunk_pdf_path_traversal_symlink_escape(client):
|
||||
"""Should reject resolved path escaping base directory (400 or 404)."""
|
||||
response = client.get("/api/v1/chunks/../../etc/passwd/pdf")
|
||||
|
||||
def test_get_chunk_pdf_with_spaces_in_filename(self):
|
||||
"""Should serve files with spaces in the filename."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
test_file = os.path.join(tmp_dir, "NEC4 ACC_page_3.pdf")
|
||||
with open(test_file, "wb") as f:
|
||||
f.write(b"%PDF-1.4 fake content")
|
||||
assert response.status_code in (400, 404)
|
||||
|
||||
with patch("app.core.config.get_settings") as mock_settings:
|
||||
mock_settings.return_value.document_chunk_path = tmp_dir
|
||||
response = self.client.get("/api/v1/chunks/NEC4 ACC_page_3.pdf/pdf")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertIn("application/pdf", response.headers["content-type"])
|
||||
def test_get_chunk_pdf_with_spaces_in_filename(client, tmp_path):
|
||||
"""Should serve files with spaces in the filename."""
|
||||
chunk_dir = tmp_path / "chunks"
|
||||
test_file = chunk_dir / "NEC4 ACC_page_3.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
response = client.get("/api/v1/chunks/NEC4 ACC_page_3.pdf/pdf")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "application/pdf" in response.headers["content-type"]
|
||||
|
|
|
|||
|
|
@ -5,164 +5,162 @@ Covers:
|
|||
- GET /documents/{id}/chunks
|
||||
- DELETE /documents/{id}
|
||||
- DELETE /chunks/{id}
|
||||
|
||||
Uses real ChromaDB via tmp_path + TestClient — no mocks on internal services.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestDocumentsRouter:
|
||||
"""Documents CRUD endpoint tests."""
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
"""TestClient with real ChromaDB isolated in tmp_path."""
|
||||
chroma_dir = tmp_path / "chroma_test"
|
||||
chunk_dir = tmp_path / "chunks"
|
||||
chunk_dir.mkdir()
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", str(chroma_dir))
|
||||
monkeypatch.setenv("DOCUMENT_CHUNK_PATH", str(chunk_dir))
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
from app.main import app
|
||||
yield TestClient(app)
|
||||
get_settings.cache_clear()
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create test client with mocked dependencies."""
|
||||
from app.main import app
|
||||
return TestClient(app)
|
||||
|
||||
def test_list_documents_empty(self, client):
|
||||
"""Should return empty list when no documents exist."""
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
def _seed_document(tmp_path, monkeypatch, document_id, filename, num_chunks, chunk_file_paths=None):
|
||||
"""Ingest test document into the real ChromaDB used by the client fixture.
|
||||
|
||||
response = client.get("/api/v1/documents")
|
||||
Must be called AFTER the `client` fixture has been established so that
|
||||
get_settings() resolves to the same tmp_path ChromaDB directory.
|
||||
"""
|
||||
from app.core.config import get_settings
|
||||
from app.services.rag import RAGService
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["documents"] == []
|
||||
assert data["total_documents"] == 0
|
||||
assert data["total_chunks"] == 0
|
||||
settings = get_settings()
|
||||
rag = RAGService(settings=settings)
|
||||
|
||||
def test_list_documents_with_data(self, client):
|
||||
"""Should return grouped documents with chunk counts."""
|
||||
doc_list = [
|
||||
{
|
||||
"document_id": "abc-123",
|
||||
"filename": "report.pdf",
|
||||
"chunk_count": 3,
|
||||
"upload_date": "2026-04-23",
|
||||
},
|
||||
{
|
||||
"document_id": "def-456",
|
||||
"filename": "notes.txt",
|
||||
"chunk_count": 1,
|
||||
"upload_date": "2026-04-22",
|
||||
},
|
||||
]
|
||||
chunks = [f"chunk content {i}" for i in range(num_chunks)]
|
||||
metadata_list = []
|
||||
for i in range(num_chunks):
|
||||
meta = {
|
||||
"filename": filename,
|
||||
"upload_date": "2026-04-23",
|
||||
"content_summary": f"summary {i}",
|
||||
"chunk_index": i,
|
||||
}
|
||||
if chunk_file_paths and i < len(chunk_file_paths):
|
||||
meta["chunk_file_path"] = chunk_file_paths[i]
|
||||
metadata_list.append(meta)
|
||||
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.list_documents.return_value = (doc_list, 2, 4)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
rag.ingest_document(
|
||||
file_path=filename,
|
||||
chunks=chunks,
|
||||
metadata_list=metadata_list,
|
||||
document_id=document_id,
|
||||
)
|
||||
return document_id
|
||||
|
||||
response = client.get("/api/v1/documents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_documents"] == 2
|
||||
assert data["total_chunks"] == 4
|
||||
assert len(data["documents"]) == 2
|
||||
assert data["documents"][0]["document_id"] == "abc-123"
|
||||
assert data["documents"][0]["filename"] == "report.pdf"
|
||||
assert data["documents"][0]["chunk_count"] == 3
|
||||
def test_list_documents_empty(client):
|
||||
"""Should return empty list when no documents exist."""
|
||||
response = client.get("/api/v1/documents")
|
||||
|
||||
def test_list_chunks_for_document(self, client):
|
||||
"""Should return all chunks for a given document_id."""
|
||||
chunks = [
|
||||
{
|
||||
"chunk_id": "abc-123_0",
|
||||
"chunk_index": 0,
|
||||
"content_summary": "First chunk summary",
|
||||
"page_number": 1,
|
||||
"chunk_file_path": None,
|
||||
},
|
||||
{
|
||||
"chunk_id": "abc-123_1",
|
||||
"chunk_index": 1,
|
||||
"content_summary": "Second chunk summary",
|
||||
"page_number": 2,
|
||||
"chunk_file_path": None,
|
||||
},
|
||||
]
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["documents"] == []
|
||||
assert data["total_documents"] == 0
|
||||
assert data["total_chunks"] == 0
|
||||
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.list_chunks.return_value = chunks
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
response = client.get("/api/v1/documents/abc-123/chunks")
|
||||
def test_list_documents_with_data(client, tmp_path, monkeypatch):
|
||||
"""Should return grouped documents with chunk counts."""
|
||||
_seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 3)
|
||||
_seed_document(tmp_path, monkeypatch, "def-456", "notes.txt", 1)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["chunk_id"] == "abc-123_0"
|
||||
assert data[0]["chunk_index"] == 0
|
||||
assert data[0]["content_summary"] == "First chunk summary"
|
||||
assert data[1]["chunk_index"] == 1
|
||||
response = client.get("/api/v1/documents")
|
||||
|
||||
def test_list_chunks_document_not_found(self, client):
|
||||
"""Should return empty list for nonexistent document."""
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.list_chunks.return_value = []
|
||||
mock_rag_class.return_value = mock_rag
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_documents"] == 2
|
||||
assert data["total_chunks"] == 4
|
||||
assert len(data["documents"]) == 2
|
||||
|
||||
response = client.get("/api/v1/documents/nonexistent-id/chunks")
|
||||
by_id = {d["document_id"]: d for d in data["documents"]}
|
||||
assert by_id["abc-123"]["filename"] == "report.pdf"
|
||||
assert by_id["abc-123"]["chunk_count"] == 3
|
||||
assert by_id["def-456"]["filename"] == "notes.txt"
|
||||
assert by_id["def-456"]["chunk_count"] == 1
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data == []
|
||||
|
||||
def test_delete_document_success(self, client):
|
||||
"""Should delete all chunks for a document and return confirmation."""
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.delete_document.return_value = (True, 3)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
def test_list_chunks_for_document(client, tmp_path, monkeypatch):
|
||||
"""Should return all chunks for a given document_id."""
|
||||
_seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 2)
|
||||
|
||||
response = client.delete("/api/v1/documents/abc-123")
|
||||
response = client.get("/api/v1/documents/abc-123/chunks")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["deleted"] is True
|
||||
assert "3 chunks removed" in data["message"]
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["chunk_id"] == "abc-123_0"
|
||||
assert data[0]["chunk_index"] == 0
|
||||
assert data[0]["content_summary"] == "summary 0"
|
||||
assert data[1]["chunk_index"] == 1
|
||||
|
||||
def test_delete_document_not_found(self, client):
|
||||
"""Should return 404 for nonexistent document."""
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.delete_document.return_value = (False, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
response = client.delete("/api/v1/documents/nonexistent-id")
|
||||
def test_list_chunks_document_not_found(client):
|
||||
"""Should return empty list for nonexistent document."""
|
||||
response = client.get("/api/v1/documents/nonexistent-id/chunks")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data == []
|
||||
|
||||
def test_delete_chunk_success(self, client):
|
||||
"""Should delete a single chunk and return confirmation."""
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.delete_chunk.return_value = True
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
response = client.delete("/api/v1/chunks/abc-123_0")
|
||||
def test_delete_document_success(client, tmp_path, monkeypatch):
|
||||
"""Should delete all chunks for a document and return confirmation."""
|
||||
_seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 3)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["deleted"] is True
|
||||
assert "abc-123_0" in data["message"]
|
||||
response = client.delete("/api/v1/documents/abc-123")
|
||||
|
||||
def test_delete_chunk_not_found(self, client):
|
||||
"""Should return 404 for nonexistent chunk."""
|
||||
with patch("app.routers.documents.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.delete_chunk.return_value = False
|
||||
mock_rag_class.return_value = mock_rag
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["deleted"] is True
|
||||
assert "3 chunks removed" in data["message"]
|
||||
|
||||
response = client.delete("/api/v1/chunks/nonexistent-chunk")
|
||||
# Verify actually deleted
|
||||
response = client.get("/api/v1/documents")
|
||||
assert response.json()["total_documents"] == 0
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_delete_document_not_found(client):
|
||||
"""Should return 404 for nonexistent document."""
|
||||
response = client.delete("/api/v1/documents/nonexistent-id")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_delete_chunk_success(client, tmp_path, monkeypatch):
|
||||
"""Should delete a single chunk and return confirmation."""
|
||||
_seed_document(tmp_path, monkeypatch, "abc-123", "report.pdf", 2)
|
||||
|
||||
response = client.delete("/api/v1/chunks/abc-123_0")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["deleted"] is True
|
||||
assert "abc-123_0" in data["message"]
|
||||
|
||||
# Verify chunk gone but other chunk remains
|
||||
response = client.get("/api/v1/documents/abc-123/chunks")
|
||||
chunks = response.json()
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0]["chunk_id"] == "abc-123_1"
|
||||
|
||||
|
||||
def test_delete_chunk_not_found(client):
|
||||
"""Should return 404 for nonexistent chunk."""
|
||||
response = client.delete("/api/v1/chunks/nonexistent-chunk")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
|
|
|||
|
|
@ -184,5 +184,5 @@ def test_extract_metadata_page_numbers_none_in_list(tmp_path):
|
|||
)
|
||||
|
||||
assert len(metadata) == 2
|
||||
assert metadata[0]["page_number"] is None
|
||||
assert "page_number" not in metadata[0]
|
||||
assert metadata[1]["page_number"] == 1
|
||||
|
|
|
|||
|
|
@ -1,96 +1,194 @@
|
|||
"""Phase 1 tests: Document ingestion endpoint.
|
||||
|
||||
Covers:
|
||||
- POST /api/v1/ingest with valid documents
|
||||
- POST /api/v1/ingest with valid documents (PDF, DOCX, TXT)
|
||||
- Metadata extraction (filename, upload_date, content_summary)
|
||||
- ChromaDB persistence with embeddings
|
||||
- ChromaDB persistence (verify by querying real collection)
|
||||
- Error handling for unsupported file types
|
||||
- Error handling for missing file field
|
||||
|
||||
Uses TestClient + real ChromaDB + real chunking + real metadata extraction.
|
||||
Embedding function is mocked with deterministic vectors (external API).
|
||||
No LLM calls involved in the ingest pipeline.
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
from pypdf import PdfWriter
|
||||
|
||||
from app.routers.ingest import router
|
||||
|
||||
|
||||
class _DeterministicEmbedding:
|
||||
def name(self) -> str:
|
||||
return "test_deterministic"
|
||||
|
||||
def __call__(self, input):
|
||||
return self._embed(input)
|
||||
|
||||
def embed_query(self, input):
|
||||
return self._embed(input)
|
||||
|
||||
@staticmethod
|
||||
def _embed(texts):
|
||||
vectors = []
|
||||
for text in texts:
|
||||
vec = [0.0] * 384
|
||||
for i, ch in enumerate(text[:384]):
|
||||
vec[i] = ord(ch) / 1000.0
|
||||
vectors.append(vec)
|
||||
return vectors
|
||||
|
||||
|
||||
def _create_real_pdf(content: str) -> bytes:
|
||||
from pypdf import PdfWriter
|
||||
writer = PdfWriter()
|
||||
writer.add_blank_page(width=200, height=200)
|
||||
page = writer.pages[0]
|
||||
# Add text content via page-level operator (simple approach)
|
||||
# pypdf blank pages have no text — we write the content as annotation
|
||||
# For testing, we just need a valid PDF; actual text extraction tested separately
|
||||
buf = io.BytesIO()
|
||||
writer.write(buf)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _create_text_pdf(lines: list[str]) -> bytes:
|
||||
"""Create a PDF with actual extractable text using reportlab if available."""
|
||||
try:
|
||||
from reportlab.pdfgen import canvas as rl_canvas
|
||||
buf = io.BytesIO()
|
||||
c = rl_canvas.Canvas(buf)
|
||||
y = 750
|
||||
for line in lines:
|
||||
c.drawString(72, y, line)
|
||||
y -= 20
|
||||
c.save()
|
||||
return buf.getvalue()
|
||||
except ImportError:
|
||||
# Fallback: pypdf blank PDF (no extractable text)
|
||||
return _create_real_pdf("")
|
||||
|
||||
|
||||
def _create_real_docx(paragraphs: list[str]) -> bytes:
|
||||
try:
|
||||
from docx import Document
|
||||
doc = Document()
|
||||
for para in paragraphs:
|
||||
doc.add_paragraph(para)
|
||||
buf = io.BytesIO()
|
||||
doc.save(buf)
|
||||
return buf.getvalue()
|
||||
except ImportError:
|
||||
return b""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
chroma_path = str(tmp_path / "chroma_db")
|
||||
chunk_path = str(tmp_path / "document_chunk")
|
||||
prompts_path = str(tmp_path / "prompts.db")
|
||||
history_path = str(tmp_path / "history.db")
|
||||
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", chroma_path)
|
||||
monkeypatch.setenv("DOCUMENT_CHUNK_PATH", chunk_path)
|
||||
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
||||
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
||||
monkeypatch.setenv("EMBEDDING_MODEL", "test-mock")
|
||||
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
||||
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
from app.core.dependencies import get_settings_cached
|
||||
get_settings_cached.cache_clear()
|
||||
|
||||
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
|
||||
conn = _get_db(prompts_path)
|
||||
init_prompts_db(conn)
|
||||
seed_default_profiles(conn)
|
||||
conn.close()
|
||||
|
||||
hconn = _get_db(history_path)
|
||||
init_history_db(hconn)
|
||||
hconn.close()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.core.database.get_embedding_function_settings",
|
||||
lambda settings: _DeterministicEmbedding(),
|
||||
)
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(router, prefix="/api/v1")
|
||||
|
||||
yield TestClient(test_app)
|
||||
|
||||
get_settings_cached.cache_clear()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
class TestIngest:
|
||||
"""Document upload and ChromaDB ingestion tests."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create test client with mocked dependencies."""
|
||||
from app.main import app
|
||||
return TestClient(app)
|
||||
def test_ingest_txt_success(self, client, tmp_path):
|
||||
"""Should ingest TXT and return document ID with metadata. Verify real ChromaDB."""
|
||||
import chromadb
|
||||
from app.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
def test_ingest_pdf_success(self, client, tmp_path):
|
||||
"""Should ingest PDF and return document ID with metadata."""
|
||||
import io
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = "doc-123"
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
with patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse:
|
||||
mock_parse.return_value = [(1, "Page 1 text"), (2, "Page 2 text")]
|
||||
|
||||
with patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class:
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk_pages.return_value = [("chunk 1", 1), ("chunk 2", 2)]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
|
||||
with patch("app.utils.metadata.extract_metadata") as mock_meta:
|
||||
mock_meta.return_value = [
|
||||
{"filename": "test.pdf", "chunk_index": 0},
|
||||
{"filename": "test.pdf", "chunk_index": 1},
|
||||
]
|
||||
|
||||
with patch("app.utils.pdf_extractor.extract_page_as_pdf"):
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("notes.txt", io.BytesIO(b"This is a test document about testing.\nIt has multiple lines of content."), "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "document_id" in data
|
||||
assert data["chunk_count"] == 2
|
||||
assert data["filename"] == "test.pdf"
|
||||
assert data["chunk_count"] >= 1
|
||||
assert data["filename"] == "notes.txt"
|
||||
|
||||
# Verify data persisted in real ChromaDB
|
||||
db_client = chromadb.PersistentClient(path=settings.chroma_db_path)
|
||||
collection = db_client.get_collection("documents")
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
assert len(all_data["ids"]) >= 1
|
||||
filenames = [m["filename"] for m in all_data["metadatas"]]
|
||||
assert "notes.txt" in filenames
|
||||
|
||||
def test_ingest_docx_success(self, client, tmp_path):
|
||||
"""Should ingest DOCX and return document ID with metadata."""
|
||||
import io
|
||||
docx_bytes = _create_real_docx(["Paragraph one content.", "Paragraph two content."])
|
||||
if not docx_bytes:
|
||||
pytest.skip("python-docx not installed")
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = "doc-456"
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
with patch("app.utils.docx_parser.parse_docx") as mock_parse:
|
||||
mock_parse.return_value = "Parsed DOCX text content"
|
||||
|
||||
with patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class:
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk.return_value = ["chunk 1"]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
|
||||
with patch("app.utils.metadata.extract_metadata") as mock_meta:
|
||||
mock_meta.return_value = [{"filename": "test.docx", "chunk_index": 0}]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.docx", io.BytesIO(b"docx content"), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.docx", io.BytesIO(docx_bytes),
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["chunk_count"] == 1
|
||||
assert data["chunk_count"] >= 1
|
||||
assert data["filename"] == "test.docx"
|
||||
|
||||
def test_ingest_pdf_success(self, client, tmp_path):
|
||||
"""Should ingest PDF and return document ID with metadata."""
|
||||
pdf_bytes = _create_text_pdf(["Page 1 line one", "Page 1 line two"])
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "document_id" in data
|
||||
assert data["filename"] == "test.pdf"
|
||||
|
||||
def test_ingest_unsupported_format(self, client):
|
||||
"""Should reject unsupported file formats."""
|
||||
import io
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.jpg", io.BytesIO(b"image data"), "image/jpeg")},
|
||||
|
|
|
|||
|
|
@ -1,435 +1,361 @@
|
|||
"""Phase 1.5.5c tests: Page-aware ingest router.
|
||||
|
||||
Covers:
|
||||
1. PDF upload triggers page-aware pipeline (parse_pdf_by_page, chunk_pages, extract_page_as_pdf)
|
||||
2. DOCX upload uses old pipeline with document_id
|
||||
3. TXT upload uses old pipeline with document_id
|
||||
1. PDF upload triggers page-aware pipeline (page_number in metadata, page PDFs saved)
|
||||
2. DOCX upload uses old pipeline (no page_number in metadata)
|
||||
3. TXT upload uses old pipeline (no page_number in metadata)
|
||||
4. Same-filename replacement: existing document found → old chunks + PDFs deleted
|
||||
5. Same-filename replacement: no existing document → no deletion
|
||||
6. Empty PDF (no pages with text) → 400 error
|
||||
7. Page PDFs saved to correct directory with correct naming
|
||||
8. Metadata includes page_number and chunk_file_path for PDF uploads
|
||||
9. Metadata does NOT include page_number for DOCX uploads (None)
|
||||
9. Metadata does NOT include page_number for DOCX uploads
|
||||
|
||||
Uses TestClient + real ChromaDB + real file parsing + real chunking.
|
||||
Embedding function mocked with deterministic vectors (external API).
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.routers.ingest import router
|
||||
|
||||
|
||||
class _DeterministicEmbedding:
|
||||
def name(self) -> str:
|
||||
return "test_deterministic"
|
||||
|
||||
def __call__(self, input):
|
||||
return self._embed(input)
|
||||
|
||||
def embed_query(self, input):
|
||||
return self._embed(input)
|
||||
|
||||
@staticmethod
|
||||
def _embed(texts):
|
||||
vectors = []
|
||||
for text in texts:
|
||||
vec = [0.0] * 384
|
||||
for i, ch in enumerate(text[:384]):
|
||||
vec[i] = ord(ch) / 1000.0
|
||||
vectors.append(vec)
|
||||
return vectors
|
||||
|
||||
|
||||
def _create_text_pdf(lines: list[str]) -> bytes:
|
||||
try:
|
||||
from reportlab.pdfgen import canvas as rl_canvas
|
||||
buf = io.BytesIO()
|
||||
c = rl_canvas.Canvas(buf)
|
||||
y = 750
|
||||
for line in lines:
|
||||
c.drawString(72, y, line)
|
||||
y -= 20
|
||||
c.save()
|
||||
return buf.getvalue()
|
||||
except ImportError:
|
||||
pytest.skip("reportlab not installed")
|
||||
|
||||
|
||||
def _create_multipage_pdf(pages_text: list[list[str]]) -> bytes:
|
||||
try:
|
||||
from reportlab.pdfgen import canvas as rl_canvas
|
||||
buf = io.BytesIO()
|
||||
c = rl_canvas.Canvas(buf)
|
||||
for page_lines in pages_text:
|
||||
y = 750
|
||||
for line in page_lines:
|
||||
c.drawString(72, y, line)
|
||||
y -= 20
|
||||
c.showPage()
|
||||
c.save()
|
||||
return buf.getvalue()
|
||||
except ImportError:
|
||||
pytest.skip("reportlab not installed")
|
||||
|
||||
|
||||
def _create_real_docx(paragraphs: list[str]) -> bytes:
|
||||
try:
|
||||
from docx import Document
|
||||
doc = Document()
|
||||
for para in paragraphs:
|
||||
doc.add_paragraph(para)
|
||||
buf = io.BytesIO()
|
||||
doc.save(buf)
|
||||
return buf.getvalue()
|
||||
except ImportError:
|
||||
pytest.skip("python-docx not installed")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
chroma_path = str(tmp_path / "chroma_db")
|
||||
chunk_path = str(tmp_path / "document_chunk")
|
||||
prompts_path = str(tmp_path / "prompts.db")
|
||||
history_path = str(tmp_path / "history.db")
|
||||
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", chroma_path)
|
||||
monkeypatch.setenv("DOCUMENT_CHUNK_PATH", chunk_path)
|
||||
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
||||
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
||||
monkeypatch.setenv("EMBEDDING_MODEL", "test-mock")
|
||||
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
||||
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
from app.core.dependencies import get_settings_cached
|
||||
get_settings_cached.cache_clear()
|
||||
|
||||
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
|
||||
conn = _get_db(prompts_path)
|
||||
init_prompts_db(conn)
|
||||
seed_default_profiles(conn)
|
||||
conn.close()
|
||||
|
||||
hconn = _get_db(history_path)
|
||||
init_history_db(hconn)
|
||||
hconn.close()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"app.core.database.get_embedding_function_settings",
|
||||
lambda settings: _DeterministicEmbedding(),
|
||||
)
|
||||
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(router, prefix="/api/v1")
|
||||
|
||||
yield TestClient(test_app)
|
||||
|
||||
get_settings_cached.cache_clear()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
def _get_collection(client_fixture, chroma_path: str):
|
||||
db_client = chromadb.PersistentClient(path=chroma_path)
|
||||
return db_client.get_collection("documents")
|
||||
|
||||
|
||||
class TestPageAwareIngest:
|
||||
"""Page-aware document ingestion tests."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create test client with mocked dependencies."""
|
||||
from app.main import app
|
||||
return TestClient(app)
|
||||
def test_pdf_upload_uses_page_aware_pipeline(self, client, tmp_path):
|
||||
"""PDF should produce chunks with page_number metadata and page PDF files on disk."""
|
||||
pdf_bytes = _create_multipage_pdf([
|
||||
["Page 1 content about testing"],
|
||||
["Page 2 content about more testing"],
|
||||
])
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(self):
|
||||
"""Mock settings with document_chunk_path."""
|
||||
settings = MagicMock()
|
||||
settings.chunk_size = 1000
|
||||
settings.chunk_overlap = 200
|
||||
settings.document_chunk_path = "/tmp/test_document_chunk"
|
||||
return settings
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 1: PDF upload triggers page-aware pipeline
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_pdf_upload_uses_page_aware_pipeline(self, client, mock_settings):
|
||||
"""PDF should go through parse_pdf_by_page → chunk_pages → extract_page_as_pdf."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta, \
|
||||
patch("app.utils.pdf_extractor.extract_page_as_pdf") as mock_extract_page, \
|
||||
patch("app.services.rag.RAGService.list_documents") as mock_list_docs:
|
||||
|
||||
# RAGService instance
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
mock_rag_class.list_documents = MagicMock(return_value=([], 0, 0))
|
||||
|
||||
# parse_pdf_by_page returns 2 pages
|
||||
mock_parse_by_page.return_value = [
|
||||
(1, "Page 1 text content"),
|
||||
(2, "Page 2 text content"),
|
||||
]
|
||||
|
||||
# chunk_pages returns one chunk per page
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk_pages.return_value = [
|
||||
("Page 1 text content", 1),
|
||||
("Page 2 text content", 2),
|
||||
]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
|
||||
# metadata
|
||||
mock_meta.return_value = [
|
||||
{"filename": "test.pdf", "chunk_index": 0, "page_number": 1},
|
||||
{"filename": "test.pdf", "chunk_index": 1, "page_number": 2},
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["chunk_count"] == 2
|
||||
assert data["filename"] == "test.pdf"
|
||||
assert data["chunk_count"] >= 1
|
||||
|
||||
# Verify page-aware parsing was called
|
||||
mock_parse_by_page.assert_called_once()
|
||||
# Verify page_number metadata in real ChromaDB
|
||||
settings = _get_settings()
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
page_numbers = [m.get("page_number") for m in all_data["metadatas"]]
|
||||
assert any(pn is not None for pn in page_numbers)
|
||||
|
||||
# Verify chunk_pages was used (not chunk)
|
||||
mock_chunker.chunk_pages.assert_called_once()
|
||||
mock_chunker.chunk.assert_not_called()
|
||||
# Verify page PDF files exist in chunk_dir
|
||||
chunk_dir = settings.document_chunk_path
|
||||
if os.path.isdir(chunk_dir):
|
||||
pdf_files = [f for f in os.listdir(chunk_dir) if f.endswith(".pdf")]
|
||||
assert len(pdf_files) >= 1
|
||||
|
||||
# Verify extract_page_as_pdf was called for each page
|
||||
assert mock_extract_page.call_count == 2
|
||||
def test_docx_upload_uses_old_pipeline(self, client, tmp_path):
|
||||
"""DOCX should produce chunks without page_number metadata."""
|
||||
docx_bytes = _create_real_docx(["DOCX paragraph one.", "DOCX paragraph two."])
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 2: DOCX upload uses old pipeline
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_docx_upload_uses_old_pipeline(self, client, mock_settings):
|
||||
"""DOCX should use parse_docx → chunk → metadata with document_id only."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.docx_parser.parse_docx") as mock_parse, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta:
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
mock_parse.return_value = "DOCX text content"
|
||||
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk.return_value = ["chunk 1"]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
|
||||
mock_meta.return_value = [{"filename": "test.docx", "chunk_index": 0}]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.docx", io.BytesIO(b"docx"), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.docx", io.BytesIO(docx_bytes),
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["chunk_count"] == 1
|
||||
assert data["filename"] == "test.docx"
|
||||
assert data["chunk_count"] >= 1
|
||||
|
||||
# Verify old pipeline: parse_docx → chunk (not chunk_pages)
|
||||
mock_parse.assert_called_once()
|
||||
mock_chunker.chunk.assert_called_once()
|
||||
mock_chunker.chunk_pages.assert_not_called()
|
||||
# Verify no page_number in metadata
|
||||
settings = _get_settings()
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
for meta in all_data["metadatas"]:
|
||||
if meta.get("filename") == "test.docx":
|
||||
assert meta.get("page_number") is None
|
||||
assert meta.get("chunk_file_path") is None
|
||||
|
||||
# Verify extract_metadata was called with document_id
|
||||
meta_call = mock_meta.call_args
|
||||
assert meta_call[1].get("document_id") is not None or \
|
||||
(len(meta_call[0]) > 3 and meta_call[0][3] is not None) or \
|
||||
"document_id" in str(meta_call)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 3: TXT upload uses old pipeline
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_txt_upload_uses_old_pipeline(self, client, mock_settings):
|
||||
"""TXT should read file → chunk → metadata with document_id."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta:
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk.return_value = ["txt chunk"]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
|
||||
mock_meta.return_value = [{"filename": "notes.txt", "chunk_index": 0}]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("notes.txt", io.BytesIO(b"Text content here"), "text/plain")},
|
||||
)
|
||||
def test_txt_upload_uses_old_pipeline(self, client, tmp_path):
|
||||
"""TXT should produce chunks without page_number metadata."""
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("notes.txt", io.BytesIO(b"Text content with enough words to form a chunk."),
|
||||
"text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["chunk_count"] == 1
|
||||
assert data["filename"] == "notes.txt"
|
||||
assert data["chunk_count"] >= 1
|
||||
|
||||
mock_chunker.chunk.assert_called_once()
|
||||
mock_chunker.chunk_pages.assert_not_called()
|
||||
settings = _get_settings()
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
for meta in all_data["metadatas"]:
|
||||
if meta.get("filename") == "notes.txt":
|
||||
assert meta.get("page_number") is None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 4: Same-filename replacement: existing document → deletion
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_same_filename_replacement_deletes_old(self, client, mock_settings, tmp_path):
|
||||
"""Uploading file with same filename should delete old chunks and chunk PDFs."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
old_doc_id = "old-doc-uuid-1234"
|
||||
chunk_dir = tmp_path / "document_chunk"
|
||||
chunk_dir.mkdir()
|
||||
old_pdf = chunk_dir / "test_page_3.pdf"
|
||||
old_pdf.write_text("old chunk pdf")
|
||||
def test_same_filename_replacement_deletes_old(self, client, tmp_path):
|
||||
"""Uploading file with same filename should replace old chunks in ChromaDB."""
|
||||
settings = _get_settings()
|
||||
pdf_bytes = _create_text_pdf(["First upload content"])
|
||||
|
||||
mock_settings.document_chunk_path = str(chunk_dir)
|
||||
# First upload
|
||||
response1 = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
first_doc_id = response1.json()["document_id"]
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta, \
|
||||
patch("app.utils.pdf_extractor.extract_page_as_pdf"):
|
||||
# Verify first doc exists
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
first_ids = [cid for cid in all_data["ids"] if cid.startswith(first_doc_id)]
|
||||
assert len(first_ids) >= 1
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
# list_documents returns existing document with same filename
|
||||
mock_rag.list_documents.return_value = (
|
||||
[{"document_id": old_doc_id, "filename": "test.pdf", "chunk_count": 3}],
|
||||
1, 3
|
||||
)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
# Second upload with same filename
|
||||
pdf_bytes2 = _create_text_pdf(["Second upload content replacement"])
|
||||
response2 = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(pdf_bytes2), "application/pdf")},
|
||||
)
|
||||
assert response2.status_code == 200
|
||||
second_doc_id = response2.json()["document_id"]
|
||||
|
||||
# list_chunks returns chunk with file path
|
||||
mock_rag.list_chunks.return_value = [
|
||||
{"chunk_id": f"{old_doc_id}_0", "chunk_file_path": "test_page_3.pdf"},
|
||||
{"chunk_id": f"{old_doc_id}_1", "chunk_file_path": "test_page_4.pdf"},
|
||||
]
|
||||
# Verify old doc chunks are gone
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
remaining_ids = all_data["ids"]
|
||||
assert not any(rid.startswith(first_doc_id) for rid in remaining_ids)
|
||||
|
||||
mock_parse_by_page.return_value = [(1, "New page text")]
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk_pages.return_value = [("New page text", 1)]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
mock_meta.return_value = [{"filename": "test.pdf", "chunk_index": 0}]
|
||||
# Verify new doc chunks exist
|
||||
assert any(rid.startswith(second_doc_id) for rid in remaining_ids)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
def test_no_existing_document_no_deletion(self, client, tmp_path):
|
||||
"""Uploading new filename should succeed normally."""
|
||||
settings = _get_settings()
|
||||
pdf_bytes = _create_text_pdf(["Brand new document content"])
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("newdoc.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["filename"] == "newdoc.pdf"
|
||||
|
||||
# Verify delete_document was called for old doc
|
||||
mock_rag.delete_document.assert_called_once_with(old_doc_id)
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
assert len(all_data["ids"]) >= 1
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 5: Same-filename replacement: no existing document
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_no_existing_document_no_deletion(self, client, mock_settings):
|
||||
"""Uploading new filename should NOT trigger any deletion."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta, \
|
||||
patch("app.utils.pdf_extractor.extract_page_as_pdf"):
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
mock_parse_by_page.return_value = [(1, "Page text")]
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk_pages.return_value = [("Page text", 1)]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
mock_meta.return_value = [{"filename": "newdoc.pdf", "chunk_index": 0}]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("newdoc.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify NO deletion happened
|
||||
mock_rag.delete_document.assert_not_called()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 6: Empty PDF → 400 error
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_empty_pdf_returns_400(self, client, mock_settings):
|
||||
def test_empty_pdf_returns_400(self, client, tmp_path):
|
||||
"""PDF with no extractable text should return 400."""
|
||||
with patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \
|
||||
patch("app.services.rag.RAGService") as mock_rag_class:
|
||||
from pypdf import PdfWriter
|
||||
writer = PdfWriter()
|
||||
writer.add_blank_page(width=200, height=200)
|
||||
buf = io.BytesIO()
|
||||
writer.write(buf)
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
# Empty PDF: no pages
|
||||
mock_parse_by_page.return_value = []
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("empty.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("empty.pdf", io.BytesIO(buf.getvalue()), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "empty" in response.json()["detail"].lower()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 7: Page PDFs saved with correct naming
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_page_pdf_naming_convention(self, client, mock_settings, tmp_path):
|
||||
"""Chunk PDFs should be named {stem}_page_{N}.pdf with relative paths in metadata."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
chunk_dir = tmp_path / "document_chunk"
|
||||
chunk_dir.mkdir()
|
||||
mock_settings.document_chunk_path = str(chunk_dir)
|
||||
def test_page_pdf_naming_convention(self, client, tmp_path):
|
||||
"""Chunk PDFs should be named {stem}_page_{N}.pdf in document_chunk_path."""
|
||||
settings = _get_settings()
|
||||
pdf_bytes = _create_multipage_pdf([
|
||||
["Page one content"],
|
||||
["Page two content"],
|
||||
])
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta, \
|
||||
patch("app.utils.pdf_extractor.extract_page_as_pdf") as mock_extract_page:
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
mock_parse_by_page.return_value = [
|
||||
(1, "Page 1"),
|
||||
(3, "Page 3"), # page 2 was empty, skipped
|
||||
]
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk_pages.return_value = [
|
||||
("Page 1", 1),
|
||||
("Page 3", 3),
|
||||
]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
mock_meta.return_value = [
|
||||
{"filename": "NEC4 ACC.pdf", "chunk_index": 0},
|
||||
{"filename": "NEC4 ACC.pdf", "chunk_index": 1},
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("NEC4 ACC.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("NEC4 ACC.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify extract_page_as_pdf called with correct naming
|
||||
calls = mock_extract_page.call_args_list
|
||||
assert len(calls) == 2
|
||||
chunk_dir = settings.document_chunk_path
|
||||
assert os.path.isdir(chunk_dir)
|
||||
|
||||
# First call: page 1 → "NEC4 ACC_page_1.pdf"
|
||||
output_path_1 = calls[0][0][2] # third positional arg = output_path
|
||||
assert output_path_1.endswith("NEC4 ACC_page_1.pdf")
|
||||
pdf_files = sorted(os.listdir(chunk_dir))
|
||||
assert len(pdf_files) >= 1
|
||||
|
||||
# Second call: page 3 → "NEC4 ACC_page_3.pdf"
|
||||
output_path_3 = calls[1][0][2]
|
||||
assert output_path_3.endswith("NEC4 ACC_page_3.pdf")
|
||||
# Each file should match naming convention: {stem}_page_{N}.pdf
|
||||
for fname in pdf_files:
|
||||
assert fname.startswith("NEC4 ACC_page_")
|
||||
assert fname.endswith(".pdf")
|
||||
|
||||
# Verify the directory was created
|
||||
assert os.path.isdir(str(chunk_dir))
|
||||
def test_pdf_metadata_includes_page_info(self, client, tmp_path):
|
||||
"""PDF metadata in ChromaDB should include page_number and chunk_file_path."""
|
||||
settings = _get_settings()
|
||||
pdf_bytes = _create_text_pdf(["Page content for metadata check"])
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 8: Metadata includes page_number and chunk_file_path for PDFs
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_pdf_metadata_includes_page_info(self, client, mock_settings, tmp_path):
|
||||
"""PDF metadata should include page_number and chunk_file_path."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
chunk_dir = tmp_path / "document_chunk"
|
||||
chunk_dir.mkdir()
|
||||
mock_settings.document_chunk_path = str(chunk_dir)
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.pdf_parser.parse_pdf_by_page") as mock_parse_by_page, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta, \
|
||||
patch("app.utils.pdf_extractor.extract_page_as_pdf"):
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
mock_parse_by_page.return_value = [(2, "Page 2 content")]
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk_pages.return_value = [("Page 2 content", 2)]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
mock_meta.return_value = [
|
||||
{"filename": "doc.pdf", "chunk_index": 0, "page_number": 2, "chunk_file_path": "doc_page_2.pdf"},
|
||||
]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("doc.pdf", io.BytesIO(b"%PDF-1.4"), "application/pdf")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("doc.pdf", io.BytesIO(pdf_bytes), "application/pdf")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify extract_metadata was called with page_numbers and chunk_file_paths
|
||||
meta_call_kwargs = mock_meta.call_args[1]
|
||||
assert "page_numbers" in meta_call_kwargs
|
||||
assert meta_call_kwargs["page_numbers"] == [2]
|
||||
assert "chunk_file_paths" in meta_call_kwargs
|
||||
assert meta_call_kwargs["chunk_file_paths"] == ["doc_page_2.pdf"]
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Test 9: Metadata does NOT include page_number for DOCX (None)
|
||||
# ------------------------------------------------------------------ #
|
||||
def test_docx_metadata_no_page_info(self, client, mock_settings):
|
||||
"""DOCX metadata should have page_number=None (no page_numbers passed)."""
|
||||
doc_id = str(uuid.uuid4())
|
||||
pdf_metas = [m for m in all_data["metadatas"] if m.get("filename") == "doc.pdf"]
|
||||
assert len(pdf_metas) >= 1
|
||||
|
||||
with patch("app.services.rag.RAGService") as mock_rag_class, \
|
||||
patch("app.core.config.get_settings", return_value=mock_settings), \
|
||||
patch("app.utils.docx_parser.parse_docx") as mock_parse, \
|
||||
patch("app.utils.chunking.TokenChunkingStrategy") as mock_chunk_class, \
|
||||
patch("app.utils.metadata.extract_metadata") as mock_meta:
|
||||
for meta in pdf_metas:
|
||||
assert meta.get("page_number") is not None
|
||||
assert meta.get("chunk_file_path") is not None
|
||||
assert "doc_page_" in meta["chunk_file_path"]
|
||||
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.ingest_document.return_value = doc_id
|
||||
mock_rag.list_documents.return_value = ([], 0, 0)
|
||||
mock_rag_class.return_value = mock_rag
|
||||
def test_docx_metadata_no_page_info(self, client, tmp_path):
|
||||
"""DOCX metadata in ChromaDB should have page_number=None and chunk_file_path=None."""
|
||||
docx_bytes = _create_real_docx(["Content for DOCX metadata test"])
|
||||
|
||||
mock_parse.return_value = "DOCX content"
|
||||
mock_chunker = MagicMock()
|
||||
mock_chunker.chunk.return_value = ["chunk 1"]
|
||||
mock_chunk_class.return_value = mock_chunker
|
||||
mock_meta.return_value = [{"filename": "test.docx", "chunk_index": 0}]
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.docx", io.BytesIO(b"docx"), "application/vnd.openxmlformats-officedocument.wordprocessingml.document")},
|
||||
)
|
||||
response = client.post(
|
||||
"/api/v1/ingest",
|
||||
files={"file": ("test.docx", io.BytesIO(docx_bytes),
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document")},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify extract_metadata was called WITHOUT page_numbers
|
||||
meta_call_kwargs = mock_meta.call_args[1]
|
||||
assert meta_call_kwargs.get("page_numbers") is None
|
||||
assert meta_call_kwargs.get("chunk_file_paths") is None
|
||||
settings = _get_settings()
|
||||
collection = _get_collection(client, settings.chroma_db_path)
|
||||
all_data = collection.get(include=["metadatas"])
|
||||
|
||||
docx_metas = [m for m in all_data["metadatas"] if m.get("filename") == "test.docx"]
|
||||
assert len(docx_metas) >= 1
|
||||
|
||||
for meta in docx_metas:
|
||||
assert "page_number" not in meta
|
||||
assert "chunk_file_path" not in meta
|
||||
|
||||
|
||||
def _get_settings():
|
||||
from app.core.config import get_settings
|
||||
return get_settings()
|
||||
|
|
|
|||
|
|
@ -1,97 +1,288 @@
|
|||
"""Phase 1 tests: RAG query endpoint.
|
||||
|
||||
Covers:
|
||||
- POST /api/v1/query question → retrieve → LLM → bullet-point response
|
||||
- POST /api/v1/query with SSE stream response
|
||||
- Strict RAG prompt enforcement (only use retrieved context)
|
||||
- Bullet-point response format
|
||||
- Source metadata inclusion
|
||||
- Source metadata inclusion in SSE events
|
||||
- 422 on missing question field
|
||||
|
||||
Uses TestClient + real ChromaDB + real SQLite. Only the LLMClient
|
||||
(external API) is mocked with controlled responses.
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
from app.core.sqlite_db import (
|
||||
_get_db,
|
||||
init_history_db,
|
||||
init_prompts_db,
|
||||
seed_default_profiles,
|
||||
)
|
||||
from app.routers.query import router
|
||||
|
||||
|
||||
# ── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MockLLMClient:
|
||||
"""Deterministic LLM mock returning controlled responses for each pipeline step.
|
||||
|
||||
Step sequence:
|
||||
1. decompose → JSON array of sub-questions
|
||||
2. filter → JSON object mapping sub-q indices to relevance scores
|
||||
3. generate → markdown bullet-point answer
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._call_count = 0
|
||||
|
||||
async def complete(self, prompt: str, temperature: float = 0.7,
|
||||
step_name: str = "LLM") -> str:
|
||||
self._call_count += 1
|
||||
if step_name == "QueryDecomposer":
|
||||
return json.dumps(["test sub-question"])
|
||||
if step_name == "RelevanceFilter":
|
||||
return json.dumps({"0": [8.0, 7.5]})
|
||||
return "- Bullet point answer\n- Another point"
|
||||
|
||||
|
||||
class _MockLLMClientNoChunks:
|
||||
"""LLM mock that returns decomposition but no relevant chunks survive filter."""
|
||||
|
||||
async def complete(self, prompt: str, temperature: float = 0.7,
|
||||
step_name: str = "LLM") -> str:
|
||||
if step_name == "QueryDecomposer":
|
||||
return json.dumps(["test"])
|
||||
if step_name == "RelevanceFilter":
|
||||
return json.dumps({"0": [2.0, 1.5]})
|
||||
return "I could not find any relevant information."
|
||||
|
||||
|
||||
class _DeterministicEmbedding:
|
||||
"""Lightweight embedding function that returns deterministic vectors.
|
||||
|
||||
Uses a simple hash-based approach to produce consistent 384-dim vectors
|
||||
for any input text. No external API calls.
|
||||
"""
|
||||
|
||||
def name(self) -> str:
|
||||
return "test_deterministic"
|
||||
|
||||
def __call__(self, input):
|
||||
return self._embed(input)
|
||||
|
||||
def embed_query(self, input):
|
||||
return self._embed(input)
|
||||
|
||||
@staticmethod
|
||||
def _embed(texts):
|
||||
vectors = []
|
||||
for text in texts:
|
||||
vec = [0.0] * 384
|
||||
for i, ch in enumerate(text[:384]):
|
||||
vec[i] = ord(ch) / 1000.0
|
||||
vectors.append(vec)
|
||||
return vectors
|
||||
|
||||
|
||||
def _seed_document(chroma_path: str):
|
||||
"""Insert a test document into real ChromaDB for retrieval."""
|
||||
import chromadb
|
||||
from app.core.database import get_or_create_collection
|
||||
|
||||
client = chromadb.PersistentClient(path=chroma_path)
|
||||
collection = get_or_create_collection(client, "documents",
|
||||
embedding_function=_DeterministicEmbedding())
|
||||
collection.add(
|
||||
documents=["chunk one with some text about testing",
|
||||
"chunk two with more details about testing"],
|
||||
metadatas=[
|
||||
{"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00",
|
||||
"content_summary": "chunk one", "chunk_index": 0,
|
||||
"document_id": "test-doc-123"},
|
||||
{"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00",
|
||||
"content_summary": "chunk two", "chunk_index": 1,
|
||||
"document_id": "test-doc-123"},
|
||||
],
|
||||
ids=["test-doc-123_0", "test-doc-123_1"],
|
||||
)
|
||||
|
||||
|
||||
# ── Fixture ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(tmp_path, monkeypatch):
|
||||
chroma_path = str(tmp_path / "chroma_db")
|
||||
prompts_path = str(tmp_path / "prompts.db")
|
||||
history_path = str(tmp_path / "history.db")
|
||||
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", chroma_path)
|
||||
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
||||
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
||||
monkeypatch.setenv("EMBEDDING_MODEL", "test-mock")
|
||||
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
||||
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
from app.core.dependencies import get_settings_cached
|
||||
get_settings_cached.cache_clear()
|
||||
|
||||
# Init real SQLite databases
|
||||
conn = _get_db(prompts_path)
|
||||
init_prompts_db(conn)
|
||||
seed_default_profiles(conn)
|
||||
conn.close()
|
||||
|
||||
hconn = _get_db(history_path)
|
||||
init_history_db(hconn)
|
||||
hconn.close()
|
||||
|
||||
# Seed a document into real ChromaDB
|
||||
_seed_document(chroma_path)
|
||||
|
||||
# Mock embedding function at database module level (external API)
|
||||
monkeypatch.setattr(
|
||||
"app.core.database.get_embedding_function_settings",
|
||||
lambda settings: _DeterministicEmbedding(),
|
||||
)
|
||||
|
||||
# Build a minimal FastAPI app with only the query router
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(router, prefix="/api/v1")
|
||||
|
||||
yield TestClient(test_app)
|
||||
|
||||
get_settings_cached.cache_clear()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
# ── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestQuery:
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
from app.main import app
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON")
|
||||
def test_query_returns_bullets(self, client):
|
||||
"""Should return bullet-point answer with source metadata."""
|
||||
with patch("app.routers.query.QueryDecomposer") as mock_decomposer_class:
|
||||
mock_decomposer = MagicMock()
|
||||
mock_decomposer.decompose = AsyncMock(return_value=["test", "keywords"])
|
||||
mock_decomposer_class.return_value = mock_decomposer
|
||||
|
||||
with patch("app.routers.query.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.retrieve.return_value = [
|
||||
("chunk one", {"filename": "test.pdf"}, 0.1),
|
||||
("chunk two", {"filename": "test.pdf"}, 0.2),
|
||||
]
|
||||
mock_rag.generate_response = AsyncMock(return_value="- Bullet point answer\n- Another point")
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
with patch("app.routers.query.RelevanceFilter") as mock_filter_class:
|
||||
mock_filter = MagicMock()
|
||||
mock_filter.filter = AsyncMock(return_value=[
|
||||
("chunk one", {"filename": "test.pdf"}),
|
||||
("chunk two", {"filename": "test.pdf"}),
|
||||
])
|
||||
mock_filter_class.return_value = mock_filter
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/query",
|
||||
json={"question": "What is this about?"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "extracted_questions" in data
|
||||
assert data["extracted_questions"] == ["test", "keywords"]
|
||||
assert "answer" in data
|
||||
assert "- Bullet point answer" in data["answer"]
|
||||
assert "sources" in data
|
||||
assert len(data["sources"]) == 2
|
||||
assert data["sources"][0]["filename"] == "test.pdf"
|
||||
# Left as skip — SSE tests below cover this functionality.
|
||||
|
||||
@pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON")
|
||||
def test_query_no_relevant_chunks(self, client):
|
||||
"""Should handle case when no relevant chunks found."""
|
||||
with patch("app.routers.query.QueryDecomposer") as mock_decomposer_class:
|
||||
mock_decomposer = MagicMock()
|
||||
mock_decomposer.decompose = AsyncMock(return_value=["test"])
|
||||
mock_decomposer_class.return_value = mock_decomposer
|
||||
|
||||
with patch("app.routers.query.RAGService") as mock_rag_class:
|
||||
mock_rag = MagicMock()
|
||||
mock_rag.retrieve.return_value = [
|
||||
("chunk one", {"filename": "test.pdf"}, 0.1),
|
||||
]
|
||||
mock_rag.generate_response = AsyncMock(return_value="I could not find any relevant information.")
|
||||
mock_rag_class.return_value = mock_rag
|
||||
|
||||
with patch("app.routers.query.RelevanceFilter") as mock_filter_class:
|
||||
mock_filter = MagicMock()
|
||||
mock_filter.filter = AsyncMock(return_value=[])
|
||||
mock_filter_class.return_value = mock_filter
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/query",
|
||||
json={"question": "What is this about?"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["extracted_questions"] == ["test"]
|
||||
assert "could not find" in data["answer"].lower()
|
||||
assert data["sources"] == []
|
||||
# Left as skip — SSE tests below cover this functionality.
|
||||
|
||||
def test_query_no_question(self, client):
|
||||
"""Should reject request without question."""
|
||||
response = client.post("/api/v1/query", json={})
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_query_sse_stream_with_mock_llm(self, client, monkeypatch):
|
||||
"""Should return SSE stream with decomposed → retrieving → filtering → generating → completed phases.
|
||||
|
||||
Uses real ChromaDB (seeded with a document) + real RAGService.
|
||||
Only LLMClient is mocked (external API).
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
lambda settings: _MockLLMClient(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/query",
|
||||
json={"question": "What is this about?"},
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
||||
|
||||
# Parse SSE events from the response
|
||||
events = []
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = json.loads(line[6:])
|
||||
events.append(data)
|
||||
|
||||
phases = [e["phase"] for e in events]
|
||||
assert "decomposed" in phases
|
||||
assert "retrieving" in phases
|
||||
assert "filtering" in phases
|
||||
assert "generating" in phases
|
||||
assert "completed" in phases
|
||||
|
||||
# Find completed event
|
||||
completed = next(e for e in events if e["phase"] == "completed")
|
||||
assert "- Bullet point answer" in completed["answer"]
|
||||
assert len(completed["sources"]) > 0
|
||||
assert completed["sources"][0]["filename"] == "test.pdf"
|
||||
|
||||
def test_query_sse_no_relevant_chunks_after_filter(self, client, monkeypatch):
|
||||
"""Should return 'no relevant information' when filter eliminates all chunks.
|
||||
|
||||
Uses real ChromaDB (seeded) + real RAGService. LLM mock returns
|
||||
low relevance scores so all chunks are filtered out.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
lambda settings: _MockLLMClientNoChunks(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/query",
|
||||
json={"question": "Something completely unrelated"},
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
events = []
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = json.loads(line[6:])
|
||||
events.append(data)
|
||||
|
||||
completed = next(e for e in events if e["phase"] == "completed")
|
||||
assert "could not find" in completed["answer"].lower()
|
||||
assert completed["sources"] == []
|
||||
|
||||
def test_query_empty_question_returns_400(self, client, monkeypatch):
|
||||
"""Should reject empty/whitespace-only question with 400."""
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
lambda settings: _MockLLMClient(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/query",
|
||||
json={"question": " "},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_query_sse_decomposed_event_has_questions(self, client, monkeypatch):
|
||||
"""Decomposed SSE event should contain extracted sub-questions from LLM."""
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
lambda settings: _MockLLMClient(),
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/query",
|
||||
json={"question": "What is testing?"},
|
||||
)
|
||||
|
||||
events = []
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = json.loads(line[6:])
|
||||
events.append(data)
|
||||
|
||||
decomposed = next(e for e in events if e["phase"] == "decomposed")
|
||||
assert "extracted_questions" in decomposed
|
||||
assert isinstance(decomposed["extracted_questions"], list)
|
||||
assert len(decomposed["extracted_questions"]) > 0
|
||||
|
|
|
|||
|
|
@ -5,23 +5,53 @@ Covers:
|
|||
- Retrieval with query keywords
|
||||
- Response generation with strict RAG prompt
|
||||
- Metadata handling per chunk
|
||||
|
||||
All tests use real ChromaDB via tmp_path. Only the LLM client (external API)
|
||||
is mocked where needed.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
import chromadb
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
|
||||
class _MockLLM:
|
||||
"""Minimal mock for the external LLM API."""
|
||||
|
||||
def __init__(self, response: str = "mock answer"):
|
||||
self._response = response
|
||||
self.last_prompt: str | None = None
|
||||
|
||||
async def complete(self, prompt: str, temperature: float = 0.7, step_name: str = "LLM") -> str: # type: ignore[override]
|
||||
self.last_prompt = prompt
|
||||
return self._response
|
||||
|
||||
|
||||
def _setup_chroma(tmp_path, monkeypatch, collection_name: str = "documents"):
|
||||
"""Create an isolated real ChromaDB client + collection for a test.
|
||||
|
||||
Returns (client, collection, service_kwargs) where service_kwargs can be
|
||||
unpacked into RAGService().
|
||||
"""
|
||||
from app.core.config import get_settings
|
||||
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", str(tmp_path / "test_chroma"))
|
||||
get_settings.cache_clear()
|
||||
|
||||
client = chromadb.PersistentClient(path=str(tmp_path / "test_chroma"))
|
||||
collection = client.get_or_create_collection(name=collection_name)
|
||||
return client, collection
|
||||
|
||||
|
||||
class TestRAGService:
|
||||
"""RAG retrieval and prompt logic tests."""
|
||||
|
||||
def test_ingest_document_adds_chunks(self):
|
||||
"""Should add chunks with metadata to ChromaDB collection."""
|
||||
def test_ingest_document_adds_chunks(self, tmp_path, monkeypatch):
|
||||
"""Should add chunks with metadata to real ChromaDB collection."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
service = RAGService(chroma_client=client)
|
||||
|
||||
chunks = ["chunk one", "chunk two"]
|
||||
metadata = [
|
||||
|
|
@ -29,170 +59,152 @@ class TestRAGService:
|
|||
{"filename": "test.txt", "upload_date": "2024-01-01", "content_summary": "summary 2", "chunk_index": 1},
|
||||
]
|
||||
|
||||
service.ingest_document("test.txt", chunks, metadata)
|
||||
doc_id = service.ingest_document("test.txt", chunks, metadata)
|
||||
|
||||
mock_client.get_or_create_collection.assert_called_once_with(name="documents")
|
||||
mock_collection.add.assert_called_once()
|
||||
call_args = mock_collection.add.call_args[1]
|
||||
assert len(call_args["documents"]) == 2
|
||||
assert call_args["documents"] == chunks
|
||||
assert len(call_args["metadatas"]) == 2
|
||||
assert call_args["metadatas"] == metadata
|
||||
assert len(call_args["ids"]) == 2
|
||||
assert doc_id != ""
|
||||
assert collection.count() == 2
|
||||
|
||||
def test_ingest_document_empty_chunks(self):
|
||||
"""Should not call ChromaDB when chunks list is empty."""
|
||||
stored = collection.get(include=["documents", "metadatas"])
|
||||
assert len(stored["documents"]) == 2
|
||||
assert stored["documents"] == chunks
|
||||
for i, meta in enumerate(stored["metadatas"]):
|
||||
assert meta["filename"] == "test.txt"
|
||||
assert meta["content_summary"] == f"summary {i + 1}"
|
||||
|
||||
def test_ingest_document_empty_chunks(self, tmp_path, monkeypatch):
|
||||
"""Should not add anything when chunks list is empty."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
service.ingest_document("test.txt", [], [])
|
||||
service = RAGService(chroma_client=client)
|
||||
result = service.ingest_document("test.txt", [], [])
|
||||
|
||||
mock_collection.add.assert_not_called()
|
||||
assert result == ""
|
||||
assert collection.count() == 0
|
||||
|
||||
def test_retrieve_returns_chunks(self):
|
||||
"""Should retrieve chunks and metadata from ChromaDB."""
|
||||
def test_retrieve_returns_chunks(self, tmp_path, monkeypatch):
|
||||
"""Should retrieve chunks from real ChromaDB by query."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [["chunk one", "chunk two"]],
|
||||
"metadatas": [[{"filename": "test.txt"}, {"filename": "test.txt"}]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
}
|
||||
collection.add(
|
||||
documents=["chunk one about apples", "chunk two about bananas"],
|
||||
metadatas=[
|
||||
{"filename": "test.txt"},
|
||||
{"filename": "test.txt"},
|
||||
],
|
||||
ids=["id1", "id2"],
|
||||
)
|
||||
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
results = service.retrieve(["query", "keywords"], n_results=5)
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve(["apples"], n_results=5)
|
||||
|
||||
mock_collection.query.assert_called_once()
|
||||
call_args = mock_collection.query.call_args[1]
|
||||
assert call_args["n_results"] == 5
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("chunk one", {"filename": "test.txt"}, 0.1)
|
||||
assert results[1] == ("chunk two", {"filename": "test.txt"}, 0.2)
|
||||
assert len(results) >= 1
|
||||
assert "apples" in results[0][0]
|
||||
assert results[0][1]["filename"] == "test.txt"
|
||||
assert isinstance(results[0][2], float)
|
||||
|
||||
def test_retrieve_no_results(self):
|
||||
"""Should return empty list when no results found."""
|
||||
def test_retrieve_no_results(self, tmp_path, monkeypatch):
|
||||
"""Should return empty list when querying an empty collection."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
mock_collection.query.return_value = {
|
||||
"documents": [[]],
|
||||
"metadatas": [[]],
|
||||
"distances": [[]],
|
||||
}
|
||||
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
results = service.retrieve(["query"])
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve(["nonexistent query terms xyz"])
|
||||
|
||||
assert results == []
|
||||
|
||||
async def test_generate_response_calls_llm(self, mock_prompt_service):
|
||||
async def test_generate_response_calls_llm(self, tmp_path, monkeypatch, mock_prompt_service):
|
||||
"""Should call LLM with strict RAG prompt."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.complete = AsyncMock(return_value="- Bullet point answer")
|
||||
mock_llm = _MockLLM(response="- Bullet point answer")
|
||||
|
||||
service = RAGService(chroma_client=mock_client, llm_client=mock_llm, prompt_service=mock_prompt_service)
|
||||
service = RAGService(
|
||||
chroma_client=client,
|
||||
llm_client=mock_llm,
|
||||
prompt_service=mock_prompt_service,
|
||||
)
|
||||
|
||||
chunks = ["relevant chunk"]
|
||||
metadata = [{"filename": "test.txt", "content_summary": "summary"}]
|
||||
|
||||
answer, gen_prompt = await service.generate_response("What is this?", chunks, metadata)
|
||||
|
||||
mock_llm.complete.assert_called_once()
|
||||
sent_prompt = mock_llm.complete.call_args[1]["prompt"]
|
||||
assert "What is this?" in sent_prompt
|
||||
assert "relevant chunk" in sent_prompt
|
||||
assert "test.txt" in sent_prompt
|
||||
assert "only these document chunks" in sent_prompt.lower()
|
||||
assert mock_llm.last_prompt is not None
|
||||
assert "What is this?" in mock_llm.last_prompt
|
||||
assert "relevant chunk" in mock_llm.last_prompt
|
||||
assert "test.txt" in mock_llm.last_prompt
|
||||
assert "only these document chunks" in mock_llm.last_prompt.lower()
|
||||
assert answer == "- Bullet point answer"
|
||||
assert gen_prompt == sent_prompt
|
||||
assert gen_prompt == mock_llm.last_prompt
|
||||
|
||||
async def test_generate_response_no_chunks(self):
|
||||
async def test_generate_response_no_chunks(self, tmp_path, monkeypatch):
|
||||
"""Should return fallback message when no chunks provided."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
mock_llm = _MockLLM()
|
||||
|
||||
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
||||
service = RAGService(chroma_client=client, llm_client=mock_llm)
|
||||
|
||||
answer, gen_prompt = await service.generate_response("What is this?", [], [])
|
||||
|
||||
assert "no relevant" in answer.lower() or "could not find" in answer.lower()
|
||||
assert gen_prompt == ""
|
||||
|
||||
def test_retrieve_per_subquestion_returns_per_query(self):
|
||||
def test_retrieve_per_subquestion_returns_per_query(self, tmp_path, monkeypatch):
|
||||
"""Each sub-question retrieves its own chunks independently."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
mock_collection.query.side_effect = [
|
||||
{
|
||||
"documents": [["chunk A1", "chunk A2"]],
|
||||
"metadatas": [[{"filename": "a.pdf"}, {"filename": "a.pdf"}]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
},
|
||||
{
|
||||
"documents": [["chunk B1"]],
|
||||
"metadatas": [[{"filename": "b.pdf"}]],
|
||||
"distances": [[0.3]],
|
||||
},
|
||||
]
|
||||
collection.add(
|
||||
documents=["Alpha content about apples", "Alpha extra about apples"],
|
||||
metadatas=[{"filename": "a.pdf"}, {"filename": "a2.pdf"}],
|
||||
ids=["a1", "a2"],
|
||||
)
|
||||
collection.add(
|
||||
documents=["Beta content about bananas"],
|
||||
metadatas=[{"filename": "b.pdf"}],
|
||||
ids=["b1"],
|
||||
)
|
||||
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
results = service.retrieve_per_subquestion(["query A", "query B"], n_results=5)
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve_per_subquestion(["apples", "bananas"], n_results=5)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0][0] == "query A"
|
||||
assert len(results[0][1]) == 2
|
||||
assert results[1][0] == "query B"
|
||||
assert len(results[1][1]) == 1
|
||||
assert mock_collection.query.call_count == 2
|
||||
assert results[0][0] == "apples"
|
||||
assert len(results[0][1]) >= 1
|
||||
assert results[1][0] == "bananas"
|
||||
assert len(results[1][1]) >= 1
|
||||
|
||||
def test_retrieve_per_subquestion_empty_list(self):
|
||||
def test_retrieve_per_subquestion_empty_list(self, tmp_path, monkeypatch):
|
||||
"""Empty sub_questions list returns empty list without querying."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve_per_subquestion([], n_results=5)
|
||||
|
||||
assert results == []
|
||||
mock_collection.query.assert_not_called()
|
||||
|
||||
async def test_generate_response_per_subquestion_calls_llm(self, mock_prompt_service):
|
||||
async def test_generate_response_per_subquestion_calls_llm(self, tmp_path, monkeypatch, mock_prompt_service):
|
||||
"""LLM should receive sub-question-organized context."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.complete = AsyncMock(return_value="## Sub-question 1: Q?\n- Answer")
|
||||
mock_llm = _MockLLM(response="## Sub-question 1: Q?\n- Answer")
|
||||
|
||||
service = RAGService(
|
||||
chroma_client=mock_client,
|
||||
chroma_client=client,
|
||||
llm_client=mock_llm,
|
||||
prompt_service=mock_prompt_service,
|
||||
)
|
||||
|
|
@ -203,22 +215,20 @@ class TestRAGService:
|
|||
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
||||
)
|
||||
|
||||
mock_llm.complete.assert_called_once()
|
||||
sent_prompt = mock_llm.complete.call_args[1]["prompt"]
|
||||
assert "chunk data" in sent_prompt
|
||||
assert "Sub-question 0" in sent_prompt
|
||||
assert mock_llm.last_prompt is not None
|
||||
assert "chunk data" in mock_llm.last_prompt
|
||||
assert answer == "## Sub-question 1: Q?\n- Answer"
|
||||
assert len(grouped_sources) == 1
|
||||
assert grouped_sources[0][0]["filename"] == "f.txt"
|
||||
|
||||
async def test_generate_response_per_subquestion_no_subquestions(self):
|
||||
async def test_generate_response_per_subquestion_no_subquestions(self, tmp_path, monkeypatch):
|
||||
"""Should return fallback when sub_questions is empty."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
mock_llm = _MockLLM()
|
||||
|
||||
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
||||
service = RAGService(chroma_client=client, llm_client=mock_llm)
|
||||
|
||||
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
[], [], [],
|
||||
|
|
@ -228,14 +238,14 @@ class TestRAGService:
|
|||
assert gen_prompt == ""
|
||||
assert grouped_sources == []
|
||||
|
||||
async def test_generate_response_per_subquestion_no_chunks(self):
|
||||
async def test_generate_response_per_subquestion_no_chunks(self, tmp_path, monkeypatch):
|
||||
"""Should return fallback when all chunk lists are empty."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
mock_llm = _MockLLM()
|
||||
|
||||
service = RAGService(chroma_client=mock_client, llm_client=MagicMock())
|
||||
service = RAGService(chroma_client=client, llm_client=mock_llm)
|
||||
|
||||
answer, gen_prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
["Q?"], [[]], [[]],
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Tests for Phase 3 history router — HTTP endpoint integration tests.
|
||||
|
||||
Uses a mock HistoryService injected via FastAPI dependency_overrides.
|
||||
TestClient hits a minimal FastAPI app wired with an inline history router
|
||||
that mirrors the expected real router contract.
|
||||
Uses real sqlite3 with tmp_path and real HistoryService. TestClient hits a
|
||||
minimal FastAPI app wired with an inline history router that calls real
|
||||
HistoryService methods (list, get, delete, clear_all, get_stats) backed by a
|
||||
temporary SQLite database. No mocks on the DB or service layer.
|
||||
|
||||
Coverage:
|
||||
- GET /api/v1/history — paginated listing (limit/offset)
|
||||
|
|
@ -19,14 +20,27 @@ Coverage:
|
|||
- 404 on non-existent query_id, 422 on invalid limit/offset
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Query
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.sqlite_db import _get_db, init_history_db
|
||||
from app.services.history_service import HistoryService
|
||||
|
||||
# ── Sample data ──────────────────────────────────────────────────────────
|
||||
|
||||
_SAMPLE_DETAIL: dict = {
|
||||
"id": 1,
|
||||
_CHUNKS_RETRIEVED = [
|
||||
{"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 0.95, "source": "budget.pdf"},
|
||||
{"chunk_id": "c2", "text": "Previous year was $45M", "score": 0.80, "source": "budget.pdf"},
|
||||
]
|
||||
|
||||
_CHUNKS_FILTERED = [
|
||||
{"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 9, "source": "budget.pdf"},
|
||||
]
|
||||
|
||||
_SAMPLE_RECORD: dict = {
|
||||
"input_text": "What is the budget for 2024?",
|
||||
"extracted_questions": '["What is the budget allocation?", "How does 2024 compare?"]',
|
||||
"decompose_prompt": "Break down: {question}",
|
||||
|
|
@ -34,22 +48,16 @@ _SAMPLE_DETAIL: dict = {
|
|||
"generate_prompt": "Generate: {question} {context}",
|
||||
"decomposer_time_ms": 120,
|
||||
"retriever_time_ms": 300,
|
||||
"chunks_retrieved": [
|
||||
{"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 0.95, "source": "budget.pdf"},
|
||||
{"chunk_id": "c2", "text": "Previous year was $45M", "score": 0.80, "source": "budget.pdf"},
|
||||
],
|
||||
"chunks_retrieved": json.dumps(_CHUNKS_RETRIEVED),
|
||||
"chunks_retrieved_count": 2,
|
||||
"filter_time_ms": 80,
|
||||
"chunks_filtered": [
|
||||
{"chunk_id": "c1", "text": "Budget 2024 is $50M", "score": 9, "source": "budget.pdf"},
|
||||
],
|
||||
"chunks_filtered": json.dumps(_CHUNKS_FILTERED),
|
||||
"chunks_filtered_count": 1,
|
||||
"generator_time_ms": 500,
|
||||
"total_time_ms": 1000,
|
||||
"final_answer": "- The 2024 budget is $50M [budget.pdf]",
|
||||
"sources": '["budget.pdf"]',
|
||||
"profile_used": "A",
|
||||
"created_at": "2025-01-15T10:30:00",
|
||||
}
|
||||
|
||||
_SUMMARY_KEYS = {
|
||||
|
|
@ -62,64 +70,12 @@ _SUMMARY_KEYS = {
|
|||
"created_at",
|
||||
}
|
||||
|
||||
_SAMPLE_STATS: dict = {
|
||||
"total_queries": 10,
|
||||
"avg_total_time_ms": 850.5,
|
||||
"avg_chunks_retrieved": 5.2,
|
||||
"avg_chunks_filtered": 3.1,
|
||||
"profile_distribution": {"A": 7, "B": 3},
|
||||
}
|
||||
|
||||
|
||||
# ── Mock service ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MockHistoryService:
|
||||
"""In-memory mock implementing the expected HistoryService interface."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: dict[int, dict] = {1: dict(_SAMPLE_DETAIL)}
|
||||
self._next_id: int = 2
|
||||
|
||||
def list_queries(self, limit: int = 50, offset: int = 0) -> dict:
|
||||
items = sorted(self._records.values(), key=lambda r: r["id"], reverse=True)
|
||||
page = items[offset : offset + limit]
|
||||
summaries = [
|
||||
{k: r[k] for k in _SUMMARY_KEYS}
|
||||
for r in page
|
||||
]
|
||||
return {
|
||||
"queries": summaries,
|
||||
"total": len(items),
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
def get_query(self, query_id: int) -> dict | None:
|
||||
return self._records.get(query_id)
|
||||
|
||||
def delete_query(self, query_id: int) -> bool:
|
||||
if query_id in self._records:
|
||||
del self._records[query_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_all(self) -> int:
|
||||
count = len(self._records)
|
||||
self._records.clear()
|
||||
return count
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
return dict(_SAMPLE_STATS)
|
||||
|
||||
def insert(self, **overrides: object) -> int:
|
||||
"""Helper: insert a record and return its id."""
|
||||
record = dict(_SAMPLE_DETAIL)
|
||||
record.update(overrides)
|
||||
record["id"] = self._next_id
|
||||
self._records[self._next_id] = record
|
||||
self._next_id += 1
|
||||
return record["id"]
|
||||
def _make_record(**overrides: object) -> dict:
|
||||
"""Create a sample record dict suitable for HistoryService.record()."""
|
||||
base = dict(_SAMPLE_RECORD)
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
# ── Dependency & inline router ───────────────────────────────────────────
|
||||
|
|
@ -139,7 +95,14 @@ def list_history(
|
|||
offset: int = Query(0, ge=0),
|
||||
svc=Depends(_get_history_service),
|
||||
):
|
||||
return svc.list_queries(limit=limit, offset=offset)
|
||||
queries = svc.list(limit=limit, offset=offset)
|
||||
stats = svc.get_stats()
|
||||
return {
|
||||
"queries": queries,
|
||||
"total": stats["total_queries"],
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
|
||||
@_router.get("/stats")
|
||||
|
|
@ -149,7 +112,7 @@ def get_stats(svc=Depends(_get_history_service)):
|
|||
|
||||
@_router.get("/{query_id}")
|
||||
def get_history_detail(query_id: int, svc=Depends(_get_history_service)):
|
||||
record = svc.get_query(query_id)
|
||||
record = svc.get(query_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail="Query not found")
|
||||
return record
|
||||
|
|
@ -157,7 +120,7 @@ def get_history_detail(query_id: int, svc=Depends(_get_history_service)):
|
|||
|
||||
@_router.delete("/{query_id}")
|
||||
def delete_history(query_id: int, svc=Depends(_get_history_service)):
|
||||
deleted = svc.delete_query(query_id)
|
||||
deleted = svc.delete(query_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Query not found")
|
||||
return {"status": "ok", "deleted_id": query_id}
|
||||
|
|
@ -173,17 +136,38 @@ def clear_all_history(svc=Depends(_get_history_service)):
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_svc() -> MockHistoryService:
|
||||
return MockHistoryService()
|
||||
def svc(tmp_path, monkeypatch):
|
||||
"""Real HistoryService backed by a temporary SQLite database."""
|
||||
history_db = str(tmp_path / "history.db")
|
||||
monkeypatch.setenv("HISTORY_DB_PATH", history_db)
|
||||
monkeypatch.setenv("PROMPTS_DB_PATH", str(tmp_path / "prompts.db"))
|
||||
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
from app.core.dependencies import get_settings_cached
|
||||
get_settings_cached.cache_clear()
|
||||
|
||||
conn = _get_db(history_db)
|
||||
init_history_db(conn)
|
||||
conn.close()
|
||||
|
||||
service = HistoryService(db_path=history_db)
|
||||
service.record(_SAMPLE_RECORD)
|
||||
|
||||
yield service
|
||||
|
||||
get_settings_cached.cache_clear()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(mock_svc: MockHistoryService) -> TestClient:
|
||||
app = FastAPI()
|
||||
app.include_router(_router)
|
||||
app.dependency_overrides[_get_history_service] = lambda: mock_svc
|
||||
yield TestClient(app)
|
||||
app.dependency_overrides.clear()
|
||||
def client(svc: HistoryService):
|
||||
"""TestClient wired with inline router + real HistoryService override."""
|
||||
test_app = FastAPI()
|
||||
test_app.include_router(_router)
|
||||
test_app.dependency_overrides[_get_history_service] = lambda: svc
|
||||
yield TestClient(test_app)
|
||||
test_app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════
|
||||
|
|
@ -206,9 +190,9 @@ class TestListHistory:
|
|||
assert data["total"] == 1
|
||||
assert len(data["queries"]) == 1
|
||||
|
||||
def test_custom_limit_and_offset(self, client: TestClient, mock_svc: MockHistoryService) -> None:
|
||||
def test_custom_limit_and_offset(self, client: TestClient, svc: HistoryService) -> None:
|
||||
for i in range(12):
|
||||
mock_svc.insert(input_text=f"Question {i + 2}")
|
||||
svc.record(_make_record(input_text=f"Question {i + 2}"))
|
||||
|
||||
resp = client.get("/api/v1/history", params={"limit": 5, "offset": 3})
|
||||
assert resp.status_code == 200
|
||||
|
|
@ -272,8 +256,11 @@ class TestGetHistoryDetail:
|
|||
data = client.get("/api/v1/history/1").json()
|
||||
assert "chunks_retrieved" in data
|
||||
assert "chunks_filtered" in data
|
||||
assert isinstance(data["chunks_retrieved"], list)
|
||||
assert isinstance(data["chunks_filtered"], list)
|
||||
# Real SQLite stores JSON as TEXT; verify they are parseable JSON arrays
|
||||
assert isinstance(data["chunks_retrieved"], str)
|
||||
assert isinstance(data["chunks_filtered"], str)
|
||||
assert isinstance(json.loads(data["chunks_retrieved"]), list)
|
||||
assert isinstance(json.loads(data["chunks_filtered"]), list)
|
||||
|
||||
def test_has_all_required_detail_fields(self, client: TestClient) -> None:
|
||||
data = client.get("/api/v1/history/1").json()
|
||||
|
|
@ -359,8 +346,8 @@ class TestClearAllHistory:
|
|||
assert isinstance(data["deleted_count"], int)
|
||||
assert data["deleted_count"] >= 1
|
||||
|
||||
def test_empties_list(self, client: TestClient, mock_svc: MockHistoryService) -> None:
|
||||
mock_svc.insert(input_text="extra query")
|
||||
def test_empties_list(self, client: TestClient, svc: HistoryService) -> None:
|
||||
svc.record(_make_record(input_text="extra query"))
|
||||
client.delete("/api/v1/history")
|
||||
data = client.get("/api/v1/history").json()
|
||||
assert data["total"] == 0
|
||||
|
|
@ -387,10 +374,10 @@ class TestHistoryStats:
|
|||
def test_response_shape(self, client: TestClient) -> None:
|
||||
data = client.get("/api/v1/history/stats").json()
|
||||
assert "total_queries" in data
|
||||
assert "avg_total_time_ms" in data
|
||||
assert "avg_time_ms" in data
|
||||
assert "avg_chunks_retrieved" in data
|
||||
assert "avg_chunks_filtered" in data
|
||||
assert "profile_distribution" in data
|
||||
assert "most_used_profile" in data
|
||||
|
||||
def test_total_queries_is_integer(self, client: TestClient) -> None:
|
||||
data = client.get("/api/v1/history/stats").json()
|
||||
|
|
@ -398,14 +385,12 @@ class TestHistoryStats:
|
|||
|
||||
def test_averages_are_numeric(self, client: TestClient) -> None:
|
||||
data = client.get("/api/v1/history/stats").json()
|
||||
assert isinstance(data["avg_total_time_ms"], (int, float))
|
||||
assert isinstance(data["avg_time_ms"], (int, float))
|
||||
assert isinstance(data["avg_chunks_retrieved"], (int, float))
|
||||
assert isinstance(data["avg_chunks_filtered"], (int, float))
|
||||
|
||||
def test_profile_distribution_values_are_integers(self, client: TestClient) -> None:
|
||||
dist = client.get("/api/v1/history/stats").json()["profile_distribution"]
|
||||
assert isinstance(dist, dict)
|
||||
for profile, count in dist.items():
|
||||
assert isinstance(count, int), (
|
||||
f"profile_distribution['{profile}'] should be int, got {type(count)}"
|
||||
)
|
||||
data = client.get("/api/v1/history/stats").json()
|
||||
# Real service returns most_used_profile (str or None), not a distribution dict
|
||||
profile = data["most_used_profile"]
|
||||
assert profile is None or isinstance(profile, str)
|
||||
|
|
|
|||
|
|
@ -2,45 +2,98 @@
|
|||
|
||||
Verifies that QueryDecomposer, RelevanceFilter, and RAGService
|
||||
correctly fetch templates from PromptService and substitute placeholders.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
from app.services.rag import RAGService
|
||||
Uses real PromptService (SQLite via tmp_path), real ChromaDB (tmp_path),
|
||||
and only mocks the external LLM API.
|
||||
"""
|
||||
import sqlite3
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
from app.core.sqlite_db import init_prompts_db, seed_default_profiles
|
||||
from app.services.prompt_service import PromptService
|
||||
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_custom_prompt_service(templates: dict[str, str]):
|
||||
"""Build a mock PromptService returning *templates* for get_prompt_template."""
|
||||
svc = MagicMock()
|
||||
svc.get_prompt_template = MagicMock(side_effect=lambda step: templates.get(step, ""))
|
||||
class _MockLLM:
|
||||
"""Mock external LLM API — only external dependency we're allowed to mock."""
|
||||
|
||||
def __init__(self, response: str = '["sub-q"]', side_effect: Exception | None = None):
|
||||
self._response = response
|
||||
self._side_effect = side_effect
|
||||
self.last_prompt: str | None = None
|
||||
self.calls: list[dict] = []
|
||||
self._call_count: int = 0
|
||||
|
||||
async def complete(
|
||||
self, prompt: str, temperature: float = 0.7, step_name: str = "LLM"
|
||||
) -> str:
|
||||
self.calls.append({"prompt": prompt, "step": step_name})
|
||||
self.last_prompt = prompt
|
||||
self._call_count += 1
|
||||
if self._side_effect:
|
||||
raise self._side_effect
|
||||
return self._response
|
||||
|
||||
@property
|
||||
def call_count(self) -> int:
|
||||
return self._call_count
|
||||
|
||||
def assert_called(self):
|
||||
assert self._call_count > 0, "LLM.complete was not called"
|
||||
|
||||
def assert_not_called(self):
|
||||
assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)"
|
||||
|
||||
|
||||
def _create_prompt_service(
|
||||
tmp_path, custom_templates: dict[str, str] | None = None
|
||||
) -> PromptService:
|
||||
"""Create a real PromptService backed by real SQLite in tmp_path.
|
||||
|
||||
Seeds default A/B/C profiles, then optionally updates the active profile
|
||||
(A) with *custom_templates* so the service returns controlled templates.
|
||||
"""
|
||||
db_path = str(tmp_path / "prompts.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
init_prompts_db(conn)
|
||||
seed_default_profiles(conn)
|
||||
conn.close()
|
||||
|
||||
svc = PromptService(db_path=db_path)
|
||||
if custom_templates:
|
||||
for step, template in custom_templates.items():
|
||||
svc.update_prompt("A", step, template)
|
||||
return svc
|
||||
|
||||
|
||||
def _make_llm(response: str = '["sub-q"]'):
|
||||
"""Build a mock LLM client that records the prompt sent."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value=response)
|
||||
return llm
|
||||
def _setup_chroma(tmp_path):
|
||||
"""Create an isolated real ChromaDB PersistentClient for a test."""
|
||||
chroma_dir = tmp_path / "chroma"
|
||||
chroma_dir.mkdir(parents=True, exist_ok=True)
|
||||
return chromadb.PersistentClient(path=str(chroma_dir))
|
||||
|
||||
|
||||
# ── QueryDecomposer tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_decomposer_fetches_template_from_prompt_service():
|
||||
async def test_decomposer_fetches_template_from_prompt_service(tmp_path):
|
||||
"""QueryDecomposer should use the template returned by PromptService."""
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
|
||||
custom_template = "CUSTOM DECOMPOSE: {question} -> split"
|
||||
ps = _make_custom_prompt_service({"decompose": custom_template})
|
||||
llm = _make_llm('["a"]')
|
||||
ps = _create_prompt_service(tmp_path, {"decompose": custom_template})
|
||||
llm = _MockLLM('["a"]')
|
||||
|
||||
d = QueryDecomposer(llm, prompt_service=ps)
|
||||
questions, returned_prompt = await d.decompose("What is X?")
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert sent_prompt.startswith("CUSTOM DECOMPOSE:")
|
||||
assert "What is X?" in sent_prompt
|
||||
assert returned_prompt == sent_prompt
|
||||
|
|
@ -48,11 +101,13 @@ async def test_decomposer_fetches_template_from_prompt_service():
|
|||
|
||||
async def test_decomposer_uses_builtin_when_no_prompt_service():
|
||||
"""Without prompt_service, the built-in seed template is used."""
|
||||
llm = _make_llm('["a"]')
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
|
||||
llm = _MockLLM('["a"]')
|
||||
d = QueryDecomposer(llm, prompt_service=None)
|
||||
questions, returned_prompt = await d.decompose("What is X?")
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert "Break it down into 2-5 simplified sub-questions" in sent_prompt
|
||||
assert "What is X?" in sent_prompt
|
||||
assert returned_prompt == sent_prompt
|
||||
|
|
@ -61,17 +116,19 @@ async def test_decomposer_uses_builtin_when_no_prompt_service():
|
|||
# ── RelevanceFilter tests ───────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_filter_fetches_template_from_prompt_service():
|
||||
async def test_filter_fetches_template_from_prompt_service(tmp_path):
|
||||
"""RelevanceFilter should use the template from PromptService."""
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
custom_template = "FILTER: q={question} chunks={chunks}"
|
||||
ps = _make_custom_prompt_service({"filter": custom_template})
|
||||
llm = _make_llm("[5.0]")
|
||||
ps = _create_prompt_service(tmp_path, {"filter": custom_template})
|
||||
llm = _MockLLM("[5.0]")
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
chunks = [("text A", {"filename": "a.pdf"})]
|
||||
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert sent_prompt.startswith("FILTER:")
|
||||
assert "My question" in sent_prompt
|
||||
assert "text A" in sent_prompt
|
||||
|
|
@ -80,12 +137,14 @@ async def test_filter_fetches_template_from_prompt_service():
|
|||
|
||||
async def test_filter_uses_builtin_when_no_prompt_service():
|
||||
"""Without prompt_service, the built-in filter template is used."""
|
||||
llm = _make_llm("[5.0]")
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
llm = _MockLLM("[5.0]")
|
||||
rf = RelevanceFilter(llm, prompt_service=None)
|
||||
chunks = [("text A", {"filename": "a.pdf"})]
|
||||
filtered, returned_prompt = await rf.filter("My question", chunks, threshold=3.0)
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert "rate each 0-10 for relevance" in sent_prompt
|
||||
assert "My question" in sent_prompt
|
||||
|
||||
|
|
@ -93,24 +152,23 @@ async def test_filter_uses_builtin_when_no_prompt_service():
|
|||
# ── RAGService generate tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_generate_fetches_template_from_prompt_service():
|
||||
async def test_generate_fetches_template_from_prompt_service(tmp_path):
|
||||
"""RAGService.generate_response should use PromptService template."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
custom_template = "GEN: {question} --- {context} END"
|
||||
ps = _make_custom_prompt_service({"generate": custom_template})
|
||||
llm = _make_llm("answer")
|
||||
ps = _create_prompt_service(tmp_path, {"generate": custom_template})
|
||||
llm = _MockLLM("answer")
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
|
||||
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps)
|
||||
answer, gen_prompt = await svc.generate_response(
|
||||
"What is X?",
|
||||
["chunk data"],
|
||||
[{"filename": "f.txt", "content_summary": "sum"}],
|
||||
)
|
||||
|
||||
sent_prompt = llm.complete.call_args[1]["prompt"]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert sent_prompt.startswith("GEN:")
|
||||
assert "What is X?" in sent_prompt
|
||||
assert "chunk data" in sent_prompt
|
||||
|
|
@ -118,22 +176,21 @@ async def test_generate_fetches_template_from_prompt_service():
|
|||
assert gen_prompt == sent_prompt
|
||||
|
||||
|
||||
async def test_generate_uses_builtin_when_no_prompt_service():
|
||||
async def test_generate_uses_builtin_when_no_prompt_service(tmp_path):
|
||||
"""Without prompt_service, the built-in generate template is used."""
|
||||
llm = _make_llm("answer")
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
llm = _MockLLM("answer")
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None)
|
||||
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=None)
|
||||
answer, gen_prompt = await svc.generate_response(
|
||||
"What is X?",
|
||||
["chunk data"],
|
||||
[{"filename": "f.txt", "content_summary": "sum"}],
|
||||
)
|
||||
|
||||
sent_prompt = llm.complete.call_args[1]["prompt"]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert "What is X?" in sent_prompt
|
||||
assert gen_prompt == sent_prompt
|
||||
|
||||
|
|
@ -141,47 +198,50 @@ async def test_generate_uses_builtin_when_no_prompt_service():
|
|||
# ── Placeholder substitution safety tests ───────────────────────────────
|
||||
|
||||
|
||||
async def test_placeholder_substitution_safe_with_curly_braces():
|
||||
async def test_placeholder_substitution_safe_with_curly_braces(tmp_path):
|
||||
"""User text containing curly braces must not crash str.replace."""
|
||||
ps = _make_custom_prompt_service({
|
||||
"decompose": "Question: {question} — decompose it"
|
||||
})
|
||||
llm = _make_llm('["a"]')
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
llm = _MockLLM('["a"]')
|
||||
|
||||
d = QueryDecomposer(llm, prompt_service=ps)
|
||||
# This question has literal braces — must not raise KeyError
|
||||
result, returned_prompt = await d.decompose("What about {key: value}?")
|
||||
assert isinstance(result, list)
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert "{key: value}" in sent_prompt
|
||||
assert returned_prompt == sent_prompt
|
||||
|
||||
|
||||
async def test_unknown_placeholder_left_untouched():
|
||||
async def test_unknown_placeholder_left_untouched(tmp_path):
|
||||
"""Placeholders not matched by str.replace stay as-is in the prompt."""
|
||||
ps = _make_custom_prompt_service({
|
||||
"decompose": "HELLO {fake_var} and {question}"
|
||||
})
|
||||
llm = _make_llm('["a"]')
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
|
||||
ps = _create_prompt_service(
|
||||
tmp_path, {"decompose": "HELLO {fake_var} and {question}"}
|
||||
)
|
||||
llm = _MockLLM('["a"]')
|
||||
|
||||
d = QueryDecomposer(llm, prompt_service=ps)
|
||||
questions, returned_prompt = await d.decompose("Q?")
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert "{fake_var}" in sent_prompt
|
||||
assert "Q?" in sent_prompt
|
||||
|
||||
|
||||
async def test_empty_template_produces_empty_prompt():
|
||||
"""An empty template string results in an empty (or question-only) prompt."""
|
||||
ps = _make_custom_prompt_service({"decompose": ""})
|
||||
llm = _make_llm('["a"]')
|
||||
async def test_empty_template_produces_empty_prompt(tmp_path):
|
||||
"""An empty template string results in an empty prompt."""
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
|
||||
ps = _create_prompt_service(tmp_path, {"decompose": ""})
|
||||
llm = _MockLLM('["a"]')
|
||||
|
||||
d = QueryDecomposer(llm, prompt_service=ps)
|
||||
questions, returned_prompt = await d.decompose("Doesn't matter")
|
||||
|
||||
sent_prompt = llm.complete.call_args[0][0]
|
||||
sent_prompt = llm.last_prompt
|
||||
# Empty template with .replace("{question}", ...) still has no text
|
||||
assert sent_prompt == ""
|
||||
|
||||
|
|
@ -189,73 +249,67 @@ async def test_empty_template_produces_empty_prompt():
|
|||
# ── Edge case: no question / no chunks ──────────────────────────────────
|
||||
|
||||
|
||||
async def test_decomposer_no_question_returns_empty():
|
||||
"""Empty question returns [] without calling prompt_service."""
|
||||
ps = MagicMock()
|
||||
ps.get_prompt_template = MagicMock(return_value="tmpl")
|
||||
async def test_decomposer_no_question_returns_empty(tmp_path):
|
||||
"""Empty question returns [] without calling LLM."""
|
||||
from app.services.query_decomposer import QueryDecomposer
|
||||
|
||||
llm = _make_llm('["should_not_see"]')
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
llm = _MockLLM('["should_not_see"]')
|
||||
d = QueryDecomposer(llm, prompt_service=ps)
|
||||
result, returned_prompt = await d.decompose("")
|
||||
|
||||
assert result == []
|
||||
assert returned_prompt == ""
|
||||
llm.complete.assert_not_called()
|
||||
ps.get_prompt_template.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
async def test_filter_empty_chunks_no_template_fetch():
|
||||
"""Empty chunks list returns [] without fetching a template."""
|
||||
ps = MagicMock()
|
||||
ps.get_prompt_template = MagicMock(return_value="tmpl")
|
||||
async def test_filter_empty_chunks_no_template_fetch(tmp_path):
|
||||
"""Empty chunks list returns [] without calling LLM."""
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
llm = _make_llm("[5]")
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
llm = _MockLLM("[5]")
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
result, returned_prompt = await rf.filter("Q?", [], threshold=5.0)
|
||||
|
||||
assert result == []
|
||||
assert returned_prompt == ""
|
||||
llm.complete.assert_not_called()
|
||||
ps.get_prompt_template.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
async def test_generate_no_chunks_returns_fallback():
|
||||
"""No chunks returns fallback message without touching PromptService."""
|
||||
ps = MagicMock()
|
||||
ps.get_prompt_template = MagicMock(return_value="tmpl")
|
||||
async def test_generate_no_chunks_returns_fallback(tmp_path):
|
||||
"""No chunks returns fallback message without calling LLM."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
llm = _make_llm("answer")
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
llm = _MockLLM("answer")
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
|
||||
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps)
|
||||
answer, gen_prompt = await svc.generate_response("Q?", [], [])
|
||||
|
||||
assert "could not find" in answer.lower()
|
||||
assert gen_prompt == ""
|
||||
llm.complete.assert_not_called()
|
||||
ps.get_prompt_template.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
async def test_generate_per_subq_fetches_template_from_prompt_service():
|
||||
async def test_generate_per_subq_fetches_template_from_prompt_service(tmp_path):
|
||||
"""RAGService.generate_response_per_subquestion should use PromptService template."""
|
||||
from app.services.rag import RAGService
|
||||
|
||||
custom_template = "PER_SUBQ: {context_sections} DONE"
|
||||
ps = _make_custom_prompt_service({"generate_per_subq": custom_template})
|
||||
llm = _make_llm("answer")
|
||||
ps = _create_prompt_service(tmp_path, {"generate_per_subq": custom_template})
|
||||
llm = _MockLLM("answer")
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=ps)
|
||||
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=ps)
|
||||
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
|
||||
["What is X?"],
|
||||
[["chunk data"]],
|
||||
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
||||
)
|
||||
|
||||
sent_prompt = llm.complete.call_args[1]["prompt"]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert sent_prompt.startswith("PER_SUBQ:")
|
||||
assert "chunk data" in sent_prompt
|
||||
assert sent_prompt.endswith("DONE")
|
||||
|
|
@ -263,22 +317,21 @@ async def test_generate_per_subq_fetches_template_from_prompt_service():
|
|||
assert len(grouped_sources) == 1
|
||||
|
||||
|
||||
async def test_generate_per_subq_uses_builtin_when_no_prompt_service():
|
||||
async def test_generate_per_subq_uses_builtin_when_no_prompt_service(tmp_path):
|
||||
"""Without prompt_service, the built-in per-subq template is used."""
|
||||
llm = _make_llm("answer")
|
||||
from app.services.rag import RAGService
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
llm = _MockLLM("answer")
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
svc = RAGService(chroma_client=mock_client, llm_client=llm, prompt_service=None)
|
||||
svc = RAGService(chroma_client=client, llm_client=llm, prompt_service=None)
|
||||
answer, gen_prompt, grouped_sources = await svc.generate_response_per_subquestion(
|
||||
["What is X?"],
|
||||
[["chunk data"]],
|
||||
[[{"filename": "f.txt", "content_summary": "sum"}]],
|
||||
)
|
||||
|
||||
sent_prompt = llm.complete.call_args[1]["prompt"]
|
||||
sent_prompt = llm.last_prompt
|
||||
assert "Sub-question" in sent_prompt
|
||||
assert "chunk data" in sent_prompt
|
||||
assert "{context_sections}" not in sent_prompt
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -6,33 +6,73 @@ Covers sub-question-organized response generation:
|
|||
- All-empty chunks fallback
|
||||
- Prompt contains context_sections placeholder
|
||||
- LLM client not configured fallback
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.services.rag import RAGService
|
||||
Uses real ChromaDB (tmp_path) and only mocks the external LLM API.
|
||||
"""
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
|
||||
class _MockLLM:
|
||||
"""Mock external LLM API."""
|
||||
|
||||
def __init__(self, response: str = "mock answer", side_effect: Exception | None = None):
|
||||
self._response = response
|
||||
self._side_effect = side_effect
|
||||
self.last_prompt: str | None = None
|
||||
self.calls: list[dict] = []
|
||||
self._call_count: int = 0
|
||||
|
||||
async def complete(
|
||||
self, prompt: str, temperature: float = 0.7, step_name: str = "LLM"
|
||||
) -> str:
|
||||
self.calls.append({"prompt": prompt, "step": step_name})
|
||||
self.last_prompt = prompt
|
||||
self._call_count += 1
|
||||
if self._side_effect:
|
||||
raise self._side_effect
|
||||
return self._response
|
||||
|
||||
@property
|
||||
def call_count(self) -> int:
|
||||
return self._call_count
|
||||
|
||||
def assert_called(self):
|
||||
assert self._call_count > 0, "LLM.complete was not called"
|
||||
|
||||
def assert_not_called(self):
|
||||
assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)"
|
||||
|
||||
|
||||
def _setup_chroma(tmp_path):
|
||||
"""Create an isolated real ChromaDB PersistentClient."""
|
||||
chroma_dir = tmp_path / "chroma"
|
||||
chroma_dir.mkdir(parents=True, exist_ok=True)
|
||||
return chromadb.PersistentClient(path=str(chroma_dir))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: two sub-questions, LLM returns markdown with headers
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_per_subq_two_questions():
|
||||
async def test_generate_per_subq_two_questions(tmp_path):
|
||||
"""Two sub-questions, 2 chunks for first, 1 for second.
|
||||
|
||||
LLM returns markdown with ## Sub-question 1/2 headers.
|
||||
Assert answer contains both headers and grouped_sources has correct shape.
|
||||
"""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value=(
|
||||
from app.services.rag import RAGService
|
||||
|
||||
llm = _MockLLM(response=(
|
||||
"## Sub-question 1: What is A?\n"
|
||||
"- Bullet point A1 [file_a.pdf, page 1]\n"
|
||||
"- Bullet point A2 [file_a.pdf, page 2]\n\n"
|
||||
"## Sub-question 2: What is B?\n"
|
||||
"- Bullet point B1 [file_b.pdf, page 1]\n"
|
||||
))
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
answer, prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["What is A?", "What is B?"],
|
||||
sub_chunks=[
|
||||
|
|
@ -53,23 +93,26 @@ async def test_generate_per_subq_two_questions():
|
|||
assert "## Sub-question 1: What is A?" in answer
|
||||
assert "## Sub-question 2: What is B?" in answer
|
||||
assert len(grouped_sources) == 2
|
||||
assert len(grouped_sources[0]) == 2 # 2 sources for sub-q 0
|
||||
assert len(grouped_sources[1]) == 1 # 1 source for sub-q 1
|
||||
assert len(grouped_sources[0]) == 2
|
||||
assert len(grouped_sources[1]) == 1
|
||||
assert grouped_sources[0][0]["filename"] == "file_a.pdf"
|
||||
assert grouped_sources[1][0]["filename"] == "file_b.pdf"
|
||||
llm.complete.assert_called_once()
|
||||
llm.assert_called()
|
||||
assert llm.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: empty input
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_per_subq_empty_input():
|
||||
async def test_generate_per_subq_empty_input(tmp_path):
|
||||
"""Empty sub_questions returns fallback message and empty grouped_sources."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock()
|
||||
from app.services.rag import RAGService
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
llm = _MockLLM()
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
answer, prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=[],
|
||||
sub_chunks=[],
|
||||
|
|
@ -78,19 +121,21 @@ async def test_generate_per_subq_empty_input():
|
|||
|
||||
assert answer == "I could not find any relevant information to answer your question."
|
||||
assert grouped_sources == []
|
||||
llm.complete.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: sub-questions provided but all chunk lists empty
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_per_subq_no_chunks():
|
||||
async def test_generate_per_subq_no_chunks(tmp_path):
|
||||
"""Sub-questions provided but all chunk lists empty → fallback message."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock()
|
||||
from app.services.rag import RAGService
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
llm = _MockLLM()
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
answer, prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["What is A?", "What is B?"],
|
||||
sub_chunks=[[], []],
|
||||
|
|
@ -99,33 +144,29 @@ async def test_generate_per_subq_no_chunks():
|
|||
|
||||
assert answer == "I could not find any relevant information to answer your question."
|
||||
assert grouped_sources == []
|
||||
llm.complete.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: prompt contains context_sections placeholder
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_per_subq_prompt_contains_context_sections():
|
||||
async def test_generate_per_subq_prompt_contains_context_sections(tmp_path):
|
||||
"""Verify the prompt sent to LLM contains ### Context for Sub-question 0:
|
||||
header and chunk content."""
|
||||
captured_prompt = None
|
||||
from app.services.rag import RAGService
|
||||
|
||||
async def capture_complete(prompt, **kwargs):
|
||||
nonlocal captured_prompt
|
||||
captured_prompt = prompt
|
||||
return "## Sub-question 1: What is A?\n- Answer"
|
||||
llm = _MockLLM(response="## Sub-question 1: What is A?\n- Answer")
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(side_effect=capture_complete)
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
await service.generate_response_per_subquestion(
|
||||
sub_questions=["What is A?"],
|
||||
sub_chunks=[["chunk text here"]],
|
||||
sub_metadata=[[{"filename": "file_a.pdf", "page_number": 1, "content_summary": "Sum"}]],
|
||||
)
|
||||
|
||||
captured_prompt = llm.last_prompt
|
||||
assert captured_prompt is not None
|
||||
assert "### Context for Sub-question 0:" in captured_prompt
|
||||
assert "chunk text here" in captured_prompt
|
||||
|
|
@ -136,9 +177,13 @@ async def test_generate_per_subq_prompt_contains_context_sections():
|
|||
# Test: LLM client not configured
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_per_subq_llm_not_configured():
|
||||
async def test_generate_per_subq_llm_not_configured(tmp_path):
|
||||
"""llm_client=None → returns 'LLM client not configured' message."""
|
||||
service = RAGService(llm_client=None)
|
||||
from app.services.rag import RAGService
|
||||
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(chroma_client=client, llm_client=None)
|
||||
answer, prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["What is A?"],
|
||||
sub_chunks=[["some chunk"]],
|
||||
|
|
|
|||
|
|
@ -1,96 +1,157 @@
|
|||
"""Phase 4 integration test: Full per-sub-question query pipeline.
|
||||
|
||||
Simulates the complete 4-stage pipeline (decompose → retrieve → filter → generate)
|
||||
using mocked services, verifying end-to-end data flow and SSE event emission.
|
||||
Uses TestClient hitting POST /api/v1/query with real ChromaDB and SQLite.
|
||||
Only the LLM (external API) is mocked.
|
||||
|
||||
Verifies end-to-end data flow and SSE event emission through the real
|
||||
HTTP endpoint, real ChromaDB retrieval, and real SQLite services.
|
||||
|
||||
Key behaviours under test:
|
||||
- Full pipeline with 2 sub-questions produces grouped results
|
||||
- Empty decomposition falls back to original question (Decision #13)
|
||||
- Single sub-question still uses ## Sub-question N format
|
||||
- All chunks filtered out returns "no relevant information"
|
||||
- One sub-q with empty retrieval still produces partial answer
|
||||
|
||||
All external services (LLM, ChromaDB) are mocked.
|
||||
Tests call ``_query_stream()`` directly via ``async for`` — no HTTP layer.
|
||||
- One sub-q with all chunks filtered out produces partial answer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import sqlite3
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.models.query import QueryRequest
|
||||
from app.core.config import Settings
|
||||
from app.core.sqlite_db import init_history_db, init_prompts_db, seed_default_profiles
|
||||
|
||||
|
||||
# ── Shared fixtures ──────────────────────────────────────────────────────
|
||||
# ── Test seed data ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
CHUNK_A = (
|
||||
"Time extensions must be notified within 8 weeks.",
|
||||
{"filename": "NEC4.pdf", "page_number": 3, "content_summary": "Time extensions", "chunk_index": 0, "upload_date": "2024-01-01"},
|
||||
0.15,
|
||||
)
|
||||
CHUNK_B = (
|
||||
"Notice must be given to the project manager.",
|
||||
{"filename": "NEC4.pdf", "page_number": 12, "content_summary": "Notification", "chunk_index": 1, "upload_date": "2024-01-01"},
|
||||
0.22,
|
||||
)
|
||||
SEED_DOCS = [
|
||||
{
|
||||
"text": "Time extensions must be notified within 8 weeks.",
|
||||
"metadata": {
|
||||
"filename": "NEC4.pdf",
|
||||
"page_number": 3,
|
||||
"content_summary": "Time extensions",
|
||||
"chunk_index": 0,
|
||||
"upload_date": "2024-01-01",
|
||||
},
|
||||
},
|
||||
{
|
||||
"text": "Notice must be given to the project manager before expiry of the period.",
|
||||
"metadata": {
|
||||
"filename": "NEC4.pdf",
|
||||
"page_number": 12,
|
||||
"content_summary": "Notification",
|
||||
"chunk_index": 1,
|
||||
"upload_date": "2024-01-01",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _make_settings():
|
||||
s = MagicMock()
|
||||
s.retrieval_n_results = 10
|
||||
s.relevance_threshold = 7.0
|
||||
s.prompts_db_path = ":memory:"
|
||||
s.history_db_path = ":memory:"
|
||||
return s
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_prompt_service():
|
||||
ps = MagicMock()
|
||||
ps.get_active_profile_name.return_value = "default"
|
||||
ps.get_prompt_template = MagicMock(
|
||||
side_effect=lambda step: {
|
||||
"decompose": "Given question: '{question}' — decompose.",
|
||||
"filter": "Rate chunks 0-10 for: {question}\n{chunks}",
|
||||
"generate": "Answer: {question}\nContext:\n{context}",
|
||||
"generate_per_subq": "Answer per sub-q:\n{context_sections}",
|
||||
}.get(step, "")
|
||||
class _ConstantEmbedding:
|
||||
"""Identical vectors for all inputs — all docs equally match any query."""
|
||||
|
||||
DIM = 10
|
||||
|
||||
def __call__(self, input):
|
||||
return [[0.1] * self.DIM for _ in input]
|
||||
|
||||
def embed_query(self, input):
|
||||
return self(input)
|
||||
|
||||
def name(self):
|
||||
return "constant_test"
|
||||
|
||||
|
||||
def _make_mock_llm_class(responses):
|
||||
"""Create a mock LLMClient class that returns *responses* in sequence."""
|
||||
|
||||
class _MockLLM:
|
||||
def __init__(self, settings):
|
||||
self.settings = settings
|
||||
self._responses = list(responses)
|
||||
self._idx = 0
|
||||
|
||||
async def complete(self, prompt, temperature=0.7, step_name="LLM"):
|
||||
if self._idx < len(self._responses):
|
||||
resp = self._responses[self._idx]
|
||||
self._idx += 1
|
||||
return resp
|
||||
raise RuntimeError(f"No more mock responses (call #{self._idx + 1})")
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
return _MockLLM
|
||||
|
||||
|
||||
def _setup_env(tmp_path, monkeypatch, seed_docs=None):
|
||||
"""Set up real ChromaDB + SQLite via tmp_path for pipeline tests."""
|
||||
seed_docs = seed_docs or SEED_DOCS
|
||||
chroma_dir = tmp_path / "chroma_db"
|
||||
chroma_dir.mkdir()
|
||||
prompts_db = str(tmp_path / "prompts.db")
|
||||
history_db = str(tmp_path / "history.db")
|
||||
|
||||
test_settings = Settings(
|
||||
chroma_db_path=str(chroma_dir),
|
||||
prompts_db_path=prompts_db,
|
||||
history_db_path=history_db,
|
||||
llm_base_url="http://mock-llm",
|
||||
llm_api_key="test-key",
|
||||
llm_model_name="test-model",
|
||||
embedding_base_url="http://mock-emb",
|
||||
embedding_api_key="test-key",
|
||||
embedding_model="test-model",
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(prompts_db)
|
||||
conn.row_factory = sqlite3.Row
|
||||
init_prompts_db(conn)
|
||||
seed_default_profiles(conn)
|
||||
conn.close()
|
||||
|
||||
conn = sqlite3.connect(history_db)
|
||||
conn.row_factory = sqlite3.Row
|
||||
init_history_db(conn)
|
||||
conn.close()
|
||||
|
||||
embedding_fn = _ConstantEmbedding()
|
||||
client = chromadb.PersistentClient(path=str(chroma_dir))
|
||||
collection = client.get_or_create_collection(
|
||||
"documents", embedding_function=embedding_fn
|
||||
)
|
||||
for i, doc in enumerate(seed_docs):
|
||||
collection.add(
|
||||
documents=[doc["text"]],
|
||||
metadatas=[doc["metadata"]],
|
||||
ids=[f"test_doc_{i}"],
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.routers.query.get_settings", lambda: test_settings)
|
||||
monkeypatch.setattr("app.core.database.get_settings", lambda: test_settings)
|
||||
monkeypatch.setattr(
|
||||
"app.core.database.get_embedding_function_settings",
|
||||
lambda s: embedding_fn,
|
||||
)
|
||||
return ps
|
||||
|
||||
|
||||
def _make_llm(decompose_resp, filter_resp, generate_resp):
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(side_effect=[decompose_resp, filter_resp, generate_resp])
|
||||
return llm
|
||||
|
||||
|
||||
def _make_chroma(chunks: list):
|
||||
"""Return a mock collection that returns *chunks* from query()."""
|
||||
col = MagicMock()
|
||||
col.query.return_value = {
|
||||
"documents": [[c[0] for c in chunks]],
|
||||
"metadatas": [[c[1] for c in chunks]],
|
||||
"distances": [[c[2] for c in chunks]],
|
||||
}
|
||||
return col
|
||||
|
||||
|
||||
def _mock_chroma_client(collection):
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
|
||||
async def _collect_sse(request: QueryRequest):
|
||||
"""Run _query_stream and return list of parsed SSE event dicts."""
|
||||
from app.routers.query import _query_stream
|
||||
def _collect_sse(client, question):
|
||||
"""POST to /api/v1/query and collect SSE events as parsed dicts."""
|
||||
events = []
|
||||
async for raw in _query_stream(request):
|
||||
# raw is like "data: {...}\n\n"
|
||||
for line in raw.split("\n"):
|
||||
with client.stream(
|
||||
"POST", "/api/v1/query", json={"question": question}
|
||||
) as response:
|
||||
assert response.status_code == 200
|
||||
for line in response.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
|
@ -99,10 +160,13 @@ async def _collect_sse(request: QueryRequest):
|
|||
# ── Tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_full_pipeline_with_two_subquestions():
|
||||
def test_full_pipeline_with_two_subquestions(tmp_path, monkeypatch):
|
||||
"""Two sub-questions flow through all 4 stages with per-sub-q grouping."""
|
||||
_setup_env(tmp_path, monkeypatch)
|
||||
|
||||
decompose_resp = '["What are time extensions?", "What notice is required?"]'
|
||||
filter_resp = '{"0": [8.5, 3.2], "1": [9.0]}'
|
||||
# 2 sub-qs × 2 chunks each; all above threshold 7.0
|
||||
filter_resp = '{"0": [8.5, 9.0], "1": [8.5, 9.0]}'
|
||||
generate_resp = (
|
||||
"## Sub-question 1: What are time extensions?\n"
|
||||
"- Extensions need 8 weeks notice [NEC4.pdf, page 3]\n\n"
|
||||
|
|
@ -110,63 +174,18 @@ async def test_full_pipeline_with_two_subquestions():
|
|||
"- Notify the project manager [NEC4.pdf, page 12]\n"
|
||||
)
|
||||
|
||||
llm = _make_llm(decompose_resp, filter_resp, generate_resp)
|
||||
chroma = _make_chroma([CHUNK_A, CHUNK_B])
|
||||
settings = _make_settings()
|
||||
ps = _make_prompt_service()
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
_make_mock_llm_class([decompose_resp, filter_resp, generate_resp]),
|
||||
)
|
||||
|
||||
request = QueryRequest(question="What are the time extension rules?")
|
||||
from app.main import app
|
||||
|
||||
with patch("app.routers.query.get_settings", return_value=settings), \
|
||||
patch("app.routers.query.PromptService", return_value=ps), \
|
||||
patch("app.routers.query.LLMClient", return_value=llm), \
|
||||
patch("app.routers.query.RAGService") as MockRAG, \
|
||||
patch("app.routers.query.QueryDecomposer") as MockDec, \
|
||||
patch("app.routers.query.RelevanceFilter") as MockFilter, \
|
||||
patch("app.routers.query.HistoryService") as MockHist, \
|
||||
patch("app.routers.query._schedule_history"):
|
||||
|
||||
# Wire decomposer
|
||||
dec = MockDec.return_value
|
||||
dec.decompose = AsyncMock(return_value=(
|
||||
["What are time extensions?", "What notice is required?"],
|
||||
"decompose-prompt-text"
|
||||
))
|
||||
|
||||
# Wire RAG
|
||||
rag = MockRAG.return_value
|
||||
rag.retrieve_per_subquestion.return_value = [
|
||||
("What are time extensions?", [CHUNK_A, CHUNK_B]),
|
||||
("What notice is required?", [CHUNK_A]),
|
||||
]
|
||||
rag.generate_response_per_subquestion = AsyncMock(return_value=(
|
||||
generate_resp,
|
||||
"gen-prompt-text",
|
||||
[
|
||||
[CHUNK_A[1], CHUNK_B[1]], # sources for sub-q 0
|
||||
[CHUNK_A[1]], # sources for sub-q 1
|
||||
],
|
||||
))
|
||||
|
||||
# Wire filter
|
||||
filt = MockFilter.return_value
|
||||
filt.filter_per_subquestion = AsyncMock(return_value=(
|
||||
[
|
||||
("What are time extensions?", [
|
||||
(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5}),
|
||||
]),
|
||||
("What notice is required?", [
|
||||
(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 9.0}),
|
||||
]),
|
||||
],
|
||||
"filter-prompt-text"
|
||||
))
|
||||
|
||||
events = await _collect_sse(request)
|
||||
client = TestClient(app)
|
||||
events = _collect_sse(client, "What are the time extension rules?")
|
||||
|
||||
phases = [e["phase"] for e in events]
|
||||
|
||||
# Should emit all expected phases
|
||||
assert "decomposed" in phases
|
||||
assert "retrieving" in phases
|
||||
assert "filtering" in phases
|
||||
|
|
@ -174,11 +193,9 @@ async def test_full_pipeline_with_two_subquestions():
|
|||
assert "generating_subquestion" in phases
|
||||
assert "completed" in phases
|
||||
|
||||
# Decomposed event has extracted questions
|
||||
dec_evt = next(e for e in events if e["phase"] == "decomposed")
|
||||
assert len(dec_evt["extracted_questions"]) == 2
|
||||
|
||||
# Completed event has per-sub-q sources
|
||||
comp_evt = next(e for e in events if e["phase"] == "completed")
|
||||
assert "sub_question_sources" in comp_evt
|
||||
sq_sources = comp_evt["sub_question_sources"]
|
||||
|
|
@ -186,56 +203,35 @@ async def test_full_pipeline_with_two_subquestions():
|
|||
assert sq_sources[0]["sub_question_text"] == "What are time extensions?"
|
||||
assert sq_sources[1]["sub_question_text"] == "What notice is required?"
|
||||
|
||||
# Answer has sub-question headers
|
||||
assert "## Sub-question 1:" in comp_evt["answer"]
|
||||
assert "## Sub-question 2:" in comp_evt["answer"]
|
||||
|
||||
# generating_subquestion events
|
||||
gen_subq = [e for e in events if e["phase"] == "generating_subquestion"]
|
||||
assert len(gen_subq) == 2
|
||||
assert gen_subq[0]["sub_question_index"] == 0
|
||||
assert gen_subq[1]["sub_question_index"] == 1
|
||||
|
||||
|
||||
async def test_pipeline_with_empty_decomposition():
|
||||
def test_pipeline_with_empty_decomposition(tmp_path, monkeypatch):
|
||||
"""Empty decomposition falls back to original question as single sub-q."""
|
||||
generate_resp = "## Sub-question 1: What is the time limit?\n- Answer here\n"
|
||||
_setup_env(tmp_path, monkeypatch)
|
||||
|
||||
llm = MagicMock()
|
||||
ps = _make_prompt_service()
|
||||
settings = _make_settings()
|
||||
decompose_resp = "[]"
|
||||
# 1 fallback sub-q × 2 chunks
|
||||
filter_resp = '{"0": [8.5, 9.0]}'
|
||||
generate_resp = (
|
||||
"## Sub-question 1: What is the time limit?\n- Answer here\n"
|
||||
)
|
||||
|
||||
request = QueryRequest(question="What is the time limit?")
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
_make_mock_llm_class([decompose_resp, filter_resp, generate_resp]),
|
||||
)
|
||||
|
||||
with patch("app.routers.query.get_settings", return_value=settings), \
|
||||
patch("app.routers.query.PromptService", return_value=ps), \
|
||||
patch("app.routers.query.LLMClient", return_value=llm), \
|
||||
patch("app.routers.query.RAGService") as MockRAG, \
|
||||
patch("app.routers.query.QueryDecomposer") as MockDec, \
|
||||
patch("app.routers.query.RelevanceFilter") as MockFilter, \
|
||||
patch("app.routers.query.HistoryService") as MockHist, \
|
||||
patch("app.routers.query._schedule_history"):
|
||||
from app.main import app
|
||||
|
||||
dec = MockDec.return_value
|
||||
dec.decompose = AsyncMock(return_value=([], "decompose-prompt"))
|
||||
|
||||
rag = MockRAG.return_value
|
||||
rag.retrieve_per_subquestion.return_value = [
|
||||
("What is the time limit?", [CHUNK_A]),
|
||||
]
|
||||
rag.generate_response_per_subquestion = AsyncMock(return_value=(
|
||||
generate_resp,
|
||||
"gen-prompt",
|
||||
[[CHUNK_A[1]]],
|
||||
))
|
||||
|
||||
filt = MockFilter.return_value
|
||||
filt.filter_per_subquestion = AsyncMock(return_value=(
|
||||
[("What is the time limit?", [(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5})])],
|
||||
"filter-prompt"
|
||||
))
|
||||
|
||||
events = await _collect_sse(request)
|
||||
client = TestClient(app)
|
||||
events = _collect_sse(client, "What is the time limit?")
|
||||
|
||||
phases = [e["phase"] for e in events]
|
||||
assert "decomposed" in phases
|
||||
|
|
@ -246,103 +242,65 @@ async def test_pipeline_with_empty_decomposition():
|
|||
comp_evt = next(e for e in events if e["phase"] == "completed")
|
||||
assert "## Sub-question 1:" in comp_evt["answer"]
|
||||
|
||||
rag.retrieve_per_subquestion.assert_called_once_with(
|
||||
["What is the time limit?"], n_results=10,
|
||||
)
|
||||
|
||||
|
||||
async def test_pipeline_single_subquestion():
|
||||
def test_pipeline_single_subquestion(tmp_path, monkeypatch):
|
||||
"""Single sub-question still uses per-sub-q format with ## header."""
|
||||
_setup_env(tmp_path, monkeypatch)
|
||||
|
||||
decompose_resp = '["What is X?"]'
|
||||
filter_resp = '{"0": [8.5, 9.0]}'
|
||||
generate_resp = "## Sub-question 1: What is X?\n- Answer here\n"
|
||||
|
||||
llm = MagicMock()
|
||||
ps = _make_prompt_service()
|
||||
settings = _make_settings()
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
_make_mock_llm_class([decompose_resp, filter_resp, generate_resp]),
|
||||
)
|
||||
|
||||
request = QueryRequest(question="What is X?")
|
||||
from app.main import app
|
||||
|
||||
with patch("app.routers.query.get_settings", return_value=settings), \
|
||||
patch("app.routers.query.PromptService", return_value=ps), \
|
||||
patch("app.routers.query.LLMClient", return_value=llm), \
|
||||
patch("app.routers.query.RAGService") as MockRAG, \
|
||||
patch("app.routers.query.QueryDecomposer") as MockDec, \
|
||||
patch("app.routers.query.RelevanceFilter") as MockFilter, \
|
||||
patch("app.routers.query.HistoryService") as MockHist, \
|
||||
patch("app.routers.query._schedule_history"):
|
||||
|
||||
dec = MockDec.return_value
|
||||
dec.decompose = AsyncMock(return_value=(
|
||||
["What is X?"],
|
||||
"decompose-prompt"
|
||||
))
|
||||
|
||||
rag = MockRAG.return_value
|
||||
rag.retrieve_per_subquestion.return_value = [
|
||||
("What is X?", [CHUNK_A]),
|
||||
]
|
||||
rag.generate_response_per_subquestion = AsyncMock(return_value=(
|
||||
generate_resp,
|
||||
"gen-prompt",
|
||||
[[CHUNK_A[1]]],
|
||||
))
|
||||
|
||||
filt = MockFilter.return_value
|
||||
filt.filter_per_subquestion = AsyncMock(return_value=(
|
||||
[("What is X?", [(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5})])],
|
||||
"filter-prompt"
|
||||
))
|
||||
|
||||
events = await _collect_sse(request)
|
||||
client = TestClient(app)
|
||||
events = _collect_sse(client, "What is X?")
|
||||
|
||||
comp_evt = next(e for e in events if e["phase"] == "completed")
|
||||
assert "## Sub-question 1:" in comp_evt["answer"]
|
||||
assert len(comp_evt["sub_question_sources"]) == 1
|
||||
|
||||
|
||||
async def test_pipeline_filter_all_rejected():
|
||||
def test_pipeline_filter_all_rejected(tmp_path, monkeypatch):
|
||||
"""All chunks score below threshold — returns 'no relevant information'."""
|
||||
llm = MagicMock()
|
||||
ps = _make_prompt_service()
|
||||
settings = _make_settings()
|
||||
_setup_env(tmp_path, monkeypatch)
|
||||
|
||||
request = QueryRequest(question="Irrelevant question?")
|
||||
decompose_resp = '["sub-q-1"]'
|
||||
# Both chunks score below threshold 7.0
|
||||
filter_resp = '{"0": [2.0, 3.0]}'
|
||||
|
||||
with patch("app.routers.query.get_settings", return_value=settings), \
|
||||
patch("app.routers.query.PromptService", return_value=ps), \
|
||||
patch("app.routers.query.LLMClient", return_value=llm), \
|
||||
patch("app.routers.query.RAGService") as MockRAG, \
|
||||
patch("app.routers.query.QueryDecomposer") as MockDec, \
|
||||
patch("app.routers.query.RelevanceFilter") as MockFilter, \
|
||||
patch("app.routers.query.HistoryService") as MockHist, \
|
||||
patch("app.routers.query._schedule_history"):
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
_make_mock_llm_class([decompose_resp, filter_resp]),
|
||||
)
|
||||
|
||||
dec = MockDec.return_value
|
||||
dec.decompose = AsyncMock(return_value=(
|
||||
["sub-q-1"],
|
||||
"decompose-prompt"
|
||||
))
|
||||
from app.main import app
|
||||
|
||||
rag = MockRAG.return_value
|
||||
rag.retrieve_per_subquestion.return_value = [
|
||||
("sub-q-1", [CHUNK_A]),
|
||||
]
|
||||
|
||||
# All chunks filtered out
|
||||
filt = MockFilter.return_value
|
||||
filt.filter_per_subquestion = AsyncMock(return_value=(
|
||||
[("sub-q-1", [])],
|
||||
"filter-prompt"
|
||||
))
|
||||
|
||||
events = await _collect_sse(request)
|
||||
client = TestClient(app)
|
||||
events = _collect_sse(client, "Irrelevant question?")
|
||||
|
||||
comp_evt = next(e for e in events if e["phase"] == "completed")
|
||||
assert "could not find" in comp_evt["answer"].lower()
|
||||
assert comp_evt["sources"] == []
|
||||
|
||||
|
||||
async def test_pipeline_retrieval_empty_for_one_subq():
|
||||
"""One sub-q gets chunks, another gets nothing — partial answer produced."""
|
||||
def test_pipeline_retrieval_empty_for_one_subq(tmp_path, monkeypatch):
|
||||
"""One sub-q's chunks all filtered out — partial answer produced.
|
||||
|
||||
With real ChromaDB (constant embedding), both sub-queries retrieve all
|
||||
seed documents. The filter mock rejects all chunks for sub-q 1 while
|
||||
keeping sub-q 0's chunks, producing a partial answer.
|
||||
"""
|
||||
_setup_env(tmp_path, monkeypatch)
|
||||
|
||||
decompose_resp = '["Has chunks?", "No chunks?"]'
|
||||
# sub-q 0 keeps all (above threshold), sub-q 1 rejects all (below)
|
||||
filter_resp = '{"0": [8.5, 9.0], "1": [2.0, 3.0]}'
|
||||
generate_resp = (
|
||||
"## Sub-question 1: Has chunks?\n"
|
||||
"- Yes [NEC4.pdf, page 3]\n\n"
|
||||
|
|
@ -350,56 +308,21 @@ async def test_pipeline_retrieval_empty_for_one_subq():
|
|||
"- No relevant information found.\n"
|
||||
)
|
||||
|
||||
llm = MagicMock()
|
||||
ps = _make_prompt_service()
|
||||
settings = _make_settings()
|
||||
monkeypatch.setattr(
|
||||
"app.routers.query.LLMClient",
|
||||
_make_mock_llm_class([decompose_resp, filter_resp, generate_resp]),
|
||||
)
|
||||
|
||||
request = QueryRequest(question="Compare two things")
|
||||
from app.main import app
|
||||
|
||||
with patch("app.routers.query.get_settings", return_value=settings), \
|
||||
patch("app.routers.query.PromptService", return_value=ps), \
|
||||
patch("app.routers.query.LLMClient", return_value=llm), \
|
||||
patch("app.routers.query.RAGService") as MockRAG, \
|
||||
patch("app.routers.query.QueryDecomposer") as MockDec, \
|
||||
patch("app.routers.query.RelevanceFilter") as MockFilter, \
|
||||
patch("app.routers.query.HistoryService") as MockHist, \
|
||||
patch("app.routers.query._schedule_history"):
|
||||
|
||||
dec = MockDec.return_value
|
||||
dec.decompose = AsyncMock(return_value=(
|
||||
["Has chunks?", "No chunks?"],
|
||||
"decompose-prompt"
|
||||
))
|
||||
|
||||
rag = MockRAG.return_value
|
||||
# First sub-q has chunks, second has none
|
||||
rag.retrieve_per_subquestion.return_value = [
|
||||
("Has chunks?", [CHUNK_A]),
|
||||
("No chunks?", []),
|
||||
]
|
||||
rag.generate_response_per_subquestion = AsyncMock(return_value=(
|
||||
generate_resp,
|
||||
"gen-prompt",
|
||||
[[CHUNK_A[1]], []],
|
||||
))
|
||||
|
||||
filt = MockFilter.return_value
|
||||
filt.filter_per_subquestion = AsyncMock(return_value=(
|
||||
[
|
||||
("Has chunks?", [(CHUNK_A[0], {**CHUNK_A[1], "relevance_score": 8.5})]),
|
||||
("No chunks?", []),
|
||||
],
|
||||
"filter-prompt"
|
||||
))
|
||||
|
||||
events = await _collect_sse(request)
|
||||
client = TestClient(app)
|
||||
events = _collect_sse(client, "Compare two things")
|
||||
|
||||
comp_evt = next(e for e in events if e["phase"] == "completed")
|
||||
assert "## Sub-question 1:" in comp_evt["answer"]
|
||||
assert "## Sub-question 2:" in comp_evt["answer"]
|
||||
|
||||
# sub_question_sources has 2 entries
|
||||
sq_sources = comp_evt["sub_question_sources"]
|
||||
assert len(sq_sources) == 2
|
||||
assert len(sq_sources[0]["sources"]) > 0 # first sub-q has sources
|
||||
assert len(sq_sources[1]["sources"]) == 0 # second sub-q has no sources
|
||||
assert len(sq_sources[0]["sources"]) > 0
|
||||
assert len(sq_sources[1]["sources"]) == 0
|
||||
|
|
|
|||
|
|
@ -5,23 +5,72 @@ Covers per-sub-question chunk filtering in a single LLM call:
|
|||
- Empty inputs and edge cases
|
||||
- Invalid JSON / score-count mismatch error handling
|
||||
- Threshold boundary behaviour (strict >)
|
||||
|
||||
Uses real PromptService (SQLite via tmp_path) and only mocks the external LLM API.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sqlite3
|
||||
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
import pytest
|
||||
|
||||
from app.core.sqlite_db import init_prompts_db, seed_default_profiles
|
||||
from app.services.prompt_service import PromptService
|
||||
|
||||
|
||||
class _MockLLM:
|
||||
"""Mock external LLM API."""
|
||||
|
||||
def __init__(self, response: str = "[]", side_effect: Exception | None = None):
|
||||
self._response = response
|
||||
self._side_effect = side_effect
|
||||
self.last_prompt: str | None = None
|
||||
self.calls: list[dict] = []
|
||||
self._call_count: int = 0
|
||||
|
||||
async def complete(
|
||||
self, prompt: str, temperature: float = 0.7, step_name: str = "LLM"
|
||||
) -> str:
|
||||
self.calls.append({"prompt": prompt, "step": step_name})
|
||||
self.last_prompt = prompt
|
||||
self._call_count += 1
|
||||
if self._side_effect:
|
||||
raise self._side_effect
|
||||
return self._response
|
||||
|
||||
@property
|
||||
def call_count(self) -> int:
|
||||
return self._call_count
|
||||
|
||||
def assert_called(self):
|
||||
assert self._call_count > 0, "LLM.complete was not called"
|
||||
|
||||
def assert_not_called(self):
|
||||
assert self._call_count == 0, f"LLM.complete was called {self._call_count} time(s)"
|
||||
|
||||
|
||||
def _create_prompt_service(tmp_path) -> PromptService:
|
||||
"""Create a real PromptService backed by real SQLite with seed data."""
|
||||
db_path = str(tmp_path / "prompts.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
init_prompts_db(conn)
|
||||
seed_default_profiles(conn)
|
||||
conn.close()
|
||||
return PromptService(db_path=db_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: basic per-sub-question filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_basic(mock_prompt_service):
|
||||
async def test_filter_per_subq_basic(tmp_path):
|
||||
"""Two sub-questions, LLM returns per-sub-q scores, threshold filters correctly."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value='{"0": [8.5, 3.2], "1": [9.0]}')
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(response='{"0": [8.5, 3.2], "1": [9.0]}')
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?", "What is B?"],
|
||||
[
|
||||
|
|
@ -31,7 +80,6 @@ async def test_filter_per_subq_basic(mock_prompt_service):
|
|||
threshold=7.0,
|
||||
)
|
||||
|
||||
# Structure check
|
||||
assert len(results) == 2
|
||||
assert results[0][0] == "What is A?"
|
||||
assert results[1][0] == "What is B?"
|
||||
|
|
@ -47,38 +95,42 @@ async def test_filter_per_subq_basic(mock_prompt_service):
|
|||
assert results[1][1][0][0] == "chunk B1"
|
||||
assert results[1][1][0][1]["relevance_score"] == 9.0
|
||||
|
||||
# Prompt contains sub-question labels
|
||||
assert prompt != ""
|
||||
assert "Sub-question 0" in prompt
|
||||
assert "Sub-question 1" in prompt
|
||||
llm.complete.assert_called_once()
|
||||
llm.assert_called()
|
||||
assert llm.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: empty input
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_empty_input(mock_prompt_service):
|
||||
async def test_filter_per_subq_empty_input(tmp_path):
|
||||
"""Empty sub_questions list returns ([], '')."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock()
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM()
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion([], [], threshold=7.0)
|
||||
|
||||
assert results == []
|
||||
assert prompt == ""
|
||||
llm.complete.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: sub-questions with all-empty chunk lists
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_all_empty_chunks(mock_prompt_service):
|
||||
async def test_filter_per_subq_all_empty_chunks(tmp_path):
|
||||
"""Two sub-questions, both with empty chunk lists → empty filtered lists."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock()
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM()
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?", "What is B?"],
|
||||
[[], []],
|
||||
|
|
@ -90,19 +142,20 @@ async def test_filter_per_subq_all_empty_chunks(mock_prompt_service):
|
|||
assert results[0][1] == []
|
||||
assert results[1][0] == "What is B?"
|
||||
assert results[1][1] == []
|
||||
# No LLM call needed when all chunk lists are empty
|
||||
llm.complete.assert_not_called()
|
||||
llm.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: LLM returns invalid JSON
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_llm_returns_invalid_json(mock_prompt_service):
|
||||
async def test_filter_per_subq_llm_returns_invalid_json(tmp_path):
|
||||
"""LLM returns non-JSON string → returns ([], prompt)."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value="not json at all")
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(response="not json at all")
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?"],
|
||||
[[("chunk A1", {"filename": "a.pdf"})]],
|
||||
|
|
@ -116,12 +169,14 @@ async def test_filter_per_subq_llm_returns_invalid_json(mock_prompt_service):
|
|||
# ---------------------------------------------------------------------------
|
||||
# Test: score count mismatch
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_score_count_mismatch(mock_prompt_service):
|
||||
async def test_filter_per_subq_score_count_mismatch(tmp_path):
|
||||
"""Sub-q 0 has 2 chunks but LLM returns only 1 score → returns ([], prompt)."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value='{"0": [8.5]}')
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(response='{"0": [8.5]}')
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?"],
|
||||
[[("chunk A1", {"filename": "a.pdf"}), ("chunk A2", {"filename": "a2.pdf"})]],
|
||||
|
|
@ -135,13 +190,14 @@ async def test_filter_per_subq_score_count_mismatch(mock_prompt_service):
|
|||
# ---------------------------------------------------------------------------
|
||||
# Test: strict threshold boundary
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_passes_threshold_correctly(mock_prompt_service):
|
||||
async def test_filter_per_subq_passes_threshold_correctly(tmp_path):
|
||||
"""Score == threshold is NOT kept (strict >). Score > threshold IS kept."""
|
||||
llm = MagicMock()
|
||||
# Sub-q 0: scores [7.0, 7.1] with threshold 7.0 → only 7.1 kept
|
||||
llm.complete = AsyncMock(return_value='{"0": [7.0, 7.1]}')
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(response='{"0": [7.0, 7.1]}')
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["Boundary test?"],
|
||||
[[("exact threshold", {"filename": "f1.pdf"}), ("above threshold", {"filename": "f2.pdf"})]],
|
||||
|
|
@ -157,12 +213,14 @@ async def test_filter_per_subq_passes_threshold_correctly(mock_prompt_service):
|
|||
# ---------------------------------------------------------------------------
|
||||
# Test: LLM exception
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_llm_exception(mock_prompt_service):
|
||||
async def test_filter_per_subq_llm_exception(tmp_path):
|
||||
"""LLM call raises an exception → returns ([], '')."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(side_effect=RuntimeError("LLM unavailable"))
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?"],
|
||||
[[("chunk A1", {"filename": "a.pdf"})]],
|
||||
|
|
@ -176,12 +234,14 @@ async def test_filter_per_subq_llm_exception(mock_prompt_service):
|
|||
# ---------------------------------------------------------------------------
|
||||
# Test: JSON wrapped in markdown code block
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_json_in_markdown_code_block(mock_prompt_service):
|
||||
async def test_filter_per_subq_json_in_markdown_code_block(tmp_path):
|
||||
"""LLM returns JSON inside ```json ... ``` block → should parse correctly."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value='```json\n{"0": [9.0]}\n```')
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(response='```json\n{"0": [9.0]}\n```')
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?"],
|
||||
[[("chunk A1", {"filename": "a.pdf"})]],
|
||||
|
|
@ -196,12 +256,14 @@ async def test_filter_per_subq_json_in_markdown_code_block(mock_prompt_service):
|
|||
# ---------------------------------------------------------------------------
|
||||
# Test: mixed empty and non-empty sub-questions
|
||||
# ---------------------------------------------------------------------------
|
||||
async def test_filter_per_subq_mixed_empty_and_nonempty(mock_prompt_service):
|
||||
async def test_filter_per_subq_mixed_empty_and_nonempty(tmp_path):
|
||||
"""One sub-q with chunks, one without. Only non-empty ones get scored."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value='{"0": [8.5]}')
|
||||
from app.services.relevance_filter import RelevanceFilter
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=mock_prompt_service)
|
||||
llm = _MockLLM(response='{"0": [8.5]}')
|
||||
ps = _create_prompt_service(tmp_path)
|
||||
|
||||
rf = RelevanceFilter(llm, prompt_service=ps)
|
||||
results, prompt = await rf.filter_per_subquestion(
|
||||
["What is A?", "What is B?"],
|
||||
[[("chunk A1", {"filename": "a.pdf"})], []],
|
||||
|
|
|
|||
|
|
@ -5,28 +5,53 @@ Covers answer format invariants:
|
|||
- Citation bracket labels in answer text
|
||||
- grouped_sources match sub-question boundaries
|
||||
- Single sub-question still uses header format
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.services.rag import RAGService
|
||||
Uses real ChromaDB (tmp_path) and only mocks the external LLM API.
|
||||
"""
|
||||
import chromadb
|
||||
import pytest
|
||||
|
||||
|
||||
class _MockLLM:
|
||||
"""Mock external LLM API."""
|
||||
|
||||
def __init__(self, response: str = "mock answer"):
|
||||
self._response = response
|
||||
self.last_prompt: str | None = None
|
||||
self._call_count: int = 0
|
||||
|
||||
async def complete(
|
||||
self, prompt: str, temperature: float = 0.7, step_name: str = "LLM"
|
||||
) -> str:
|
||||
self.last_prompt = prompt
|
||||
self._call_count += 1
|
||||
return self._response
|
||||
|
||||
|
||||
def _setup_chroma(tmp_path):
|
||||
"""Create an isolated real ChromaDB PersistentClient."""
|
||||
chroma_dir = tmp_path / "chroma"
|
||||
chroma_dir.mkdir(parents=True, exist_ok=True)
|
||||
return chromadb.PersistentClient(path=str(chroma_dir))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test: answer has sub-question headers
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_answer_has_subquestion_headers():
|
||||
async def test_answer_has_subquestion_headers(tmp_path):
|
||||
"""Answer string contains ## Sub-question N: headers."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value=(
|
||||
from app.services.rag import RAGService
|
||||
|
||||
llm = _MockLLM(response=(
|
||||
"## Sub-question 1: First question?\n"
|
||||
"- Point one [doc.pdf, page 1]\n\n"
|
||||
"## Sub-question 2: Second question?\n"
|
||||
"- Point two [doc.pdf, page 2]\n"
|
||||
))
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
answer, _prompt, _sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["First question?", "Second question?"],
|
||||
sub_chunks=[["chunk1"], ["chunk2"]],
|
||||
|
|
@ -44,15 +69,17 @@ async def test_answer_has_subquestion_headers():
|
|||
# Test: citations use bracket labels
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_answer_citations_use_bracket_labels():
|
||||
async def test_answer_citations_use_bracket_labels(tmp_path):
|
||||
"""Answer contains [filename, page N] citation format."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value=(
|
||||
from app.services.rag import RAGService
|
||||
|
||||
llm = _MockLLM(response=(
|
||||
"## Sub-question 1: What is X?\n"
|
||||
"- X is defined as a variable [report.pdf, page 5]\n"
|
||||
))
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
answer, _prompt, _sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["What is X?"],
|
||||
sub_chunks=[["chunk about X"]],
|
||||
|
|
@ -66,14 +93,16 @@ async def test_answer_citations_use_bracket_labels():
|
|||
# Test: grouped_sources match sub-question boundaries
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_grouped_sources_match_subquestions():
|
||||
async def test_grouped_sources_match_subquestions(tmp_path):
|
||||
"""Each sub-question's source list only contains metadata from its own chunks."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value=(
|
||||
from app.services.rag import RAGService
|
||||
|
||||
llm = _MockLLM(response=(
|
||||
"## Sub-question 1: Q1?\n- A1\n\n## Sub-question 2: Q2?\n- A2\n"
|
||||
))
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
_answer, _prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["Q1?", "Q2?"],
|
||||
sub_chunks=[
|
||||
|
|
@ -92,10 +121,8 @@ async def test_grouped_sources_match_subquestions():
|
|||
)
|
||||
|
||||
assert len(grouped_sources) == 2
|
||||
# Sub-q 0 sources should only contain alpha and beta
|
||||
filenames_0 = {m["filename"] for m in grouped_sources[0]}
|
||||
assert filenames_0 == {"alpha.pdf", "beta.pdf"}
|
||||
# Sub-q 1 sources should only contain gamma
|
||||
filenames_1 = {m["filename"] for m in grouped_sources[1]}
|
||||
assert filenames_1 == {"gamma.pdf"}
|
||||
|
||||
|
|
@ -104,15 +131,17 @@ async def test_grouped_sources_match_subquestions():
|
|||
# Test: single sub-question still uses header format
|
||||
# ---------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_subquestion_format():
|
||||
async def test_single_subquestion_format(tmp_path):
|
||||
"""When only one sub-question, answer still uses ## Sub-question 1: header."""
|
||||
llm = MagicMock()
|
||||
llm.complete = AsyncMock(return_value=(
|
||||
from app.services.rag import RAGService
|
||||
|
||||
llm = _MockLLM(response=(
|
||||
"## Sub-question 1: What is this?\n"
|
||||
"- It is a test [test.pdf, page 1]\n"
|
||||
))
|
||||
client = _setup_chroma(tmp_path)
|
||||
|
||||
service = RAGService(llm_client=llm)
|
||||
service = RAGService(chroma_client=client, llm_client=llm)
|
||||
answer, _prompt, grouped_sources = await service.generate_response_per_subquestion(
|
||||
sub_questions=["What is this?"],
|
||||
sub_chunks=[["test chunk"]],
|
||||
|
|
|
|||
|
|
@ -7,126 +7,156 @@ Covers:
|
|||
- Verify retrieve() is called once per sub-question
|
||||
- n_results parameter passthrough
|
||||
- Handling of empty results for individual sub-questions
|
||||
|
||||
All tests use real ChromaDB via tmp_path. No mocks for DB or internal services.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
import chromadb
|
||||
|
||||
from app.services.rag import RAGService
|
||||
|
||||
|
||||
def _setup_chroma(tmp_path, monkeypatch, collection_name: str = "documents"):
|
||||
from app.core.config import get_settings
|
||||
|
||||
monkeypatch.setenv("CHROMA_DB_PATH", str(tmp_path / "test_chroma"))
|
||||
get_settings.cache_clear()
|
||||
|
||||
client = chromadb.PersistentClient(path=str(tmp_path / "test_chroma"))
|
||||
collection = client.get_or_create_collection(name=collection_name)
|
||||
return client, collection
|
||||
|
||||
|
||||
class TestRetrievePerSubquestion:
|
||||
"""Tests for RAGService.retrieve_per_subquestion()."""
|
||||
|
||||
@staticmethod
|
||||
def _make_service() -> RAGService:
|
||||
"""Create a RAGService with a mocked collection."""
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
service = RAGService(chroma_client=mock_client)
|
||||
service._collection = mock_collection
|
||||
return service
|
||||
|
||||
def test_retrieve_per_subquestion_two_subqs(self):
|
||||
def test_retrieve_per_subquestion_two_subqs(self, tmp_path, monkeypatch):
|
||||
"""Two sub-questions should each return their own chunks."""
|
||||
service = self._make_service()
|
||||
service._collection.query.side_effect = [
|
||||
{
|
||||
"documents": [["chunk A1", "chunk A2"]],
|
||||
"metadatas": [[{"filename": "a.pdf"}, {"filename": "a2.pdf"}]],
|
||||
"distances": [[0.1, 0.2]],
|
||||
},
|
||||
{
|
||||
"documents": [["chunk B1"]],
|
||||
"metadatas": [[{"filename": "b.pdf"}]],
|
||||
"distances": [[0.3]],
|
||||
},
|
||||
]
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
collection.add(
|
||||
documents=["Alpha content about quantum physics", "Alpha extra about quantum physics"],
|
||||
metadatas=[{"filename": "a.pdf"}, {"filename": "a2.pdf"}],
|
||||
ids=["a1", "a2"],
|
||||
)
|
||||
collection.add(
|
||||
documents=["Beta content about machine learning"],
|
||||
metadatas=[{"filename": "b.pdf"}],
|
||||
ids=["b1"],
|
||||
)
|
||||
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve_per_subquestion(
|
||||
["What is A?", "What is B?"], n_results=5
|
||||
["quantum physics", "machine learning"], n_results=5
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
assert results[0][0] == "What is A?"
|
||||
assert len(results[0][1]) == 2
|
||||
assert results[0][1][0] == ("chunk A1", {"filename": "a.pdf"}, 0.1)
|
||||
assert results[0][1][1] == ("chunk A2", {"filename": "a2.pdf"}, 0.2)
|
||||
assert results[0][0] == "quantum physics"
|
||||
assert len(results[0][1]) >= 1
|
||||
assert "quantum" in results[0][1][0][0].lower()
|
||||
|
||||
assert results[1][0] == "What is B?"
|
||||
assert len(results[1][1]) == 1
|
||||
assert results[1][1][0] == ("chunk B1", {"filename": "b.pdf"}, 0.3)
|
||||
assert results[1][0] == "machine learning"
|
||||
assert len(results[1][1]) >= 1
|
||||
assert "machine learning" in results[1][1][0][0].lower()
|
||||
|
||||
def test_retrieve_per_subquestion_empty_list(self):
|
||||
def test_retrieve_per_subquestion_empty_list(self, tmp_path, monkeypatch):
|
||||
"""Empty sub_questions list returns empty list."""
|
||||
service = self._make_service()
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
service = RAGService(chroma_client=client)
|
||||
|
||||
results = service.retrieve_per_subquestion([], n_results=10)
|
||||
assert results == []
|
||||
|
||||
def test_retrieve_per_subquestion_single_subq(self):
|
||||
def test_retrieve_per_subquestion_single_subq(self, tmp_path, monkeypatch):
|
||||
"""Single sub-question returns a single-element result list."""
|
||||
service = self._make_service()
|
||||
service._collection.query.return_value = {
|
||||
"documents": [["chunk X"]],
|
||||
"metadatas": [[{"filename": "x.pdf"}]],
|
||||
"distances": [[0.05]],
|
||||
}
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
results = service.retrieve_per_subquestion(["Only question"], n_results=3)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == "Only question"
|
||||
assert len(results[0][1]) == 1
|
||||
assert results[0][1][0] == ("chunk X", {"filename": "x.pdf"}, 0.05)
|
||||
|
||||
def test_retrieve_per_subquestion_calls_retrieve_n_times(self):
|
||||
"""retrieve() should be called once per sub-question with correct args."""
|
||||
service = self._make_service()
|
||||
|
||||
# Mock retrieve to return empty chunks so we can spy on calls
|
||||
service.retrieve = MagicMock(return_value=[])
|
||||
|
||||
sub_questions = ["Q1", "Q2", "Q3"]
|
||||
service.retrieve_per_subquestion(sub_questions, n_results=7)
|
||||
|
||||
assert service.retrieve.call_count == 3
|
||||
service.retrieve.assert_any_call(["Q1"], n_results=7)
|
||||
service.retrieve.assert_any_call(["Q2"], n_results=7)
|
||||
service.retrieve.assert_any_call(["Q3"], n_results=7)
|
||||
|
||||
def test_retrieve_per_subquestion_preserves_n_results(self):
|
||||
"""n_results parameter is passed through to each retrieve() call."""
|
||||
service = self._make_service()
|
||||
service.retrieve = MagicMock(return_value=[])
|
||||
|
||||
service.retrieve_per_subquestion(["Q1"], n_results=42)
|
||||
|
||||
service.retrieve.assert_called_once_with(["Q1"], n_results=42)
|
||||
|
||||
def test_retrieve_per_subquestion_handles_empty_results(self):
|
||||
"""One sub-q returns no results, another returns results."""
|
||||
service = self._make_service()
|
||||
|
||||
# First call returns empty, second returns data
|
||||
service.retrieve = MagicMock(
|
||||
side_effect=[
|
||||
[],
|
||||
[("chunk B", {"filename": "b.pdf"}, 0.2)],
|
||||
]
|
||||
collection.add(
|
||||
documents=["Unique content about solar energy"],
|
||||
metadatas=[{"filename": "x.pdf"}],
|
||||
ids=["x1"],
|
||||
)
|
||||
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve_per_subquestion(["solar energy"], n_results=3)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == "solar energy"
|
||||
assert len(results[0][1]) >= 1
|
||||
assert "solar" in results[0][1][0][0].lower()
|
||||
assert results[0][1][0][1]["filename"] == "x.pdf"
|
||||
|
||||
def test_retrieve_per_subquestion_calls_retrieve_n_times(self, tmp_path, monkeypatch):
|
||||
"""retrieve() should be called once per sub-question."""
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
service = RAGService(chroma_client=client)
|
||||
|
||||
call_log: list[tuple] = []
|
||||
original_retrieve = service.retrieve
|
||||
|
||||
def _tracking_retrieve(query_keywords, n_results=10):
|
||||
call_log.append((query_keywords, n_results))
|
||||
return original_retrieve(query_keywords, n_results=n_results)
|
||||
|
||||
service.retrieve = _tracking_retrieve
|
||||
|
||||
sub_questions = ["apples", "bananas", "cherries"]
|
||||
service.retrieve_per_subquestion(sub_questions, n_results=7)
|
||||
|
||||
assert len(call_log) == 3
|
||||
for i, sub_q in enumerate(sub_questions):
|
||||
assert call_log[i][0] == [sub_q]
|
||||
assert call_log[i][1] == 7
|
||||
|
||||
def test_retrieve_per_subquestion_preserves_n_results(self, tmp_path, monkeypatch):
|
||||
"""n_results parameter is passed through to each retrieve() call."""
|
||||
client, _ = _setup_chroma(tmp_path, monkeypatch)
|
||||
service = RAGService(chroma_client=client)
|
||||
|
||||
captured_n_results: list[int] = []
|
||||
original_retrieve = service.retrieve
|
||||
|
||||
def _capture_retrieve(query_keywords, n_results=10):
|
||||
captured_n_results.append(n_results)
|
||||
return original_retrieve(query_keywords, n_results=n_results)
|
||||
|
||||
service.retrieve = _capture_retrieve
|
||||
|
||||
service.retrieve_per_subquestion(["test query"], n_results=42)
|
||||
|
||||
assert captured_n_results == [42]
|
||||
|
||||
def test_retrieve_per_subquestion_handles_empty_results(self, tmp_path, monkeypatch):
|
||||
"""One sub-q returns fewer/no results, another returns results."""
|
||||
client, collection = _setup_chroma(tmp_path, monkeypatch)
|
||||
|
||||
collection.add(
|
||||
documents=[
|
||||
"Deep dive into quantum entanglement experiments",
|
||||
"Quantum entanglement and Bell's theorem",
|
||||
"History of Renaissance art in Florence",
|
||||
],
|
||||
metadatas=[
|
||||
{"filename": "b.pdf"},
|
||||
{"filename": "b2.pdf"},
|
||||
{"filename": "art.pdf"},
|
||||
],
|
||||
ids=["b1", "b2", "a1"],
|
||||
)
|
||||
|
||||
service = RAGService(chroma_client=client)
|
||||
results = service.retrieve_per_subquestion(
|
||||
["No results Q", "Has results Q"], n_results=5
|
||||
["Renaissance art Florence", "quantum entanglement"], n_results=2
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# First sub-question has empty chunks
|
||||
assert results[0][0] == "No results Q"
|
||||
assert results[0][1] == []
|
||||
assert results[0][0] == "Renaissance art Florence"
|
||||
assert results[1][0] == "quantum entanglement"
|
||||
|
||||
# Second sub-question has chunks
|
||||
assert results[1][0] == "Has results Q"
|
||||
assert len(results[1][1]) == 1
|
||||
assert results[1][1][0] == ("chunk B", {"filename": "b.pdf"}, 0.2)
|
||||
assert len(results[1][1]) >= 1
|
||||
assert "quantum" in results[1][1][0][0].lower()
|
||||
|
||||
assert len(results[0][1]) >= 1
|
||||
assert "Renaissance" in results[0][1][0][0]
|
||||
|
|
|
|||
|
|
@ -67,10 +67,14 @@ def extract_metadata(
|
|||
"upload_date": upload_date,
|
||||
"content_summary": content_summary,
|
||||
"chunk_index": idx,
|
||||
"page_number": page_numbers[idx] if page_numbers else None,
|
||||
"chunk_file_path": chunk_file_paths[idx] if chunk_file_paths else None,
|
||||
"document_id": document_id,
|
||||
}
|
||||
page_num = page_numbers[idx] if page_numbers else None
|
||||
if page_num is not None:
|
||||
entry["page_number"] = page_num
|
||||
cfp = chunk_file_paths[idx] if chunk_file_paths else None
|
||||
if cfp is not None:
|
||||
entry["chunk_file_path"] = cfp
|
||||
metadata.append(entry)
|
||||
|
||||
return metadata
|
||||
|
|
|
|||
Loading…
Reference in New Issue