diff --git a/backend/app/main.py b/backend/app/main.py index fecebac..9dbc822 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate +from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate, test_evaluate from app.core.config import get_settings from app.core.sqlite_db import ( get_prompts_db, @@ -59,6 +59,7 @@ app.include_router(chunks.router) app.include_router(video.router, prefix="/api/v1") app.include_router(ws_asr.router) app.include_router(test_generate.router, prefix="/api/v1") +app.include_router(test_evaluate.router, prefix="/api/v1") _prompts_conn = get_prompts_db() init_prompts_db(_prompts_conn) diff --git a/backend/app/routers/test_evaluate.py b/backend/app/routers/test_evaluate.py new file mode 100644 index 0000000..31af609 --- /dev/null +++ b/backend/app/routers/test_evaluate.py @@ -0,0 +1,73 @@ +import logging + +from fastapi import APIRouter, HTTPException, Query + +from app.core.config import get_settings +from app.core.dependencies import get_rag_service +from app.models.testing import EvaluateRequest +from app.services.prompt_service import PromptService +from app.services.test_evaluation_service import run_evaluation +from app.services.test_storage_service import TestStorageService + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["test"]) + + +def _get_storage_service() -> TestStorageService: + settings = get_settings() + return TestStorageService( + results_dir=settings.test_results_dir, + evaluations_dir=settings.test_evaluations_dir, + ) + + +@router.post("/test/evaluate") +async def evaluate(request: EvaluateRequest): + settings = get_settings() + storage = _get_storage_service() + + prompt_service = PromptService(db_path=settings.prompts_db_path) + prompt_service.activate_profile("A") + + try: + rag = get_rag_service() + except Exception: + rag = None + + try: + result = await run_evaluation( + request=request, + settings=settings, + storage=storage, + rag=rag, + ) + return result.model_dump() + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + logger.error("Evaluation failed: %s", str(e), exc_info=True) + raise HTTPException(status_code=500, detail=f"Evaluation failed: {str(e)}") + + +@router.get("/test/evaluations") +async def list_evaluations(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)): + storage = _get_storage_service() + return storage.list_evaluations(limit=limit, offset=offset) + + +@router.get("/test/evaluations/{eval_id}") +async def get_evaluation(eval_id: str): + storage = _get_storage_service() + result = storage.load_evaluation(eval_id) + if result is None: + raise HTTPException(status_code=404, detail="Evaluation not found") + return result.model_dump() + + +@router.delete("/test/evaluations/{eval_id}") +async def delete_evaluation(eval_id: str): + storage = _get_storage_service() + deleted = storage.delete_evaluation(eval_id) + if not deleted: + raise HTTPException(status_code=404, detail="Evaluation not found") + return {"status": "deleted", "evaluation_id": eval_id} diff --git a/backend/app/services/test_evaluation_service.py b/backend/app/services/test_evaluation_service.py new file mode 100644 index 0000000..0f4249c --- /dev/null +++ b/backend/app/services/test_evaluation_service.py @@ -0,0 +1,313 @@ +import logging +import time +import uuid +from typing import Any, Dict, List, Optional, Set, Tuple + +from app.core.config import Settings +from app.models.testing import ( + AudioEvalResult, + ChunkAccuracy, + ChunkEvalResult, + EvaluateRequest, + EvaluationResult, + EvaluationTiming, + FilteredResult, + GenerateResult, + GroundTruthInfo, + KeyQuestionsEvalResult, + ResponseEvalResult, + RetrievalResult, + SubQuestionChunkEval, + SubQuestionResponseEval, +) +from app.services.cer_wer import calculate_cer, calculate_wer +from app.services.chunk_evaluator import ( + _calculate_accuracy, + _determine_ground_truth_chunks, +) +from app.services.key_questions_evaluator import evaluate_key_questions +from app.services.response_evaluator import evaluate_response +from app.services.rag import RAGService +from app.services.test_storage_service import TestStorageService + +logger = logging.getLogger(__name__) + + +def _extract_chunk_sets( + retrieval: RetrievalResult, + filtered: FilteredResult, +) -> Tuple[Set[Tuple[str, int]], Set[Tuple[str, int]]]: + """Extract (document_id, chunk_index) sets from retrieval and filtered results.""" + retrieved = set() + filtered_set = set() + + for sq in retrieval.per_sub_question: + for chunk in sq.chunks: + doc_id = chunk.metadata.get("document_id", "unknown") + chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index) + retrieved.add((str(doc_id), int(chunk_idx))) + + for sq in filtered.per_sub_question: + for chunk in sq.chunks: + doc_id = chunk.metadata.get("document_id", "unknown") + chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index) + filtered_set.add((str(doc_id), int(chunk_idx))) + + return retrieved, filtered_set + + +def _collect_all_chunks( + rag: RAGService, +) -> List[Tuple[str, int, str, Dict[str, Any]]]: + """Fetch all chunks from all documents in ChromaDB. + + Returns list of (document_id, chunk_index, text, metadata) tuples. + """ + docs, _, _ = rag.list_documents() + all_chunks = [] + + for doc in docs: + doc_id = doc["document_id"] + chunks = rag.list_chunks(doc_id) + for chunk in chunks: + chunk_idx = chunk.get("chunk_index", 0) + text = chunk.get("text", "") + all_chunks.append((doc_id, chunk_idx, text, chunk)) + + return all_chunks + + +async def run_evaluation( + request: EvaluateRequest, + settings: Settings, + storage: TestStorageService, + rag: Optional[RAGService] = None, +) -> EvaluationResult: + evaluation_id = uuid.uuid4().hex[:12] + overall_start = time.perf_counter() + + # Load result + if request.result_id: + result = storage.load_result(request.result_id) + if result is None: + raise ValueError(f"Result not found: {request.result_id}") + elif request.results: + result = request.results + else: + raise ValueError("No result_id or inline results provided") + + cfg = request.evaluation_config + total_ms = 0 + audio_eval_result = None + kq_eval_result = None + chunk_eval_result = None + resp_eval_result = None + audio_time = 0 + kq_time = 0 + chunk_time = 0 + resp_time = 0 + + # (i) Audio evaluation + if result.input_type == "audio" and result.input.reference_transcript: + t0 = time.perf_counter() + cer_data = calculate_cer(result.input.reference_transcript, result.input.text) + wer_data = calculate_wer(result.input.reference_transcript, result.input.text) + audio_eval_result = AudioEvalResult( + status="completed", + cer=cer_data["cer"], + wer=wer_data["wer"], + reference_length=cer_data["reference_length"], + transcribed_length=cer_data["transcribed_length"], + substitutions=cer_data["substitutions"], + deletions=cer_data["deletions"], + insertions=cer_data["insertions"], + hits=cer_data["hits"], + ) + audio_time = int((time.perf_counter() - t0) * 1000) + + # (ii) Key questions evaluation + if cfg.key_questions_evaluators: + t0 = time.perf_counter() + kq_eval_result = await evaluate_key_questions( + original_text=result.input.text, + extracted_questions=result.extracted_key_questions, + evaluator_configs=cfg.key_questions_evaluators, + ) + kq_time = int((time.perf_counter() - t0) * 1000) + + # (iii) Chunk evaluation + if rag and cfg.chunk_evaluator: + t0 = time.perf_counter() + all_chunks = _collect_all_chunks(rag) + retrieved_set, filtered_set = _extract_chunk_sets(result.retrieval, result.filtered) + + per_sub_q = [] + overall_unfiltered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0} + overall_filtered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0} + unfiltered_accuracies = [] + filtered_accuracies = [] + + import asyncio + semaphore = asyncio.Semaphore(settings.eval_max_concurrent_batches) + + for sq_idx, sq_text in enumerate(result.extracted_key_questions): + gt_set, total_chunks, gt_time = await _determine_ground_truth_chunks( + sub_question=sq_text, + all_chunks=all_chunks, + config=cfg.chunk_evaluator, + semaphore=semaphore, + model_idx=sq_idx, + batch_size=settings.eval_chunk_batch_size, + ) + + relevant_docs = list(set(doc_id for doc_id, _ in gt_set)) + relevant_chunk_dicts = [ + {"document_id": doc_id, "chunk_index": idx} + for doc_id, idx in gt_set + ] + + gt_info = GroundTruthInfo( + relevant_documents=relevant_docs, + relevant_chunks=relevant_chunk_dicts, + total_relevant_chunks=len(gt_set), + chunk_evaluation_time_ms=gt_time, + ) + + sq_retrieved = { + (doc_id, idx) + for doc_id, idx in retrieved_set + } + sq_filtered = { + (doc_id, idx) + for doc_id, idx in filtered_set + } + + unf_acc = _calculate_accuracy(sq_retrieved, gt_set) + fil_acc = _calculate_accuracy(sq_filtered, gt_set) + + unfiltered_accuracies.append(unf_acc) + filtered_accuracies.append(fil_acc) + + per_sub_q.append( + SubQuestionChunkEval( + sub_question_index=sq_idx, + sub_question_text=sq_text, + ground_truth=gt_info, + unfiltered_accuracy=unf_acc, + filtered_accuracy=fil_acc, + ) + ) + + if unfiltered_accuracies: + n = len(unfiltered_accuracies) + overall_unfiltered_metrics = { + "precision": round(sum(a.precision for a in unfiltered_accuracies) / n, 4), + "recall": round(sum(a.recall for a in unfiltered_accuracies) / n, 4), + "f1": round(sum(a.f1 for a in unfiltered_accuracies) / n, 4), + } + if filtered_accuracies: + n = len(filtered_accuracies) + overall_filtered_metrics = { + "precision": round(sum(a.precision for a in filtered_accuracies) / n, 4), + "recall": round(sum(a.recall for a in filtered_accuracies) / n, 4), + "f1": round(sum(a.f1 for a in filtered_accuracies) / n, 4), + } + + chunk_eval_result = ChunkEvalResult( + per_sub_question=per_sub_q, + overall_unfiltered=ChunkAccuracy( + precision=overall_unfiltered_metrics["precision"], + recall=overall_unfiltered_metrics["recall"], + f1=overall_unfiltered_metrics["f1"], + pipeline_chunks=sum(a.pipeline_chunks for a in unfiltered_accuracies) if unfiltered_accuracies else 0, + relevant_in_pipeline=sum(a.relevant_in_pipeline for a in unfiltered_accuracies) if unfiltered_accuracies else 0, + ), + overall_filtered=ChunkAccuracy( + precision=overall_filtered_metrics["precision"], + recall=overall_filtered_metrics["recall"], + f1=overall_filtered_metrics["f1"], + pipeline_chunks=sum(a.pipeline_chunks for a in filtered_accuracies) if filtered_accuracies else 0, + relevant_in_pipeline=sum(a.relevant_in_pipeline for a in filtered_accuracies) if filtered_accuracies else 0, + ), + ) + chunk_time = int((time.perf_counter() - t0) * 1000) + + # (iv) Response evaluation + if rag and cfg.response_evaluator and chunk_eval_result: + t0 = time.perf_counter() + per_sub_q_resp = [] + + for sq_idx, sq_text in enumerate(result.extracted_key_questions): + if sq_idx < len(chunk_eval_result.per_sub_question): + gt_chunks_data = chunk_eval_result.per_sub_question[sq_idx].ground_truth + relevant_chunks_meta = [] + for rc in gt_chunks_data.relevant_chunks[:10]: + doc_id = rc["document_id"] + chunk_idx = rc["chunk_index"] + # Try to match with pipeline's response sources + for sq in result.response.sub_question_sources: + if isinstance(sq, dict): + for s in sq.get("sources", []): + if s.get("document_id") == doc_id and s.get("chunk_index") == chunk_idx: + relevant_chunks_meta.append((s.get("content_summary", ""), s)) + break + elif hasattr(sq, "sources"): + for s in sq.sources: + if hasattr(s, "document_id") and s.document_id == doc_id and s.chunk_index == chunk_idx: + relevant_chunks_meta.append((s.content_summary, s.model_dump() if hasattr(s, "model_dump") else {})) + break + + if relevant_chunks_meta: + section_text = "" + for src in result.response.sub_question_sources: + if isinstance(src, dict): + if src.get("sub_question_index") == sq_idx: + section_text = str(src) + elif hasattr(src, "sub_question_index") and src.sub_question_index == sq_idx: + section_text = str(src) + + resp_eval = await evaluate_response( + key_question=sq_text, + ground_truth_chunks=relevant_chunks_meta, + pipeline_response=section_text or result.response.final_answer, + evaluator_config=cfg.response_evaluator, + ) + if resp_eval: + resp_eval.sub_question_index = sq_idx + resp_eval.sub_question_text = sq_text + per_sub_q_resp.append(resp_eval) + + overall_completeness = 0.0 + overall_factual = 0.0 + if per_sub_q_resp: + overall_completeness = round(sum(r.completeness_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4) + overall_factual = round(sum(r.factual_accuracy_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4) + + resp_eval_result = ResponseEvalResult( + per_sub_question=per_sub_q_resp, + overall_completeness=overall_completeness, + overall_factual_accuracy=overall_factual, + ) + resp_time = int((time.perf_counter() - t0) * 1000) + + total_ms = int((time.perf_counter() - overall_start) * 1000) + + eval_result = EvaluationResult( + evaluation_id=evaluation_id, + result_id=result.result_id, + status="completed", + audio_evaluation=audio_eval_result, + key_questions_evaluation=kq_eval_result, + chunk_evaluation=chunk_eval_result, + response_evaluation=resp_eval_result, + timing=EvaluationTiming( + audio_evaluation_time_ms=audio_time, + key_questions_evaluation_time_ms=kq_time, + chunk_evaluation_time_ms=chunk_time, + response_evaluation_time_ms=resp_time, + total_evaluation_time_ms=total_ms, + ), + ) + + storage.save_evaluation(eval_result) + return eval_result diff --git a/backend/app/test/test_phase9_evaluate.py b/backend/app/test/test_phase9_evaluate.py new file mode 100644 index 0000000..f839706 --- /dev/null +++ b/backend/app/test/test_phase9_evaluate.py @@ -0,0 +1,135 @@ +"""Phase 9 tests: Evaluation API endpoint integration (Sub-Phase 9.3).""" +import json +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app.models.testing import ( + ChunkAccuracy, + DimensionScores, + EvaluatorConfig, + EvaluationResult, + FilteredResult, + GenerateResult, + InputInfo, + KeyQuestionsEvalEntry, + KeyQuestionsEvalResult, + ResponseResult, + RetrievalResult, + TimingInfo, +) + + +@pytest.fixture(autouse=True) +def _set_api_keys(monkeypatch): + monkeypatch.setenv("LLM_API_KEY", "test-key") + monkeypatch.setenv("DP_API_KEY", "test-dp-key") + monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key") + + +@pytest.fixture +def client(tmp_path, monkeypatch): + results_dir = str(tmp_path / "test_results") + evals_dir = str(tmp_path / "test_evaluations") + prompts_path = str(tmp_path / "prompts.db") + history_path = str(tmp_path / "history.db") + + monkeypatch.setenv("TEST_RESULTS_DIR", results_dir) + monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir) + monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path) + monkeypatch.setenv("HISTORY_DB_PATH", history_path) + monkeypatch.setenv("LLM_API_KEY", "test-key") + monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1") + monkeypatch.setenv("LLM_MODEL_NAME", "test-model") + monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding") + + from app.core.config import get_settings + get_settings.cache_clear() + + from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles + conn = _get_db(prompts_path) + init_prompts_db(conn) + seed_default_profiles(conn) + conn.close() + + hconn = _get_db(history_path) + init_history_db(hconn) + hconn.close() + + from app.routers.test_evaluate import router + test_app = FastAPI() + test_app.include_router(router, prefix="/api/v1") + yield TestClient(test_app) + + get_settings.cache_clear() + + +def _make_sample_result(): + return GenerateResult( + result_id="test-result-001", + input_type="text", + profile="A", + input=InputInfo(text="test question"), + extracted_key_questions=["q1", "q2"], + retrieval=RetrievalResult(per_sub_question=[], total_chunks_retrieved=10, retriever_time_ms=100), + filtered=FilteredResult(per_sub_question=[], total_chunks_filtered=5, filter_time_ms=100), + response=ResponseResult(final_answer="answer", sub_question_sources=[], generate_time_ms=100), + timing=TimingInfo(decomposer_time_ms=100, retriever_time_ms=100, filter_time_ms=100, generator_time_ms=100, total_time_ms=400), + ) + + +@pytest.fixture +def saved_result(client): + from app.services.test_storage_service import TestStorageService + from app.core.config import get_settings + + result = _make_sample_result() + svc = TestStorageService(get_settings().test_results_dir, get_settings().test_evaluations_dir) + svc.save_result(result) + return result.result_id + + +class TestEvaluateEndpoint: + @pytest.mark.asyncio + async def test_valid_evaluate_returns_200(self, client, saved_result): + mock_scores = DimensionScores(dimension_1_準確性=35.0, dimension_2_完整性=22.0, dimension_3_清晰度=18.0, dimension_4_簡潔性=13.0) + mock_kq = KeyQuestionsEvalResult( + evaluations=[ + KeyQuestionsEvalEntry(model_name="m1", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100), + KeyQuestionsEvalEntry(model_name="m2", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100), + ], + average_scores=mock_scores, + average_total=88.0, + ) + + payload = { + "result_id": saved_result, + "evaluation_config": { + "key_questions_evaluators": [ + {"model_name": "deepseek-v4-pro", "base_url": "https://api.deepseek.com", "api_key_env": "DP_API_KEY", "enable_thinking": True}, + {"model_name": "qwen3-7b-max", "base_url": "https://dashscope.example.com/v1", "api_key_env": "DASHSCOPE_API_KEY", "enable_thinking": True}, + ], + "chunk_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True}, + "response_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True}, + }, + } + + resp = client.post("/api/v1/test/evaluate", json=payload) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] in ("completed", "partial") + assert "evaluation_id" in data + + def test_missing_result_returns_404(self, client): + payload = { + "result_id": "no-such-id", + "evaluation_config": { + "key_questions_evaluators": [], + "chunk_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"}, + "response_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"}, + }, + } + resp = client.post("/api/v1/test/evaluate", json=payload) + assert resp.status_code == 404