feat: add Sub-Phase 9.1 results generation APIs with reusable RAGPipeline

This commit is contained in:
Woody 2026-05-25 18:35:55 +08:00
parent 852430f1f1
commit ac81df0704
7 changed files with 1121 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate
from app.core.config import get_settings from app.core.config import get_settings
from app.core.sqlite_db import ( from app.core.sqlite_db import (
get_prompts_db, get_prompts_db,
@ -58,6 +58,7 @@ app.include_router(history.router)
app.include_router(chunks.router) app.include_router(chunks.router)
app.include_router(video.router, prefix="/api/v1") app.include_router(video.router, prefix="/api/v1")
app.include_router(ws_asr.router) app.include_router(ws_asr.router)
app.include_router(test_generate.router, prefix="/api/v1")
_prompts_conn = get_prompts_db() _prompts_conn = get_prompts_db()
init_prompts_db(_prompts_conn) init_prompts_db(_prompts_conn)

View File

@ -0,0 +1,104 @@
import io
import logging
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from app.core.config import get_settings
from app.models.testing import GenerateTextRequest
from app.services.prompt_service import PromptService
from app.services.test_runner_service import TestRunnerService
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["test"])
def _get_prompt_service() -> PromptService:
settings = get_settings()
return PromptService(db_path=settings.prompts_db_path)
def _get_storage_service() -> TestStorageService:
settings = get_settings()
return TestStorageService(
results_dir=settings.test_results_dir,
evaluations_dir=settings.test_evaluations_dir,
)
@router.post("/test/generate/text")
async def generate_text(request: GenerateTextRequest):
settings = get_settings()
prompt_service = _get_prompt_service()
runner = TestRunnerService(settings)
result = await runner.run_text_test(
question=request.question,
profile=request.profile,
prompt_service=prompt_service,
label=request.label,
)
storage = _get_storage_service()
storage.save_result(result)
return result.model_dump()
@router.post("/test/generate/audio")
async def generate_audio(
audio_file: UploadFile = File(...),
profile: str = Form(...),
reference_transcript: str = Form(""),
label: str = Form(""),
language: str = Form("yue"),
):
if profile not in ("A", "B", "C"):
raise HTTPException(status_code=400, detail="profile must be A, B, or C")
settings = get_settings()
prompt_service = _get_prompt_service()
audio_bytes = await audio_file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Audio file is empty")
runner = TestRunnerService(settings)
result = await runner.run_audio_test(
audio_bytes=audio_bytes,
reference_transcript=reference_transcript,
profile=profile,
prompt_service=prompt_service,
language=language,
label=label,
audio_filename=audio_file.filename or "unknown",
)
storage = _get_storage_service()
storage.save_result(result)
return result.model_dump()
@router.get("/test/results")
async def list_results(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
storage = _get_storage_service()
return storage.list_results(limit=limit, offset=offset)
@router.get("/test/results/{result_id}")
async def get_result(result_id: str):
storage = _get_storage_service()
result = storage.load_result(result_id)
if result is None:
raise HTTPException(status_code=404, detail="Result not found")
return result.model_dump()
@router.delete("/test/results/{result_id}")
async def delete_result(result_id: str):
storage = _get_storage_service()
deleted = storage.delete_result(result_id)
if not deleted:
raise HTTPException(status_code=404, detail="Result not found")
return {"status": "deleted", "result_id": result_id}

View File

@ -0,0 +1,289 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from app.models.common import SourceMetadata
from app.models.query import SubQuestionSources
from app.services.query_decomposer import QueryDecomposer
from app.services.relevance_filter import RelevanceFilter
from app.services.rag import RAGService
logger = logging.getLogger(__name__)
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
@dataclass
class PipelineSnapshot:
"""Complete snapshot at a given pipeline stage for accuracy testing capture."""
phase: str
question: str = ""
# Stage 1: Decompose
extracted_questions: List[str] = field(default_factory=list)
decompose_prompt: str = ""
decomposer_time_ms: int = 0
# Stage 2: Retrieve
retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field(
default_factory=list
)
chunks_retrieved_count: int = 0
retriever_time_ms: int = 0
# Stage 3: Filter
filter_prompt: str = ""
filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field(
default_factory=list
)
chunks_filtered_count: int = 0
filter_time_ms: int = 0
# Stage 4: Generate
generate_prompt: str = ""
answer: str = ""
sub_question_sources: List[SubQuestionSources] = field(default_factory=list)
generator_time_ms: int = 0
# Metadata
total_time_ms: int = 0
error_message: str = ""
class RAGPipeline:
"""Reusable RAG pipeline: decompose → retrieve → filter → generate.
Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling.
Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints
(collect snapshots into GenerateResult).
Usage:
pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...)
async for snap in pipeline.execute("question text"):
if snap.phase == "completed":
print(snap.answer)
"""
def __init__(
self,
*,
decomposer: QueryDecomposer,
rag: RAGService,
relevance_filter: RelevanceFilter,
retrieval_n_results: int = 10,
relevance_threshold: float = 7.0,
):
self._decomposer = decomposer
self._rag = rag
self._relevance_filter = relevance_filter
self._retrieval_n_results = retrieval_n_results
self._relevance_threshold = relevance_threshold
async def execute(
self,
question: str,
stop_after_decompose: bool = False,
) -> AsyncGenerator[PipelineSnapshot, None]:
"""Execute the full pipeline, yielding one snapshot per stage."""
overall_start = time.perf_counter()
# --- Stage 1: Decompose ---
stage_start = time.perf_counter()
decompose_result = await self._decomposer.decompose(question)
if isinstance(decompose_result, tuple):
extracted_questions, decompose_prompt = decompose_result
else:
extracted_questions, decompose_prompt = decompose_result, ""
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
if not extracted_questions:
extracted_questions = [question]
yield PipelineSnapshot(
phase="decomposed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
)
if stop_after_decompose:
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
total_time_ms=total_ms,
)
return
# --- Stage 2: Retrieve ---
stage_start = time.perf_counter()
retrieval_results = (
self._rag.retrieve_per_subquestion(
extracted_questions, n_results=self._retrieval_n_results
)
if extracted_questions
else []
)
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_retrieved_count = sum(
len(chunks) for _, chunks in retrieval_results
)
yield PipelineSnapshot(
phase="retrieving",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
)
if not any(chunks for _, chunks in retrieval_results):
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
chunks_retrieved_count=0,
retriever_time_ms=retriever_time_ms,
answer=NO_RESULTS_ANSWER,
total_time_ms=total_ms,
)
return
# --- Stage 3: Filter ---
stage_start = time.perf_counter()
chunks_by_subq = [
[(text, meta) for text, meta, _dist in chunks]
for _, chunks in retrieval_results
]
if extracted_questions and chunks_by_subq:
filter_result = await self._relevance_filter.filter_per_subquestion(
extracted_questions, chunks_by_subq, threshold=self._relevance_threshold
)
else:
filter_result = ([], "")
if isinstance(filter_result, tuple):
filtered_by_subq, filter_prompt = filter_result
else:
filtered_by_subq, filter_prompt = filter_result, ""
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_filtered_count = sum(
len(chunks) for _, chunks in filtered_by_subq
)
yield PipelineSnapshot(
phase="filtering",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
filter_prompt=filter_prompt,
filtered_by_subq=filtered_by_subq,
chunks_filtered_count=chunks_filtered_count,
filter_time_ms=filter_time_ms,
)
if not filtered_by_subq or not any(
chunks for _, chunks in filtered_by_subq
):
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decomposer_time_ms=decomposer_time_ms,
retriever_time_ms=retriever_time_ms,
filter_time_ms=filter_time_ms,
answer=NO_RESULTS_ANSWER,
total_time_ms=total_ms,
)
return
# --- Stage 4: Generate ---
stage_start = time.perf_counter()
sub_chunk_texts = []
sub_chunk_metadata = []
for _, filtered_chunks in filtered_by_subq:
sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks])
sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks])
if extracted_questions and filtered_by_subq:
gen_result = await self._rag.generate_response_per_subquestion(
extracted_questions, sub_chunk_texts, sub_chunk_metadata
)
else:
gen_result = ("", "", [])
if isinstance(gen_result, tuple) and len(gen_result) == 3:
answer, generate_prompt, grouped_sources_meta = gen_result
else:
answer, generate_prompt = (
gen_result if isinstance(gen_result, tuple) else (gen_result, "")
)
grouped_sources_meta = []
sub_question_sources = []
for idx, (sub_q_text, sources_meta) in enumerate(
zip(extracted_questions, grouped_sources_meta)
):
sources = [
SourceMetadata(
filename=meta.get("filename", "unknown"),
upload_date=meta.get("upload_date", ""),
content_summary=meta.get("content_summary", ""),
chunk_index=meta.get("chunk_index", 0),
page_number=meta.get("page_number"),
chunk_file_path=meta.get("chunk_file_path"),
document_id=meta.get("document_id"),
)
for meta in sources_meta
]
sub_question_sources.append(
SubQuestionSources(
sub_question_index=idx,
sub_question_text=sub_q_text,
sources=sources,
)
)
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
filter_prompt=filter_prompt,
filtered_by_subq=filtered_by_subq,
chunks_filtered_count=chunks_filtered_count,
filter_time_ms=filter_time_ms,
generate_prompt=generate_prompt,
answer=answer,
sub_question_sources=sub_question_sources,
generator_time_ms=generator_time_ms,
total_time_ms=total_time_ms,
)

