diff --git a/backend/app/main.py b/backend/app/main.py index d111747..fecebac 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr +from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate from app.core.config import get_settings from app.core.sqlite_db import ( get_prompts_db, @@ -58,6 +58,7 @@ app.include_router(history.router) app.include_router(chunks.router) app.include_router(video.router, prefix="/api/v1") app.include_router(ws_asr.router) +app.include_router(test_generate.router, prefix="/api/v1") _prompts_conn = get_prompts_db() init_prompts_db(_prompts_conn) diff --git a/backend/app/routers/test_generate.py b/backend/app/routers/test_generate.py new file mode 100644 index 0000000..94e698a --- /dev/null +++ b/backend/app/routers/test_generate.py @@ -0,0 +1,104 @@ +import io +import logging + +from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile + +from app.core.config import get_settings +from app.models.testing import GenerateTextRequest +from app.services.prompt_service import PromptService +from app.services.test_runner_service import TestRunnerService +from app.services.test_storage_service import TestStorageService + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["test"]) + + +def _get_prompt_service() -> PromptService: + settings = get_settings() + return PromptService(db_path=settings.prompts_db_path) + + +def _get_storage_service() -> TestStorageService: + settings = get_settings() + return TestStorageService( + results_dir=settings.test_results_dir, + evaluations_dir=settings.test_evaluations_dir, + ) + + +@router.post("/test/generate/text") +async def generate_text(request: GenerateTextRequest): + settings = get_settings() + prompt_service = _get_prompt_service() + + runner = TestRunnerService(settings) + result = await runner.run_text_test( + question=request.question, + profile=request.profile, + prompt_service=prompt_service, + label=request.label, + ) + + storage = _get_storage_service() + storage.save_result(result) + + return result.model_dump() + + +@router.post("/test/generate/audio") +async def generate_audio( + audio_file: UploadFile = File(...), + profile: str = Form(...), + reference_transcript: str = Form(""), + label: str = Form(""), + language: str = Form("yue"), +): + if profile not in ("A", "B", "C"): + raise HTTPException(status_code=400, detail="profile must be A, B, or C") + + settings = get_settings() + prompt_service = _get_prompt_service() + + audio_bytes = await audio_file.read() + if not audio_bytes: + raise HTTPException(status_code=400, detail="Audio file is empty") + + runner = TestRunnerService(settings) + result = await runner.run_audio_test( + audio_bytes=audio_bytes, + reference_transcript=reference_transcript, + profile=profile, + prompt_service=prompt_service, + language=language, + label=label, + audio_filename=audio_file.filename or "unknown", + ) + + storage = _get_storage_service() + storage.save_result(result) + + return result.model_dump() + + +@router.get("/test/results") +async def list_results(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)): + storage = _get_storage_service() + return storage.list_results(limit=limit, offset=offset) + + +@router.get("/test/results/{result_id}") +async def get_result(result_id: str): + storage = _get_storage_service() + result = storage.load_result(result_id) + if result is None: + raise HTTPException(status_code=404, detail="Result not found") + return result.model_dump() + + +@router.delete("/test/results/{result_id}") +async def delete_result(result_id: str): + storage = _get_storage_service() + deleted = storage.delete_result(result_id) + if not deleted: + raise HTTPException(status_code=404, detail="Result not found") + return {"status": "deleted", "result_id": result_id} diff --git a/backend/app/services/rag_pipeline.py b/backend/app/services/rag_pipeline.py new file mode 100644 index 0000000..f988ecf --- /dev/null +++ b/backend/app/services/rag_pipeline.py @@ -0,0 +1,289 @@ +import logging +import time +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple + +from app.models.common import SourceMetadata +from app.models.query import SubQuestionSources +from app.services.query_decomposer import QueryDecomposer +from app.services.relevance_filter import RelevanceFilter +from app.services.rag import RAGService + +logger = logging.getLogger(__name__) + +NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question." + + +@dataclass +class PipelineSnapshot: + """Complete snapshot at a given pipeline stage for accuracy testing capture.""" + + phase: str + question: str = "" + + # Stage 1: Decompose + extracted_questions: List[str] = field(default_factory=list) + decompose_prompt: str = "" + decomposer_time_ms: int = 0 + + # Stage 2: Retrieve + retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field( + default_factory=list + ) + chunks_retrieved_count: int = 0 + retriever_time_ms: int = 0 + + # Stage 3: Filter + filter_prompt: str = "" + filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field( + default_factory=list + ) + chunks_filtered_count: int = 0 + filter_time_ms: int = 0 + + # Stage 4: Generate + generate_prompt: str = "" + answer: str = "" + sub_question_sources: List[SubQuestionSources] = field(default_factory=list) + generator_time_ms: int = 0 + + # Metadata + total_time_ms: int = 0 + error_message: str = "" + + +class RAGPipeline: + """Reusable RAG pipeline: decompose → retrieve → filter → generate. + + Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling. + Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints + (collect snapshots into GenerateResult). + + Usage: + pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...) + async for snap in pipeline.execute("question text"): + if snap.phase == "completed": + print(snap.answer) + """ + + def __init__( + self, + *, + decomposer: QueryDecomposer, + rag: RAGService, + relevance_filter: RelevanceFilter, + retrieval_n_results: int = 10, + relevance_threshold: float = 7.0, + ): + self._decomposer = decomposer + self._rag = rag + self._relevance_filter = relevance_filter + self._retrieval_n_results = retrieval_n_results + self._relevance_threshold = relevance_threshold + + async def execute( + self, + question: str, + stop_after_decompose: bool = False, + ) -> AsyncGenerator[PipelineSnapshot, None]: + """Execute the full pipeline, yielding one snapshot per stage.""" + overall_start = time.perf_counter() + + # --- Stage 1: Decompose --- + stage_start = time.perf_counter() + decompose_result = await self._decomposer.decompose(question) + if isinstance(decompose_result, tuple): + extracted_questions, decompose_prompt = decompose_result + else: + extracted_questions, decompose_prompt = decompose_result, "" + + decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000) + + if not extracted_questions: + extracted_questions = [question] + + yield PipelineSnapshot( + phase="decomposed", + question=question, + extracted_questions=extracted_questions, + decompose_prompt=decompose_prompt, + decomposer_time_ms=decomposer_time_ms, + ) + + if stop_after_decompose: + total_ms = int((time.perf_counter() - overall_start) * 1000) + yield PipelineSnapshot( + phase="completed", + question=question, + extracted_questions=extracted_questions, + decompose_prompt=decompose_prompt, + decomposer_time_ms=decomposer_time_ms, + total_time_ms=total_ms, + ) + return + + # --- Stage 2: Retrieve --- + stage_start = time.perf_counter() + retrieval_results = ( + self._rag.retrieve_per_subquestion( + extracted_questions, n_results=self._retrieval_n_results + ) + if extracted_questions + else [] + ) + retriever_time_ms = int((time.perf_counter() - stage_start) * 1000) + + chunks_retrieved_count = sum( + len(chunks) for _, chunks in retrieval_results + ) + + yield PipelineSnapshot( + phase="retrieving", + question=question, + extracted_questions=extracted_questions, + decompose_prompt=decompose_prompt, + decomposer_time_ms=decomposer_time_ms, + retrieval_results=retrieval_results, + chunks_retrieved_count=chunks_retrieved_count, + retriever_time_ms=retriever_time_ms, + ) + + if not any(chunks for _, chunks in retrieval_results): + total_ms = int((time.perf_counter() - overall_start) * 1000) + yield PipelineSnapshot( + phase="completed", + question=question, + extracted_questions=extracted_questions, + decompose_prompt=decompose_prompt, + decomposer_time_ms=decomposer_time_ms, + chunks_retrieved_count=0, + retriever_time_ms=retriever_time_ms, + answer=NO_RESULTS_ANSWER, + total_time_ms=total_ms, + ) + return + + # --- Stage 3: Filter --- + stage_start = time.perf_counter() + chunks_by_subq = [ + [(text, meta) for text, meta, _dist in chunks] + for _, chunks in retrieval_results + ] + + if extracted_questions and chunks_by_subq: + filter_result = await self._relevance_filter.filter_per_subquestion( + extracted_questions, chunks_by_subq, threshold=self._relevance_threshold + ) + else: + filter_result = ([], "") + + if isinstance(filter_result, tuple): + filtered_by_subq, filter_prompt = filter_result + else: + filtered_by_subq, filter_prompt = filter_result, "" + + filter_time_ms = int((time.perf_counter() - stage_start) * 1000) + chunks_filtered_count = sum( + len(chunks) for _, chunks in filtered_by_subq + ) + + yield PipelineSnapshot( + phase="filtering", + question=question, + extracted_questions=extracted_questions, + decompose_prompt=decompose_prompt, + decomposer_time_ms=decomposer_time_ms, + retrieval_results=retrieval_results, + chunks_retrieved_count=chunks_retrieved_count, + retriever_time_ms=retriever_time_ms, + filter_prompt=filter_prompt, + filtered_by_subq=filtered_by_subq, + chunks_filtered_count=chunks_filtered_count, + filter_time_ms=filter_time_ms, + ) + + if not filtered_by_subq or not any( + chunks for _, chunks in filtered_by_subq + ): + total_ms = int((time.perf_counter() - overall_start) * 1000) + yield PipelineSnapshot( + phase="completed", + question=question, + extracted_questions=extracted_questions, + decomposer_time_ms=decomposer_time_ms, + retriever_time_ms=retriever_time_ms, + filter_time_ms=filter_time_ms, + answer=NO_RESULTS_ANSWER, + total_time_ms=total_ms, + ) + return + + # --- Stage 4: Generate --- + stage_start = time.perf_counter() + sub_chunk_texts = [] + sub_chunk_metadata = [] + for _, filtered_chunks in filtered_by_subq: + sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks]) + sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks]) + + if extracted_questions and filtered_by_subq: + gen_result = await self._rag.generate_response_per_subquestion( + extracted_questions, sub_chunk_texts, sub_chunk_metadata + ) + else: + gen_result = ("", "", []) + + if isinstance(gen_result, tuple) and len(gen_result) == 3: + answer, generate_prompt, grouped_sources_meta = gen_result + else: + answer, generate_prompt = ( + gen_result if isinstance(gen_result, tuple) else (gen_result, "") + ) + grouped_sources_meta = [] + + sub_question_sources = [] + for idx, (sub_q_text, sources_meta) in enumerate( + zip(extracted_questions, grouped_sources_meta) + ): + sources = [ + SourceMetadata( + filename=meta.get("filename", "unknown"), + upload_date=meta.get("upload_date", ""), + content_summary=meta.get("content_summary", ""), + chunk_index=meta.get("chunk_index", 0), + page_number=meta.get("page_number"), + chunk_file_path=meta.get("chunk_file_path"), + document_id=meta.get("document_id"), + ) + for meta in sources_meta + ] + sub_question_sources.append( + SubQuestionSources( + sub_question_index=idx, + sub_question_text=sub_q_text, + sources=sources, + ) + ) + + generator_time_ms = int((time.perf_counter() - stage_start) * 1000) + total_time_ms = int((time.perf_counter() - overall_start) * 1000) + + yield PipelineSnapshot( + phase="completed", + question=question, + extracted_questions=extracted_questions, + decompose_prompt=decompose_prompt, + decomposer_time_ms=decomposer_time_ms, + retrieval_results=retrieval_results, + chunks_retrieved_count=chunks_retrieved_count, + retriever_time_ms=retriever_time_ms, + filter_prompt=filter_prompt, + filtered_by_subq=filtered_by_subq, + chunks_filtered_count=chunks_filtered_count, + filter_time_ms=filter_time_ms, + generate_prompt=generate_prompt, + answer=answer, + sub_question_sources=sub_question_sources, + generator_time_ms=generator_time_ms, + total_time_ms=total_time_ms, + ) diff --git a/backend/app/services/test_runner_service.py b/backend/app/services/test_runner_service.py new file mode 100644 index 0000000..35a66a8 --- /dev/null +++ b/backend/app/services/test_runner_service.py @@ -0,0 +1,220 @@ +import logging +import uuid +from typing import Optional + +from app.core.config import Settings +from app.models.testing import ( + ChunkEntry, + FilteredResult, + GenerateResult, + InputInfo, + ResponseResult, + RetrievalResult, + SubQuestionChunks, + TimingInfo, +) +from app.services.asr_client import ASRClient +from app.services.llm_client import LLMClient +from app.services.llm_client_dp import LLMClientDP +from app.services.prompt_service import PromptService +from app.services.query_decomposer import QueryDecomposer +from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot +from app.services.rag import RAGService +from app.services.relevance_filter import RelevanceFilter + +logger = logging.getLogger(__name__) + + +class TestRunnerService: + """Runs the full RAG pipeline and captures all intermediate data for accuracy testing.""" + + def __init__(self, settings: Settings): + self.settings = settings + + async def run_text_test( + self, + question: str, + profile: str, + prompt_service: PromptService, + label: str = "", + ) -> GenerateResult: + result_id = uuid.uuid4().hex[:12] + + prompt_service.activate_profile(profile) + active_profile = prompt_service.get_active_profile_name() + + llm_client_dp = LLMClientDP(self.settings) + llm_client = LLMClient(self.settings) + rag = RAGService( + llm_client=llm_client, + settings=self.settings, + prompt_service=prompt_service, + ) + decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service) + relevance_filter = RelevanceFilter( + llm_client, prompt_service=prompt_service + ) + + pipeline = RAGPipeline( + decomposer=decomposer, + rag=rag, + relevance_filter=relevance_filter, + retrieval_n_results=self.settings.retrieval_n_results, + relevance_threshold=self.settings.relevance_threshold, + ) + + # Collect all snapshots — use the last "completed" one for the final result + decomposed_snap = None + retrieval_snap = None + filtering_snap = None + completed_snap = None + + async for snap in pipeline.execute(question): + if snap.phase == "decomposed": + decomposed_snap = snap + elif snap.phase == "retrieving": + retrieval_snap = snap + elif snap.phase == "filtering": + filtering_snap = snap + elif snap.phase == "completed": + completed_snap = snap + + if completed_snap is None: + raise RuntimeError("Pipeline did not produce a completed snapshot") + + # Build retrieval result + retrieval_per_subq = [] + if retrieval_snap and retrieval_snap.retrieval_results: + for sq_idx, (sub_q_text, chunks) in enumerate( + retrieval_snap.retrieval_results + ): + retrieval_per_subq.append( + SubQuestionChunks( + sub_question_index=sq_idx, + sub_question_text=sub_q_text, + chunks=[ + ChunkEntry( + chunk_index=i, + text=text, + metadata=meta, + distance=distance, + ) + for i, (text, meta, distance) in enumerate(chunks) + ], + ) + ) + + # Build filtered result + filtered_per_subq = [] + if filtering_snap and filtering_snap.filtered_by_subq: + for sq_idx, (sub_q_text, chunks) in enumerate( + filtering_snap.filtered_by_subq + ): + filtered_per_subq.append( + SubQuestionChunks( + sub_question_index=sq_idx, + sub_question_text=sub_q_text, + chunks=[ + ChunkEntry( + chunk_index=i, + text=text, + metadata=meta, + distance=0.0, + ) + for i, (text, meta) in enumerate(chunks) + ], + ) + ) + + # Build response result — serialize through dicts to convert between + # query.py SubQuestionSources and testing.py SubQuestionSources (same fields) + response_result = ResponseResult( + final_answer=completed_snap.answer, + sub_question_sources=[ + { + "sub_question_index": sq.sub_question_index, + "sub_question_text": sq.sub_question_text, + "sources": [s.model_dump() for s in sq.sources], + } + for sq in completed_snap.sub_question_sources + ], + generate_time_ms=completed_snap.generator_time_ms, + ) + + return GenerateResult( + result_id=result_id, + input_type="text", + profile=active_profile, + label=label, + input=InputInfo(text=question), + extracted_key_questions=completed_snap.extracted_questions, + retrieval=RetrievalResult( + per_sub_question=retrieval_per_subq, + total_chunks_retrieved=( + retrieval_snap.chunks_retrieved_count if retrieval_snap else 0 + ), + retriever_time_ms=( + retrieval_snap.retriever_time_ms if retrieval_snap else 0 + ), + ), + filtered=FilteredResult( + per_sub_question=filtered_per_subq, + total_chunks_filtered=( + filtering_snap.chunks_filtered_count if filtering_snap else 0 + ), + filter_time_ms=( + filtering_snap.filter_time_ms if filtering_snap else 0 + ), + ), + response=response_result, + timing=TimingInfo( + decomposer_time_ms=completed_snap.decomposer_time_ms, + retriever_time_ms=( + retrieval_snap.retriever_time_ms if retrieval_snap else 0 + ), + filter_time_ms=( + filtering_snap.filter_time_ms if filtering_snap else 0 + ), + generator_time_ms=completed_snap.generator_time_ms, + total_time_ms=completed_snap.total_time_ms, + ), + ) + + async def run_audio_test( + self, + audio_bytes: bytes, + reference_transcript: str, + profile: str, + prompt_service: PromptService, + language: str = "yue", + label: str = "", + audio_filename: str = "", + ) -> GenerateResult: + result_id = uuid.uuid4().hex[:12] + + # Run ASR + asr_client = ASRClient(self.settings) + transcribed_text = await asr_client.transcribe_full( + audio_bytes, language=language + ) + + # Run text test on transcribed text + result = await self.run_text_test( + question=transcribed_text, + profile=profile, + prompt_service=prompt_service, + label=label, + ) + + # Override with audio-specific fields + result.result_id = result_id + result.input_type = "audio" + result.input = InputInfo( + text=transcribed_text, + reference_transcript=reference_transcript, + audio_filename=audio_filename, + audio_duration_seconds=0.0, + asr_language=language, + ) + + return result diff --git a/backend/app/services/test_storage_service.py b/backend/app/services/test_storage_service.py new file mode 100644 index 0000000..3bb5db6 --- /dev/null +++ b/backend/app/services/test_storage_service.py @@ -0,0 +1,101 @@ +import json +import logging +import os +from pathlib import Path +from typing import List, Optional + +from app.models.testing import GenerateResult, EvaluationResult + +logger = logging.getLogger(__name__) + + +class TestStorageService: + def __init__(self, results_dir: str, evaluations_dir: str): + self.results_dir = results_dir + self.evaluations_dir = evaluations_dir + Path(results_dir).mkdir(parents=True, exist_ok=True) + Path(evaluations_dir).mkdir(parents=True, exist_ok=True) + + # --- Results --- + + def save_result(self, result: GenerateResult) -> str: + filepath = os.path.join(self.results_dir, f"{result.result_id}.json") + with open(filepath, "w", encoding="utf-8") as f: + f.write(result.model_dump_json(indent=2)) + logger.info("Saved test result: %s", filepath) + return filepath + + def load_result(self, result_id: str) -> Optional[GenerateResult]: + filepath = os.path.join(self.results_dir, f"{result_id}.json") + if not os.path.isfile(filepath): + return None + with open(filepath, "r", encoding="utf-8") as f: + data = json.load(f) + return GenerateResult.model_validate(data) + + def list_results(self, limit: int = 50, offset: int = 0) -> List[dict]: + items = [] + try: + for entry in sorted( + Path(self.results_dir).iterdir(), + key=lambda p: p.stat().st_mtime, + reverse=True, + ): + if entry.suffix == ".json": + stat = entry.stat() + items.append({ + "result_id": entry.stem, + "file_size_bytes": stat.st_size, + }) + except FileNotFoundError: + return [] + return items[offset : offset + limit] + + def delete_result(self, result_id: str) -> bool: + filepath = os.path.join(self.results_dir, f"{result_id}.json") + if not os.path.isfile(filepath): + return False + os.remove(filepath) + return True + + # --- Evaluations --- + + def save_evaluation(self, evaluation: EvaluationResult) -> str: + filepath = os.path.join(self.evaluations_dir, f"{evaluation.evaluation_id}.json") + with open(filepath, "w", encoding="utf-8") as f: + f.write(evaluation.model_dump_json(indent=2)) + logger.info("Saved evaluation: %s", filepath) + return filepath + + def load_evaluation(self, eval_id: str) -> Optional[EvaluationResult]: + filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json") + if not os.path.isfile(filepath): + return None + with open(filepath, "r", encoding="utf-8") as f: + data = json.load(f) + return EvaluationResult.model_validate(data) + + def list_evaluations(self, limit: int = 50, offset: int = 0) -> List[dict]: + items = [] + try: + for entry in sorted( + Path(self.evaluations_dir).iterdir(), + key=lambda p: p.stat().st_mtime, + reverse=True, + ): + if entry.suffix == ".json": + stat = entry.stat() + items.append({ + "evaluation_id": entry.stem, + "file_size_bytes": stat.st_size, + }) + except FileNotFoundError: + return [] + return items[offset : offset + limit] + + def delete_evaluation(self, eval_id: str) -> bool: + filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json") + if not os.path.isfile(filepath): + return False + os.remove(filepath) + return True diff --git a/backend/app/test/test_phase9_generate_text.py b/backend/app/test/test_phase9_generate_text.py new file mode 100644 index 0000000..be090a0 --- /dev/null +++ b/backend/app/test/test_phase9_generate_text.py @@ -0,0 +1,173 @@ +"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1). + +Covers: +- POST /api/v1/test/generate/text with valid request returns 200 +- Response includes all pipeline stages (retrieval, filtered, response, timing) +- Invalid profile returns validation error +- Empty question returns validation error +- GET /api/v1/test/results lists saved results +- GET /api/v1/test/results/{id} retrieves specific result +- DELETE /api/v1/test/results/{id} deletes result +""" +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app.routers.test_generate import router + + +@pytest.fixture +def client(tmp_path, monkeypatch): + results_dir = str(tmp_path / "test_results") + evals_dir = str(tmp_path / "test_evaluations") + prompts_path = str(tmp_path / "prompts.db") + history_path = str(tmp_path / "history.db") + + monkeypatch.setenv("TEST_RESULTS_DIR", results_dir) + monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir) + monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path) + monkeypatch.setenv("HISTORY_DB_PATH", history_path) + monkeypatch.setenv("LLM_API_KEY", "test-key") + monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1") + monkeypatch.setenv("LLM_MODEL_NAME", "test-model") + monkeypatch.setenv("DP_API_KEY", "test-dp-key") + monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding") + + from app.core.config import get_settings + get_settings.cache_clear() + + from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles + conn = _get_db(prompts_path) + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + hconn = _get_db(history_path) + init_history_db(hconn) + hconn.close() + + test_app = FastAPI() + test_app.include_router(router, prefix="/api/v1") + yield TestClient(test_app) + + get_settings.cache_clear() + + +@pytest.fixture +def mock_pipeline(monkeypatch): + """Mock LLM and RAG to avoid real API calls.""" + + async def _mock_decompose(self, question): + return (["sub question 1", "sub question 2"], "mocked decompose prompt") + + def _mock_retrieve(self, sub_questions, n_results=10): + return [ + (sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01", + "content_summary": "test chunk", "chunk_index": 0, + "page_number": 1, "document_id": "doc-1"}, 0.15)]) + for sq in sub_questions + ] + + async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0): + result = [ + (sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks]) + for sq, chunks in zip(sub_questions, chunks_by_subq) + ] + return (result, "mocked filter prompt") + + async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata): + answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content" + grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions] + return (answer, "mocked generate prompt", grouped_sources) + + monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose) + monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve) + monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter) + monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate) + + +@pytest.mark.usefixtures("mock_pipeline") +class TestGenerateTextEndpoint: + def test_valid_request_returns_200(self, client): + resp = client.post("/api/v1/test/generate/text", json={ + "question": "test question", + "profile": "A", + "label": "my label", + }) + assert resp.status_code == 200 + data = resp.json() + assert data["input_type"] == "text" + assert data["profile"] == "A" + assert data["label"] == "my label" + assert "result_id" in data + + def test_result_contains_all_stages(self, client): + resp = client.post("/api/v1/test/generate/text", json={ + "question": "test", + "profile": "B", + }) + assert resp.status_code == 200 + data = resp.json() + assert len(data["extracted_key_questions"]) == 2 + assert data["retrieval"]["total_chunks_retrieved"] > 0 + assert data["filtered"]["total_chunks_filtered"] > 0 + assert len(data["response"]["final_answer"]) > 0 + assert data["timing"]["total_time_ms"] >= 0 + + def test_invalid_profile_rejected(self, client): + resp = client.post("/api/v1/test/generate/text", json={ + "question": "test", + "profile": "D", + }) + assert resp.status_code == 422 + + def test_empty_question_rejected(self, client): + resp = client.post("/api/v1/test/generate/text", json={ + "question": "", + "profile": "A", + }) + assert resp.status_code == 422 + + def test_result_saved_and_retrievable(self, client): + resp = client.post("/api/v1/test/generate/text", json={ + "question": "save test", + "profile": "A", + }) + assert resp.status_code == 200 + result_id = resp.json()["result_id"] + + get_resp = client.get(f"/api/v1/test/results/{result_id}") + assert get_resp.status_code == 200 + assert get_resp.json()["result_id"] == result_id + + def test_list_results(self, client): + client.post("/api/v1/test/generate/text", json={ + "question": "list test 1", "profile": "A", + }) + client.post("/api/v1/test/generate/text", json={ + "question": "list test 2", "profile": "B", + }) + + resp = client.get("/api/v1/test/results?limit=10") + assert resp.status_code == 200 + data = resp.json() + assert len(data) >= 2 + + def test_get_nonexistent_result(self, client): + resp = client.get("/api/v1/test/results/no-such-id") + assert resp.status_code == 404 + + def test_delete_result(self, client): + resp = client.post("/api/v1/test/generate/text", json={ + "question": "delete test", "profile": "A", + }) + result_id = resp.json()["result_id"] + + del_resp = client.delete(f"/api/v1/test/results/{result_id}") + assert del_resp.status_code == 200 + + get_resp = client.get(f"/api/v1/test/results/{result_id}") + assert get_resp.status_code == 404 diff --git a/backend/app/test/test_phase9_results_storage.py b/backend/app/test/test_phase9_results_storage.py new file mode 100644 index 0000000..83f2dcd --- /dev/null +++ b/backend/app/test/test_phase9_results_storage.py @@ -0,0 +1,232 @@ +"""Phase 9 tests: Results storage service CRUD operations (Sub-Phase 9.1). + +Covers: +- save_result writes JSON file and returns result_id +- load_result reads and parses JSON file +- list_results returns list of result metadata +- delete_result removes file +- Nonexistent result loading returns None +- save_evaluation / load_evaluation / list_evaluations / delete_evaluation +- Empty storage dirs don't error +""" +import json +import os +from pathlib import Path + +import pytest + +from app.models.testing import ( + GenerateResult, + InputInfo, + TimingInfo, + RetrievalResult, + FilteredResult, + ResponseResult, + EvaluationResult, + EvaluationTiming, +) + + +@pytest.fixture +def storage_dirs(tmp_path): + results_dir = tmp_path / "test_results" + evals_dir = tmp_path / "test_evaluations" + results_dir.mkdir() + evals_dir.mkdir() + return str(results_dir), str(evals_dir) + + +@pytest.fixture +def sample_result(): + return GenerateResult( + result_id="test-001", + input_type="text", + profile="A", + label="sample test", + input=InputInfo(text="sample question"), + extracted_key_questions=["key q1"], + retrieval=RetrievalResult( + per_sub_question=[], + total_chunks_retrieved=0, + retriever_time_ms=100, + ), + filtered=FilteredResult( + per_sub_question=[], + total_chunks_filtered=0, + filter_time_ms=100, + ), + response=ResponseResult( + final_answer="answer", + sub_question_sources=[], + generate_time_ms=100, + ), + timing=TimingInfo( + decomposer_time_ms=100, + retriever_time_ms=100, + filter_time_ms=100, + generator_time_ms=100, + total_time_ms=400, + ), + ) + + +class TestResultStorage: + def test_save_and_load(self, storage_dirs, sample_result, monkeypatch): + results_dir, _ = storage_dirs + monkeypatch.setenv("TEST_RESULTS_DIR", results_dir) + + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + path = svc.save_result(sample_result) + assert os.path.exists(path) + + loaded = svc.load_result(sample_result.result_id) + assert loaded is not None + assert loaded.result_id == sample_result.result_id + assert loaded.input_type == "text" + assert loaded.profile == "A" + + def test_load_nonexistent(self, storage_dirs): + results_dir, _ = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + result = svc.load_result("nonexistent-id") + assert result is None + + def test_list_results(self, storage_dirs, sample_result, monkeypatch): + results_dir, _ = storage_dirs + monkeypatch.setenv("TEST_RESULTS_DIR", results_dir) + + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + svc.save_result(sample_result) + + items = svc.list_results() + assert len(items) >= 1 + assert any(r["result_id"] == "test-001" for r in items) + + def test_list_results_with_limit_offset(self, storage_dirs, sample_result): + results_dir, _ = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + for i in range(5): + r = sample_result.model_copy(update={"result_id": f"test-{i:03d}"}) + svc.save_result(r) + + items = svc.list_results(limit=2, offset=1) + assert len(items) == 2 + + def test_delete_result(self, storage_dirs, sample_result): + results_dir, _ = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + svc.save_result(sample_result) + filepath = os.path.join(results_dir, f"{sample_result.result_id}.json") + assert os.path.exists(filepath) + + result = svc.delete_result(sample_result.result_id) + assert result is True + assert not os.path.exists(filepath) + + def test_delete_nonexistent(self, storage_dirs): + results_dir, _ = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + result = svc.delete_result("no-such-id") + assert result is False + + def test_creates_dir_if_missing(self, storage_dirs): + results_dir, evals_dir = storage_dirs + new_results = os.path.join(results_dir, "auto_created") + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(new_results, evals_dir) + assert os.path.isdir(new_results) + + def test_list_empty_dir(self, storage_dirs): + results_dir, _ = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, results_dir) + + items = svc.list_results() + assert items == [] + + +class TestEvaluationStorage: + def test_save_and_load_eval(self, storage_dirs): + results_dir, evals_dir = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, evals_dir) + + eval_result = EvaluationResult( + evaluation_id="eval-001", + result_id="result-001", + status="completed", + timing=EvaluationTiming( + audio_evaluation_time_ms=10, + key_questions_evaluation_time_ms=100, + chunk_evaluation_time_ms=200, + response_evaluation_time_ms=300, + total_evaluation_time_ms=610, + ), + ) + + path = svc.save_evaluation(eval_result) + assert os.path.exists(path) + + loaded = svc.load_evaluation("eval-001") + assert loaded is not None + assert loaded.evaluation_id == "eval-001" + assert loaded.status == "completed" + + def test_list_evaluations(self, storage_dirs): + results_dir, evals_dir = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, evals_dir) + + eval_result = EvaluationResult( + evaluation_id="eval-001", + result_id="result-001", + status="completed", + timing=EvaluationTiming( + audio_evaluation_time_ms=10, + key_questions_evaluation_time_ms=100, + chunk_evaluation_time_ms=200, + response_evaluation_time_ms=300, + total_evaluation_time_ms=610, + ), + ) + svc.save_evaluation(eval_result) + + items = svc.list_evaluations() + assert len(items) >= 1 + + def test_delete_evaluation(self, storage_dirs): + results_dir, evals_dir = storage_dirs + from app.services.test_storage_service import TestStorageService + svc = TestStorageService(results_dir, evals_dir) + + eval_result = EvaluationResult( + evaluation_id="eval-002", + result_id="result-002", + status="completed", + timing=EvaluationTiming( + audio_evaluation_time_ms=10, + key_questions_evaluation_time_ms=100, + chunk_evaluation_time_ms=200, + response_evaluation_time_ms=300, + total_evaluation_time_ms=610, + ), + ) + svc.save_evaluation(eval_result) + filepath = os.path.join(evals_dir, "eval-002.json") + assert os.path.exists(filepath) + + result = svc.delete_evaluation("eval-002") + assert result is True + assert not os.path.exists(filepath)