legco_ai_assistant/backend/app/services/test_evaluation_service.py

314 lines
12 KiB
Python

import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Set, Tuple
from app.core.config import Settings
from app.models.testing import (
AudioEvalResult,
ChunkAccuracy,
ChunkEvalResult,
EvaluateRequest,
EvaluationResult,
EvaluationTiming,
FilteredResult,
GenerateResult,
GroundTruthInfo,
KeyQuestionsEvalResult,
ResponseEvalResult,
RetrievalResult,
SubQuestionChunkEval,
SubQuestionResponseEval,
)
from app.services.cer_wer import calculate_cer, calculate_wer
from app.services.chunk_evaluator import (
_calculate_accuracy,
_determine_ground_truth_chunks,
)
from app.services.key_questions_evaluator import evaluate_key_questions
from app.services.response_evaluator import evaluate_response
from app.services.rag import RAGService
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
def _extract_chunk_sets(
retrieval: RetrievalResult,
filtered: FilteredResult,
) -> Tuple[Set[Tuple[str, int]], Set[Tuple[str, int]]]:
"""Extract (document_id, chunk_index) sets from retrieval and filtered results."""
retrieved = set()
filtered_set = set()
for sq in retrieval.per_sub_question:
for chunk in sq.chunks:
doc_id = chunk.metadata.get("document_id", "unknown")
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
retrieved.add((str(doc_id), int(chunk_idx)))
for sq in filtered.per_sub_question:
for chunk in sq.chunks:
doc_id = chunk.metadata.get("document_id", "unknown")
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
filtered_set.add((str(doc_id), int(chunk_idx)))
return retrieved, filtered_set
def _collect_all_chunks(
rag: RAGService,
) -> List[Tuple[str, int, str, Dict[str, Any]]]:
"""Fetch all chunks from all documents in ChromaDB.
Returns list of (document_id, chunk_index, text, metadata) tuples.
"""
docs, _, _ = rag.list_documents()
all_chunks = []
for doc in docs:
doc_id = doc["document_id"]
chunks = rag.list_chunks(doc_id)
for chunk in chunks:
chunk_idx = chunk.get("chunk_index", 0)
text = chunk.get("text", "")
all_chunks.append((doc_id, chunk_idx, text, chunk))
return all_chunks
async def run_evaluation(
request: EvaluateRequest,
settings: Settings,
storage: TestStorageService,
rag: Optional[RAGService] = None,
) -> EvaluationResult:
evaluation_id = uuid.uuid4().hex[:12]
overall_start = time.perf_counter()
# Load result
if request.result_id:
result = storage.load_result(request.result_id)
if result is None:
raise ValueError(f"Result not found: {request.result_id}")
elif request.results:
result = request.results
else:
raise ValueError("No result_id or inline results provided")
cfg = request.evaluation_config
total_ms = 0
audio_eval_result = None
kq_eval_result = None
chunk_eval_result = None
resp_eval_result = None
audio_time = 0
kq_time = 0
chunk_time = 0
resp_time = 0
# (i) Audio evaluation
if result.input_type == "audio" and result.input.reference_transcript:
t0 = time.perf_counter()
cer_data = calculate_cer(result.input.reference_transcript, result.input.text)
wer_data = calculate_wer(result.input.reference_transcript, result.input.text)
audio_eval_result = AudioEvalResult(
status="completed",
cer=cer_data["cer"],
wer=wer_data["wer"],
reference_length=cer_data["reference_length"],
transcribed_length=cer_data["transcribed_length"],
substitutions=cer_data["substitutions"],
deletions=cer_data["deletions"],
insertions=cer_data["insertions"],
hits=cer_data["hits"],
)
audio_time = int((time.perf_counter() - t0) * 1000)
# (ii) Key questions evaluation
if cfg.key_questions_evaluators:
t0 = time.perf_counter()
kq_eval_result = await evaluate_key_questions(
original_text=result.input.text,
extracted_questions=result.extracted_key_questions,
evaluator_configs=cfg.key_questions_evaluators,
)
kq_time = int((time.perf_counter() - t0) * 1000)
# (iii) Chunk evaluation
if rag and cfg.chunk_evaluator:
t0 = time.perf_counter()
all_chunks = _collect_all_chunks(rag)
retrieved_set, filtered_set = _extract_chunk_sets(result.retrieval, result.filtered)
per_sub_q = []
overall_unfiltered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
overall_filtered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
unfiltered_accuracies = []
filtered_accuracies = []
import asyncio
semaphore = asyncio.Semaphore(settings.eval_max_concurrent_batches)
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
gt_set, total_chunks, gt_time = await _determine_ground_truth_chunks(
sub_question=sq_text,
all_chunks=all_chunks,
config=cfg.chunk_evaluator,
semaphore=semaphore,
model_idx=sq_idx,
batch_size=settings.eval_chunk_batch_size,
)
relevant_docs = list(set(doc_id for doc_id, _ in gt_set))
relevant_chunk_dicts = [
{"document_id": doc_id, "chunk_index": idx}
for doc_id, idx in gt_set
]
gt_info = GroundTruthInfo(
relevant_documents=relevant_docs,
relevant_chunks=relevant_chunk_dicts,
total_relevant_chunks=len(gt_set),
chunk_evaluation_time_ms=gt_time,
)
sq_retrieved = {
(doc_id, idx)
for doc_id, idx in retrieved_set
}
sq_filtered = {
(doc_id, idx)
for doc_id, idx in filtered_set
}
unf_acc = _calculate_accuracy(sq_retrieved, gt_set)
fil_acc = _calculate_accuracy(sq_filtered, gt_set)
unfiltered_accuracies.append(unf_acc)
filtered_accuracies.append(fil_acc)
per_sub_q.append(
SubQuestionChunkEval(
sub_question_index=sq_idx,
sub_question_text=sq_text,
ground_truth=gt_info,
unfiltered_accuracy=unf_acc,
filtered_accuracy=fil_acc,
)
)
if unfiltered_accuracies:
n = len(unfiltered_accuracies)
overall_unfiltered_metrics = {
"precision": round(sum(a.precision for a in unfiltered_accuracies) / n, 4),
"recall": round(sum(a.recall for a in unfiltered_accuracies) / n, 4),
"f1": round(sum(a.f1 for a in unfiltered_accuracies) / n, 4),
}
if filtered_accuracies:
n = len(filtered_accuracies)
overall_filtered_metrics = {
"precision": round(sum(a.precision for a in filtered_accuracies) / n, 4),
"recall": round(sum(a.recall for a in filtered_accuracies) / n, 4),
"f1": round(sum(a.f1 for a in filtered_accuracies) / n, 4),
}
chunk_eval_result = ChunkEvalResult(
per_sub_question=per_sub_q,
overall_unfiltered=ChunkAccuracy(
precision=overall_unfiltered_metrics["precision"],
recall=overall_unfiltered_metrics["recall"],
f1=overall_unfiltered_metrics["f1"],
pipeline_chunks=sum(a.pipeline_chunks for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
),
overall_filtered=ChunkAccuracy(
precision=overall_filtered_metrics["precision"],
recall=overall_filtered_metrics["recall"],
f1=overall_filtered_metrics["f1"],
pipeline_chunks=sum(a.pipeline_chunks for a in filtered_accuracies) if filtered_accuracies else 0,
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in filtered_accuracies) if filtered_accuracies else 0,
),
)
chunk_time = int((time.perf_counter() - t0) * 1000)
# (iv) Response evaluation
if rag and cfg.response_evaluator and chunk_eval_result:
t0 = time.perf_counter()
per_sub_q_resp = []
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
if sq_idx < len(chunk_eval_result.per_sub_question):
gt_chunks_data = chunk_eval_result.per_sub_question[sq_idx].ground_truth
relevant_chunks_meta = []
for rc in gt_chunks_data.relevant_chunks[:10]:
doc_id = rc["document_id"]
chunk_idx = rc["chunk_index"]
# Try to match with pipeline's response sources
for sq in result.response.sub_question_sources:
if isinstance(sq, dict):
for s in sq.get("sources", []):
if s.get("document_id") == doc_id and s.get("chunk_index") == chunk_idx:
relevant_chunks_meta.append((s.get("content_summary", ""), s))
break
elif hasattr(sq, "sources"):
for s in sq.sources:
if hasattr(s, "document_id") and s.document_id == doc_id and s.chunk_index == chunk_idx:
relevant_chunks_meta.append((s.content_summary, s.model_dump() if hasattr(s, "model_dump") else {}))
break
if relevant_chunks_meta:
section_text = ""
for src in result.response.sub_question_sources:
if isinstance(src, dict):
if src.get("sub_question_index") == sq_idx:
section_text = str(src)
elif hasattr(src, "sub_question_index") and src.sub_question_index == sq_idx:
section_text = str(src)
resp_eval = await evaluate_response(
key_question=sq_text,
ground_truth_chunks=relevant_chunks_meta,
pipeline_response=section_text or result.response.final_answer,
evaluator_config=cfg.response_evaluator,
)
if resp_eval:
resp_eval.sub_question_index = sq_idx
resp_eval.sub_question_text = sq_text
per_sub_q_resp.append(resp_eval)
overall_completeness = 0.0
overall_factual = 0.0
if per_sub_q_resp:
overall_completeness = round(sum(r.completeness_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
overall_factual = round(sum(r.factual_accuracy_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
resp_eval_result = ResponseEvalResult(
per_sub_question=per_sub_q_resp,
overall_completeness=overall_completeness,
overall_factual_accuracy=overall_factual,
)
resp_time = int((time.perf_counter() - t0) * 1000)
total_ms = int((time.perf_counter() - overall_start) * 1000)
eval_result = EvaluationResult(
evaluation_id=evaluation_id,
result_id=result.result_id,
status="completed",
audio_evaluation=audio_eval_result,
key_questions_evaluation=kq_eval_result,
chunk_evaluation=chunk_eval_result,
response_evaluation=resp_eval_result,
timing=EvaluationTiming(
audio_evaluation_time_ms=audio_time,
key_questions_evaluation_time_ms=kq_time,
chunk_evaluation_time_ms=chunk_time,
response_evaluation_time_ms=resp_time,
total_evaluation_time_ms=total_ms,
),
)
storage.save_evaluation(eval_result)
return eval_result