feat: add Sub-Phase 9.3 evaluation API endpoint and 9.4 polish

This commit is contained in:
Woody 2026-05-25 19:30:17 +08:00
parent 098be359e7
commit 032dd75e17
4 changed files with 523 additions and 1 deletions

View File

@ -7,7 +7,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate, test_evaluate
from app.core.config import get_settings from app.core.config import get_settings
from app.core.sqlite_db import ( from app.core.sqlite_db import (
get_prompts_db, get_prompts_db,
@ -59,6 +59,7 @@ app.include_router(chunks.router)
app.include_router(video.router, prefix="/api/v1") app.include_router(video.router, prefix="/api/v1")
app.include_router(ws_asr.router) app.include_router(ws_asr.router)
app.include_router(test_generate.router, prefix="/api/v1") app.include_router(test_generate.router, prefix="/api/v1")
app.include_router(test_evaluate.router, prefix="/api/v1")
_prompts_conn = get_prompts_db() _prompts_conn = get_prompts_db()
init_prompts_db(_prompts_conn) init_prompts_db(_prompts_conn)

View File

@ -0,0 +1,73 @@
import logging
from fastapi import APIRouter, HTTPException, Query
from app.core.config import get_settings
from app.core.dependencies import get_rag_service
from app.models.testing import EvaluateRequest
from app.services.prompt_service import PromptService
from app.services.test_evaluation_service import run_evaluation
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["test"])
def _get_storage_service() -> TestStorageService:
settings = get_settings()
return TestStorageService(
results_dir=settings.test_results_dir,
evaluations_dir=settings.test_evaluations_dir,
)
@router.post("/test/evaluate")
async def evaluate(request: EvaluateRequest):
settings = get_settings()
storage = _get_storage_service()
prompt_service = PromptService(db_path=settings.prompts_db_path)
prompt_service.activate_profile("A")
try:
rag = get_rag_service()
except Exception:
rag = None
try:
result = await run_evaluation(
request=request,
settings=settings,
storage=storage,
rag=rag,
)
return result.model_dump()
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error("Evaluation failed: %s", str(e), exc_info=True)
raise HTTPException(status_code=500, detail=f"Evaluation failed: {str(e)}")
@router.get("/test/evaluations")
async def list_evaluations(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
storage = _get_storage_service()
return storage.list_evaluations(limit=limit, offset=offset)
@router.get("/test/evaluations/{eval_id}")
async def get_evaluation(eval_id: str):
storage = _get_storage_service()
result = storage.load_evaluation(eval_id)
if result is None:
raise HTTPException(status_code=404, detail="Evaluation not found")
return result.model_dump()
@router.delete("/test/evaluations/{eval_id}")
async def delete_evaluation(eval_id: str):
storage = _get_storage_service()
deleted = storage.delete_evaluation(eval_id)
if not deleted:
raise HTTPException(status_code=404, detail="Evaluation not found")
return {"status": "deleted", "evaluation_id": eval_id}

View File

@ -0,0 +1,313 @@
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Set, Tuple
from app.core.config import Settings
from app.models.testing import (
AudioEvalResult,
ChunkAccuracy,
ChunkEvalResult,
EvaluateRequest,
EvaluationResult,
EvaluationTiming,
FilteredResult,
GenerateResult,
GroundTruthInfo,
KeyQuestionsEvalResult,
ResponseEvalResult,
RetrievalResult,
SubQuestionChunkEval,
SubQuestionResponseEval,
)
from app.services.cer_wer import calculate_cer, calculate_wer
from app.services.chunk_evaluator import (
_calculate_accuracy,
_determine_ground_truth_chunks,
)
from app.services.key_questions_evaluator import evaluate_key_questions
from app.services.response_evaluator import evaluate_response
from app.services.rag import RAGService
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
def _extract_chunk_sets(
retrieval: RetrievalResult,
filtered: FilteredResult,
) -> Tuple[Set[Tuple[str, int]], Set[Tuple[str, int]]]:
"""Extract (document_id, chunk_index) sets from retrieval and filtered results."""
retrieved = set()
filtered_set = set()
for sq in retrieval.per_sub_question:
for chunk in sq.chunks:
doc_id = chunk.metadata.get("document_id", "unknown")
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
retrieved.add((str(doc_id), int(chunk_idx)))
for sq in filtered.per_sub_question:
for chunk in sq.chunks:
doc_id = chunk.metadata.get("document_id", "unknown")
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
filtered_set.add((str(doc_id), int(chunk_idx)))
return retrieved, filtered_set
def _collect_all_chunks(
rag: RAGService,
) -> List[Tuple[str, int, str, Dict[str, Any]]]:
"""Fetch all chunks from all documents in ChromaDB.
Returns list of (document_id, chunk_index, text, metadata) tuples.
"""
docs, _, _ = rag.list_documents()
all_chunks = []
for doc in docs:
doc_id = doc["document_id"]
chunks = rag.list_chunks(doc_id)
for chunk in chunks:
chunk_idx = chunk.get("chunk_index", 0)
text = chunk.get("text", "")
all_chunks.append((doc_id, chunk_idx, text, chunk))
return all_chunks
async def run_evaluation(
request: EvaluateRequest,
settings: Settings,
storage: TestStorageService,
rag: Optional[RAGService] = None,
) -> EvaluationResult:
evaluation_id = uuid.uuid4().hex[:12]
overall_start = time.perf_counter()
# Load result
if request.result_id:
result = storage.load_result(request.result_id)
if result is None:
raise ValueError(f"Result not found: {request.result_id}")
elif request.results:
result = request.results
else:
raise ValueError("No result_id or inline results provided")
cfg = request.evaluation_config
total_ms = 0
audio_eval_result = None
kq_eval_result = None
chunk_eval_result = None
resp_eval_result = None
audio_time = 0
kq_time = 0
chunk_time = 0
resp_time = 0
# (i) Audio evaluation
if result.input_type == "audio" and result.input.reference_transcript:
t0 = time.perf_counter()
cer_data = calculate_cer(result.input.reference_transcript, result.input.text)
wer_data = calculate_wer(result.input.reference_transcript, result.input.text)
audio_eval_result = AudioEvalResult(
status="completed",
cer=cer_data["cer"],
wer=wer_data["wer"],
reference_length=cer_data["reference_length"],
transcribed_length=cer_data["transcribed_length"],
substitutions=cer_data["substitutions"],
deletions=cer_data["deletions"],
insertions=cer_data["insertions"],
hits=cer_data["hits"],
)
audio_time = int((time.perf_counter() - t0) * 1000)
# (ii) Key questions evaluation
if cfg.key_questions_evaluators:
t0 = time.perf_counter()
kq_eval_result = await evaluate_key_questions(
original_text=result.input.text,
extracted_questions=result.extracted_key_questions,
evaluator_configs=cfg.key_questions_evaluators,
)
kq_time = int((time.perf_counter() - t0) * 1000)
# (iii) Chunk evaluation
if rag and cfg.chunk_evaluator:
t0 = time.perf_counter()
all_chunks = _collect_all_chunks(rag)
retrieved_set, filtered_set = _extract_chunk_sets(result.retrieval, result.filtered)
per_sub_q = []
overall_unfiltered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
overall_filtered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
unfiltered_accuracies = []
filtered_accuracies = []
import asyncio
semaphore = asyncio.Semaphore(settings.eval_max_concurrent_batches)
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
gt_set, total_chunks, gt_time = await _determine_ground_truth_chunks(
sub_question=sq_text,
all_chunks=all_chunks,
config=cfg.chunk_evaluator,
semaphore=semaphore,
model_idx=sq_idx,
batch_size=settings.eval_chunk_batch_size,
)
relevant_docs = list(set(doc_id for doc_id, _ in gt_set))
relevant_chunk_dicts = [
{"document_id": doc_id, "chunk_index": idx}
for doc_id, idx in gt_set
]
gt_info = GroundTruthInfo(
relevant_documents=relevant_docs,
relevant_chunks=relevant_chunk_dicts,
total_relevant_chunks=len(gt_set),
chunk_evaluation_time_ms=gt_time,
)
sq_retrieved = {
(doc_id, idx)
for doc_id, idx in retrieved_set
}
sq_filtered = {
(doc_id, idx)
for doc_id, idx in filtered_set
}
unf_acc = _calculate_accuracy(sq_retrieved, gt_set)
fil_acc = _calculate_accuracy(sq_filtered, gt_set)
unfiltered_accuracies.append(unf_acc)
filtered_accuracies.append(fil_acc)
per_sub_q.append(
SubQuestionChunkEval(
sub_question_index=sq_idx,
sub_question_text=sq_text,
ground_truth=gt_info,
unfiltered_accuracy=unf_acc,
filtered_accuracy=fil_acc,
)
)
if unfiltered_accuracies:
n = len(unfiltered_accuracies)
overall_unfiltered_metrics = {
"precision": round(sum(a.precision for a in unfiltered_accuracies) / n, 4),
"recall": round(sum(a.recall for a in unfiltered_accuracies) / n, 4),
"f1": round(sum(a.f1 for a in unfiltered_accuracies) / n, 4),
}
if filtered_accuracies:
n = len(filtered_accuracies)
overall_filtered_metrics = {
"precision": round(sum(a.precision for a in filtered_accuracies) / n, 4),
"recall": round(sum(a.recall for a in filtered_accuracies) / n, 4),
"f1": round(sum(a.f1 for a in filtered_accuracies) / n, 4),
}
chunk_eval_result = ChunkEvalResult(
per_sub_question=per_sub_q,
overall_unfiltered=ChunkAccuracy(
precision=overall_unfiltered_metrics["precision"],
recall=overall_unfiltered_metrics["recall"],
f1=overall_unfiltered_metrics["f1"],
pipeline_chunks=sum(a.pipeline_chunks for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
),
overall_filtered=ChunkAccuracy(
precision=overall_filtered_metrics["precision"],
recall=overall_filtered_metrics["recall"],
f1=overall_filtered_metrics["f1"],
pipeline_chunks=sum(a.pipeline_chunks for a in filtered_accuracies) if filtered_accuracies else 0,
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in filtered_accuracies) if filtered_accuracies else 0,
),
)
chunk_time = int((time.perf_counter() - t0) * 1000)
# (iv) Response evaluation
if rag and cfg.response_evaluator and chunk_eval_result:
t0 = time.perf_counter()
per_sub_q_resp = []
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
if sq_idx < len(chunk_eval_result.per_sub_question):
gt_chunks_data = chunk_eval_result.per_sub_question[sq_idx].ground_truth
relevant_chunks_meta = []
for rc in gt_chunks_data.relevant_chunks[:10]:
doc_id = rc["document_id"]
chunk_idx = rc["chunk_index"]
# Try to match with pipeline's response sources
for sq in result.response.sub_question_sources:
if isinstance(sq, dict):
for s in sq.get("sources", []):
if s.get("document_id") == doc_id and s.get("chunk_index") == chunk_idx:
relevant_chunks_meta.append((s.get("content_summary", ""), s))
break
elif hasattr(sq, "sources"):
for s in sq.sources:
if hasattr(s, "document_id") and s.document_id == doc_id and s.chunk_index == chunk_idx:
relevant_chunks_meta.append((s.content_summary, s.model_dump() if hasattr(s, "model_dump") else {}))
break
if relevant_chunks_meta:
section_text = ""
for src in result.response.sub_question_sources:
if isinstance(src, dict):
if src.get("sub_question_index") == sq_idx:
section_text = str(src)
elif hasattr(src, "sub_question_index") and src.sub_question_index == sq_idx:
section_text = str(src)
resp_eval = await evaluate_response(
key_question=sq_text,
ground_truth_chunks=relevant_chunks_meta,
pipeline_response=section_text or result.response.final_answer,
evaluator_config=cfg.response_evaluator,
)
if resp_eval:
resp_eval.sub_question_index = sq_idx
resp_eval.sub_question_text = sq_text
per_sub_q_resp.append(resp_eval)
overall_completeness = 0.0
overall_factual = 0.0
if per_sub_q_resp:
overall_completeness = round(sum(r.completeness_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
overall_factual = round(sum(r.factual_accuracy_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
resp_eval_result = ResponseEvalResult(
per_sub_question=per_sub_q_resp,
overall_completeness=overall_completeness,
overall_factual_accuracy=overall_factual,
)
resp_time = int((time.perf_counter() - t0) * 1000)
total_ms = int((time.perf_counter() - overall_start) * 1000)
eval_result = EvaluationResult(
evaluation_id=evaluation_id,
result_id=result.result_id,
status="completed",
audio_evaluation=audio_eval_result,
key_questions_evaluation=kq_eval_result,
chunk_evaluation=chunk_eval_result,
response_evaluation=resp_eval_result,
timing=EvaluationTiming(
audio_evaluation_time_ms=audio_time,
key_questions_evaluation_time_ms=kq_time,
chunk_evaluation_time_ms=chunk_time,
response_evaluation_time_ms=resp_time,
total_evaluation_time_ms=total_ms,
),
)
storage.save_evaluation(eval_result)
return eval_result

View File

@ -0,0 +1,135 @@
"""Phase 9 tests: Evaluation API endpoint integration (Sub-Phase 9.3)."""
import json
from unittest.mock import AsyncMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.models.testing import (
ChunkAccuracy,
DimensionScores,
EvaluatorConfig,
EvaluationResult,
FilteredResult,
GenerateResult,
InputInfo,
KeyQuestionsEvalEntry,
KeyQuestionsEvalResult,
ResponseResult,
RetrievalResult,
TimingInfo,
)
@pytest.fixture(autouse=True)
def _set_api_keys(monkeypatch):
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
@pytest.fixture
def client(tmp_path, monkeypatch):
results_dir = str(tmp_path / "test_results")
evals_dir = str(tmp_path / "test_evaluations")
prompts_path = str(tmp_path / "prompts.db")
history_path = str(tmp_path / "history.db")
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
from app.core.config import get_settings
get_settings.cache_clear()
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
conn = _get_db(prompts_path)
init_prompts_db(conn)
seed_default_profiles(conn)
conn.close()
hconn = _get_db(history_path)
init_history_db(hconn)
hconn.close()
from app.routers.test_evaluate import router
test_app = FastAPI()
test_app.include_router(router, prefix="/api/v1")
yield TestClient(test_app)
get_settings.cache_clear()
def _make_sample_result():
return GenerateResult(
result_id="test-result-001",
input_type="text",
profile="A",
input=InputInfo(text="test question"),
extracted_key_questions=["q1", "q2"],
retrieval=RetrievalResult(per_sub_question=[], total_chunks_retrieved=10, retriever_time_ms=100),
filtered=FilteredResult(per_sub_question=[], total_chunks_filtered=5, filter_time_ms=100),
response=ResponseResult(final_answer="answer", sub_question_sources=[], generate_time_ms=100),
timing=TimingInfo(decomposer_time_ms=100, retriever_time_ms=100, filter_time_ms=100, generator_time_ms=100, total_time_ms=400),
)
@pytest.fixture
def saved_result(client):
from app.services.test_storage_service import TestStorageService
from app.core.config import get_settings
result = _make_sample_result()
svc = TestStorageService(get_settings().test_results_dir, get_settings().test_evaluations_dir)
svc.save_result(result)
return result.result_id
class TestEvaluateEndpoint:
@pytest.mark.asyncio
async def test_valid_evaluate_returns_200(self, client, saved_result):
mock_scores = DimensionScores(dimension_1_準確性=35.0, dimension_2_完整性=22.0, dimension_3_清晰度=18.0, dimension_4_簡潔性=13.0)
mock_kq = KeyQuestionsEvalResult(
evaluations=[
KeyQuestionsEvalEntry(model_name="m1", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100),
KeyQuestionsEvalEntry(model_name="m2", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100),
],
average_scores=mock_scores,
average_total=88.0,
)
payload = {
"result_id": saved_result,
"evaluation_config": {
"key_questions_evaluators": [
{"model_name": "deepseek-v4-pro", "base_url": "https://api.deepseek.com", "api_key_env": "DP_API_KEY", "enable_thinking": True},
{"model_name": "qwen3-7b-max", "base_url": "https://dashscope.example.com/v1", "api_key_env": "DASHSCOPE_API_KEY", "enable_thinking": True},
],
"chunk_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True},
"response_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True},
},
}
resp = client.post("/api/v1/test/evaluate", json=payload)
assert resp.status_code == 200
data = resp.json()
assert data["status"] in ("completed", "partial")
assert "evaluation_id" in data
def test_missing_result_returns_404(self, client):
payload = {
"result_id": "no-such-id",
"evaluation_config": {
"key_questions_evaluators": [],
"chunk_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"},
"response_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"},
},
}
resp = client.post("/api/v1/test/evaluate", json=payload)
assert resp.status_code == 404