feat: Phase 1.3 query pipeline with decomposition, relevance filter, and response

- Add QueryDecomposer: extracts keywords from question via LLM JSON response
- Add RelevanceFilter: batch scores chunks 0-10, filters by threshold
- Add POST /api/v1/query endpoint with full 3-step pipeline:
  1. QueryDecomposer.decompose() → keywords
  2. RAGService.retrieve() → chunks from ChromaDB
  3. RelevanceFilter.filter() → score and filter chunks
  4. RAGService.generate_response() → bullet-point answer
- Fix SourceMetadata.upload_date type from datetime to str for flexibility
- Test-first: 13 new tests pass (5 decomposer, 5 relevance filter, 3 query endpoint)
- All Phase 1 tests: 41 passed, 2 skipped
This commit is contained in:
Woody 2026-04-22 17:19:21 +08:00
parent 4d346dc1c6
commit 181f4eca5b
8 changed files with 444 additions and 10 deletions

View File

@ -1,7 +1,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.routers import ingest
from app.routers import ingest, query
app = FastAPI(title="RAG Video Q&A", version="1.0.0")
@ -14,6 +14,7 @@ app.add_middleware(
)
app.include_router(ingest.router, prefix="/api/v1")
app.include_router(query.router, prefix="/api/v1")
@app.get("/health")

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel
class SourceMetadata(BaseModel):
filename: str
upload_date: datetime
upload_date: str
content_summary: str
chunk_index: int

View File

