feat: add Sub-Phase 9.3 evaluation API endpoint and 9.4 polish

This commit is contained in:
Woody 2026-05-25 19:30:17 +08:00
parent 098be359e7
commit 032dd75e17
4 changed files with 523 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate, test_evaluate
from app.core.config import get_settings
from app.core.sqlite_db import (
get_prompts_db,
@ -59,6 +59,7 @@ app.include_router(chunks.router)
app.include_router(video.router, prefix="/api/v1")
app.include_router(ws_asr.router)
app.include_router(test_generate.router, prefix="/api/v1")
app.include_router(test_evaluate.router, prefix="/api/v1")
_prompts_conn = get_prompts_db()
init_prompts_db(_prompts_conn)

View File

@ -0,0 +1,73 @@
import logging
from fastapi import APIRouter, HTTPException, Query
from app.core.config import get_settings
from app.core.dependencies import get_rag_service
from app.models.testing import EvaluateRequest
from app.services.prompt_service import PromptService
from app.services.test_evaluation_service import run_evaluation
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["test"])
def _get_storage_service() -> TestStorageService:
settings = get_settings()
return TestStorageService(
results_dir=settings.test_results_dir,
evaluations_dir=settings.test_evaluations_dir,
)
@router.post("/test/evaluate")
async def evaluate(request: EvaluateRequest):
settings = get_settings()
storage = _get_storage_service()
prompt_service = PromptService(db_path=settings.prompts_db_path)
prompt_service.activate_profile("A")
try:
rag = get_rag_service()
except Exception:
rag = None
try:
result = await run_evaluation(
request=request,
settings=settings,
storage=storage,
rag=rag,
)
return result.model_dump()
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error("Evaluation failed: %s", str(e), exc_info=True)
raise HTTPException(status_code=500, detail=f"Evaluation failed: {str(e)}")
@router.get("/test/evaluations")
async def list_evaluations(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
storage = _get_storage_service()
return storage.list_evaluations(limit=limit, offset=offset)
@router.get("/test/evaluations/{eval_id}")
async def get_evaluation(eval_id: str):
storage = _get_storage_service()
result = storage.load_evaluation(eval_id)
if result is None:
raise HTTPException(status_code=404, detail="Evaluation not found")
return result.model_dump()
@router.delete("/test/evaluations/{eval_id}")
async def delete_evaluation(eval_id: str):
storage = _get_storage_service()
deleted = storage.delete_evaluation(eval_id)
if not deleted:
raise HTTPException(status_code=404, detail="Evaluation not found")
return {"status": "deleted", "evaluation_id": eval_id}

View File

@ -0,0 +1,313 @@
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Set, Tuple
from app.core.config import Settings
from app.models.testing import (
AudioEvalResult,
ChunkAccuracy,
ChunkEvalResult,
EvaluateRequest,
EvaluationResult,
EvaluationTiming,
FilteredResult,
GenerateResult,
GroundTruthInfo,
KeyQuestionsEvalResult,
ResponseEvalResult,
RetrievalResult,
SubQuestionChunkEval,
SubQuestionResponseEval,
)
from app.services.cer_wer import calculate_cer, calculate_wer
from app.services.chunk_evaluator import (
_calculate_accuracy,
_determine_ground_truth_chunks,
)
from app.services.key_questions_evaluator import evaluate_key_questions
from app.services.response_evaluator import evaluate_response
from app.services.rag import RAGService
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
def _extract_chunk_sets(
retrieval: RetrievalResult,
filtered: FilteredResult,
) -> Tuple[Set[Tuple[str, int]], Set[Tuple[str, int]]]:
"""Extract (document_id, chunk_index) sets from retrieval and filtered results."""
retrieved = set()
filtered_set = set()
for sq in retrieval.per_sub_question:
for chunk in sq.chunks:
doc_id = chunk.metadata.get("document_id", "unknown")
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
retrieved.add((str(doc_id), int(chunk_idx)))
for sq in filtered.per_sub_question:
for chunk in sq.chunks:
doc_id = chunk.metadata.get("document_id", "unknown")
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
filtered_set.add((str(doc_id), int(chunk_idx)))
return retrieved, filtered_set
def _collect_all_chunks(
rag: RAGService,
) -> List[Tuple[str, int, str, Dict[str, Any]]]:
"""Fetch all chunks from all documents in ChromaDB.
Returns list of (document_id, chunk_index, text, metadata) tuples.
"""
docs, _, _ = rag.list_documents()
all_chunks = []
for doc in docs:
doc_id = doc["document_id"]
chunks = rag.list_chunks(doc_id)
for chunk in chunks:
chunk_idx = chunk.get("chunk_index", 0)
text = chunk.get("text", "")
all_chunks.append((doc_id, chunk_idx, text, chunk))
return all_chunks
async def run_evaluation(
request: EvaluateRequest,
settings: Settings,
storage: TestStorageService,
rag: Optional[RAGService] = None,
) -> EvaluationResult:
evaluation_id = uuid.uuid4().hex[:12]
overall_start = time.perf_counter()
# Load result
if request.result_id:
result = storage.load_result(request.result_id)
if result is None:
raise ValueError(f"Result not found: {request.result_id}")
elif request.results:
result = request.results
else:
raise ValueError("No result_id or inline results provided")
cfg = request.evaluation_config
total_ms = 0
audio_eval_result = None
kq_eval_result = None
chunk_eval_result = None
resp_eval_result = None
audio_time = 0
kq_time = 0
chunk_time = 0
resp_time = 0
# (i) Audio evaluation
if result.input_type == "audio" and result.input.reference_transcript:
t0 = time.perf_counter()
cer_data = calculate_cer(result.input.reference_transcript, result.input.text)
wer_data = calculate_wer(result.input.reference_transcript, result.input.text)
audio_eval_result = AudioEvalResult(
status="completed",
cer=cer_data["cer"],
wer=wer_data["wer"],
reference_length=cer_data["reference_length"],
transcribed_length=cer_data["transcribed_length"],
substitutions=cer_data["substitutions"],
deletions=cer_data["deletions"],
insertions=cer_data["insertions"],
hits=cer_data["hits"],
)
audio_time = int((time.perf_counter() - t0) * 1000)
# (ii) Key questions evaluation
if cfg.key_questions_evaluators:
t0 = time.perf_counter()
kq_eval_result = await evaluate_key_questions(
original_text=result.input.text,
extracted_questions=result.extracted_key_questions,
evaluator_configs=cfg.key_questions_evaluators,
)
kq_time = int((time.perf_counter() - t0) * 1000)
# (iii) Chunk evaluation
if rag and cfg.chunk_evaluator:
t0 = time.perf_counter()
all_chunks = _collect_all_chunks(rag)
retrieved_set, filtered_set = _extract_chunk_sets(result.retrieval, result.filtered)
per_sub_q = []
overall_unfiltered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
overall_filtered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
unfiltered_accuracies = []
filtered_accuracies = []
import asyncio
semaphore = asyncio.Semaphore(settings.eval_max_concurrent_batches)
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
gt_set, total_chunks, gt_time = await _determine_ground_truth_chunks(
sub_question=sq_text,
all_chunks=all_chunks,
config=cfg.chunk_evaluator,
semaphore=semaphore,
model_idx=sq_idx,
batch_size=settings.eval_chunk_batch_size,
)
relevant_docs = list(set(doc_id for doc_id, _ in gt_set))
relevant_chunk_dicts = [
{"document_id": doc_id, "chunk_index": idx}
for doc_id, idx in gt_set
]
gt_info = GroundTruthInfo(
relevant_documents=relevant_docs,
relevant_chunks=relevant_chunk_dicts,
total_relevant_chunks=len(gt_set),
chunk_evaluation_time_ms=gt_time,
)
sq_retrieved = {
(doc_id, idx)
for doc_id, idx in retrieved_set
}
sq_filtered = {
(doc_id, idx)
for doc_id, idx in filtered_set
}
unf_acc = _calculate_accuracy(sq_retrieved, gt_set)
fil_acc = _calculate_accuracy(sq_filtered, gt_set)
unfiltered_accuracies.append(unf_acc)
filtered_accuracies.append(fil_acc)
per_sub_q.append(
SubQuestionChunkEval(
sub_question_index=sq_idx,
sub_question_text=sq_text,
ground_truth=gt_info,
unfiltered_accuracy=unf_acc,
filtered_accuracy=fil_acc,
)
)
if unfiltered_accuracies:
n = len(unfiltered_accuracies)
overall_unfiltered_metrics = {
"precision": round(sum(a.precision for a in unfiltered_accuracies) / n, 4),
"recall": round(sum(a.recall for a in unfiltered_accuracies) / n, 4),
"f1": round(sum(a.f1 for a in unfiltered_accuracies) / n, 4),
}
if filtered_accuracies:
n = len(filtered_accuracies)
overall_filtered_metrics = {
"precision": round(sum(a.precision for a in filtered_accuracies) / n, 4),
"recall": round(sum(a.recall for a in filtered_accuracies) / n, 4),
"f1": round(sum(a.f1 for a in filtered_accuracies) / n, 4),
}
chunk_eval_result = ChunkEvalResult(
per_sub_question=per_sub_q,
overall_unfiltered=ChunkAccuracy(
precision=overall_unfiltered_metrics["precision"],
recall=overall_unfiltered_metrics["recall"],
f1=overall_unfiltered_metrics["f1"],
pipeline_chunks=sum(a.pipeline_chunks for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
),
overall_filtered=ChunkAccuracy(
precision=overall_filtered_metrics["precision"],
recall=overall_filtered_metrics["recall"],
f1=overall_filtered_metrics["f1"],
pipeline_chunks=sum(a.pipeline_chunks for a in filtered_accuracies) if filtered_accuracies else 0,
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in filtered_accuracies) if filtered_accuracies else 0,
),
)
chunk_time = int((time.perf_counter() - t0) * 1000)
# (iv) Response evaluation
if rag and cfg.response_evaluator and chunk_eval_result:
t0 = time.perf_counter()
per_sub_q_resp = []
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
if sq_idx < len(chunk_eval_result.per_sub_question):
gt_chunks_data = chunk_eval_result.per_sub_question[sq_idx].ground_truth
relevant_chunks_meta = []
for rc in gt_chunks_data.relevant_chunks[:10]:
doc_id = rc["document_id"]
chunk_idx = rc["chunk_index"]
# Try to match with pipeline's response sources
for sq in result.response.sub_question_sources:
if isinstance(sq, dict):
for s in sq.get("sources", []):
if s.get("document_id") == doc_id and s.get("chunk_index") == chunk_idx:
relevant_chunks_meta.append((s.get("content_summary", ""), s))
break
elif hasattr(sq, "sources"):
for s in sq.sources:
if hasattr(s, "document_id") and s.document_id == doc_id and s.chunk_index == chunk_idx:
relevant_chunks_meta.append((s.content_summary, s.model_dump() if hasattr(s, "model_dump") else {}))
break
if relevant_chunks_meta:
section_text = ""
for src in result.response.sub_question_sources:
if isinstance(src, dict):
if src.get("sub_question_index") == sq_idx:
section_text = str(src)
elif hasattr(src, "sub_question_index") and src.sub_question_index == sq_idx:
section_text = str(src)
resp_eval = await evaluate_response(
key_question=sq_text,
ground_truth_chunks=relevant_chunks_meta,
pipeline_response=section_text or result.response.final_answer,
evaluator_config=cfg.response_evaluator,
)
if resp_eval:
resp_eval.sub_question_index = sq_idx
resp_eval.sub_question_text = sq_text
per_sub_q_resp.append(resp_eval)
overall_completeness = 0.0
overall_factual = 0.0
if per_sub_q_resp:
overall_completeness = round(sum(r.completeness_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
overall_factual = round(sum(r.factual_accuracy_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
resp_eval_result = ResponseEvalResult(
per_sub_question=per_sub_q_resp,
overall_completeness=overall_completeness,
overall_factual_accuracy=overall_factual,
)
resp_time = int((time.perf_counter() - t0) * 1000)
total_ms = int((time.perf_counter() - overall_start) * 1000)
eval_result = EvaluationResult(
evaluation_id=evaluation_id,
result_id=result.result_id,
status="completed",
audio_evaluation=audio_eval_result,
key_questions_evaluation=kq_eval_result,
chunk_evaluation=chunk_eval_result,
response_evaluation=resp_eval_result,
timing=EvaluationTiming(
audio_evaluation_time_ms=audio_time,
key_questions_evaluation_time_ms=kq_time,
chunk_evaluation_time_ms=chunk_time,
response_evaluation_time_ms=resp_time,
total_evaluation_time_ms=total_ms,
),
)
storage.save_evaluation(eval_result)
return eval_result

View File

@ -0,0 +1,135 @@
"""Phase 9 tests: Evaluation API endpoint integration (Sub-Phase 9.3)."""
import json
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.models.testing import (
ChunkAccuracy,
DimensionScores,
EvaluatorConfig,
EvaluationResult,
FilteredResult,
GenerateResult,
InputInfo,
KeyQuestionsEvalEntry,
KeyQuestionsEvalResult,
ResponseResult,
RetrievalResult,
TimingInfo,
)
@pytest.fixture(autouse=True)
def _set_api_keys(monkeypatch):
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
@pytest.fixture
def client(tmp_path, monkeypatch):
results_dir = str(tmp_path / "test_results")
evals_dir = str(tmp_path / "test_evaluations")
prompts_path = str(tmp_path / "prompts.db")
history_path = str(tmp_path / "history.db")
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
from app.core.config import get_settings
get_settings.cache_clear()
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
conn = _get_db(prompts_path)
init_prompts_db(conn)
seed_default_profiles(conn)
conn.close()
hconn = _get_db(history_path)
init_history_db(hconn)
hconn.close()
from app.routers.test_evaluate import router
test_app = FastAPI()
test_app.include_router(router, prefix="/api/v1")
yield TestClient(test_app)
get_settings.cache_clear()
def _make_sample_result():
return GenerateResult(
result_id="test-result-001",
input_type="text",
profile="A",
input=InputInfo(text="test question"),
extracted_key_questions=["q1", "q2"],
retrieval=RetrievalResult(per_sub_question=[], total_chunks_retrieved=10, retriever_time_ms=100),
filtered=FilteredResult(per_sub_question=[], total_chunks_filtered=5, filter_time_ms=100),
response=ResponseResult(final_answer="answer", sub_question_sources=[], generate_time_ms=100),
timing=TimingInfo(decomposer_time_ms=100, retriever_time_ms=100, filter_time_ms=100, generator_time_ms=100, total_time_ms=400),
)
@pytest.fixture
def saved_result(client):
from app.services.test_storage_service import TestStorageService
from app.core.config import get_settings
result = _make_sample_result()
svc = TestStorageService(get_settings().test_results_dir, get_settings().test_evaluations_dir)
svc.save_result(result)
return result.result_id
class TestEvaluateEndpoint:
@pytest.mark.asyncio
async def test_valid_evaluate_returns_200(self, client, saved_result):
mock_scores = DimensionScores(dimension_1_準確性=35.0, dimension_2_完整性=22.0, dimension_3_清晰度=18.0, dimension_4_簡潔性=13.0)
mock_kq = KeyQuestionsEvalResult(
evaluations=[
KeyQuestionsEvalEntry(model_name="m1", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100),
KeyQuestionsEvalEntry(model_name="m2", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100),
],
average_scores=mock_scores,
average_total=88.0,
)
payload = {
"result_id": saved_result,
"evaluation_config": {
"key_questions_evaluators": [
{"model_name": "deepseek-v4-pro", "base_url": "https://api.deepseek.com", "api_key_env": "DP_API_KEY", "enable_thinking": True},
{"model_name": "qwen3-7b-max", "base_url": "https://dashscope.example.com/v1", "api_key_env": "DASHSCOPE_API_KEY", "enable_thinking": True},
],
"chunk_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True},
"response_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True},
},
}
resp = client.post("/api/v1/test/evaluate", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["status"] in ("completed", "partial")
assert "evaluation_id" in data
def test_missing_result_returns_404(self, client):
payload = {
"result_id": "no-such-id",
"evaluation_config": {
"key_questions_evaluators": [],
"chunk_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"},
"response_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"},
},
}
resp = client.post("/api/v1/test/evaluate", json=payload)
assert resp.status_code == 404