feat: add Sub-Phase 9.1 results generation APIs with reusable RAGPipeline

This commit is contained in:
Woody 2026-05-25 18:35:55 +08:00
parent 852430f1f1
commit ac81df0704
7 changed files with 1121 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate
from app.core.config import get_settings
from app.core.sqlite_db import (
get_prompts_db,
@ -58,6 +58,7 @@ app.include_router(history.router)
app.include_router(chunks.router)
app.include_router(video.router, prefix="/api/v1")
app.include_router(ws_asr.router)
app.include_router(test_generate.router, prefix="/api/v1")
_prompts_conn = get_prompts_db()
init_prompts_db(_prompts_conn)

View File

@ -0,0 +1,104 @@
import io
import logging
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from app.core.config import get_settings
from app.models.testing import GenerateTextRequest
from app.services.prompt_service import PromptService
from app.services.test_runner_service import TestRunnerService
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["test"])
def _get_prompt_service() -> PromptService:
settings = get_settings()
return PromptService(db_path=settings.prompts_db_path)
def _get_storage_service() -> TestStorageService:
settings = get_settings()
return TestStorageService(
results_dir=settings.test_results_dir,
evaluations_dir=settings.test_evaluations_dir,
)
@router.post("/test/generate/text")
async def generate_text(request: GenerateTextRequest):
settings = get_settings()
prompt_service = _get_prompt_service()
runner = TestRunnerService(settings)
result = await runner.run_text_test(
question=request.question,
profile=request.profile,
prompt_service=prompt_service,
label=request.label,
)
storage = _get_storage_service()
storage.save_result(result)
return result.model_dump()
@router.post("/test/generate/audio")
async def generate_audio(
audio_file: UploadFile = File(...),
profile: str = Form(...),
reference_transcript: str = Form(""),
label: str = Form(""),
language: str = Form("yue"),
):
if profile not in ("A", "B", "C"):
raise HTTPException(status_code=400, detail="profile must be A, B, or C")
settings = get_settings()
prompt_service = _get_prompt_service()
audio_bytes = await audio_file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Audio file is empty")
runner = TestRunnerService(settings)
result = await runner.run_audio_test(
audio_bytes=audio_bytes,
reference_transcript=reference_transcript,
profile=profile,
prompt_service=prompt_service,
language=language,
label=label,
audio_filename=audio_file.filename or "unknown",
)
storage = _get_storage_service()
storage.save_result(result)
return result.model_dump()
@router.get("/test/results")
async def list_results(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
storage = _get_storage_service()
return storage.list_results(limit=limit, offset=offset)
@router.get("/test/results/{result_id}")
async def get_result(result_id: str):
storage = _get_storage_service()
result = storage.load_result(result_id)
if result is None:
raise HTTPException(status_code=404, detail="Result not found")
return result.model_dump()
@router.delete("/test/results/{result_id}")
async def delete_result(result_id: str):
storage = _get_storage_service()
deleted = storage.delete_result(result_id)
if not deleted:
raise HTTPException(status_code=404, detail="Result not found")
return {"status": "deleted", "result_id": result_id}

View File

@ -0,0 +1,289 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from app.models.common import SourceMetadata
from app.models.query import SubQuestionSources
from app.services.query_decomposer import QueryDecomposer
from app.services.relevance_filter import RelevanceFilter
from app.services.rag import RAGService
logger = logging.getLogger(__name__)
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
@dataclass
class PipelineSnapshot:
"""Complete snapshot at a given pipeline stage for accuracy testing capture."""
phase: str
question: str = ""
# Stage 1: Decompose
extracted_questions: List[str] = field(default_factory=list)
decompose_prompt: str = ""
decomposer_time_ms: int = 0
# Stage 2: Retrieve
retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field(
default_factory=list
)
chunks_retrieved_count: int = 0
retriever_time_ms: int = 0
# Stage 3: Filter
filter_prompt: str = ""
filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field(
default_factory=list
)
chunks_filtered_count: int = 0
filter_time_ms: int = 0
# Stage 4: Generate
generate_prompt: str = ""
answer: str = ""
sub_question_sources: List[SubQuestionSources] = field(default_factory=list)
generator_time_ms: int = 0
# Metadata
total_time_ms: int = 0
error_message: str = ""
class RAGPipeline:
"""Reusable RAG pipeline: decompose → retrieve → filter → generate.
Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling.
Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints
(collect snapshots into GenerateResult).
Usage:
pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...)
async for snap in pipeline.execute("question text"):
if snap.phase == "completed":
print(snap.answer)
"""
def __init__(
self,
*,
decomposer: QueryDecomposer,
rag: RAGService,
relevance_filter: RelevanceFilter,
retrieval_n_results: int = 10,
relevance_threshold: float = 7.0,
):
self._decomposer = decomposer
self._rag = rag
self._relevance_filter = relevance_filter
self._retrieval_n_results = retrieval_n_results
self._relevance_threshold = relevance_threshold
async def execute(
self,
question: str,
stop_after_decompose: bool = False,
) -> AsyncGenerator[PipelineSnapshot, None]:
"""Execute the full pipeline, yielding one snapshot per stage."""
overall_start = time.perf_counter()
# --- Stage 1: Decompose ---
stage_start = time.perf_counter()
decompose_result = await self._decomposer.decompose(question)
if isinstance(decompose_result, tuple):
extracted_questions, decompose_prompt = decompose_result
else:
extracted_questions, decompose_prompt = decompose_result, ""
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
if not extracted_questions:
extracted_questions = [question]
yield PipelineSnapshot(
phase="decomposed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
)
if stop_after_decompose:
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
total_time_ms=total_ms,
)
return
# --- Stage 2: Retrieve ---
stage_start = time.perf_counter()
retrieval_results = (
self._rag.retrieve_per_subquestion(
extracted_questions, n_results=self._retrieval_n_results
)
if extracted_questions
else []
)
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_retrieved_count = sum(
len(chunks) for _, chunks in retrieval_results
)
yield PipelineSnapshot(
phase="retrieving",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
)
if not any(chunks for _, chunks in retrieval_results):
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
chunks_retrieved_count=0,
retriever_time_ms=retriever_time_ms,
answer=NO_RESULTS_ANSWER,
total_time_ms=total_ms,
)
return
# --- Stage 3: Filter ---
stage_start = time.perf_counter()
chunks_by_subq = [
[(text, meta) for text, meta, _dist in chunks]
for _, chunks in retrieval_results
]
if extracted_questions and chunks_by_subq:
filter_result = await self._relevance_filter.filter_per_subquestion(
extracted_questions, chunks_by_subq, threshold=self._relevance_threshold
)
else:
filter_result = ([], "")
if isinstance(filter_result, tuple):
filtered_by_subq, filter_prompt = filter_result
else:
filtered_by_subq, filter_prompt = filter_result, ""
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_filtered_count = sum(
len(chunks) for _, chunks in filtered_by_subq
)
yield PipelineSnapshot(
phase="filtering",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
filter_prompt=filter_prompt,
filtered_by_subq=filtered_by_subq,
chunks_filtered_count=chunks_filtered_count,
filter_time_ms=filter_time_ms,
)
if not filtered_by_subq or not any(
chunks for _, chunks in filtered_by_subq
):
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decomposer_time_ms=decomposer_time_ms,
retriever_time_ms=retriever_time_ms,
filter_time_ms=filter_time_ms,
answer=NO_RESULTS_ANSWER,
total_time_ms=total_ms,
)
return
# --- Stage 4: Generate ---
stage_start = time.perf_counter()
sub_chunk_texts = []
sub_chunk_metadata = []
for _, filtered_chunks in filtered_by_subq:
sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks])
sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks])
if extracted_questions and filtered_by_subq:
gen_result = await self._rag.generate_response_per_subquestion(
extracted_questions, sub_chunk_texts, sub_chunk_metadata
)
else:
gen_result = ("", "", [])
if isinstance(gen_result, tuple) and len(gen_result) == 3:
answer, generate_prompt, grouped_sources_meta = gen_result
else:
answer, generate_prompt = (
gen_result if isinstance(gen_result, tuple) else (gen_result, "")
)
grouped_sources_meta = []
sub_question_sources = []
for idx, (sub_q_text, sources_meta) in enumerate(
zip(extracted_questions, grouped_sources_meta)
):
sources = [
SourceMetadata(
filename=meta.get("filename", "unknown"),
upload_date=meta.get("upload_date", ""),
content_summary=meta.get("content_summary", ""),
chunk_index=meta.get("chunk_index", 0),
page_number=meta.get("page_number"),
chunk_file_path=meta.get("chunk_file_path"),
document_id=meta.get("document_id"),
)
for meta in sources_meta
]
sub_question_sources.append(
SubQuestionSources(
sub_question_index=idx,
sub_question_text=sub_q_text,
sources=sources,
)
)
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
filter_prompt=filter_prompt,
filtered_by_subq=filtered_by_subq,
chunks_filtered_count=chunks_filtered_count,
filter_time_ms=filter_time_ms,
generate_prompt=generate_prompt,
answer=answer,
sub_question_sources=sub_question_sources,
generator_time_ms=generator_time_ms,
total_time_ms=total_time_ms,
)