@ -0,0 +1,77 @@
"""Query router for RAG pipeline."""
from fastapi import APIRouter, HTTPException
from app.core.config import get_settings
from app.models.ingest import QueryRequest, QueryResponse, SourceMetadata
from app.services.llm_client import LLMClient
from app.services.query_decomposer import QueryDecomposer
from app.services.relevance_filter import RelevanceFilter
from app.services.rag import RAGService
router = APIRouter(tags=["query"])
@router.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
"""Execute the 3-step RAG query pipeline.
Pipeline:
1. QueryDecomposer: Extract keywords from question
2. RAGService.retrieve: Get relevant chunks from ChromaDB
3. RelevanceFilter: Score and filter chunks by relevance
4. RAGService.generate_response: Generate bullet-point answer
"""
settings = get_settings()
if not request.question or not request.question.strip():
raise HTTPException(status_code=400, detail="Question is required")
try:
llm_client = LLMClient(settings)
decomposer = QueryDecomposer(llm_client)
keywords = decomposer.decompose(request.question)
rag = RAGService(llm_client=llm_client)
chunks = rag.retrieve(keywords, n_results=10)
if not chunks:
return QueryResponse(
keywords=keywords,
answer="I could not find any relevant information to answer your question.",
sources=[],
)
relevance_filter = RelevanceFilter(llm_client)
filtered = relevance_filter.filter(request.question, chunks, threshold=7.0)
if not filtered:
return QueryResponse(
keywords=keywords,
answer="I could not find any relevant information to answer your question.",
sources=[],
)
chunk_texts = [chunk for chunk, _meta in filtered]
chunk_metadata = [meta for _chunk, meta in filtered]
answer = rag.generate_response(request.question, chunk_texts, chunk_metadata)
sources = []
for meta in chunk_metadata:
sources.append(
SourceMetadata(
filename=meta.get("filename", "unknown"),
upload_date=meta.get("upload_date", ""),
content_summary=meta.get("content_summary", ""),
chunk_index=meta.get("chunk_index", 0),
)
)
return QueryResponse(
keywords=keywords,
answer=answer,
sources=sources,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")

View File

@ -0,0 +1,65 @@
"""Query decomposer service.
This module provides a lightweight QueryDecomposer that delegates the
translation of a natural language question into a list of keyword search
terms to an LLM client. The interface is intentionally minimal to support
test-driven development for Phase 1.3.
"""
from __future__ import annotations
import json
from typing import List
class QueryDecomposer:
"""Decompose a natural language question into a list of keywords.
The class expects an object that exposes a ``complete(prompt: str) -> str``
method (an LLM client). The ``decompose`` method builds a prompt, asks the
LLM to return a JSON array of strings, and parses that JSON into a Python
list of strings. Edge cases are handled gracefully.
"""
def __init__(self, llm_client) -> None:
self.llm_client = llm_client
def decompose(self, question: str) -> List[str]:
"""Return a list of keywords extracted for the given question.
Args:
question: The natural language question to decompose.
Returns:
A list of keyword strings. If the LLM response is invalid or the
input is empty, an empty list is returned.
"""
if question is None or question.strip() == "":
return []
prompt = f"Given question: '{question}', extract key search keywords as JSON array"
try:
response = self.llm_client.complete(prompt)
except Exception:
# If the LLM call fails for any reason, defensively return no keywords
return []
if not isinstance(response, str):
response = str(response)
try:
data = json.loads(response)
except json.JSONDecodeError:
# Invalid JSON treat as no keywords
return []
if not isinstance(data, list):
return []
# If all items are strings, return as-is. Otherwise, coerce to strings.
if len(data) == 0:
return []
if all(isinstance(item, str) for item in data):
return data # type: ignore[return-value]
return [str(item) for item in data]

View File

@ -0,0 +1,74 @@
from __future__ import annotations
import json
from typing import List, Tuple, Dict
class RelevanceFilter:
"""RelevanceFilter batches chunk texts to an LLM and selects those with
relevance scores above a threshold.
The constructor expects an llm_client-like object with a `complete(prompt: str, temperature: float = 0.7) -> str` method.
"""
def __init__(self, llm_client):
self.llm_client = llm_client
def _build_prompt(self, question: str, chunks: List[Tuple[str, Dict]]) -> str:
"""Build the single prompt that asks the LLM to score all chunks.
The prompt format is designed to be simple and deterministic for tests:
- Include the question
- List the chunk texts in order
- Ask for a JSON array of scores corresponding to each chunk
"""
texts = [chunk_text for (chunk_text, _meta) in chunks]
# Keep the prompt readable and deterministic
prompt = (
f"Given question '{question}' and these document chunks, rate each 0-10 for relevance. "
f"Return JSON array of scores. Chunks: {texts}"
)
return prompt
def filter(
self, question: str, chunks: List[Tuple[str, Dict]], threshold: float = 7.0
) -> List[Tuple[str, Dict]]:
"""Return only chunks whose relevance score exceeds the threshold.
- Chunks are sent to the LLM in a single batch call.
- Expects the LLM to respond with a JSON array of numbers, in the same
order as the provided chunks.
- If input is empty, returns an empty list.
- If the LLM response cannot be parsed or the length mismatches, returns an empty list.
"""
if not chunks:
return []
prompt = self._build_prompt(question, chunks)
response = self.llm_client.complete(prompt, temperature=0.0)
scores: List[float] = []
try:
parsed = json.loads(response)
if not isinstance(parsed, list):
return []
# Ensure all values are numeric
for v in parsed:
if isinstance(v, (int, float)):
scores.append(float(v))
else:
return []
except Exception:
# Gracefully handle invalid JSON or unexpected formats
return []
if len(scores) != len(chunks):
return []
result: List[Tuple[str, Dict]] = []
for (chunk, meta), score in zip(chunks, scores):
if score > threshold:
result.append((chunk, meta))
return result

View File

@ -7,19 +7,91 @@ Covers:
- Source metadata inclusion
"""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
class TestQuery:
"""RAG query endpoint tests."""
def test_query_returns_bullets(self):
@pytest.fixture
def client(self):
"""Create test client with mocked dependencies."""
from app.main import app
return TestClient(app)
def test_query_returns_bullets(self, client):
"""Should return bullet-point answer with source metadata."""
pass # TODO: implement
with patch("app.routers.query.QueryDecomposer") as mock_decomposer_class:
mock_decomposer = MagicMock()
mock_decomposer.decompose.return_value = ["test", "keywords"]
mock_decomposer_class.return_value = mock_decomposer
def test_query_strict_rag_no_hallucination(self):
"""Should refuse to answer when no relevant context retrieved."""
pass # TODO: implement
with patch("app.routers.query.RAGService") as mock_rag_class:
mock_rag = MagicMock()
mock_rag.retrieve.return_value = [
("chunk one", {"filename": "test.pdf"}, 0.1),
("chunk two", {"filename": "test.pdf"}, 0.2),
]
mock_rag.generate_response.return_value = "- Bullet point answer\n- Another point"
mock_rag_class.return_value = mock_rag
def test_query_includes_source_metadata(self):
"""Should include filename, upload_date in response."""
pass # TODO: implement
with patch("app.routers.query.RelevanceFilter") as mock_filter_class:
mock_filter = MagicMock()
mock_filter.filter.return_value = [
("chunk one", {"filename": "test.pdf"}),
("chunk two", {"filename": "test.pdf"}),
]
mock_filter_class.return_value = mock_filter
response = client.post(
"/api/v1/query",
json={"question": "What is this about?"},
)
assert response.status_code == 200
data = response.json()
assert "keywords" in data
assert data["keywords"] == ["test", "keywords"]
assert "answer" in data
assert "- Bullet point answer" in data["answer"]
assert "sources" in data
assert len(data["sources"]) == 2
assert data["sources"][0]["filename"] == "test.pdf"
def test_query_no_relevant_chunks(self, client):
"""Should handle case when no relevant chunks found."""
with patch("app.routers.query.QueryDecomposer") as mock_decomposer_class:
mock_decomposer = MagicMock()
mock_decomposer.decompose.return_value = ["test"]
mock_decomposer_class.return_value = mock_decomposer
with patch("app.routers.query.RAGService") as mock_rag_class:
mock_rag = MagicMock()
mock_rag.retrieve.return_value = [
("chunk one", {"filename": "test.pdf"}, 0.1),
]
mock_rag.generate_response.return_value = "I could not find any relevant information."
mock_rag_class.return_value = mock_rag
with patch("app.routers.query.RelevanceFilter") as mock_filter_class:
mock_filter = MagicMock()
mock_filter.filter.return_value = []
mock_filter_class.return_value = mock_filter
response = client.post(
"/api/v1/query",
json={"question": "What is this about?"},
)
assert response.status_code == 200
data = response.json()
assert data["keywords"] == ["test"]
assert "could not find" in data["answer"].lower()
assert data["sources"] == []
def test_query_no_question(self, client):
"""Should reject request without question."""
response = client.post("/api/v1/query", json={})
assert response.status_code == 422

View File

@ -0,0 +1,60 @@
"""Tests for the Phase 1.3 QueryDecomposer component."""
import json
from typing import List
import pytest
from app.services.query_decomposer import QueryDecomposer
class MockLLMClient:
"""Simple mock LLM client with a fixed response."""
def __init__(self, response: str):
self._response = response
self.last_prompt = None
def complete(self, prompt: str, temperature: float = 0.7) -> str:
self.last_prompt = prompt
return self._response
def test_decompose_valid_json():
llm = MockLLMClient('["alpha", "beta", "gamma"]')
decomposer = QueryDecomposer(llm)
result: List[str] = decomposer.decompose("What are keywords for X?")
assert result == ["alpha", "beta", "gamma"]
# Ensure the prompt was constructed with the given question
assert llm.last_prompt == "Given question: 'What are keywords for X?', extract key search keywords as JSON array"
def test_decompose_empty_question_returns_empty():
llm = MockLLMClient('["should_not_be_used"]')
decomposer = QueryDecomposer(llm)
result = decomposer.decompose("")
assert result == []
# LLM should not be called for empty input
assert llm.last_prompt is None
def test_decompose_invalid_json_returns_empty():
llm = MockLLMClient("not-json")
decomposer = QueryDecomposer(llm)
result = decomposer.decompose("Question?")
assert result == []
def test_decompose_non_list_json_returns_empty():
llm = MockLLMClient("{\"a\": 1}")
decomposer = QueryDecomposer(llm)
result = decomposer.decompose("Question?")
assert result == []
def test_decompose_mixed_types_coerced_to_strings():
llm = MockLLMClient('["a", 2, null]')
decomposer = QueryDecomposer(llm)
result = decomposer.decompose("Question?")
# Non-string items should be coerced to strings
assert result == ["a", "2", "None"]

View File

@ -0,0 +1,85 @@
import json
import pytest
from unittest.mock import MagicMock
# Import strategy: try standard import first, fallback to path hack if needed.
try:
from app.services.relevance_filter import RelevanceFilter # type: ignore
except Exception:
# Fallback: attempt to load module directly by path to avoid import issues
import sys
from pathlib import Path
path_to_module = Path(__file__).resolve().parents[2] / 'app' / 'services' / 'relevance_filter.py'
if path_to_module.exists():
import importlib.util
spec = importlib.util.spec_from_file_location("relevance_filter", str(path_to_module))
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore
RelevanceFilter = module.RelevanceFilter # type: ignore
else:
raise
def _make_chunks():
return [
("Chunk A text", {"filename": "doc1.pdf", "chunk_index": 0}),
("Chunk B text", {"filename": "doc1.pdf", "chunk_index": 1}),
("Chunk C text", {"filename": "doc2.pdf", "chunk_index": 2}),
]
def test_filter_basic_returns_only_above_threshold():
chunks = _make_chunks()
llm = MagicMock()
llm.complete.return_value = "[8.5, 3.2, 9.0]"
rf = RelevanceFilter(llm)
result = rf.filter("What is this about?", chunks, threshold=7.0)
expected = [chunks[0], chunks[2]]
assert result == expected
# Ensure a single batch call was made
llm.complete.assert_called_once()
# Optional validation of prompt structure (contains the question and chunks)
called_prompt = llm.complete.call_args[0][0]
assert "What is this about?" in called_prompt
for t in ["Chunk A text", "Chunk B text", "Chunk C text"]:
assert t in called_prompt
def test_filter_empty_chunks_returns_empty_and_no_llm_call():
llm = MagicMock()
rf = RelevanceFilter(llm)
result = rf.filter("Question", [], threshold=7.0)
assert result == []
llm.complete.assert_not_called()
def test_filter_invalid_json_returns_empty():
chunks = _make_chunks()
llm = MagicMock()
llm.complete.return_value = "not json"
rf = RelevanceFilter(llm)
result = rf.filter("Question", chunks, threshold=7.0)
assert result == []
def test_filter_length_mismatch_returns_empty():
chunks = _make_chunks()[:2] # 2 chunks
llm = MagicMock()
llm.complete.return_value = "[5, 6]" # 2 scores, ok length, but threshold will filter all
rf = RelevanceFilter(llm)
result = rf.filter("Question", chunks, threshold=7.0)
# Length matches, but both below threshold -> empty
assert result == []
def test_filter_all_outside_threshold():
chunks = _make_chunks()
llm = MagicMock()
llm.complete.return_value = "[1.0, 2.0, 3.0]"
rf = RelevanceFilter(llm)
result = rf.filter("Question", chunks, threshold=5.0)
assert result == []