legco_ai_assistant/backend/app/test/test_phase1_query.py

313 lines
11 KiB
Python

"""Phase 1 tests: RAG query endpoint.
Covers:
- POST /api/v1/query with SSE stream response
- Strict RAG prompt enforcement (only use retrieved context)
- Source metadata inclusion in SSE events
- 422 on missing question field
Uses TestClient + real ChromaDB + real SQLite. Only the LLMClient
(external API) is mocked with controlled responses.
"""
import json
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.core.sqlite_db import (
_get_db,
init_history_db,
init_prompts_db,
seed_default_profiles,
)
from app.routers.query import router
# ── Helpers ─────────────────────────────────────────────────────────────────
class _MockLLMClient:
"""Deterministic LLM mock returning controlled responses for each pipeline step.
Step sequence:
1. decompose → JSON array of sub-questions
2. filter → JSON object mapping sub-q indices to relevance scores
3. generate → markdown bullet-point answer
"""
def __init__(self):
self._call_count = 0
async def complete(self, prompt: str, temperature: float = 0.7,
step_name: str = "LLM") -> str:
self._call_count += 1
if step_name == "QueryDecomposer":
return json.dumps(["test sub-question"])
if step_name == "RelevanceFilter":
return json.dumps({"0": [8.0, 7.5]})
return "- Bullet point answer\n- Another point"
async def complete_structured(self, prompt, pydantic_model, step_name="LLM"):
"""Structured output path — raise to trigger legacy fallback."""
raise RuntimeError("structured output not mocked")
class _MockLLMClientNoChunks:
"""LLM mock that returns decomposition but no relevant chunks survive filter."""
async def complete(self, prompt: str, temperature: float = 0.7,
step_name: str = "LLM") -> str:
if step_name == "QueryDecomposer":
return json.dumps(["test"])
if step_name == "RelevanceFilter":
return json.dumps({"0": [2.0, 1.5]})
return "I could not find any relevant information."
async def complete_structured(self, prompt, pydantic_model, step_name="LLM"):
"""Structured output path — raise to trigger legacy fallback."""
raise RuntimeError("structured output not mocked")
class _DeterministicEmbedding:
"""Lightweight embedding function that returns deterministic vectors.
Uses a simple hash-based approach to produce consistent 384-dim vectors
for any input text. No external API calls.
"""
def name(self) -> str:
return "test_deterministic"
def __call__(self, input):
return self._embed(input)
def embed_query(self, input):
return self._embed(input)
@staticmethod
def _embed(texts):
vectors = []
for text in texts:
vec = [0.0] * 384
for i, ch in enumerate(text[:384]):
vec[i] = ord(ch) / 1000.0
vectors.append(vec)
return vectors
def _seed_document(chroma_path: str):
"""Insert a test document into real ChromaDB for retrieval."""
import chromadb
from app.core.database import get_or_create_collection
client = chromadb.PersistentClient(path=chroma_path)
collection = get_or_create_collection(client, "documents",
embedding_function=_DeterministicEmbedding())
collection.add(
documents=["chunk one with some text about testing",
"chunk two with more details about testing"],
metadatas=[
{"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00",
"content_summary": "chunk one", "chunk_index": 0,
"document_id": "test-doc-123"},
{"filename": "test.pdf", "upload_date": "2025-01-01T00:00:00",
"content_summary": "chunk two", "chunk_index": 1,
"document_id": "test-doc-123"},
],
ids=["test-doc-123_0", "test-doc-123_1"],
)
# ── Fixture ─────────────────────────────────────────────────────────────────
@pytest.fixture
def client(tmp_path, monkeypatch):
chroma_path = str(tmp_path / "chroma_db")
prompts_path = str(tmp_path / "prompts.db")
history_path = str(tmp_path / "history.db")
monkeypatch.setenv("CHROMA_DB_PATH", chroma_path)
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
monkeypatch.setenv("EMBEDDING_MODEL", "test-mock")
monkeypatch.setenv("LLM_API_KEY", "test-key")
from app.core.config import get_settings
get_settings.cache_clear()
from app.core.dependencies import get_settings_cached
get_settings_cached.cache_clear()
# Init real SQLite databases
conn = _get_db(prompts_path)
init_prompts_db(conn)
seed_default_profiles(conn)
conn.close()
hconn = _get_db(history_path)
init_history_db(hconn)
hconn.close()
# Seed a document into real ChromaDB
_seed_document(chroma_path)
# Mock embedding function at database module level (external API)
monkeypatch.setattr(
"app.core.database.get_embedding_function_settings",
lambda settings: _DeterministicEmbedding(),
)
# Build a minimal FastAPI app with only the query router
test_app = FastAPI()
test_app.include_router(router, prefix="/api/v1")
yield TestClient(test_app)
get_settings_cached.cache_clear()
get_settings.cache_clear()
# ── Tests ───────────────────────────────────────────────────────────────────
class TestQuery:
@pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON")
def test_query_returns_bullets(self, client):
"""Should return bullet-point answer with source metadata."""
# Left as skip — SSE tests below cover this functionality.
@pytest.mark.skip(reason="Deprecated: endpoint now returns SSE stream, not JSON")
def test_query_no_relevant_chunks(self, client):
"""Should handle case when no relevant chunks found."""
# Left as skip — SSE tests below cover this functionality.
def test_query_no_question(self, client):
"""Should reject request without question."""
response = client.post("/api/v1/query", json={})
assert response.status_code == 422
def test_query_sse_stream_with_mock_llm(self, client, monkeypatch):
"""Should return SSE stream with decomposed → retrieving → filtering → generating → completed phases.
Uses real ChromaDB (seeded with a document) + real RAGService.
Only LLMClient is mocked (external API).
"""
monkeypatch.setattr(
"app.routers.query.LLMClient",
lambda settings: _MockLLMClient(),
)
monkeypatch.setattr(
"app.routers.query.LLMClientDP",
lambda settings: _MockLLMClient(),
)
response = client.post(
"/api/v1/query",
json={"question": "What is this about?"},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# Parse SSE events from the response
events = []
for line in response.iter_lines():
if line.startswith("data: "):
data = json.loads(line[6:])
events.append(data)
phases = [e["phase"] for e in events]
assert "decomposed" in phases
assert "retrieving" in phases
assert "filtering" in phases
assert "generating" in phases
assert "completed" in phases
# Find completed event
completed = next(e for e in events if e["phase"] == "completed")
assert "- Bullet point answer" in completed["answer"]
assert len(completed["sources"]) > 0
assert completed["sources"][0]["filename"] == "test.pdf"
def test_query_sse_no_relevant_chunks_after_filter(self, client, monkeypatch):
"""Should return 'no relevant information' when filter eliminates all chunks.
Uses real ChromaDB (seeded) + real RAGService. LLM mock returns
low relevance scores so all chunks are filtered out.
"""
monkeypatch.setattr(
"app.routers.query.LLMClient",
lambda settings: _MockLLMClientNoChunks(),
)
monkeypatch.setattr(
"app.routers.query.LLMClientDP",
lambda settings: _MockLLMClientNoChunks(),
)
response = client.post(
"/api/v1/query",
json={"question": "Something completely unrelated"},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
events = []
for line in response.iter_lines():
if line.startswith("data: "):
data = json.loads(line[6:])
events.append(data)
completed = next(e for e in events if e["phase"] == "completed")
assert "could not find" in completed["answer"].lower()
assert completed["sources"] == []
def test_query_empty_question_returns_400(self, client, monkeypatch):
"""Should reject empty/whitespace-only question with 400."""
monkeypatch.setattr(
"app.routers.query.LLMClient",
lambda settings: _MockLLMClient(),
)
monkeypatch.setattr(
"app.routers.query.LLMClientDP",
lambda settings: _MockLLMClient(),
)
response = client.post(
"/api/v1/query",
json={"question": " "},
)
assert response.status_code == 400
def test_query_sse_decomposed_event_has_questions(self, client, monkeypatch):
"""Decomposed SSE event should contain extracted sub-questions from LLM."""
monkeypatch.setattr(
"app.routers.query.LLMClient",
lambda settings: _MockLLMClient(),
)
monkeypatch.setattr(
"app.routers.query.LLMClientDP",
lambda settings: _MockLLMClient(),
)
response = client.post(
"/api/v1/query",
json={"question": "What is testing?"},
)
events = []
for line in response.iter_lines():
if line.startswith("data: "):
data = json.loads(line[6:])
events.append(data)
decomposed = next(e for e in events if e["phase"] == "decomposed")
assert "extracted_questions" in decomposed
assert isinstance(decomposed["extracted_questions"], list)
assert len(decomposed["extracted_questions"]) > 0