legco_ai_assistant/backend/app/services/test_runner_service.py

221 lines
7.6 KiB
Python

import logging
import uuid
from typing import Optional
from app.core.config import Settings
from app.models.testing import (
ChunkEntry,
FilteredResult,
GenerateResult,
InputInfo,
ResponseResult,
RetrievalResult,
SubQuestionChunks,
TimingInfo,
)
from app.services.asr_client import ASRClient
from app.services.llm_client import LLMClient
from app.services.llm_client_dp import LLMClientDP
from app.services.prompt_service import PromptService
from app.services.query_decomposer import QueryDecomposer
from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot
from app.services.rag import RAGService
from app.services.relevance_filter import RelevanceFilter
logger = logging.getLogger(__name__)
class TestRunnerService:
"""Runs the full RAG pipeline and captures all intermediate data for accuracy testing."""
def __init__(self, settings: Settings):
self.settings = settings
async def run_text_test(
self,
question: str,
profile: str,
prompt_service: PromptService,
label: str = "",
) -> GenerateResult:
result_id = uuid.uuid4().hex[:12]
prompt_service.activate_profile(profile)
active_profile = prompt_service.get_active_profile_name()
llm_client_dp = LLMClientDP(self.settings)
llm_client = LLMClient(self.settings)
rag = RAGService(
llm_client=llm_client,
settings=self.settings,
prompt_service=prompt_service,
)
decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service)
relevance_filter = RelevanceFilter(
llm_client, prompt_service=prompt_service
)
pipeline = RAGPipeline(
decomposer=decomposer,
rag=rag,
relevance_filter=relevance_filter,
retrieval_n_results=self.settings.retrieval_n_results,
relevance_threshold=self.settings.relevance_threshold,
)
# Collect all snapshots — use the last "completed" one for the final result
decomposed_snap = None
retrieval_snap = None
filtering_snap = None
completed_snap = None
async for snap in pipeline.execute(question):
if snap.phase == "decomposed":
decomposed_snap = snap
elif snap.phase == "retrieving":
retrieval_snap = snap
elif snap.phase == "filtering":
filtering_snap = snap
elif snap.phase == "completed":
completed_snap = snap
if completed_snap is None:
raise RuntimeError("Pipeline did not produce a completed snapshot")
# Build retrieval result
retrieval_per_subq = []
if retrieval_snap and retrieval_snap.retrieval_results:
for sq_idx, (sub_q_text, chunks) in enumerate(
retrieval_snap.retrieval_results
):
retrieval_per_subq.append(
SubQuestionChunks(
sub_question_index=sq_idx,
sub_question_text=sub_q_text,
chunks=[
ChunkEntry(
chunk_index=i,
text=text,
metadata=meta,
distance=distance,
)
for i, (text, meta, distance) in enumerate(chunks)
],
)
)
# Build filtered result
filtered_per_subq = []
if filtering_snap and filtering_snap.filtered_by_subq:
for sq_idx, (sub_q_text, chunks) in enumerate(
filtering_snap.filtered_by_subq
):
filtered_per_subq.append(
SubQuestionChunks(
sub_question_index=sq_idx,
sub_question_text=sub_q_text,
chunks=[
ChunkEntry(
chunk_index=i,
text=text,
metadata=meta,
distance=0.0,
)
for i, (text, meta) in enumerate(chunks)
],
)
)
# Build response result — serialize through dicts to convert between
# query.py SubQuestionSources and testing.py SubQuestionSources (same fields)
response_result = ResponseResult(
final_answer=completed_snap.answer,
sub_question_sources=[
{
"sub_question_index": sq.sub_question_index,
"sub_question_text": sq.sub_question_text,
"sources": [s.model_dump() for s in sq.sources],
}
for sq in completed_snap.sub_question_sources
],
generate_time_ms=completed_snap.generator_time_ms,
)
return GenerateResult(
result_id=result_id,
input_type="text",
profile=active_profile,
label=label,
input=InputInfo(text=question),
extracted_key_questions=completed_snap.extracted_questions,
retrieval=RetrievalResult(
per_sub_question=retrieval_per_subq,
total_chunks_retrieved=(
retrieval_snap.chunks_retrieved_count if retrieval_snap else 0
),
retriever_time_ms=(
retrieval_snap.retriever_time_ms if retrieval_snap else 0
),
),
filtered=FilteredResult(
per_sub_question=filtered_per_subq,
total_chunks_filtered=(
filtering_snap.chunks_filtered_count if filtering_snap else 0
),
filter_time_ms=(
filtering_snap.filter_time_ms if filtering_snap else 0
),
),
response=response_result,
timing=TimingInfo(
decomposer_time_ms=completed_snap.decomposer_time_ms,
retriever_time_ms=(
retrieval_snap.retriever_time_ms if retrieval_snap else 0
),
filter_time_ms=(
filtering_snap.filter_time_ms if filtering_snap else 0
),
generator_time_ms=completed_snap.generator_time_ms,
total_time_ms=completed_snap.total_time_ms,
),
)
async def run_audio_test(
self,
audio_bytes: bytes,
reference_transcript: str,
profile: str,
prompt_service: PromptService,
language: str = "yue",
label: str = "",
audio_filename: str = "",
) -> GenerateResult:
result_id = uuid.uuid4().hex[:12]
# Run ASR
asr_client = ASRClient(self.settings)
transcribed_text = await asr_client.transcribe_full(
audio_bytes, language=language
)
# Run text test on transcribed text
result = await self.run_text_test(
question=transcribed_text,
profile=profile,
prompt_service=prompt_service,
label=label,
)
# Override with audio-specific fields
result.result_id = result_id
result.input_type = "audio"
result.input = InputInfo(
text=transcribed_text,
reference_transcript=reference_transcript,
audio_filename=audio_filename,
audio_duration_seconds=0.0,
asr_language=language,
)
return result