View File

@ -0,0 +1,220 @@
import logging
import uuid
from typing import Optional
from app.core.config import Settings
from app.models.testing import (
ChunkEntry,
FilteredResult,
GenerateResult,
InputInfo,
ResponseResult,
RetrievalResult,
SubQuestionChunks,
TimingInfo,
)
from app.services.asr_client import ASRClient
from app.services.llm_client import LLMClient
from app.services.llm_client_dp import LLMClientDP
from app.services.prompt_service import PromptService
from app.services.query_decomposer import QueryDecomposer
from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot
from app.services.rag import RAGService
from app.services.relevance_filter import RelevanceFilter
logger = logging.getLogger(__name__)
class TestRunnerService:
"""Runs the full RAG pipeline and captures all intermediate data for accuracy testing."""
def __init__(self, settings: Settings):
self.settings = settings
async def run_text_test(
self,
question: str,
profile: str,
prompt_service: PromptService,
label: str = "",
) -> GenerateResult:
result_id = uuid.uuid4().hex[:12]
prompt_service.activate_profile(profile)
active_profile = prompt_service.get_active_profile_name()
llm_client_dp = LLMClientDP(self.settings)
llm_client = LLMClient(self.settings)
rag = RAGService(
llm_client=llm_client,
settings=self.settings,
prompt_service=prompt_service,
)
decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service)
relevance_filter = RelevanceFilter(
llm_client, prompt_service=prompt_service
)
pipeline = RAGPipeline(
decomposer=decomposer,
rag=rag,
relevance_filter=relevance_filter,
retrieval_n_results=self.settings.retrieval_n_results,
relevance_threshold=self.settings.relevance_threshold,
)
# Collect all snapshots — use the last "completed" one for the final result
decomposed_snap = None
retrieval_snap = None
filtering_snap = None
completed_snap = None
async for snap in pipeline.execute(question):
if snap.phase == "decomposed":
decomposed_snap = snap
elif snap.phase == "retrieving":
retrieval_snap = snap
elif snap.phase == "filtering":
filtering_snap = snap
elif snap.phase == "completed":
completed_snap = snap
if completed_snap is None:
raise RuntimeError("Pipeline did not produce a completed snapshot")
# Build retrieval result
retrieval_per_subq = []
if retrieval_snap and retrieval_snap.retrieval_results:
for sq_idx, (sub_q_text, chunks) in enumerate(
retrieval_snap.retrieval_results
):
retrieval_per_subq.append(
SubQuestionChunks(
sub_question_index=sq_idx,
sub_question_text=sub_q_text,
chunks=[
ChunkEntry(
chunk_index=i,
text=text,
metadata=meta,
distance=distance,
)
for i, (text, meta, distance) in enumerate(chunks)
],
)
)
# Build filtered result
filtered_per_subq = []
if filtering_snap and filtering_snap.filtered_by_subq:
for sq_idx, (sub_q_text, chunks) in enumerate(
filtering_snap.filtered_by_subq
):
filtered_per_subq.append(
SubQuestionChunks(
sub_question_index=sq_idx,
sub_question_text=sub_q_text,
chunks=[
ChunkEntry(
chunk_index=i,
text=text,
metadata=meta,
distance=0.0,
)
for i, (text, meta) in enumerate(chunks)
],
)
)
# Build response result — serialize through dicts to convert between
# query.py SubQuestionSources and testing.py SubQuestionSources (same fields)
response_result = ResponseResult(
final_answer=completed_snap.answer,
sub_question_sources=[
{
"sub_question_index": sq.sub_question_index,
"sub_question_text": sq.sub_question_text,
"sources": [s.model_dump() for s in sq.sources],
}
for sq in completed_snap.sub_question_sources
],
generate_time_ms=completed_snap.generator_time_ms,
)
return GenerateResult(
result_id=result_id,
input_type="text",
profile=active_profile,
label=label,
input=InputInfo(text=question),
extracted_key_questions=completed_snap.extracted_questions,
retrieval=RetrievalResult(
per_sub_question=retrieval_per_subq,
total_chunks_retrieved=(
retrieval_snap.chunks_retrieved_count if retrieval_snap else 0
),
retriever_time_ms=(
retrieval_snap.retriever_time_ms if retrieval_snap else 0
),
),
filtered=FilteredResult(
per_sub_question=filtered_per_subq,
total_chunks_filtered=(
filtering_snap.chunks_filtered_count if filtering_snap else 0
),
filter_time_ms=(
filtering_snap.filter_time_ms if filtering_snap else 0
),
),
response=response_result,
timing=TimingInfo(
decomposer_time_ms=completed_snap.decomposer_time_ms,
retriever_time_ms=(
retrieval_snap.retriever_time_ms if retrieval_snap else 0
),
filter_time_ms=(
filtering_snap.filter_time_ms if filtering_snap else 0
),
generator_time_ms=completed_snap.generator_time_ms,
total_time_ms=completed_snap.total_time_ms,
),
)
async def run_audio_test(
self,
audio_bytes: bytes,
reference_transcript: str,
profile: str,
prompt_service: PromptService,
language: str = "yue",
label: str = "",
audio_filename: str = "",
) -> GenerateResult:
result_id = uuid.uuid4().hex[:12]
# Run ASR
asr_client = ASRClient(self.settings)
transcribed_text = await asr_client.transcribe_full(
audio_bytes, language=language
)
# Run text test on transcribed text
result = await self.run_text_test(
question=transcribed_text,
profile=profile,
prompt_service=prompt_service,
label=label,
)
# Override with audio-specific fields
result.result_id = result_id
result.input_type = "audio"
result.input = InputInfo(
text=transcribed_text,
reference_transcript=reference_transcript,
audio_filename=audio_filename,
audio_duration_seconds=0.0,
asr_language=language,
)
return result

