289 lines
10 KiB
Python
289 lines
10 KiB
Python
"""Phase 1 tests: RAG query endpoint.
|
|
|
|
Covers:
|
|
- POST /api/v1/query with SSE stream response
|
|
- Strict RAG prompt enforcement (only use retrieved context)
|
|
- Source metadata inclusion in SSE events
|
|
- 422 on missing question field
|
|
|
|
Uses TestClient + real ChromaDB + real SQLite. Only the LLMClient
|
|
(external API) is mocked with controlled responses.
|
|
"""
|
|
import json
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.core.sqlite_db import (
|
|
_get_db,
|
|
init_history_db,
|
|
init_prompts_db,
|
|
seed_default_profiles,
|
|
)
|
|
from app.routers.query import router
|
|
|
|
|
|
# ── Helpers ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class _MockLLMClient:
|
|
"""Deterministic LLM mock returning controlled responses for each pipeline step.
|
|
|
|
Step sequence:
|
|
1. decompose → JSON array of sub-questions
|
|
2. filter → JSON object mapping sub-q indices to relevance scores
|
|
3. generate → markdown bullet-point answer
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._call_count = 0
|
|
|
|
async def complete(self, prompt: str, temperature: float = 0.7,
|
|
step_name: str = "LLM") -> str:
|
|
self._call_count += 1
|
|
if step_name == "QueryDecomposer":
|
|
return json.dumps(["test sub-question"])
|
|
if step_name == "RelevanceFilter":
|
|
return json.dumps({"0": [8.0, 7.5]})
|
|
return "- Bullet point answer\n- Another point"
|
|
|
|
|
|
class _MockLLMClientNoChunks:
|
|
"""LLM mock that returns decomposition but no relevant chunks survive filter."""
|
|
|
|
async def complete(self, prompt: str, temperature: float = 0.7,
|
|
step_name: str = "LLM") -> str:
|
|
if step_name == "QueryDecomposer":
|
|
return json.dumps(["test"])
|
|
if step_name == "RelevanceFilter":
|
|
return json.dumps({"0": [2.0, 1.5]})
|
|
return "I could not find any relevant information."
|
|
|
|
|
|
class _DeterministicEmbedding:
|
|
"""Lightweight embedding function that returns deterministic vectors.
|
|
|
|
Uses a simple hash-based approach to produce consistent 384-dim vectors
|
|
for any input text. No external API calls.
|
|
"""
|
|
|
|
def name(self) -> str:
|
|
return "test_deterministic"
|
|
|
|
def __call__(self, input):
|
|
return self._embed(input)
|
|
|
|
def embed_query(self, input):
|
|
return self._embed(input)
|
|
|
|
@staticmethod
|
|
def _embed(texts):
|
|
vectors = []
|
|
for text in texts:
|
|
vec = [0.0] * 384
|
|
for i, ch in enumerate(text[:384]):
|
|
vec[i] = ord(ch) / 1000.0
|
|
vectors.append(vec)
|
|
return vectors
|
|
|
|
|
|
def _seed_document(chroma_path: str):
|
|
"""Insert a test document into real ChromaDB for retrieval."""
|
|
import chromadb
|
|
from app.core.database import get_or_create_collection
|
|
|
|
client = chromadb.PersistentClient(path=chroma_path)
|
|
collection = get_or_create_collection(client, "documents",
|
|
embedding_function=_DeterministicEmbedding())
|
|
collection.add(
|
|
documents=["chunk one with some text about testing",
|
|
"chunk two with more details about testing"],
|
|
metadatas=[
|
|
{"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00",
|
|
"content_summary": "chunk one", "chunk_index": 0,
|
|
"document_id": "test-doc-123"},
|
|
{"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00",
|
|
"content_summary": "chunk two", "chunk_index": 1,
|
|
"document_id": "test-doc-123"},
|
|
],
|
|
ids=["test-doc-123_0", "test-doc-123_1"],
|
|
)
|
|
|
|
|
|
# ── Fixture ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def client(tmp_path, monkeypatch):
|
|
chroma_path = str(tmp_path / "chroma_db")
|
|
prompts_path = str(tmp_path / "prompts.db")
|
|
history_path = str(tmp_path / "history.db")
|
|
|
|
monkeypatch.setenv("CHROMA_DB_PATH", chroma_path)
|
|
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
|
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
|
monkeypatch.setenv("EMBEDDING_MODEL", "test-mock")
|
|
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
|
|
|
from app.core.config import get_settings
|
|
get_settings.cache_clear()
|
|
from app.core.dependencies import get_settings_cached
|
|
get_settings_cached.cache_clear()
|
|
|
|
# Init real SQLite databases
|
|
conn = _get_db(prompts_path)
|
|
init_prompts_db(conn)
|
|
seed_default_profiles(conn)
|
|
conn.close()
|
|
|
|
hconn = _get_db(history_path)
|
|
init_history_db(hconn)
|
|
hconn.close()
|
|
|
|
# Seed a document into real ChromaDB
|
|
_seed_document(chroma_path)
|
|
|
|
# Mock embedding function at database module level (external API)
|
|
monkeypatch.setattr(
|
|
"app.core.database.get_embedding_function_settings",
|
|
lambda settings: _DeterministicEmbedding(),
|
|
)
|
|
|
|
# Build a minimal FastAPI app with only the query router
|
|
test_app = FastAPI()
|
|
test_app.include_router(router, prefix="/api/v1")
|
|
|
|
yield TestClient(test_app)
|
|
|
|
get_settings_cached.cache_clear()
|
|
get_settings.cache_clear()
|
|
|
|
|
|
# ── Tests ───────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestQuery:
|
|
|
|
@pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON")
|
|
def test_query_returns_bullets(self, client):
|
|
"""Should return bullet-point answer with source metadata."""
|
|
# Left as skip — SSE tests below cover this functionality.
|
|
|
|
@pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON")
|
|
def test_query_no_relevant_chunks(self, client):
|
|
"""Should handle case when no relevant chunks found."""
|
|
# Left as skip — SSE tests below cover this functionality.
|
|
|
|
def test_query_no_question(self, client):
|
|
"""Should reject request without question."""
|
|
response = client.post("/api/v1/query", json={})
|
|
|
|
assert response.status_code == 422
|
|
|
|
def test_query_sse_stream_with_mock_llm(self, client, monkeypatch):
|
|
"""Should return SSE stream with decomposed → retrieving → filtering → generating → completed phases.
|
|
|
|
Uses real ChromaDB (seeded with a document) + real RAGService.
|
|
Only LLMClient is mocked (external API).
|
|
"""
|
|
monkeypatch.setattr(
|
|
"app.routers.query.LLMClient",
|
|
lambda settings: _MockLLMClient(),
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/query",
|
|
json={"question": "What is this about?"},
|
|
headers={"Accept": "text/event-stream"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
# Parse SSE events from the response
|
|
events = []
|
|
for line in response.iter_lines():
|
|
if line.startswith("data: "):
|
|
data = json.loads(line[6:])
|
|
events.append(data)
|
|
|
|
phases = [e["phase"] for e in events]
|
|
assert "decomposed" in phases
|
|
assert "retrieving" in phases
|
|
assert "filtering" in phases
|
|
assert "generating" in phases
|
|
assert "completed" in phases
|
|
|
|
# Find completed event
|
|
completed = next(e for e in events if e["phase"] == "completed")
|
|
assert "- Bullet point answer" in completed["answer"]
|
|
assert len(completed["sources"]) > 0
|
|
assert completed["sources"][0]["filename"] == "test.pdf"
|
|
|
|
def test_query_sse_no_relevant_chunks_after_filter(self, client, monkeypatch):
|
|
"""Should return 'no relevant information' when filter eliminates all chunks.
|
|
|
|
Uses real ChromaDB (seeded) + real RAGService. LLM mock returns
|
|
low relevance scores so all chunks are filtered out.
|
|
"""
|
|
monkeypatch.setattr(
|
|
"app.routers.query.LLMClient",
|
|
lambda settings: _MockLLMClientNoChunks(),
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/query",
|
|
json={"question": "Something completely unrelated"},
|
|
headers={"Accept": "text/event-stream"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
events = []
|
|
for line in response.iter_lines():
|
|
if line.startswith("data: "):
|
|
data = json.loads(line[6:])
|
|
events.append(data)
|
|
|
|
completed = next(e for e in events if e["phase"] == "completed")
|
|
assert "could not find" in completed["answer"].lower()
|
|
assert completed["sources"] == []
|
|
|
|
def test_query_empty_question_returns_400(self, client, monkeypatch):
|
|
"""Should reject empty/whitespace-only question with 400."""
|
|
monkeypatch.setattr(
|
|
"app.routers.query.LLMClient",
|
|
lambda settings: _MockLLMClient(),
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/query",
|
|
json={"question": " "},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
def test_query_sse_decomposed_event_has_questions(self, client, monkeypatch):
|
|
"""Decomposed SSE event should contain extracted sub-questions from LLM."""
|
|
monkeypatch.setattr(
|
|
"app.routers.query.LLMClient",
|
|
lambda settings: _MockLLMClient(),
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/query",
|
|
json={"question": "What is testing?"},
|
|
)
|
|
|
|
events = []
|
|
for line in response.iter_lines():
|
|
if line.startswith("data: "):
|
|
data = json.loads(line[6:])
|
|
events.append(data)
|
|
|
|
decomposed = next(e for e in events if e["phase"] == "decomposed")
|
|
assert "extracted_questions" in decomposed
|
|
assert isinstance(decomposed["extracted_questions"], list)
|
|
assert len(decomposed["extracted_questions"]) > 0
|