import logging import uuid from typing import Optional from app.core.config import Settings from app.models.testing import ( ChunkEntry, FilteredResult, GenerateResult, InputInfo, ResponseResult, RetrievalResult, SubQuestionChunks, TimingInfo, ) from app.services.asr_client import ASRClient from app.services.llm_client import LLMClient from app.services.llm_client_dp import LLMClientDP from app.services.prompt_service import PromptService from app.services.query_decomposer import QueryDecomposer from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot from app.services.rag import RAGService from app.services.relevance_filter import RelevanceFilter logger = logging.getLogger(__name__) class TestRunnerService: """Runs the full RAG pipeline and captures all intermediate data for accuracy testing.""" def __init__(self, settings: Settings): self.settings = settings async def run_text_test( self, question: str, profile: str, prompt_service: PromptService, label: str = "", ) -> GenerateResult: result_id = uuid.uuid4().hex[:12] prompt_service.activate_profile(profile) active_profile = prompt_service.get_active_profile_name() llm_client_dp = LLMClientDP(self.settings) llm_client = LLMClient(self.settings) rag = RAGService( llm_client=llm_client, settings=self.settings, prompt_service=prompt_service, ) decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service) relevance_filter = RelevanceFilter( llm_client, prompt_service=prompt_service ) pipeline = RAGPipeline( decomposer=decomposer, rag=rag, relevance_filter=relevance_filter, retrieval_n_results=self.settings.retrieval_n_results, relevance_threshold=self.settings.relevance_threshold, ) # Collect all snapshots — use the last "completed" one for the final result decomposed_snap = None retrieval_snap = None filtering_snap = None completed_snap = None async for snap in pipeline.execute(question): if snap.phase == "decomposed": decomposed_snap = snap elif snap.phase == "retrieving": retrieval_snap = snap elif snap.phase == "filtering": filtering_snap = snap elif snap.phase == "completed": completed_snap = snap if completed_snap is None: raise RuntimeError("Pipeline did not produce a completed snapshot") # Build retrieval result retrieval_per_subq = [] if retrieval_snap and retrieval_snap.retrieval_results: for sq_idx, (sub_q_text, chunks) in enumerate( retrieval_snap.retrieval_results ): retrieval_per_subq.append( SubQuestionChunks( sub_question_index=sq_idx, sub_question_text=sub_q_text, chunks=[ ChunkEntry( chunk_index=i, text=text, metadata=meta, distance=distance, ) for i, (text, meta, distance) in enumerate(chunks) ], ) ) # Build filtered result filtered_per_subq = [] if filtering_snap and filtering_snap.filtered_by_subq: for sq_idx, (sub_q_text, chunks) in enumerate( filtering_snap.filtered_by_subq ): filtered_per_subq.append( SubQuestionChunks( sub_question_index=sq_idx, sub_question_text=sub_q_text, chunks=[ ChunkEntry( chunk_index=i, text=text, metadata=meta, distance=0.0, ) for i, (text, meta) in enumerate(chunks) ], ) ) # Build response result — serialize through dicts to convert between # query.py SubQuestionSources and testing.py SubQuestionSources (same fields) response_result = ResponseResult( final_answer=completed_snap.answer, sub_question_sources=[ { "sub_question_index": sq.sub_question_index, "sub_question_text": sq.sub_question_text, "sources": [s.model_dump() for s in sq.sources], } for sq in completed_snap.sub_question_sources ], generate_time_ms=completed_snap.generator_time_ms, ) return GenerateResult( result_id=result_id, input_type="text", profile=active_profile, label=label, input=InputInfo(text=question), extracted_key_questions=completed_snap.extracted_questions, retrieval=RetrievalResult( per_sub_question=retrieval_per_subq, total_chunks_retrieved=( retrieval_snap.chunks_retrieved_count if retrieval_snap else 0 ), retriever_time_ms=( retrieval_snap.retriever_time_ms if retrieval_snap else 0 ), ), filtered=FilteredResult( per_sub_question=filtered_per_subq, total_chunks_filtered=( filtering_snap.chunks_filtered_count if filtering_snap else 0 ), filter_time_ms=( filtering_snap.filter_time_ms if filtering_snap else 0 ), ), response=response_result, timing=TimingInfo( decomposer_time_ms=completed_snap.decomposer_time_ms, retriever_time_ms=( retrieval_snap.retriever_time_ms if retrieval_snap else 0 ), filter_time_ms=( filtering_snap.filter_time_ms if filtering_snap else 0 ), generator_time_ms=completed_snap.generator_time_ms, total_time_ms=completed_snap.total_time_ms, ), ) async def run_audio_test( self, audio_bytes: bytes, reference_transcript: str, profile: str, prompt_service: PromptService, language: str = "yue", label: str = "", audio_filename: str = "", ) -> GenerateResult: result_id = uuid.uuid4().hex[:12] # Run ASR asr_client = ASRClient(self.settings) transcribed_text = await asr_client.transcribe_full( audio_bytes, language=language ) # Run text test on transcribed text result = await self.run_text_test( question=transcribed_text, profile=profile, prompt_service=prompt_service, label=label, ) # Override with audio-specific fields result.result_id = result_id result.input_type = "audio" result.input = InputInfo( text=transcribed_text, reference_transcript=reference_transcript, audio_filename=audio_filename, audio_duration_seconds=0.0, asr_language=language, ) return result