314 lines
12 KiB
Python
314 lines
12 KiB
Python
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
|
from app.core.config import Settings
|
|
from app.models.testing import (
|
|
AudioEvalResult,
|
|
ChunkAccuracy,
|
|
ChunkEvalResult,
|
|
EvaluateRequest,
|
|
EvaluationResult,
|
|
EvaluationTiming,
|
|
FilteredResult,
|
|
GenerateResult,
|
|
GroundTruthInfo,
|
|
KeyQuestionsEvalResult,
|
|
ResponseEvalResult,
|
|
RetrievalResult,
|
|
SubQuestionChunkEval,
|
|
SubQuestionResponseEval,
|
|
)
|
|
from app.services.cer_wer import calculate_cer, calculate_wer
|
|
from app.services.chunk_evaluator import (
|
|
_calculate_accuracy,
|
|
_determine_ground_truth_chunks,
|
|
)
|
|
from app.services.key_questions_evaluator import evaluate_key_questions
|
|
from app.services.response_evaluator import evaluate_response
|
|
from app.services.rag import RAGService
|
|
from app.services.test_storage_service import TestStorageService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _extract_chunk_sets(
|
|
retrieval: RetrievalResult,
|
|
filtered: FilteredResult,
|
|
) -> Tuple[Set[Tuple[str, int]], Set[Tuple[str, int]]]:
|
|
"""Extract (document_id, chunk_index) sets from retrieval and filtered results."""
|
|
retrieved = set()
|
|
filtered_set = set()
|
|
|
|
for sq in retrieval.per_sub_question:
|
|
for chunk in sq.chunks:
|
|
doc_id = chunk.metadata.get("document_id", "unknown")
|
|
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
|
|
retrieved.add((str(doc_id), int(chunk_idx)))
|
|
|
|
for sq in filtered.per_sub_question:
|
|
for chunk in sq.chunks:
|
|
doc_id = chunk.metadata.get("document_id", "unknown")
|
|
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
|
|
filtered_set.add((str(doc_id), int(chunk_idx)))
|
|
|
|
return retrieved, filtered_set
|
|
|
|
|
|
def _collect_all_chunks(
|
|
rag: RAGService,
|
|
) -> List[Tuple[str, int, str, Dict[str, Any]]]:
|
|
"""Fetch all chunks from all documents in ChromaDB.
|
|
|
|
Returns list of (document_id, chunk_index, text, metadata) tuples.
|
|
"""
|
|
docs, _, _ = rag.list_documents()
|
|
all_chunks = []
|
|
|
|
for doc in docs:
|
|
doc_id = doc["document_id"]
|
|
chunks = rag.list_chunks(doc_id)
|
|
for chunk in chunks:
|
|
chunk_idx = chunk.get("chunk_index", 0)
|
|
text = chunk.get("text", "")
|
|
all_chunks.append((doc_id, chunk_idx, text, chunk))
|
|
|
|
return all_chunks
|
|
|
|
|
|
async def run_evaluation(
|
|
request: EvaluateRequest,
|
|
settings: Settings,
|
|
storage: TestStorageService,
|
|
rag: Optional[RAGService] = None,
|
|
) -> EvaluationResult:
|
|
evaluation_id = uuid.uuid4().hex[:12]
|
|
overall_start = time.perf_counter()
|
|
|
|
# Load result
|
|
if request.result_id:
|
|
result = storage.load_result(request.result_id)
|
|
if result is None:
|
|
raise ValueError(f"Result not found: {request.result_id}")
|
|
elif request.results:
|
|
result = request.results
|
|
else:
|
|
raise ValueError("No result_id or inline results provided")
|
|
|
|
cfg = request.evaluation_config
|
|
total_ms = 0
|
|
audio_eval_result = None
|
|
kq_eval_result = None
|
|
chunk_eval_result = None
|
|
resp_eval_result = None
|
|
audio_time = 0
|
|
kq_time = 0
|
|
chunk_time = 0
|
|
resp_time = 0
|
|
|
|
# (i) Audio evaluation
|
|
if result.input_type == "audio" and result.input.reference_transcript:
|
|
t0 = time.perf_counter()
|
|
cer_data = calculate_cer(result.input.reference_transcript, result.input.text)
|
|
wer_data = calculate_wer(result.input.reference_transcript, result.input.text)
|
|
audio_eval_result = AudioEvalResult(
|
|
status="completed",
|
|
cer=cer_data["cer"],
|
|
wer=wer_data["wer"],
|
|
reference_length=cer_data["reference_length"],
|
|
transcribed_length=cer_data["transcribed_length"],
|
|
substitutions=cer_data["substitutions"],
|
|
deletions=cer_data["deletions"],
|
|
insertions=cer_data["insertions"],
|
|
hits=cer_data["hits"],
|
|
)
|
|
audio_time = int((time.perf_counter() - t0) * 1000)
|
|
|
|
# (ii) Key questions evaluation
|
|
if cfg.key_questions_evaluators:
|
|
t0 = time.perf_counter()
|
|
kq_eval_result = await evaluate_key_questions(
|
|
original_text=result.input.text,
|
|
extracted_questions=result.extracted_key_questions,
|
|
evaluator_configs=cfg.key_questions_evaluators,
|
|
)
|
|
kq_time = int((time.perf_counter() - t0) * 1000)
|
|
|
|
# (iii) Chunk evaluation
|
|
if rag and cfg.chunk_evaluator:
|
|
t0 = time.perf_counter()
|
|
all_chunks = _collect_all_chunks(rag)
|
|
retrieved_set, filtered_set = _extract_chunk_sets(result.retrieval, result.filtered)
|
|
|
|
per_sub_q = []
|
|
overall_unfiltered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
overall_filtered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
|
unfiltered_accuracies = []
|
|
filtered_accuracies = []
|
|
|
|
import asyncio
|
|
semaphore = asyncio.Semaphore(settings.eval_max_concurrent_batches)
|
|
|
|
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
|
|
gt_set, total_chunks, gt_time = await _determine_ground_truth_chunks(
|
|
sub_question=sq_text,
|
|
all_chunks=all_chunks,
|
|
config=cfg.chunk_evaluator,
|
|
semaphore=semaphore,
|
|
model_idx=sq_idx,
|
|
batch_size=settings.eval_chunk_batch_size,
|
|
)
|
|
|
|
relevant_docs = list(set(doc_id for doc_id, _ in gt_set))
|
|
relevant_chunk_dicts = [
|
|
{"document_id": doc_id, "chunk_index": idx}
|
|
for doc_id, idx in gt_set
|
|
]
|
|
|
|
gt_info = GroundTruthInfo(
|
|
relevant_documents=relevant_docs,
|
|
relevant_chunks=relevant_chunk_dicts,
|
|
total_relevant_chunks=len(gt_set),
|
|
chunk_evaluation_time_ms=gt_time,
|
|
)
|
|
|
|
sq_retrieved = {
|
|
(doc_id, idx)
|
|
for doc_id, idx in retrieved_set
|
|
}
|
|
sq_filtered = {
|
|
(doc_id, idx)
|
|
for doc_id, idx in filtered_set
|
|
}
|
|
|
|
unf_acc = _calculate_accuracy(sq_retrieved, gt_set)
|
|
fil_acc = _calculate_accuracy(sq_filtered, gt_set)
|
|
|
|
unfiltered_accuracies.append(unf_acc)
|
|
filtered_accuracies.append(fil_acc)
|
|
|
|
per_sub_q.append(
|
|
SubQuestionChunkEval(
|
|
sub_question_index=sq_idx,
|
|
sub_question_text=sq_text,
|
|
ground_truth=gt_info,
|
|
unfiltered_accuracy=unf_acc,
|
|
filtered_accuracy=fil_acc,
|
|
)
|
|
)
|
|
|
|
if unfiltered_accuracies:
|
|
n = len(unfiltered_accuracies)
|
|
overall_unfiltered_metrics = {
|
|
"precision": round(sum(a.precision for a in unfiltered_accuracies) / n, 4),
|
|
"recall": round(sum(a.recall for a in unfiltered_accuracies) / n, 4),
|
|
"f1": round(sum(a.f1 for a in unfiltered_accuracies) / n, 4),
|
|
}
|
|
if filtered_accuracies:
|
|
n = len(filtered_accuracies)
|
|
overall_filtered_metrics = {
|
|
"precision": round(sum(a.precision for a in filtered_accuracies) / n, 4),
|
|
"recall": round(sum(a.recall for a in filtered_accuracies) / n, 4),
|
|
"f1": round(sum(a.f1 for a in filtered_accuracies) / n, 4),
|
|
}
|
|
|
|
chunk_eval_result = ChunkEvalResult(
|
|
per_sub_question=per_sub_q,
|
|
overall_unfiltered=ChunkAccuracy(
|
|
precision=overall_unfiltered_metrics["precision"],
|
|
recall=overall_unfiltered_metrics["recall"],
|
|
f1=overall_unfiltered_metrics["f1"],
|
|
pipeline_chunks=sum(a.pipeline_chunks for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
|
|
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
|
|
),
|
|
overall_filtered=ChunkAccuracy(
|
|
precision=overall_filtered_metrics["precision"],
|
|
recall=overall_filtered_metrics["recall"],
|
|
f1=overall_filtered_metrics["f1"],
|
|
pipeline_chunks=sum(a.pipeline_chunks for a in filtered_accuracies) if filtered_accuracies else 0,
|
|
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in filtered_accuracies) if filtered_accuracies else 0,
|
|
),
|
|
)
|
|
chunk_time = int((time.perf_counter() - t0) * 1000)
|
|
|
|
# (iv) Response evaluation
|
|
if rag and cfg.response_evaluator and chunk_eval_result:
|
|
t0 = time.perf_counter()
|
|
per_sub_q_resp = []
|
|
|
|
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
|
|
if sq_idx < len(chunk_eval_result.per_sub_question):
|
|
gt_chunks_data = chunk_eval_result.per_sub_question[sq_idx].ground_truth
|
|
relevant_chunks_meta = []
|
|
for rc in gt_chunks_data.relevant_chunks[:10]:
|
|
doc_id = rc["document_id"]
|
|
chunk_idx = rc["chunk_index"]
|
|
# Try to match with pipeline's response sources
|
|
for sq in result.response.sub_question_sources:
|
|
if isinstance(sq, dict):
|
|
for s in sq.get("sources", []):
|
|
if s.get("document_id") == doc_id and s.get("chunk_index") == chunk_idx:
|
|
relevant_chunks_meta.append((s.get("content_summary", ""), s))
|
|
break
|
|
elif hasattr(sq, "sources"):
|
|
for s in sq.sources:
|
|
if hasattr(s, "document_id") and s.document_id == doc_id and s.chunk_index == chunk_idx:
|
|
relevant_chunks_meta.append((s.content_summary, s.model_dump() if hasattr(s, "model_dump") else {}))
|
|
break
|
|
|
|
if relevant_chunks_meta:
|
|
section_text = ""
|
|
for src in result.response.sub_question_sources:
|
|
if isinstance(src, dict):
|
|
if src.get("sub_question_index") == sq_idx:
|
|
section_text = str(src)
|
|
elif hasattr(src, "sub_question_index") and src.sub_question_index == sq_idx:
|
|
section_text = str(src)
|
|
|
|
resp_eval = await evaluate_response(
|
|
key_question=sq_text,
|
|
ground_truth_chunks=relevant_chunks_meta,
|
|
pipeline_response=section_text or result.response.final_answer,
|
|
evaluator_config=cfg.response_evaluator,
|
|
)
|
|
if resp_eval:
|
|
resp_eval.sub_question_index = sq_idx
|
|
resp_eval.sub_question_text = sq_text
|
|
per_sub_q_resp.append(resp_eval)
|
|
|
|
overall_completeness = 0.0
|
|
overall_factual = 0.0
|
|
if per_sub_q_resp:
|
|
overall_completeness = round(sum(r.completeness_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
|
|
overall_factual = round(sum(r.factual_accuracy_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
|
|
|
|
resp_eval_result = ResponseEvalResult(
|
|
per_sub_question=per_sub_q_resp,
|
|
overall_completeness=overall_completeness,
|
|
overall_factual_accuracy=overall_factual,
|
|
)
|
|
resp_time = int((time.perf_counter() - t0) * 1000)
|
|
|
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
|
|
|
eval_result = EvaluationResult(
|
|
evaluation_id=evaluation_id,
|
|
result_id=result.result_id,
|
|
status="completed",
|
|
audio_evaluation=audio_eval_result,
|
|
key_questions_evaluation=kq_eval_result,
|
|
chunk_evaluation=chunk_eval_result,
|
|
response_evaluation=resp_eval_result,
|
|
timing=EvaluationTiming(
|
|
audio_evaluation_time_ms=audio_time,
|
|
key_questions_evaluation_time_ms=kq_time,
|
|
chunk_evaluation_time_ms=chunk_time,
|
|
response_evaluation_time_ms=resp_time,
|
|
total_evaluation_time_ms=total_ms,
|
|
),
|
|
)
|
|
|
|
storage.save_evaluation(eval_result)
|
|
return eval_result
|