View File

@ -0,0 +1,101 @@
import json
import logging
import os
from pathlib import Path
from typing import List, Optional
from app.models.testing import GenerateResult, EvaluationResult
logger = logging.getLogger(__name__)
class TestStorageService:
def __init__(self, results_dir: str, evaluations_dir: str):
self.results_dir = results_dir
self.evaluations_dir = evaluations_dir
Path(results_dir).mkdir(parents=True, exist_ok=True)
Path(evaluations_dir).mkdir(parents=True, exist_ok=True)
# --- Results ---
def save_result(self, result: GenerateResult) -> str:
filepath = os.path.join(self.results_dir, f"{result.result_id}.json")
with open(filepath, "w", encoding="utf-8") as f:
f.write(result.model_dump_json(indent=2))
logger.info("Saved test result: %s", filepath)
return filepath
def load_result(self, result_id: str) -> Optional[GenerateResult]:
filepath = os.path.join(self.results_dir, f"{result_id}.json")
if not os.path.isfile(filepath):
return None
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return GenerateResult.model_validate(data)
def list_results(self, limit: int = 50, offset: int = 0) -> List[dict]:
items = []
try:
for entry in sorted(
Path(self.results_dir).iterdir(),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
if entry.suffix == ".json":
stat = entry.stat()
items.append({
"result_id": entry.stem,
"file_size_bytes": stat.st_size,
})
except FileNotFoundError:
return []
return items[offset : offset + limit]
def delete_result(self, result_id: str) -> bool:
filepath = os.path.join(self.results_dir, f"{result_id}.json")
if not os.path.isfile(filepath):
return False
os.remove(filepath)
return True
# --- Evaluations ---
def save_evaluation(self, evaluation: EvaluationResult) -> str:
filepath = os.path.join(self.evaluations_dir, f"{evaluation.evaluation_id}.json")
with open(filepath, "w", encoding="utf-8") as f:
f.write(evaluation.model_dump_json(indent=2))
logger.info("Saved evaluation: %s", filepath)
return filepath
def load_evaluation(self, eval_id: str) -> Optional[EvaluationResult]:
filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json")
if not os.path.isfile(filepath):
return None
with open(filepath, "r", encoding="utf-8") as f:
data = json.load(f)
return EvaluationResult.model_validate(data)
def list_evaluations(self, limit: int = 50, offset: int = 0) -> List[dict]:
items = []
try:
for entry in sorted(
Path(self.evaluations_dir).iterdir(),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
if entry.suffix == ".json":
stat = entry.stat()
items.append({
"evaluation_id": entry.stem,
"file_size_bytes": stat.st_size,
})
except FileNotFoundError:
return []
return items[offset : offset + limit]
def delete_evaluation(self, eval_id: str) -> bool:
filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json")
if not os.path.isfile(filepath):
return False
os.remove(filepath)
return True

View File

@ -0,0 +1,173 @@
"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1).
Covers:
- POST /api/v1/test/generate/text with valid request returns 200
- Response includes all pipeline stages (retrieval, filtered, response, timing)
- Invalid profile returns validation error
- Empty question returns validation error
- GET /api/v1/test/results lists saved results
- GET /api/v1/test/results/{id} retrieves specific result
- DELETE /api/v1/test/results/{id} deletes result
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.routers.test_generate import router
@pytest.fixture
def client(tmp_path, monkeypatch):
results_dir = str(tmp_path / "test_results")
evals_dir = str(tmp_path / "test_evaluations")
prompts_path = str(tmp_path / "prompts.db")
history_path = str(tmp_path / "history.db")
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
from app.core.config import get_settings
get_settings.cache_clear()
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
conn = _get_db(prompts_path)
init_prompts_db(conn)
seed_default_profiles(conn)
conn.close()
hconn = _get_db(history_path)
init_history_db(hconn)
hconn.close()
test_app = FastAPI()
test_app.include_router(router, prefix="/api/v1")
yield TestClient(test_app)
get_settings.cache_clear()
@pytest.fixture
def mock_pipeline(monkeypatch):
"""Mock LLM and RAG to avoid real API calls."""
async def _mock_decompose(self, question):
return (["sub question 1", "sub question 2"], "mocked decompose prompt")
def _mock_retrieve(self, sub_questions, n_results=10):
return [
(sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01",
"content_summary": "test chunk", "chunk_index": 0,
"page_number": 1, "document_id": "doc-1"}, 0.15)])
for sq in sub_questions
]
async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0):
result = [
(sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks])
for sq, chunks in zip(sub_questions, chunks_by_subq)
]
return (result, "mocked filter prompt")
async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata):
answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content"
grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions]
return (answer, "mocked generate prompt", grouped_sources)
monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose)
monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve)
monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter)
monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate)
@pytest.mark.usefixtures("mock_pipeline")
class TestGenerateTextEndpoint:
def test_valid_request_returns_200(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test question",
"profile": "A",
"label": "my label",
})
assert resp.status_code == 200
data = resp.json()
assert data["input_type"] == "text"
assert data["profile"] == "A"
assert data["label"] == "my label"
assert "result_id" in data
def test_result_contains_all_stages(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test",
"profile": "B",
})
assert resp.status_code == 200
data = resp.json()
assert len(data["extracted_key_questions"]) == 2
assert data["retrieval"]["total_chunks_retrieved"] > 0
assert data["filtered"]["total_chunks_filtered"] > 0
assert len(data["response"]["final_answer"]) > 0
assert data["timing"]["total_time_ms"] >= 0
def test_invalid_profile_rejected(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test",
"profile": "D",
})
assert resp.status_code == 422
def test_empty_question_rejected(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "",
"profile": "A",
})
assert resp.status_code == 422
def test_result_saved_and_retrievable(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "save test",
"profile": "A",
})
assert resp.status_code == 200
result_id = resp.json()["result_id"]
get_resp = client.get(f"/api/v1/test/results/{result_id}")
assert get_resp.status_code == 200
assert get_resp.json()["result_id"] == result_id
def test_list_results(self, client):
client.post("/api/v1/test/generate/text", json={
"question": "list test 1", "profile": "A",
})
client.post("/api/v1/test/generate/text", json={
"question": "list test 2", "profile": "B",
})
resp = client.get("/api/v1/test/results?limit=10")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
def test_get_nonexistent_result(self, client):
resp = client.get("/api/v1/test/results/no-such-id")
assert resp.status_code == 404
def test_delete_result(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "delete test", "profile": "A",
})
result_id = resp.json()["result_id"]
del_resp = client.delete(f"/api/v1/test/results/{result_id}")
assert del_resp.status_code == 200
get_resp = client.get(f"/api/v1/test/results/{result_id}")
assert get_resp.status_code == 404