View File

@ -0,0 +1,220 @@
import logging
import uuid
from typing import Optional
from app.core.config import Settings
from app.models.testing import (
ChunkEntry,
FilteredResult,
GenerateResult,
InputInfo,
ResponseResult,
RetrievalResult,
SubQuestionChunks,
TimingInfo,
)
from app.services.asr_client import ASRClient
from app.services.llm_client import LLMClient
from app.services.llm_client_dp import LLMClientDP
from app.services.prompt_service import PromptService
from app.services.query_decomposer import QueryDecomposer
from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot
from app.services.rag import RAGService
from app.services.relevance_filter import RelevanceFilter
logger = logging.getLogger(__name__)
class TestRunnerService:
"""Runs the full RAG pipeline and captures all intermediate data for accuracy testing."""
def __init__(self, settings: Settings):
self.settings = settings
async def run_text_test(
self,
question: str,
profile: str,
prompt_service: PromptService,
label: str = "",
) -> GenerateResult:
result_id = uuid.uuid4().hex[:12]
prompt_service.activate_profile(profile)
active_profile = prompt_service.get_active_profile_name()
llm_client_dp = LLMClientDP(self.settings)
llm_client = LLMClient(self.settings)
rag = RAGService(
llm_client=llm_client,
settings=self.settings,
prompt_service=prompt_service,
)
decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service)
relevance_filter = RelevanceFilter(
llm_client, prompt_service=prompt_service
)
pipeline = RAGPipeline(
decomposer=decomposer,
rag=rag,
relevance_filter=relevance_filter,
retrieval_n_results=self.settings.retrieval_n_results,
relevance_threshold=self.settings.relevance_threshold,
)
# Collect all snapshots — use the last "completed" one for the final result
decomposed_snap = None
retrieval_snap = None
filtering_snap = None
completed_snap = None
async for snap in pipeline.execute(question):
if snap.phase == "decomposed":
decomposed_snap = snap
elif snap.phase == "retrieving":
retrieval_snap = snap
elif snap.phase == "filtering":
filtering_snap = snap
elif snap.phase == "completed":
completed_snap = snap
if completed_snap is None:
raise RuntimeError("Pipeline did not produce a completed snapshot")
# Build retrieval result
retrieval_per_subq = []
if retrieval_snap and retrieval_snap.retrieval_results:
for sq_idx, (sub_q_text, chunks) in enumerate(
retrieval_snap.retrieval_results
):
retrieval_per_subq.append(
SubQuestionChunks(
sub_question_index=sq_idx,
sub_question_text=sub_q_text,
chunks=[
ChunkEntry(
chunk_index=i,
text=text,
metadata=meta,
distance=distance,
)
for i, (text, meta, distance) in enumerate(chunks)
],
)
)
# Build filtered result
filtered_per_subq = []
if filtering_snap and filtering_snap.filtered_by_subq:
for sq_idx, (sub_q_text, chunks) in enumerate(
filtering_snap.filtered_by_subq
):
filtered_per_subq.append(
SubQuestionChunks(
sub_question_index=sq_idx,
sub_question_text=sub_q_text,
chunks=[
ChunkEntry(
chunk_index=i,
text=text,
metadata=meta,
distance=0.0,
)
for i, (text, meta) in enumerate(chunks)
],
)
)
# Build response result — serialize through dicts to convert between
# query.py SubQuestionSources and testing.py SubQuestionSources (same fields)
response_result = ResponseResult(
final_answer=completed_snap.answer,
sub_question_sources=[
{
"sub_question_index": sq.sub_question_index,
"sub_question_text": sq.sub_question_text,
"sources": [s.model_dump() for s in sq.sources],
}
for sq in completed_snap.sub_question_sources
],
generate_time_ms=completed_snap.generator_time_ms,
)
return GenerateResult(
result_id=result_id,
input_type="text",
profile=active_profile,
label=label,
input=InputInfo(text=question),
extracted_key_questions=completed_snap.extracted_questions,
retrieval=RetrievalResult(
per_sub_question=retrieval_per_subq,
total_chunks_retrieved=(
retrieval_snap.chunks_retrieved_count if retrieval_snap else 0
),
retriever_time_ms=(
retrieval_snap.retriever_time_ms if retrieval_snap else 0
),
),
filtered=FilteredResult(
per_sub_question=filtered_per_subq,
total_chunks_filtered=(
filtering_snap.chunks_filtered_count if filtering_snap else 0
),
filter_time_ms=(
filtering_snap.filter_time_ms if filtering_snap else 0
),
),
response=response_result,
timing=TimingInfo(
decomposer_time_ms=completed_snap.decomposer_time_ms,
retriever_time_ms=(
retrieval_snap.retriever_time_ms if retrieval_snap else 0
),
filter_time_ms=(
filtering_snap.filter_time_ms if filtering_snap else 0
),
generator_time_ms=completed_snap.generator_time_ms,
total_time_ms=completed_snap.total_time_ms,
),
)
async def run_audio_test(
self,
audio_bytes: bytes,
reference_transcript: str,
profile: str,
prompt_service: PromptService,
language: str = "yue",
label: str = "",
audio_filename: str = "",
) -> GenerateResult:
result_id = uuid.uuid4().hex[:12]
# Run ASR
asr_client = ASRClient(self.settings)
transcribed_text = await asr_client.transcribe_full(
audio_bytes, language=language
)
# Run text test on transcribed text
result = await self.run_text_test(
question=transcribed_text,
profile=profile,
prompt_service=prompt_service,
label=label,
)
# Override with audio-specific fields
result.result_id = result_id
result.input_type = "audio"
result.input = InputInfo(
text=transcribed_text,
reference_transcript=reference_transcript,
audio_filename=audio_filename,
audio_duration_seconds=0.0,
asr_language=language,
)
return result

