feat: add Sub-Phase 9.1 results generation APIs with reusable RAGPipeline
This commit is contained in:
parent
852430f1f1
commit
ac81df0704
|
|
@ -7,7 +7,7 @@ from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr
|
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.core.sqlite_db import (
|
from app.core.sqlite_db import (
|
||||||
get_prompts_db,
|
get_prompts_db,
|
||||||
|
|
@ -58,6 +58,7 @@ app.include_router(history.router)
|
||||||
app.include_router(chunks.router)
|
app.include_router(chunks.router)
|
||||||
app.include_router(video.router, prefix="/api/v1")
|
app.include_router(video.router, prefix="/api/v1")
|
||||||
app.include_router(ws_asr.router)
|
app.include_router(ws_asr.router)
|
||||||
|
app.include_router(test_generate.router, prefix="/api/v1")
|
||||||
|
|
||||||
_prompts_conn = get_prompts_db()
|
_prompts_conn = get_prompts_db()
|
||||||
init_prompts_db(_prompts_conn)
|
init_prompts_db(_prompts_conn)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,104 @@
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.models.testing import GenerateTextRequest
|
||||||
|
from app.services.prompt_service import PromptService
|
||||||
|
from app.services.test_runner_service import TestRunnerService
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["test"])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prompt_service() -> PromptService:
|
||||||
|
settings = get_settings()
|
||||||
|
return PromptService(db_path=settings.prompts_db_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_storage_service() -> TestStorageService:
|
||||||
|
settings = get_settings()
|
||||||
|
return TestStorageService(
|
||||||
|
results_dir=settings.test_results_dir,
|
||||||
|
evaluations_dir=settings.test_evaluations_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test/generate/text")
|
||||||
|
async def generate_text(request: GenerateTextRequest):
|
||||||
|
settings = get_settings()
|
||||||
|
prompt_service = _get_prompt_service()
|
||||||
|
|
||||||
|
runner = TestRunnerService(settings)
|
||||||
|
result = await runner.run_text_test(
|
||||||
|
question=request.question,
|
||||||
|
profile=request.profile,
|
||||||
|
prompt_service=prompt_service,
|
||||||
|
label=request.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = _get_storage_service()
|
||||||
|
storage.save_result(result)
|
||||||
|
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test/generate/audio")
|
||||||
|
async def generate_audio(
|
||||||
|
audio_file: UploadFile = File(...),
|
||||||
|
profile: str = Form(...),
|
||||||
|
reference_transcript: str = Form(""),
|
||||||
|
label: str = Form(""),
|
||||||
|
language: str = Form("yue"),
|
||||||
|
):
|
||||||
|
if profile not in ("A", "B", "C"):
|
||||||
|
raise HTTPException(status_code=400, detail="profile must be A, B, or C")
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
prompt_service = _get_prompt_service()
|
||||||
|
|
||||||
|
audio_bytes = await audio_file.read()
|
||||||
|
if not audio_bytes:
|
||||||
|
raise HTTPException(status_code=400, detail="Audio file is empty")
|
||||||
|
|
||||||
|
runner = TestRunnerService(settings)
|
||||||
|
result = await runner.run_audio_test(
|
||||||
|
audio_bytes=audio_bytes,
|
||||||
|
reference_transcript=reference_transcript,
|
||||||
|
profile=profile,
|
||||||
|
prompt_service=prompt_service,
|
||||||
|
language=language,
|
||||||
|
label=label,
|
||||||
|
audio_filename=audio_file.filename or "unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = _get_storage_service()
|
||||||
|
storage.save_result(result)
|
||||||
|
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test/results")
|
||||||
|
async def list_results(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
|
||||||
|
storage = _get_storage_service()
|
||||||
|
return storage.list_results(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test/results/{result_id}")
|
||||||
|
async def get_result(result_id: str):
|
||||||
|
storage = _get_storage_service()
|
||||||
|
result = storage.load_result(result_id)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Result not found")
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/test/results/{result_id}")
|
||||||
|
async def delete_result(result_id: str):
|
||||||
|
storage = _get_storage_service()
|
||||||
|
deleted = storage.delete_result(result_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Result not found")
|
||||||
|
return {"status": "deleted", "result_id": result_id}
|
||||||
|
|
@ -0,0 +1,289 @@
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from app.models.common import SourceMetadata
|
||||||
|
from app.models.query import SubQuestionSources
|
||||||
|
from app.services.query_decomposer import QueryDecomposer
|
||||||
|
from app.services.relevance_filter import RelevanceFilter
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineSnapshot:
|
||||||
|
"""Complete snapshot at a given pipeline stage for accuracy testing capture."""
|
||||||
|
|
||||||
|
phase: str
|
||||||
|
question: str = ""
|
||||||
|
|
||||||
|
# Stage 1: Decompose
|
||||||
|
extracted_questions: List[str] = field(default_factory=list)
|
||||||
|
decompose_prompt: str = ""
|
||||||
|
decomposer_time_ms: int = 0
|
||||||
|
|
||||||
|
# Stage 2: Retrieve
|
||||||
|
retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
chunks_retrieved_count: int = 0
|
||||||
|
retriever_time_ms: int = 0
|
||||||
|
|
||||||
|
# Stage 3: Filter
|
||||||
|
filter_prompt: str = ""
|
||||||
|
filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
chunks_filtered_count: int = 0
|
||||||
|
filter_time_ms: int = 0
|
||||||
|
|
||||||
|
# Stage 4: Generate
|
||||||
|
generate_prompt: str = ""
|
||||||
|
answer: str = ""
|
||||||
|
sub_question_sources: List[SubQuestionSources] = field(default_factory=list)
|
||||||
|
generator_time_ms: int = 0
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
total_time_ms: int = 0
|
||||||
|
error_message: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RAGPipeline:
|
||||||
|
"""Reusable RAG pipeline: decompose → retrieve → filter → generate.
|
||||||
|
|
||||||
|
Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling.
|
||||||
|
Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints
|
||||||
|
(collect snapshots into GenerateResult).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...)
|
||||||
|
async for snap in pipeline.execute("question text"):
|
||||||
|
if snap.phase == "completed":
|
||||||
|
print(snap.answer)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
decomposer: QueryDecomposer,
|
||||||
|
rag: RAGService,
|
||||||
|
relevance_filter: RelevanceFilter,
|
||||||
|
retrieval_n_results: int = 10,
|
||||||
|
relevance_threshold: float = 7.0,
|
||||||
|
):
|
||||||
|
self._decomposer = decomposer
|
||||||
|
self._rag = rag
|
||||||
|
self._relevance_filter = relevance_filter
|
||||||
|
self._retrieval_n_results = retrieval_n_results
|
||||||
|
self._relevance_threshold = relevance_threshold
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
stop_after_decompose: bool = False,
|
||||||
|
) -> AsyncGenerator[PipelineSnapshot, None]:
|
||||||
|
"""Execute the full pipeline, yielding one snapshot per stage."""
|
||||||
|
overall_start = time.perf_counter()
|
||||||
|
|
||||||
|
# --- Stage 1: Decompose ---
|
||||||
|
stage_start = time.perf_counter()
|
||||||
|
decompose_result = await self._decomposer.decompose(question)
|
||||||
|
if isinstance(decompose_result, tuple):
|
||||||
|
extracted_questions, decompose_prompt = decompose_result
|
||||||
|
else:
|
||||||
|
extracted_questions, decompose_prompt = decompose_result, ""
|
||||||
|
|
||||||
|
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
|
|
||||||
|
if not extracted_questions:
|
||||||
|
extracted_questions = [question]
|
||||||
|
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="decomposed",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decompose_prompt=decompose_prompt,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop_after_decompose:
|
||||||
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="completed",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decompose_prompt=decompose_prompt,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
total_time_ms=total_ms,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Stage 2: Retrieve ---
|
||||||
|
stage_start = time.perf_counter()
|
||||||
|
retrieval_results = (
|
||||||
|
self._rag.retrieve_per_subquestion(
|
||||||
|
extracted_questions, n_results=self._retrieval_n_results
|
||||||
|
)
|
||||||
|
if extracted_questions
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
|
|
||||||
|
chunks_retrieved_count = sum(
|
||||||
|
len(chunks) for _, chunks in retrieval_results
|
||||||
|
)
|
||||||
|
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="retrieving",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decompose_prompt=decompose_prompt,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
retrieval_results=retrieval_results,
|
||||||
|
chunks_retrieved_count=chunks_retrieved_count,
|
||||||
|
retriever_time_ms=retriever_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not any(chunks for _, chunks in retrieval_results):
|
||||||
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="completed",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decompose_prompt=decompose_prompt,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
chunks_retrieved_count=0,
|
||||||
|
retriever_time_ms=retriever_time_ms,
|
||||||
|
answer=NO_RESULTS_ANSWER,
|
||||||
|
total_time_ms=total_ms,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Stage 3: Filter ---
|
||||||
|
stage_start = time.perf_counter()
|
||||||
|
chunks_by_subq = [
|
||||||
|
[(text, meta) for text, meta, _dist in chunks]
|
||||||
|
for _, chunks in retrieval_results
|
||||||
|
]
|
||||||
|
|
||||||
|
if extracted_questions and chunks_by_subq:
|
||||||
|
filter_result = await self._relevance_filter.filter_per_subquestion(
|
||||||
|
extracted_questions, chunks_by_subq, threshold=self._relevance_threshold
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filter_result = ([], "")
|
||||||
|
|
||||||
|
if isinstance(filter_result, tuple):
|
||||||
|
filtered_by_subq, filter_prompt = filter_result
|
||||||
|
else:
|
||||||
|
filtered_by_subq, filter_prompt = filter_result, ""
|
||||||
|
|
||||||
|
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
|
chunks_filtered_count = sum(
|
||||||
|
len(chunks) for _, chunks in filtered_by_subq
|
||||||
|
)
|
||||||
|
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="filtering",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decompose_prompt=decompose_prompt,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
retrieval_results=retrieval_results,
|
||||||
|
chunks_retrieved_count=chunks_retrieved_count,
|
||||||
|
retriever_time_ms=retriever_time_ms,
|
||||||
|
filter_prompt=filter_prompt,
|
||||||
|
filtered_by_subq=filtered_by_subq,
|
||||||
|
chunks_filtered_count=chunks_filtered_count,
|
||||||
|
filter_time_ms=filter_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not filtered_by_subq or not any(
|
||||||
|
chunks for _, chunks in filtered_by_subq
|
||||||
|
):
|
||||||
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="completed",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
retriever_time_ms=retriever_time_ms,
|
||||||
|
filter_time_ms=filter_time_ms,
|
||||||
|
answer=NO_RESULTS_ANSWER,
|
||||||
|
total_time_ms=total_ms,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Stage 4: Generate ---
|
||||||
|
stage_start = time.perf_counter()
|
||||||
|
sub_chunk_texts = []
|
||||||
|
sub_chunk_metadata = []
|
||||||
|
for _, filtered_chunks in filtered_by_subq:
|
||||||
|
sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks])
|
||||||
|
sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks])
|
||||||
|
|
||||||
|
if extracted_questions and filtered_by_subq:
|
||||||
|
gen_result = await self._rag.generate_response_per_subquestion(
|
||||||
|
extracted_questions, sub_chunk_texts, sub_chunk_metadata
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gen_result = ("", "", [])
|
||||||
|
|
||||||
|
if isinstance(gen_result, tuple) and len(gen_result) == 3:
|
||||||
|
answer, generate_prompt, grouped_sources_meta = gen_result
|
||||||
|
else:
|
||||||
|
answer, generate_prompt = (
|
||||||
|
gen_result if isinstance(gen_result, tuple) else (gen_result, "")
|
||||||
|
)
|
||||||
|
grouped_sources_meta = []
|
||||||
|
|
||||||
|
sub_question_sources = []
|
||||||
|
for idx, (sub_q_text, sources_meta) in enumerate(
|
||||||
|
zip(extracted_questions, grouped_sources_meta)
|
||||||
|
):
|
||||||
|
sources = [
|
||||||
|
SourceMetadata(
|
||||||
|
filename=meta.get("filename", "unknown"),
|
||||||
|
upload_date=meta.get("upload_date", ""),
|
||||||
|
content_summary=meta.get("content_summary", ""),
|
||||||
|
chunk_index=meta.get("chunk_index", 0),
|
||||||
|
page_number=meta.get("page_number"),
|
||||||
|
chunk_file_path=meta.get("chunk_file_path"),
|
||||||
|
document_id=meta.get("document_id"),
|
||||||
|
)
|
||||||
|
for meta in sources_meta
|
||||||
|
]
|
||||||
|
sub_question_sources.append(
|
||||||
|
SubQuestionSources(
|
||||||
|
sub_question_index=idx,
|
||||||
|
sub_question_text=sub_q_text,
|
||||||
|
sources=sources,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
|
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
|
||||||
|
|
||||||
|
yield PipelineSnapshot(
|
||||||
|
phase="completed",
|
||||||
|
question=question,
|
||||||
|
extracted_questions=extracted_questions,
|
||||||
|
decompose_prompt=decompose_prompt,
|
||||||
|
decomposer_time_ms=decomposer_time_ms,
|
||||||
|
retrieval_results=retrieval_results,
|
||||||
|
chunks_retrieved_count=chunks_retrieved_count,
|
||||||
|
retriever_time_ms=retriever_time_ms,
|
||||||
|
filter_prompt=filter_prompt,
|
||||||
|
filtered_by_subq=filtered_by_subq,
|
||||||
|
chunks_filtered_count=chunks_filtered_count,
|
||||||
|
filter_time_ms=filter_time_ms,
|
||||||
|
generate_prompt=generate_prompt,
|
||||||
|
answer=answer,
|
||||||
|
sub_question_sources=sub_question_sources,
|
||||||
|
generator_time_ms=generator_time_ms,
|
||||||
|
total_time_ms=total_time_ms,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,220 @@
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.core.config import Settings
|
||||||
|
from app.models.testing import (
|
||||||
|
ChunkEntry,
|
||||||
|
FilteredResult,
|
||||||
|
GenerateResult,
|
||||||
|
InputInfo,
|
||||||
|
ResponseResult,
|
||||||
|
RetrievalResult,
|
||||||
|
SubQuestionChunks,
|
||||||
|
TimingInfo,
|
||||||
|
)
|
||||||
|
from app.services.asr_client import ASRClient
|
||||||
|
from app.services.llm_client import LLMClient
|
||||||
|
from app.services.llm_client_dp import LLMClientDP
|
||||||
|
from app.services.prompt_service import PromptService
|
||||||
|
from app.services.query_decomposer import QueryDecomposer
|
||||||
|
from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
from app.services.relevance_filter import RelevanceFilter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunnerService:
|
||||||
|
"""Runs the full RAG pipeline and captures all intermediate data for accuracy testing."""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings):
|
||||||
|
self.settings = settings
|
||||||
|
|
||||||
|
async def run_text_test(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
profile: str,
|
||||||
|
prompt_service: PromptService,
|
||||||
|
label: str = "",
|
||||||
|
) -> GenerateResult:
|
||||||
|
result_id = uuid.uuid4().hex[:12]
|
||||||
|
|
||||||
|
prompt_service.activate_profile(profile)
|
||||||
|
active_profile = prompt_service.get_active_profile_name()
|
||||||
|
|
||||||
|
llm_client_dp = LLMClientDP(self.settings)
|
||||||
|
llm_client = LLMClient(self.settings)
|
||||||
|
rag = RAGService(
|
||||||
|
llm_client=llm_client,
|
||||||
|
settings=self.settings,
|
||||||
|
prompt_service=prompt_service,
|
||||||
|
)
|
||||||
|
decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service)
|
||||||
|
relevance_filter = RelevanceFilter(
|
||||||
|
llm_client, prompt_service=prompt_service
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = RAGPipeline(
|
||||||
|
decomposer=decomposer,
|
||||||
|
rag=rag,
|
||||||
|
relevance_filter=relevance_filter,
|
||||||
|
retrieval_n_results=self.settings.retrieval_n_results,
|
||||||
|
relevance_threshold=self.settings.relevance_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect all snapshots — use the last "completed" one for the final result
|
||||||
|
decomposed_snap = None
|
||||||
|
retrieval_snap = None
|
||||||
|
filtering_snap = None
|
||||||
|
completed_snap = None
|
||||||
|
|
||||||
|
async for snap in pipeline.execute(question):
|
||||||
|
if snap.phase == "decomposed":
|
||||||
|
decomposed_snap = snap
|
||||||
|
elif snap.phase == "retrieving":
|
||||||
|
retrieval_snap = snap
|
||||||
|
elif snap.phase == "filtering":
|
||||||
|
filtering_snap = snap
|
||||||
|
elif snap.phase == "completed":
|
||||||
|
completed_snap = snap
|
||||||
|
|
||||||
|
if completed_snap is None:
|
||||||
|
raise RuntimeError("Pipeline did not produce a completed snapshot")
|
||||||
|
|
||||||
|
# Build retrieval result
|
||||||
|
retrieval_per_subq = []
|
||||||
|
if retrieval_snap and retrieval_snap.retrieval_results:
|
||||||
|
for sq_idx, (sub_q_text, chunks) in enumerate(
|
||||||
|
retrieval_snap.retrieval_results
|
||||||
|
):
|
||||||
|
retrieval_per_subq.append(
|
||||||
|
SubQuestionChunks(
|
||||||
|
sub_question_index=sq_idx,
|
||||||
|
sub_question_text=sub_q_text,
|
||||||
|
chunks=[
|
||||||
|
ChunkEntry(
|
||||||
|
chunk_index=i,
|
||||||
|
text=text,
|
||||||
|
metadata=meta,
|
||||||
|
distance=distance,
|
||||||
|
)
|
||||||
|
for i, (text, meta, distance) in enumerate(chunks)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build filtered result
|
||||||
|
filtered_per_subq = []
|
||||||
|
if filtering_snap and filtering_snap.filtered_by_subq:
|
||||||
|
for sq_idx, (sub_q_text, chunks) in enumerate(
|
||||||
|
filtering_snap.filtered_by_subq
|
||||||
|
):
|
||||||
|
filtered_per_subq.append(
|
||||||
|
SubQuestionChunks(
|
||||||
|
sub_question_index=sq_idx,
|
||||||
|
sub_question_text=sub_q_text,
|
||||||
|
chunks=[
|
||||||
|
ChunkEntry(
|
||||||
|
chunk_index=i,
|
||||||
|
text=text,
|
||||||
|
metadata=meta,
|
||||||
|
distance=0.0,
|
||||||
|
)
|
||||||
|
for i, (text, meta) in enumerate(chunks)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build response result — serialize through dicts to convert between
|
||||||
|
# query.py SubQuestionSources and testing.py SubQuestionSources (same fields)
|
||||||
|
response_result = ResponseResult(
|
||||||
|
final_answer=completed_snap.answer,
|
||||||
|
sub_question_sources=[
|
||||||
|
{
|
||||||
|
"sub_question_index": sq.sub_question_index,
|
||||||
|
"sub_question_text": sq.sub_question_text,
|
||||||
|
"sources": [s.model_dump() for s in sq.sources],
|
||||||
|
}
|
||||||
|
for sq in completed_snap.sub_question_sources
|
||||||
|
],
|
||||||
|
generate_time_ms=completed_snap.generator_time_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerateResult(
|
||||||
|
result_id=result_id,
|
||||||
|
input_type="text",
|
||||||
|
profile=active_profile,
|
||||||
|
label=label,
|
||||||
|
input=InputInfo(text=question),
|
||||||
|
extracted_key_questions=completed_snap.extracted_questions,
|
||||||
|
retrieval=RetrievalResult(
|
||||||
|
per_sub_question=retrieval_per_subq,
|
||||||
|
total_chunks_retrieved=(
|
||||||
|
retrieval_snap.chunks_retrieved_count if retrieval_snap else 0
|
||||||
|
),
|
||||||
|
retriever_time_ms=(
|
||||||
|
retrieval_snap.retriever_time_ms if retrieval_snap else 0
|
||||||
|
),
|
||||||
|
),
|
||||||
|
filtered=FilteredResult(
|
||||||
|
per_sub_question=filtered_per_subq,
|
||||||
|
total_chunks_filtered=(
|
||||||
|
filtering_snap.chunks_filtered_count if filtering_snap else 0
|
||||||
|
),
|
||||||
|
filter_time_ms=(
|
||||||
|
filtering_snap.filter_time_ms if filtering_snap else 0
|
||||||
|
),
|
||||||
|
),
|
||||||
|
response=response_result,
|
||||||
|
timing=TimingInfo(
|
||||||
|
decomposer_time_ms=completed_snap.decomposer_time_ms,
|
||||||
|
retriever_time_ms=(
|
||||||
|
retrieval_snap.retriever_time_ms if retrieval_snap else 0
|
||||||
|
),
|
||||||
|
filter_time_ms=(
|
||||||
|
filtering_snap.filter_time_ms if filtering_snap else 0
|
||||||
|
),
|
||||||
|
generator_time_ms=completed_snap.generator_time_ms,
|
||||||
|
total_time_ms=completed_snap.total_time_ms,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_audio_test(
|
||||||
|
self,
|
||||||
|
audio_bytes: bytes,
|
||||||
|
reference_transcript: str,
|
||||||
|
profile: str,
|
||||||
|
prompt_service: PromptService,
|
||||||
|
language: str = "yue",
|
||||||
|
label: str = "",
|
||||||
|
audio_filename: str = "",
|
||||||
|
) -> GenerateResult:
|
||||||
|
result_id = uuid.uuid4().hex[:12]
|
||||||
|
|
||||||
|
# Run ASR
|
||||||
|
asr_client = ASRClient(self.settings)
|
||||||
|
transcribed_text = await asr_client.transcribe_full(
|
||||||
|
audio_bytes, language=language
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run text test on transcribed text
|
||||||
|
result = await self.run_text_test(
|
||||||
|
question=transcribed_text,
|
||||||
|
profile=profile,
|
||||||
|
prompt_service=prompt_service,
|
||||||
|
label=label,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override with audio-specific fields
|
||||||
|
result.result_id = result_id
|
||||||
|
result.input_type = "audio"
|
||||||
|
result.input = InputInfo(
|
||||||
|
text=transcribed_text,
|
||||||
|
reference_transcript=reference_transcript,
|
||||||
|
audio_filename=audio_filename,
|
||||||
|
audio_duration_seconds=0.0,
|
||||||
|
asr_language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,101 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.models.testing import GenerateResult, EvaluationResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageService:
|
||||||
|
def __init__(self, results_dir: str, evaluations_dir: str):
|
||||||
|
self.results_dir = results_dir
|
||||||
|
self.evaluations_dir = evaluations_dir
|
||||||
|
Path(results_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(evaluations_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# --- Results ---
|
||||||
|
|
||||||
|
def save_result(self, result: GenerateResult) -> str:
|
||||||
|
filepath = os.path.join(self.results_dir, f"{result.result_id}.json")
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
f.write(result.model_dump_json(indent=2))
|
||||||
|
logger.info("Saved test result: %s", filepath)
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
def load_result(self, result_id: str) -> Optional[GenerateResult]:
|
||||||
|
filepath = os.path.join(self.results_dir, f"{result_id}.json")
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
|
return None
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return GenerateResult.model_validate(data)
|
||||||
|
|
||||||
|
def list_results(self, limit: int = 50, offset: int = 0) -> List[dict]:
|
||||||
|
items = []
|
||||||
|
try:
|
||||||
|
for entry in sorted(
|
||||||
|
Path(self.results_dir).iterdir(),
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
):
|
||||||
|
if entry.suffix == ".json":
|
||||||
|
stat = entry.stat()
|
||||||
|
items.append({
|
||||||
|
"result_id": entry.stem,
|
||||||
|
"file_size_bytes": stat.st_size,
|
||||||
|
})
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
return items[offset : offset + limit]
|
||||||
|
|
||||||
|
def delete_result(self, result_id: str) -> bool:
|
||||||
|
filepath = os.path.join(self.results_dir, f"{result_id}.json")
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
|
return False
|
||||||
|
os.remove(filepath)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# --- Evaluations ---
|
||||||
|
|
||||||
|
def save_evaluation(self, evaluation: EvaluationResult) -> str:
|
||||||
|
filepath = os.path.join(self.evaluations_dir, f"{evaluation.evaluation_id}.json")
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
f.write(evaluation.model_dump_json(indent=2))
|
||||||
|
logger.info("Saved evaluation: %s", filepath)
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
def load_evaluation(self, eval_id: str) -> Optional[EvaluationResult]:
|
||||||
|
filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json")
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
|
return None
|
||||||
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return EvaluationResult.model_validate(data)
|
||||||
|
|
||||||
|
def list_evaluations(self, limit: int = 50, offset: int = 0) -> List[dict]:
|
||||||
|
items = []
|
||||||
|
try:
|
||||||
|
for entry in sorted(
|
||||||
|
Path(self.evaluations_dir).iterdir(),
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
):
|
||||||
|
if entry.suffix == ".json":
|
||||||
|
stat = entry.stat()
|
||||||
|
items.append({
|
||||||
|
"evaluation_id": entry.stem,
|
||||||
|
"file_size_bytes": stat.st_size,
|
||||||
|
})
|
||||||
|
except FileNotFoundError:
|
||||||
|
return []
|
||||||
|
return items[offset : offset + limit]
|
||||||
|
|
||||||
|
def delete_evaluation(self, eval_id: str) -> bool:
|
||||||
|
filepath = os.path.join(self.evaluations_dir, f"{eval_id}.json")
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
|
return False
|
||||||
|
os.remove(filepath)
|
||||||
|
return True
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- POST /api/v1/test/generate/text with valid request returns 200
|
||||||
|
- Response includes all pipeline stages (retrieval, filtered, response, timing)
|
||||||
|
- Invalid profile returns validation error
|
||||||
|
- Empty question returns validation error
|
||||||
|
- GET /api/v1/test/results lists saved results
|
||||||
|
- GET /api/v1/test/results/{id} retrieves specific result
|
||||||
|
- DELETE /api/v1/test/results/{id} deletes result
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.routers.test_generate import router
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(tmp_path, monkeypatch):
|
||||||
|
results_dir = str(tmp_path / "test_results")
|
||||||
|
evals_dir = str(tmp_path / "test_evaluations")
|
||||||
|
prompts_path = str(tmp_path / "prompts.db")
|
||||||
|
history_path = str(tmp_path / "history.db")
|
||||||
|
|
||||||
|
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
|
||||||
|
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
|
||||||
|
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
||||||
|
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
||||||
|
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
||||||
|
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
|
||||||
|
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
|
||||||
|
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
|
||||||
|
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
|
||||||
|
conn = _get_db(prompts_path)
|
||||||
|
init_prompts_db(conn)
|
||||||
|
seed_default_profiles(conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
hconn = _get_db(history_path)
|
||||||
|
init_history_db(hconn)
|
||||||
|
hconn.close()
|
||||||
|
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(router, prefix="/api/v1")
|
||||||
|
yield TestClient(test_app)
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_pipeline(monkeypatch):
|
||||||
|
"""Mock LLM and RAG to avoid real API calls."""
|
||||||
|
|
||||||
|
async def _mock_decompose(self, question):
|
||||||
|
return (["sub question 1", "sub question 2"], "mocked decompose prompt")
|
||||||
|
|
||||||
|
def _mock_retrieve(self, sub_questions, n_results=10):
|
||||||
|
return [
|
||||||
|
(sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01",
|
||||||
|
"content_summary": "test chunk", "chunk_index": 0,
|
||||||
|
"page_number": 1, "document_id": "doc-1"}, 0.15)])
|
||||||
|
for sq in sub_questions
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0):
|
||||||
|
result = [
|
||||||
|
(sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks])
|
||||||
|
for sq, chunks in zip(sub_questions, chunks_by_subq)
|
||||||
|
]
|
||||||
|
return (result, "mocked filter prompt")
|
||||||
|
|
||||||
|
async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata):
|
||||||
|
answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content"
|
||||||
|
grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions]
|
||||||
|
return (answer, "mocked generate prompt", grouped_sources)
|
||||||
|
|
||||||
|
monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose)
|
||||||
|
monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve)
|
||||||
|
monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter)
|
||||||
|
monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_pipeline")
|
||||||
|
class TestGenerateTextEndpoint:
|
||||||
|
def test_valid_request_returns_200(self, client):
|
||||||
|
resp = client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "test question",
|
||||||
|
"profile": "A",
|
||||||
|
"label": "my label",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["input_type"] == "text"
|
||||||
|
assert data["profile"] == "A"
|
||||||
|
assert data["label"] == "my label"
|
||||||
|
assert "result_id" in data
|
||||||
|
|
||||||
|
def test_result_contains_all_stages(self, client):
|
||||||
|
resp = client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "test",
|
||||||
|
"profile": "B",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data["extracted_key_questions"]) == 2
|
||||||
|
assert data["retrieval"]["total_chunks_retrieved"] > 0
|
||||||
|
assert data["filtered"]["total_chunks_filtered"] > 0
|
||||||
|
assert len(data["response"]["final_answer"]) > 0
|
||||||
|
assert data["timing"]["total_time_ms"] >= 0
|
||||||
|
|
||||||
|
def test_invalid_profile_rejected(self, client):
|
||||||
|
resp = client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "test",
|
||||||
|
"profile": "D",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_empty_question_rejected(self, client):
|
||||||
|
resp = client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "",
|
||||||
|
"profile": "A",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_result_saved_and_retrievable(self, client):
|
||||||
|
resp = client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "save test",
|
||||||
|
"profile": "A",
|
||||||
|
})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
result_id = resp.json()["result_id"]
|
||||||
|
|
||||||
|
get_resp = client.get(f"/api/v1/test/results/{result_id}")
|
||||||
|
assert get_resp.status_code == 200
|
||||||
|
assert get_resp.json()["result_id"] == result_id
|
||||||
|
|
||||||
|
def test_list_results(self, client):
|
||||||
|
client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "list test 1", "profile": "A",
|
||||||
|
})
|
||||||
|
client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "list test 2", "profile": "B",
|
||||||
|
})
|
||||||
|
|
||||||
|
resp = client.get("/api/v1/test/results?limit=10")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) >= 2
|
||||||
|
|
||||||
|
def test_get_nonexistent_result(self, client):
|
||||||
|
resp = client.get("/api/v1/test/results/no-such-id")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_result(self, client):
|
||||||
|
resp = client.post("/api/v1/test/generate/text", json={
|
||||||
|
"question": "delete test", "profile": "A",
|
||||||
|
})
|
||||||
|
result_id = resp.json()["result_id"]
|
||||||
|
|
||||||
|
del_resp = client.delete(f"/api/v1/test/results/{result_id}")
|
||||||
|
assert del_resp.status_code == 200
|
||||||
|
|
||||||
|
get_resp = client.get(f"/api/v1/test/results/{result_id}")
|
||||||
|
assert get_resp.status_code == 404
|
||||||
|
|
@ -0,0 +1,232 @@
|
||||||
|
"""Phase 9 tests: Results storage service CRUD operations (Sub-Phase 9.1).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- save_result writes JSON file and returns result_id
|
||||||
|
- load_result reads and parses JSON file
|
||||||
|
- list_results returns list of result metadata
|
||||||
|
- delete_result removes file
|
||||||
|
- Nonexistent result loading returns None
|
||||||
|
- save_evaluation / load_evaluation / list_evaluations / delete_evaluation
|
||||||
|
- Empty storage dirs don't error
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.testing import (
|
||||||
|
GenerateResult,
|
||||||
|
InputInfo,
|
||||||
|
TimingInfo,
|
||||||
|
RetrievalResult,
|
||||||
|
FilteredResult,
|
||||||
|
ResponseResult,
|
||||||
|
EvaluationResult,
|
||||||
|
EvaluationTiming,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage_dirs(tmp_path):
|
||||||
|
results_dir = tmp_path / "test_results"
|
||||||
|
evals_dir = tmp_path / "test_evaluations"
|
||||||
|
results_dir.mkdir()
|
||||||
|
evals_dir.mkdir()
|
||||||
|
return str(results_dir), str(evals_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_result():
|
||||||
|
return GenerateResult(
|
||||||
|
result_id="test-001",
|
||||||
|
input_type="text",
|
||||||
|
profile="A",
|
||||||
|
label="sample test",
|
||||||
|
input=InputInfo(text="sample question"),
|
||||||
|
extracted_key_questions=["key q1"],
|
||||||
|
retrieval=RetrievalResult(
|
||||||
|
per_sub_question=[],
|
||||||
|
total_chunks_retrieved=0,
|
||||||
|
retriever_time_ms=100,
|
||||||
|
),
|
||||||
|
filtered=FilteredResult(
|
||||||
|
per_sub_question=[],
|
||||||
|
total_chunks_filtered=0,
|
||||||
|
filter_time_ms=100,
|
||||||
|
),
|
||||||
|
response=ResponseResult(
|
||||||
|
final_answer="answer",
|
||||||
|
sub_question_sources=[],
|
||||||
|
generate_time_ms=100,
|
||||||
|
),
|
||||||
|
timing=TimingInfo(
|
||||||
|
decomposer_time_ms=100,
|
||||||
|
retriever_time_ms=100,
|
||||||
|
filter_time_ms=100,
|
||||||
|
generator_time_ms=100,
|
||||||
|
total_time_ms=400,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResultStorage:
|
||||||
|
def test_save_and_load(self, storage_dirs, sample_result, monkeypatch):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
|
||||||
|
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
path = svc.save_result(sample_result)
|
||||||
|
assert os.path.exists(path)
|
||||||
|
|
||||||
|
loaded = svc.load_result(sample_result.result_id)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.result_id == sample_result.result_id
|
||||||
|
assert loaded.input_type == "text"
|
||||||
|
assert loaded.profile == "A"
|
||||||
|
|
||||||
|
def test_load_nonexistent(self, storage_dirs):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
result = svc.load_result("nonexistent-id")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_list_results(self, storage_dirs, sample_result, monkeypatch):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
|
||||||
|
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
svc.save_result(sample_result)
|
||||||
|
|
||||||
|
items = svc.list_results()
|
||||||
|
assert len(items) >= 1
|
||||||
|
assert any(r["result_id"] == "test-001" for r in items)
|
||||||
|
|
||||||
|
def test_list_results_with_limit_offset(self, storage_dirs, sample_result):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
r = sample_result.model_copy(update={"result_id": f"test-{i:03d}"})
|
||||||
|
svc.save_result(r)
|
||||||
|
|
||||||
|
items = svc.list_results(limit=2, offset=1)
|
||||||
|
assert len(items) == 2
|
||||||
|
|
||||||
|
def test_delete_result(self, storage_dirs, sample_result):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
svc.save_result(sample_result)
|
||||||
|
filepath = os.path.join(results_dir, f"{sample_result.result_id}.json")
|
||||||
|
assert os.path.exists(filepath)
|
||||||
|
|
||||||
|
result = svc.delete_result(sample_result.result_id)
|
||||||
|
assert result is True
|
||||||
|
assert not os.path.exists(filepath)
|
||||||
|
|
||||||
|
def test_delete_nonexistent(self, storage_dirs):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
result = svc.delete_result("no-such-id")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_creates_dir_if_missing(self, storage_dirs):
|
||||||
|
results_dir, evals_dir = storage_dirs
|
||||||
|
new_results = os.path.join(results_dir, "auto_created")
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(new_results, evals_dir)
|
||||||
|
assert os.path.isdir(new_results)
|
||||||
|
|
||||||
|
def test_list_empty_dir(self, storage_dirs):
|
||||||
|
results_dir, _ = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, results_dir)
|
||||||
|
|
||||||
|
items = svc.list_results()
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvaluationStorage:
|
||||||
|
def test_save_and_load_eval(self, storage_dirs):
|
||||||
|
results_dir, evals_dir = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, evals_dir)
|
||||||
|
|
||||||
|
eval_result = EvaluationResult(
|
||||||
|
evaluation_id="eval-001",
|
||||||
|
result_id="result-001",
|
||||||
|
status="completed",
|
||||||
|
timing=EvaluationTiming(
|
||||||
|
audio_evaluation_time_ms=10,
|
||||||
|
key_questions_evaluation_time_ms=100,
|
||||||
|
chunk_evaluation_time_ms=200,
|
||||||
|
response_evaluation_time_ms=300,
|
||||||
|
total_evaluation_time_ms=610,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
path = svc.save_evaluation(eval_result)
|
||||||
|
assert os.path.exists(path)
|
||||||
|
|
||||||
|
loaded = svc.load_evaluation("eval-001")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.evaluation_id == "eval-001"
|
||||||
|
assert loaded.status == "completed"
|
||||||
|
|
||||||
|
def test_list_evaluations(self, storage_dirs):
|
||||||
|
results_dir, evals_dir = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, evals_dir)
|
||||||
|
|
||||||
|
eval_result = EvaluationResult(
|
||||||
|
evaluation_id="eval-001",
|
||||||
|
result_id="result-001",
|
||||||
|
status="completed",
|
||||||
|
timing=EvaluationTiming(
|
||||||
|
audio_evaluation_time_ms=10,
|
||||||
|
key_questions_evaluation_time_ms=100,
|
||||||
|
chunk_evaluation_time_ms=200,
|
||||||
|
response_evaluation_time_ms=300,
|
||||||
|
total_evaluation_time_ms=610,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
svc.save_evaluation(eval_result)
|
||||||
|
|
||||||
|
items = svc.list_evaluations()
|
||||||
|
assert len(items) >= 1
|
||||||
|
|
||||||
|
def test_delete_evaluation(self, storage_dirs):
|
||||||
|
results_dir, evals_dir = storage_dirs
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
svc = TestStorageService(results_dir, evals_dir)
|
||||||
|
|
||||||
|
eval_result = EvaluationResult(
|
||||||
|
evaluation_id="eval-002",
|
||||||
|
result_id="result-002",
|
||||||
|
status="completed",
|
||||||
|
timing=EvaluationTiming(
|
||||||
|
audio_evaluation_time_ms=10,
|
||||||
|
key_questions_evaluation_time_ms=100,
|
||||||
|
chunk_evaluation_time_ms=200,
|
||||||
|
response_evaluation_time_ms=300,
|
||||||
|
total_evaluation_time_ms=610,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
svc.save_evaluation(eval_result)
|
||||||
|
filepath = os.path.join(evals_dir, "eval-002.json")
|
||||||
|
assert os.path.exists(filepath)
|
||||||
|
|
||||||
|
result = svc.delete_evaluation("eval-002")
|
||||||
|
assert result is True
|
||||||
|
assert not os.path.exists(filepath)
|
||||||
Loading…
Reference in New Issue