View File

@ -0,0 +1,232 @@
"""Phase 9 tests: Results storage service CRUD operations (Sub-Phase 9.1).
Covers:
- save_result writes JSON file and returns result_id
- load_result reads and parses JSON file
- list_results returns list of result metadata
- delete_result removes file
- Nonexistent result loading returns None
- save_evaluation / load_evaluation / list_evaluations / delete_evaluation
- Empty storage dirs don't error
"""
import json
import os
from pathlib import Path
import pytest
from app.models.testing import (
GenerateResult,
InputInfo,
TimingInfo,
RetrievalResult,
FilteredResult,
ResponseResult,
EvaluationResult,
EvaluationTiming,
)
@pytest.fixture
def storage_dirs(tmp_path):
results_dir = tmp_path / "test_results"
evals_dir = tmp_path / "test_evaluations"
results_dir.mkdir()
evals_dir.mkdir()
return str(results_dir), str(evals_dir)
@pytest.fixture
def sample_result():
return GenerateResult(
result_id="test-001",
input_type="text",
profile="A",
label="sample test",
input=InputInfo(text="sample question"),
extracted_key_questions=["key q1"],
retrieval=RetrievalResult(
per_sub_question=[],
total_chunks_retrieved=0,
retriever_time_ms=100,
),
filtered=FilteredResult(
per_sub_question=[],
total_chunks_filtered=0,
filter_time_ms=100,
),
response=ResponseResult(
final_answer="answer",
sub_question_sources=[],
generate_time_ms=100,
),
timing=TimingInfo(
decomposer_time_ms=100,
retriever_time_ms=100,
filter_time_ms=100,
generator_time_ms=100,
total_time_ms=400,
),
)
class TestResultStorage:
def test_save_and_load(self, storage_dirs, sample_result, monkeypatch):
results_dir, _ = storage_dirs
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
path = svc.save_result(sample_result)
assert os.path.exists(path)
loaded = svc.load_result(sample_result.result_id)
assert loaded is not None
assert loaded.result_id == sample_result.result_id
assert loaded.input_type == "text"
assert loaded.profile == "A"
def test_load_nonexistent(self, storage_dirs):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
result = svc.load_result("nonexistent-id")
assert result is None
def test_list_results(self, storage_dirs, sample_result, monkeypatch):
results_dir, _ = storage_dirs
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
svc.save_result(sample_result)
items = svc.list_results()
assert len(items) >= 1
assert any(r["result_id"] == "test-001" for r in items)
def test_list_results_with_limit_offset(self, storage_dirs, sample_result):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
for i in range(5):
r = sample_result.model_copy(update={"result_id": f"test-{i:03d}"})
svc.save_result(r)
items = svc.list_results(limit=2, offset=1)
assert len(items) == 2
def test_delete_result(self, storage_dirs, sample_result):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
svc.save_result(sample_result)
filepath = os.path.join(results_dir, f"{sample_result.result_id}.json")
assert os.path.exists(filepath)
result = svc.delete_result(sample_result.result_id)
assert result is True
assert not os.path.exists(filepath)
def test_delete_nonexistent(self, storage_dirs):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
result = svc.delete_result("no-such-id")
assert result is False
def test_creates_dir_if_missing(self, storage_dirs):
results_dir, evals_dir = storage_dirs
new_results = os.path.join(results_dir, "auto_created")
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(new_results, evals_dir)
assert os.path.isdir(new_results)
def test_list_empty_dir(self, storage_dirs):
results_dir, _ = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, results_dir)
items = svc.list_results()
assert items == []
class TestEvaluationStorage:
def test_save_and_load_eval(self, storage_dirs):
results_dir, evals_dir = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, evals_dir)
eval_result = EvaluationResult(
evaluation_id="eval-001",
result_id="result-001",
status="completed",
timing=EvaluationTiming(
audio_evaluation_time_ms=10,
key_questions_evaluation_time_ms=100,
chunk_evaluation_time_ms=200,
response_evaluation_time_ms=300,
total_evaluation_time_ms=610,
),
)
path = svc.save_evaluation(eval_result)
assert os.path.exists(path)
loaded = svc.load_evaluation("eval-001")
assert loaded is not None
assert loaded.evaluation_id == "eval-001"
assert loaded.status == "completed"
def test_list_evaluations(self, storage_dirs):
results_dir, evals_dir = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, evals_dir)
eval_result = EvaluationResult(
evaluation_id="eval-001",
result_id="result-001",
status="completed",
timing=EvaluationTiming(
audio_evaluation_time_ms=10,
key_questions_evaluation_time_ms=100,
chunk_evaluation_time_ms=200,
response_evaluation_time_ms=300,
total_evaluation_time_ms=610,
),
)
svc.save_evaluation(eval_result)
items = svc.list_evaluations()
assert len(items) >= 1
def test_delete_evaluation(self, storage_dirs):
results_dir, evals_dir = storage_dirs
from app.services.test_storage_service import TestStorageService
svc = TestStorageService(results_dir, evals_dir)
eval_result = EvaluationResult(
evaluation_id="eval-002",
result_id="result-002",
status="completed",
timing=EvaluationTiming(
audio_evaluation_time_ms=10,
key_questions_evaluation_time_ms=100,
chunk_evaluation_time_ms=200,
response_evaluation_time_ms=300,
total_evaluation_time_ms=610,
),
)
svc.save_evaluation(eval_result)
filepath = os.path.join(evals_dir, "eval-002.json")
assert os.path.exists(filepath)
result = svc.delete_evaluation("eval-002")
assert result is True
assert not os.path.exists(filepath)