View File

@ -0,0 +1,101 @@
import json
import logging
import os
from pathlib import Path
from typing import List, Optional
from app.models.testing import GenerateResult, EvaluationResult
logger = logging.getLogger(__name__)
class TestStorageService:
def __init__(self, results_dir: str, evaluations_dir: str):
self.results_dir = results_dir
self.evaluations_dir = evaluations_dir
Path(results_dir).mkdir(parents=True, exist_ok=True)
Path(evaluations_dir).mkdir(parents=True, exist_ok=True)
# --- Results ---
def save_result(self, result: GenerateResult) -> str:
filepath = os.path.join(self.results_dir, f"{result.result_id}.json")
with open(filepath, "w", encoding="utf-8") as f:
f.write(result.model_dump_json(indent=2))
logger.info("Saved test result: %s", filepath)
return filepath
def load_result(self, result_id: str) -> Optional[GenerateResult]:
filepath = os.path.join(self.results_dir, f"{result_id}.json")
if not os.path.isfile(filepath):
return None
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return GenerateResult.model_validate(data)
def list_results(self, limit: int = 50, offset: int = 0) -> List[dict]:
items = []
try:
for entry in sorted(
Path(self.results_dir).iterdir(),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
if entry.suffix == ".json":
stat = entry.stat()
items.append({
"result_id": entry.stem,
"file_size_bytes": stat.st_size,
})
except FileNotFoundError:
return []
return items[offset : offset + limit]
def delete_result(self, result_id: str) -> bool:
filepath = os.path.join(self.results_dir, f"{result_id}.json")
if not os.path.isfile(filepath):
return False
os.remove(filepath)
return True
# --- Evaluations ---
def save_evaluation(self, evaluation: EvaluationResult) -> str:
filepath = os.path.join(self.evaluations_dir, f"{evaluation.evaluation_id}.json")
with open(filepath, "w", encoding="utf-8") as f:
f.write(evaluation.model_dump_json(indent=2))
logger.info("Saved evaluation: %s", filepath)
return filepath
def load_evaluation(self, eval_id: str) -> Optional[EvaluationResult]:
filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json")
if not os.path.isfile(filepath):
return None
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return EvaluationResult.model_validate(data)
def list_evaluations(self, limit: int = 50, offset: int = 0) -> List[dict]:
items = []
try:
for entry in sorted(
Path(self.evaluations_dir).iterdir(),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
if entry.suffix == ".json":
stat = entry.stat()
items.append({
"evaluation_id": entry.stem,
"file_size_bytes": stat.st_size,
})
except FileNotFoundError:
return []
return items[offset : offset + limit]
def delete_evaluation(self, eval_id: str) -> bool:
filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json")
if not os.path.isfile(filepath):
return False
os.remove(filepath)
return True

View File

@ -0,0 +1,173 @@
"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1).
Covers:
- POST /api/v1/test/generate/text with valid request returns 200
- Response includes all pipeline stages (retrieval, filtered, response, timing)
- Invalid profile returns validation error
- Empty question returns validation error
- GET /api/v1/test/results lists saved results
- GET /api/v1/test/results/{id} retrieves specific result
- DELETE /api/v1/test/results/{id} deletes result
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.routers.test_generate import router
@pytest.fixture
def client(tmp_path, monkeypatch):
results_dir = str(tmp_path / "test_results")
evals_dir = str(tmp_path / "test_evaluations")
prompts_path = str(tmp_path / "prompts.db")
history_path = str(tmp_path / "history.db")
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
from app.core.config import get_settings
get_settings.cache_clear()
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
conn = _get_db(prompts_path)
init_prompts_db(conn)
seed_default_profiles(conn)
conn.close()
hconn = _get_db(history_path)
init_history_db(hconn)
hconn.close()
test_app = FastAPI()
test_app.include_router(router, prefix="/api/v1")
yield TestClient(test_app)
get_settings.cache_clear()
@pytest.fixture
def mock_pipeline(monkeypatch):
"""Mock LLM and RAG to avoid real API calls."""
async def _mock_decompose(self, question):
return (["sub question 1", "sub question 2"], "mocked decompose prompt")
def _mock_retrieve(self, sub_questions, n_results=10):
return [
(sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01",
"content_summary": "test chunk", "chunk_index": 0,
"page_number": 1, "document_id": "doc-1"}, 0.15)])
for sq in sub_questions
]
async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0):
result = [
(sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks])
for sq, chunks in zip(sub_questions, chunks_by_subq)
]
return (result, "mocked filter prompt")
async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata):
answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content"
grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions]
return (answer, "mocked generate prompt", grouped_sources)
monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose)
monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve)
monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter)
monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate)
@pytest.mark.usefixtures("mock_pipeline")
class TestGenerateTextEndpoint:
def test_valid_request_returns_200(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test question",
"profile": "A",
"label": "my label",
})
assert resp.status_code == 200
data = resp.json()
assert data["input_type"] == "text"
assert data["profile"] == "A"
assert data["label"] == "my label"
assert "result_id" in data
def test_result_contains_all_stages(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test",
"profile": "B",
})
assert resp.status_code == 200
data = resp.json()
assert len(data["extracted_key_questions"]) == 2
assert data["retrieval"]["total_chunks_retrieved"] > 0
assert data["filtered"]["total_chunks_filtered"] > 0
assert len(data["response"]["final_answer"]) > 0
assert data["timing"]["total_time_ms"] >= 0
def test_invalid_profile_rejected(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test",
"profile": "D",
})
assert resp.status_code == 422
def test_empty_question_rejected(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "",
"profile": "A",
})
assert resp.status_code == 422
def test_result_saved_and_retrievable(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "save test",
"profile": "A",
})
assert resp.status_code == 200
result_id = resp.json()["result_id"]
get_resp = client.get(f"/api/v1/test/results/{result_id}")
assert get_resp.status_code == 200
assert get_resp.json()["result_id"] == result_id
def test_list_results(self, client):
client.post("/api/v1/test/generate/text", json={
"question": "list test 1", "profile": "A",
})
client.post("/api/v1/test/generate/text", json={
"question": "list test 2", "profile": "B",
})
resp = client.get("/api/v1/test/results?limit=10")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
def test_get_nonexistent_result(self, client):
resp = client.get("/api/v1/test/results/no-such-id")
assert resp.status_code == 404
def test_delete_result(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "delete test", "profile": "A",
})
result_id = resp.json()["result_id"]
del_resp = client.delete(f"/api/v1/test/results/{result_id}")
assert del_resp.status_code == 200
get_resp = client.get(f"/api/v1/test/results/{result_id}")
assert get_resp.status_code == 404

