221 lines
7.6 KiB
Python
221 lines
7.6 KiB
Python
import logging
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from app.core.config import Settings
|
|
from app.models.testing import (
|
|
ChunkEntry,
|
|
FilteredResult,
|
|
GenerateResult,
|
|
InputInfo,
|
|
ResponseResult,
|
|
RetrievalResult,
|
|
SubQuestionChunks,
|
|
TimingInfo,
|
|
)
|
|
from app.services.asr_client import ASRClient
|
|
from app.services.llm_client import LLMClient
|
|
from app.services.llm_client_dp import LLMClientDP
|
|
from app.services.prompt_service import PromptService
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
from app.services.rag_pipeline import RAGPipeline, PipelineSnapshot
|
|
from app.services.rag import RAGService
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TestRunnerService:
|
|
"""Runs the full RAG pipeline and captures all intermediate data for accuracy testing."""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self.settings = settings
|
|
|
|
async def run_text_test(
|
|
self,
|
|
question: str,
|
|
profile: str,
|
|
prompt_service: PromptService,
|
|
label: str = "",
|
|
) -> GenerateResult:
|
|
result_id = uuid.uuid4().hex[:12]
|
|
|
|
prompt_service.activate_profile(profile)
|
|
active_profile = prompt_service.get_active_profile_name()
|
|
|
|
llm_client_dp = LLMClientDP(self.settings)
|
|
llm_client = LLMClient(self.settings)
|
|
rag = RAGService(
|
|
llm_client=llm_client,
|
|
settings=self.settings,
|
|
prompt_service=prompt_service,
|
|
)
|
|
decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service)
|
|
relevance_filter = RelevanceFilter(
|
|
llm_client, prompt_service=prompt_service
|
|
)
|
|
|
|
pipeline = RAGPipeline(
|
|
decomposer=decomposer,
|
|
rag=rag,
|
|
relevance_filter=relevance_filter,
|
|
retrieval_n_results=self.settings.retrieval_n_results,
|
|
relevance_threshold=self.settings.relevance_threshold,
|
|
)
|
|
|
|
# Collect all snapshots — use the last "completed" one for the final result
|
|
decomposed_snap = None
|
|
retrieval_snap = None
|
|
filtering_snap = None
|
|
completed_snap = None
|
|
|
|
async for snap in pipeline.execute(question):
|
|
if snap.phase == "decomposed":
|
|
decomposed_snap = snap
|
|
elif snap.phase == "retrieving":
|
|
retrieval_snap = snap
|
|
elif snap.phase == "filtering":
|
|
filtering_snap = snap
|
|
elif snap.phase == "completed":
|
|
completed_snap = snap
|
|
|
|
if completed_snap is None:
|
|
raise RuntimeError("Pipeline did not produce a completed snapshot")
|
|
|
|
# Build retrieval result
|
|
retrieval_per_subq = []
|
|
if retrieval_snap and retrieval_snap.retrieval_results:
|
|
for sq_idx, (sub_q_text, chunks) in enumerate(
|
|
retrieval_snap.retrieval_results
|
|
):
|
|
retrieval_per_subq.append(
|
|
SubQuestionChunks(
|
|
sub_question_index=sq_idx,
|
|
sub_question_text=sub_q_text,
|
|
chunks=[
|
|
ChunkEntry(
|
|
chunk_index=i,
|
|
text=text,
|
|
metadata=meta,
|
|
distance=distance,
|
|
)
|
|
for i, (text, meta, distance) in enumerate(chunks)
|
|
],
|
|
)
|
|
)
|
|
|
|
# Build filtered result
|
|
filtered_per_subq = []
|
|
if filtering_snap and filtering_snap.filtered_by_subq:
|
|
for sq_idx, (sub_q_text, chunks) in enumerate(
|
|
filtering_snap.filtered_by_subq
|
|
):
|
|
filtered_per_subq.append(
|
|
SubQuestionChunks(
|
|
sub_question_index=sq_idx,
|
|
sub_question_text=sub_q_text,
|
|
chunks=[
|
|
ChunkEntry(
|
|
chunk_index=i,
|
|
text=text,
|
|
metadata=meta,
|
|
distance=0.0,
|
|
)
|
|
for i, (text, meta) in enumerate(chunks)
|
|
],
|
|
)
|
|
)
|
|
|
|
# Build response result — serialize through dicts to convert between
|
|
# query.py SubQuestionSources and testing.py SubQuestionSources (same fields)
|
|
response_result = ResponseResult(
|
|
final_answer=completed_snap.answer,
|
|
sub_question_sources=[
|
|
{
|
|
"sub_question_index": sq.sub_question_index,
|
|
"sub_question_text": sq.sub_question_text,
|
|
"sources": [s.model_dump() for s in sq.sources],
|
|
}
|
|
for sq in completed_snap.sub_question_sources
|
|
],
|
|
generate_time_ms=completed_snap.generator_time_ms,
|
|
)
|
|
|
|
return GenerateResult(
|
|
result_id=result_id,
|
|
input_type="text",
|
|
profile=active_profile,
|
|
label=label,
|
|
input=InputInfo(text=question),
|
|
extracted_key_questions=completed_snap.extracted_questions,
|
|
retrieval=RetrievalResult(
|
|
per_sub_question=retrieval_per_subq,
|
|
total_chunks_retrieved=(
|
|
retrieval_snap.chunks_retrieved_count if retrieval_snap else 0
|
|
),
|
|
retriever_time_ms=(
|
|
retrieval_snap.retriever_time_ms if retrieval_snap else 0
|
|
),
|
|
),
|
|
filtered=FilteredResult(
|
|
per_sub_question=filtered_per_subq,
|
|
total_chunks_filtered=(
|
|
filtering_snap.chunks_filtered_count if filtering_snap else 0
|
|
),
|
|
filter_time_ms=(
|
|
filtering_snap.filter_time_ms if filtering_snap else 0
|
|
),
|
|
),
|
|
response=response_result,
|
|
timing=TimingInfo(
|
|
decomposer_time_ms=completed_snap.decomposer_time_ms,
|
|
retriever_time_ms=(
|
|
retrieval_snap.retriever_time_ms if retrieval_snap else 0
|
|
),
|
|
filter_time_ms=(
|
|
filtering_snap.filter_time_ms if filtering_snap else 0
|
|
),
|
|
generator_time_ms=completed_snap.generator_time_ms,
|
|
total_time_ms=completed_snap.total_time_ms,
|
|
),
|
|
)
|
|
|
|
async def run_audio_test(
|
|
self,
|
|
audio_bytes: bytes,
|
|
reference_transcript: str,
|
|
profile: str,
|
|
prompt_service: PromptService,
|
|
language: str = "yue",
|
|
label: str = "",
|
|
audio_filename: str = "",
|
|
) -> GenerateResult:
|
|
result_id = uuid.uuid4().hex[:12]
|
|
|
|
# Run ASR
|
|
asr_client = ASRClient(self.settings)
|
|
transcribed_text = await asr_client.transcribe_full(
|
|
audio_bytes, language=language
|
|
)
|
|
|
|
# Run text test on transcribed text
|
|
result = await self.run_text_test(
|
|
question=transcribed_text,
|
|
profile=profile,
|
|
prompt_service=prompt_service,
|
|
label=label,
|
|
)
|
|
|
|
# Override with audio-specific fields
|
|
result.result_id = result_id
|
|
result.input_type = "audio"
|
|
result.input = InputInfo(
|
|
text=transcribed_text,
|
|
reference_transcript=reference_transcript,
|
|
audio_filename=audio_filename,
|
|
audio_duration_seconds=0.0,
|
|
asr_language=language,
|
|
)
|
|
|
|
return result
|