View File

@ -0,0 +1,232 @@
"""Phase 9 tests: Results storage service CRUD operations (Sub-Phase 9.1).
Covers:
- save_result writes JSON file and returns result_id
- load_result reads and parses JSON file
- list_results returns list of result metadata
- delete_result removes file
- Nonexistent result loading returns None
- save_evaluation / load_evaluation / list_evaluations / delete_evaluation
- Empty storage dirs don't error
"""
import json
import os
from pathlib import Path
import pytest
from app.models.testing import (
GenerateResult,
InputInfo,
TimingInfo,
RetrievalResult,
FilteredResult,
ResponseResult,
EvaluationResult,
EvaluationTiming,
)
@pytest.fixture
def storage_dirs(tmp_path):
results_dir = tmp_path / "test_results"
evals_dir = tmp_path / "test_evaluations"
results_dir.mkdir()
evals_dir.mkdir()
return str(results_dir), str(evals_dir)
@pytest.fixture
def sample_result():
return GenerateResult(
result_id="test-001",
input_type="text",
profile="A",
label="sample test",
input=InputInfo(text="sample question"),
extracted_key_questions=["key q1"],
retrieval=RetrievalResult(
per_sub_question=[],
total_chunks_retrieved=0,
retriever_time_ms=100,
),
filtered=FilteredResult(
per_sub_question=[],
total_chunks_filtered=0,
filter_time_ms=100,
),
response=ResponseResult(
final_answer="answer",
sub_question_sources=[],
generate_time_ms=100,
),
timing=TimingInfo(
decomposer_time_ms=100,
retriever_time_ms=100,
filter_time_ms=100,
generator_time_ms=100,
total_time_ms=400,
),
)
class TestResultStorage:
def test_save_and_load(self, storage_dirs, sample_result, monkeypatch):
results_dir, _ = storage_dirs
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
path = svc.save_result(sample_result)
assert os.path.exists(path)
loaded = svc.load_result(sample_result.result_id)
assert loaded is not None
assert loaded.result_id == sample_result.result_id
assert loaded.input_type == "text"
assert loaded.profile == "A"
def test_load_nonexistent(self, storage_dirs):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
result = svc.load_result("nonexistent-id")
assert result is None
def test_list_results(self, storage_dirs, sample_result, monkeypatch):
results_dir, _ = storage_dirs
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
svc.save_result(sample_result)
items = svc.list_results()
assert len(items) >= 1
assert any(r["result_id"] == "test-001" for r in items)
def test_list_results_with_limit_offset(self, storage_dirs, sample_result):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
for i in range(5):
r = sample_result.model_copy(update={"result_id": f"test-{i:03d}"})
svc.save_result(r)
items = svc.list_results(limit=2, offset=1)
assert len(items) == 2
def test_delete_result(self, storage_dirs, sample_result):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
svc.save_result(sample_result)
filepath = os.path.join(results_dir, f"{sample_result.result_id}.json")
assert os.path.exists(filepath)
result = svc.delete_result(sample_result.result_id)
assert result is True
assert not os.path.exists(filepath)
def test_delete_nonexistent(self, storage_dirs):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
result = svc.delete_result("no-such-id")
assert result is False
def test_creates_dir_if_missing(self, storage_dirs):
results_dir, evals_dir = storage_dirs
new_results = os.path.join(results_dir, "auto_created")
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(new_results, evals_dir)
assert os.path.isdir(new_results)
def test_list_empty_dir(self, storage_dirs):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
items = svc.list_results()
assert items == []
class TestEvaluationStorage:
def test_save_and_load_eval(self, storage_dirs):
results_dir, evals_dir = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, evals_dir)
eval_result = EvaluationResult(
evaluation_id="eval-001",
result_id="result-001",
status="completed",
timing=EvaluationTiming(
audio_evaluation_time_ms=10,
key_questions_evaluation_time_ms=100,
chunk_evaluation_time_ms=200,
response_evaluation_time_ms=300,
total_evaluation_time_ms=610,
),
)
path = svc.save_evaluation(eval_result)
assert os.path.exists(path)
loaded = svc.load_evaluation("eval-001")
assert loaded is not None
assert loaded.evaluation_id == "eval-001"
assert loaded.status == "completed"
def test_list_evaluations(self, storage_dirs):
results_dir, evals_dir = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, evals_dir)
eval_result = EvaluationResult(
evaluation_id="eval-001",
result_id="result-001",
status="completed",
timing=EvaluationTiming(
audio_evaluation_time_ms=10,
key_questions_evaluation_time_ms=100,
chunk_evaluation_time_ms=200,
response_evaluation_time_ms=300,
total_evaluation_time_ms=610,
),
)
svc.save_evaluation(eval_result)
items = svc.list_evaluations()
assert len(items) >= 1
def test_delete_evaluation(self, storage_dirs):
results_dir, evals_dir = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, evals_dir)
eval_result = EvaluationResult(
evaluation_id="eval-002",
result_id="result-002",
status="completed",
timing=EvaluationTiming(
audio_evaluation_time_ms=10,
key_questions_evaluation_time_ms=100,
chunk_evaluation_time_ms=200,
response_evaluation_time_ms=300,
total_evaluation_time_ms=610,
),
)
svc.save_evaluation(eval_result)
filepath = os.path.join(evals_dir, "eval-002.json")
assert os.path.exists(filepath)
result = svc.delete_evaluation("eval-002")
assert result is True
assert not os.path.exists(filepath)