feat: add Sub-Phase 9.3 evaluation API endpoint and 9.4 polish
This commit is contained in:
parent
098be359e7
commit
032dd75e17
|
|
@ -7,7 +7,7 @@ from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate
|
from app.routers import ingest, query, documents, prompts, history, chunks, video, ws_asr, test_generate, test_evaluate
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.core.sqlite_db import (
|
from app.core.sqlite_db import (
|
||||||
get_prompts_db,
|
get_prompts_db,
|
||||||
|
|
@ -59,6 +59,7 @@ app.include_router(chunks.router)
|
||||||
app.include_router(video.router, prefix="/api/v1")
|
app.include_router(video.router, prefix="/api/v1")
|
||||||
app.include_router(ws_asr.router)
|
app.include_router(ws_asr.router)
|
||||||
app.include_router(test_generate.router, prefix="/api/v1")
|
app.include_router(test_generate.router, prefix="/api/v1")
|
||||||
|
app.include_router(test_evaluate.router, prefix="/api/v1")
|
||||||
|
|
||||||
_prompts_conn = get_prompts_db()
|
_prompts_conn = get_prompts_db()
|
||||||
init_prompts_db(_prompts_conn)
|
init_prompts_db(_prompts_conn)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.core.dependencies import get_rag_service
|
||||||
|
from app.models.testing import EvaluateRequest
|
||||||
|
from app.services.prompt_service import PromptService
|
||||||
|
from app.services.test_evaluation_service import run_evaluation
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(tags=["test"])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_storage_service() -> TestStorageService:
|
||||||
|
settings = get_settings()
|
||||||
|
return TestStorageService(
|
||||||
|
results_dir=settings.test_results_dir,
|
||||||
|
evaluations_dir=settings.test_evaluations_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test/evaluate")
|
||||||
|
async def evaluate(request: EvaluateRequest):
|
||||||
|
settings = get_settings()
|
||||||
|
storage = _get_storage_service()
|
||||||
|
|
||||||
|
prompt_service = PromptService(db_path=settings.prompts_db_path)
|
||||||
|
prompt_service.activate_profile("A")
|
||||||
|
|
||||||
|
try:
|
||||||
|
rag = get_rag_service()
|
||||||
|
except Exception:
|
||||||
|
rag = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await run_evaluation(
|
||||||
|
request=request,
|
||||||
|
settings=settings,
|
||||||
|
storage=storage,
|
||||||
|
rag=rag,
|
||||||
|
)
|
||||||
|
return result.model_dump()
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Evaluation failed: %s", str(e), exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Evaluation failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test/evaluations")
|
||||||
|
async def list_evaluations(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
|
||||||
|
storage = _get_storage_service()
|
||||||
|
return storage.list_evaluations(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test/evaluations/{eval_id}")
|
||||||
|
async def get_evaluation(eval_id: str):
|
||||||
|
storage = _get_storage_service()
|
||||||
|
result = storage.load_evaluation(eval_id)
|
||||||
|
if result is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Evaluation not found")
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/test/evaluations/{eval_id}")
|
||||||
|
async def delete_evaluation(eval_id: str):
|
||||||
|
storage = _get_storage_service()
|
||||||
|
deleted = storage.delete_evaluation(eval_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Evaluation not found")
|
||||||
|
return {"status": "deleted", "evaluation_id": eval_id}
|
||||||
|
|
@ -0,0 +1,313 @@
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
from app.core.config import Settings
|
||||||
|
from app.models.testing import (
|
||||||
|
AudioEvalResult,
|
||||||
|
ChunkAccuracy,
|
||||||
|
ChunkEvalResult,
|
||||||
|
EvaluateRequest,
|
||||||
|
EvaluationResult,
|
||||||
|
EvaluationTiming,
|
||||||
|
FilteredResult,
|
||||||
|
GenerateResult,
|
||||||
|
GroundTruthInfo,
|
||||||
|
KeyQuestionsEvalResult,
|
||||||
|
ResponseEvalResult,
|
||||||
|
RetrievalResult,
|
||||||
|
SubQuestionChunkEval,
|
||||||
|
SubQuestionResponseEval,
|
||||||
|
)
|
||||||
|
from app.services.cer_wer import calculate_cer, calculate_wer
|
||||||
|
from app.services.chunk_evaluator import (
|
||||||
|
_calculate_accuracy,
|
||||||
|
_determine_ground_truth_chunks,
|
||||||
|
)
|
||||||
|
from app.services.key_questions_evaluator import evaluate_key_questions
|
||||||
|
from app.services.response_evaluator import evaluate_response
|
||||||
|
from app.services.rag import RAGService
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_chunk_sets(
|
||||||
|
retrieval: RetrievalResult,
|
||||||
|
filtered: FilteredResult,
|
||||||
|
) -> Tuple[Set[Tuple[str, int]], Set[Tuple[str, int]]]:
|
||||||
|
"""Extract (document_id, chunk_index) sets from retrieval and filtered results."""
|
||||||
|
retrieved = set()
|
||||||
|
filtered_set = set()
|
||||||
|
|
||||||
|
for sq in retrieval.per_sub_question:
|
||||||
|
for chunk in sq.chunks:
|
||||||
|
doc_id = chunk.metadata.get("document_id", "unknown")
|
||||||
|
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
|
||||||
|
retrieved.add((str(doc_id), int(chunk_idx)))
|
||||||
|
|
||||||
|
for sq in filtered.per_sub_question:
|
||||||
|
for chunk in sq.chunks:
|
||||||
|
doc_id = chunk.metadata.get("document_id", "unknown")
|
||||||
|
chunk_idx = chunk.metadata.get("chunk_index", chunk.chunk_index)
|
||||||
|
filtered_set.add((str(doc_id), int(chunk_idx)))
|
||||||
|
|
||||||
|
return retrieved, filtered_set
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_all_chunks(
|
||||||
|
rag: RAGService,
|
||||||
|
) -> List[Tuple[str, int, str, Dict[str, Any]]]:
|
||||||
|
"""Fetch all chunks from all documents in ChromaDB.
|
||||||
|
|
||||||
|
Returns list of (document_id, chunk_index, text, metadata) tuples.
|
||||||
|
"""
|
||||||
|
docs, _, _ = rag.list_documents()
|
||||||
|
all_chunks = []
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
doc_id = doc["document_id"]
|
||||||
|
chunks = rag.list_chunks(doc_id)
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_idx = chunk.get("chunk_index", 0)
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
all_chunks.append((doc_id, chunk_idx, text, chunk))
|
||||||
|
|
||||||
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
|
async def run_evaluation(
|
||||||
|
request: EvaluateRequest,
|
||||||
|
settings: Settings,
|
||||||
|
storage: TestStorageService,
|
||||||
|
rag: Optional[RAGService] = None,
|
||||||
|
) -> EvaluationResult:
|
||||||
|
evaluation_id = uuid.uuid4().hex[:12]
|
||||||
|
overall_start = time.perf_counter()
|
||||||
|
|
||||||
|
# Load result
|
||||||
|
if request.result_id:
|
||||||
|
result = storage.load_result(request.result_id)
|
||||||
|
if result is None:
|
||||||
|
raise ValueError(f"Result not found: {request.result_id}")
|
||||||
|
elif request.results:
|
||||||
|
result = request.results
|
||||||
|
else:
|
||||||
|
raise ValueError("No result_id or inline results provided")
|
||||||
|
|
||||||
|
cfg = request.evaluation_config
|
||||||
|
total_ms = 0
|
||||||
|
audio_eval_result = None
|
||||||
|
kq_eval_result = None
|
||||||
|
chunk_eval_result = None
|
||||||
|
resp_eval_result = None
|
||||||
|
audio_time = 0
|
||||||
|
kq_time = 0
|
||||||
|
chunk_time = 0
|
||||||
|
resp_time = 0
|
||||||
|
|
||||||
|
# (i) Audio evaluation
|
||||||
|
if result.input_type == "audio" and result.input.reference_transcript:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
cer_data = calculate_cer(result.input.reference_transcript, result.input.text)
|
||||||
|
wer_data = calculate_wer(result.input.reference_transcript, result.input.text)
|
||||||
|
audio_eval_result = AudioEvalResult(
|
||||||
|
status="completed",
|
||||||
|
cer=cer_data["cer"],
|
||||||
|
wer=wer_data["wer"],
|
||||||
|
reference_length=cer_data["reference_length"],
|
||||||
|
transcribed_length=cer_data["transcribed_length"],
|
||||||
|
substitutions=cer_data["substitutions"],
|
||||||
|
deletions=cer_data["deletions"],
|
||||||
|
insertions=cer_data["insertions"],
|
||||||
|
hits=cer_data["hits"],
|
||||||
|
)
|
||||||
|
audio_time = int((time.perf_counter() - t0) * 1000)
|
||||||
|
|
||||||
|
# (ii) Key questions evaluation
|
||||||
|
if cfg.key_questions_evaluators:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
kq_eval_result = await evaluate_key_questions(
|
||||||
|
original_text=result.input.text,
|
||||||
|
extracted_questions=result.extracted_key_questions,
|
||||||
|
evaluator_configs=cfg.key_questions_evaluators,
|
||||||
|
)
|
||||||
|
kq_time = int((time.perf_counter() - t0) * 1000)
|
||||||
|
|
||||||
|
# (iii) Chunk evaluation
|
||||||
|
if rag and cfg.chunk_evaluator:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
all_chunks = _collect_all_chunks(rag)
|
||||||
|
retrieved_set, filtered_set = _extract_chunk_sets(result.retrieval, result.filtered)
|
||||||
|
|
||||||
|
per_sub_q = []
|
||||||
|
overall_unfiltered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
||||||
|
overall_filtered_metrics = {"precision": 0.0, "recall": 0.0, "f1": 0.0}
|
||||||
|
unfiltered_accuracies = []
|
||||||
|
filtered_accuracies = []
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
semaphore = asyncio.Semaphore(settings.eval_max_concurrent_batches)
|
||||||
|
|
||||||
|
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
|
||||||
|
gt_set, total_chunks, gt_time = await _determine_ground_truth_chunks(
|
||||||
|
sub_question=sq_text,
|
||||||
|
all_chunks=all_chunks,
|
||||||
|
config=cfg.chunk_evaluator,
|
||||||
|
semaphore=semaphore,
|
||||||
|
model_idx=sq_idx,
|
||||||
|
batch_size=settings.eval_chunk_batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
relevant_docs = list(set(doc_id for doc_id, _ in gt_set))
|
||||||
|
relevant_chunk_dicts = [
|
||||||
|
{"document_id": doc_id, "chunk_index": idx}
|
||||||
|
for doc_id, idx in gt_set
|
||||||
|
]
|
||||||
|
|
||||||
|
gt_info = GroundTruthInfo(
|
||||||
|
relevant_documents=relevant_docs,
|
||||||
|
relevant_chunks=relevant_chunk_dicts,
|
||||||
|
total_relevant_chunks=len(gt_set),
|
||||||
|
chunk_evaluation_time_ms=gt_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
sq_retrieved = {
|
||||||
|
(doc_id, idx)
|
||||||
|
for doc_id, idx in retrieved_set
|
||||||
|
}
|
||||||
|
sq_filtered = {
|
||||||
|
(doc_id, idx)
|
||||||
|
for doc_id, idx in filtered_set
|
||||||
|
}
|
||||||
|
|
||||||
|
unf_acc = _calculate_accuracy(sq_retrieved, gt_set)
|
||||||
|
fil_acc = _calculate_accuracy(sq_filtered, gt_set)
|
||||||
|
|
||||||
|
unfiltered_accuracies.append(unf_acc)
|
||||||
|
filtered_accuracies.append(fil_acc)
|
||||||
|
|
||||||
|
per_sub_q.append(
|
||||||
|
SubQuestionChunkEval(
|
||||||
|
sub_question_index=sq_idx,
|
||||||
|
sub_question_text=sq_text,
|
||||||
|
ground_truth=gt_info,
|
||||||
|
unfiltered_accuracy=unf_acc,
|
||||||
|
filtered_accuracy=fil_acc,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if unfiltered_accuracies:
|
||||||
|
n = len(unfiltered_accuracies)
|
||||||
|
overall_unfiltered_metrics = {
|
||||||
|
"precision": round(sum(a.precision for a in unfiltered_accuracies) / n, 4),
|
||||||
|
"recall": round(sum(a.recall for a in unfiltered_accuracies) / n, 4),
|
||||||
|
"f1": round(sum(a.f1 for a in unfiltered_accuracies) / n, 4),
|
||||||
|
}
|
||||||
|
if filtered_accuracies:
|
||||||
|
n = len(filtered_accuracies)
|
||||||
|
overall_filtered_metrics = {
|
||||||
|
"precision": round(sum(a.precision for a in filtered_accuracies) / n, 4),
|
||||||
|
"recall": round(sum(a.recall for a in filtered_accuracies) / n, 4),
|
||||||
|
"f1": round(sum(a.f1 for a in filtered_accuracies) / n, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_eval_result = ChunkEvalResult(
|
||||||
|
per_sub_question=per_sub_q,
|
||||||
|
overall_unfiltered=ChunkAccuracy(
|
||||||
|
precision=overall_unfiltered_metrics["precision"],
|
||||||
|
recall=overall_unfiltered_metrics["recall"],
|
||||||
|
f1=overall_unfiltered_metrics["f1"],
|
||||||
|
pipeline_chunks=sum(a.pipeline_chunks for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
|
||||||
|
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in unfiltered_accuracies) if unfiltered_accuracies else 0,
|
||||||
|
),
|
||||||
|
overall_filtered=ChunkAccuracy(
|
||||||
|
precision=overall_filtered_metrics["precision"],
|
||||||
|
recall=overall_filtered_metrics["recall"],
|
||||||
|
f1=overall_filtered_metrics["f1"],
|
||||||
|
pipeline_chunks=sum(a.pipeline_chunks for a in filtered_accuracies) if filtered_accuracies else 0,
|
||||||
|
relevant_in_pipeline=sum(a.relevant_in_pipeline for a in filtered_accuracies) if filtered_accuracies else 0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
chunk_time = int((time.perf_counter() - t0) * 1000)
|
||||||
|
|
||||||
|
# (iv) Response evaluation
|
||||||
|
if rag and cfg.response_evaluator and chunk_eval_result:
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
per_sub_q_resp = []
|
||||||
|
|
||||||
|
for sq_idx, sq_text in enumerate(result.extracted_key_questions):
|
||||||
|
if sq_idx < len(chunk_eval_result.per_sub_question):
|
||||||
|
gt_chunks_data = chunk_eval_result.per_sub_question[sq_idx].ground_truth
|
||||||
|
relevant_chunks_meta = []
|
||||||
|
for rc in gt_chunks_data.relevant_chunks[:10]:
|
||||||
|
doc_id = rc["document_id"]
|
||||||
|
chunk_idx = rc["chunk_index"]
|
||||||
|
# Try to match with pipeline's response sources
|
||||||
|
for sq in result.response.sub_question_sources:
|
||||||
|
if isinstance(sq, dict):
|
||||||
|
for s in sq.get("sources", []):
|
||||||
|
if s.get("document_id") == doc_id and s.get("chunk_index") == chunk_idx:
|
||||||
|
relevant_chunks_meta.append((s.get("content_summary", ""), s))
|
||||||
|
break
|
||||||
|
elif hasattr(sq, "sources"):
|
||||||
|
for s in sq.sources:
|
||||||
|
if hasattr(s, "document_id") and s.document_id == doc_id and s.chunk_index == chunk_idx:
|
||||||
|
relevant_chunks_meta.append((s.content_summary, s.model_dump() if hasattr(s, "model_dump") else {}))
|
||||||
|
break
|
||||||
|
|
||||||
|
if relevant_chunks_meta:
|
||||||
|
section_text = ""
|
||||||
|
for src in result.response.sub_question_sources:
|
||||||
|
if isinstance(src, dict):
|
||||||
|
if src.get("sub_question_index") == sq_idx:
|
||||||
|
section_text = str(src)
|
||||||
|
elif hasattr(src, "sub_question_index") and src.sub_question_index == sq_idx:
|
||||||
|
section_text = str(src)
|
||||||
|
|
||||||
|
resp_eval = await evaluate_response(
|
||||||
|
key_question=sq_text,
|
||||||
|
ground_truth_chunks=relevant_chunks_meta,
|
||||||
|
pipeline_response=section_text or result.response.final_answer,
|
||||||
|
evaluator_config=cfg.response_evaluator,
|
||||||
|
)
|
||||||
|
if resp_eval:
|
||||||
|
resp_eval.sub_question_index = sq_idx
|
||||||
|
resp_eval.sub_question_text = sq_text
|
||||||
|
per_sub_q_resp.append(resp_eval)
|
||||||
|
|
||||||
|
overall_completeness = 0.0
|
||||||
|
overall_factual = 0.0
|
||||||
|
if per_sub_q_resp:
|
||||||
|
overall_completeness = round(sum(r.completeness_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
|
||||||
|
overall_factual = round(sum(r.factual_accuracy_score for r in per_sub_q_resp) / len(per_sub_q_resp), 4)
|
||||||
|
|
||||||
|
resp_eval_result = ResponseEvalResult(
|
||||||
|
per_sub_question=per_sub_q_resp,
|
||||||
|
overall_completeness=overall_completeness,
|
||||||
|
overall_factual_accuracy=overall_factual,
|
||||||
|
)
|
||||||
|
resp_time = int((time.perf_counter() - t0) * 1000)
|
||||||
|
|
||||||
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
||||||
|
|
||||||
|
eval_result = EvaluationResult(
|
||||||
|
evaluation_id=evaluation_id,
|
||||||
|
result_id=result.result_id,
|
||||||
|
status="completed",
|
||||||
|
audio_evaluation=audio_eval_result,
|
||||||
|
key_questions_evaluation=kq_eval_result,
|
||||||
|
chunk_evaluation=chunk_eval_result,
|
||||||
|
response_evaluation=resp_eval_result,
|
||||||
|
timing=EvaluationTiming(
|
||||||
|
audio_evaluation_time_ms=audio_time,
|
||||||
|
key_questions_evaluation_time_ms=kq_time,
|
||||||
|
chunk_evaluation_time_ms=chunk_time,
|
||||||
|
response_evaluation_time_ms=resp_time,
|
||||||
|
total_evaluation_time_ms=total_ms,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
storage.save_evaluation(eval_result)
|
||||||
|
return eval_result
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
"""Phase 9 tests: Evaluation API endpoint integration (Sub-Phase 9.3)."""
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.models.testing import (
|
||||||
|
ChunkAccuracy,
|
||||||
|
DimensionScores,
|
||||||
|
EvaluatorConfig,
|
||||||
|
EvaluationResult,
|
||||||
|
FilteredResult,
|
||||||
|
GenerateResult,
|
||||||
|
InputInfo,
|
||||||
|
KeyQuestionsEvalEntry,
|
||||||
|
KeyQuestionsEvalResult,
|
||||||
|
ResponseResult,
|
||||||
|
RetrievalResult,
|
||||||
|
TimingInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _set_api_keys(monkeypatch):
|
||||||
|
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
||||||
|
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
|
||||||
|
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(tmp_path, monkeypatch):
|
||||||
|
results_dir = str(tmp_path / "test_results")
|
||||||
|
evals_dir = str(tmp_path / "test_evaluations")
|
||||||
|
prompts_path = str(tmp_path / "prompts.db")
|
||||||
|
history_path = str(tmp_path / "history.db")
|
||||||
|
|
||||||
|
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
|
||||||
|
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
|
||||||
|
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
||||||
|
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
||||||
|
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
||||||
|
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
|
||||||
|
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
|
||||||
|
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
|
||||||
|
conn = _get_db(prompts_path)
|
||||||
|
init_prompts_db(conn)
|
||||||
|
seed_default_profiles(conn)
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
hconn = _get_db(history_path)
|
||||||
|
init_history_db(hconn)
|
||||||
|
hconn.close()
|
||||||
|
|
||||||
|
from app.routers.test_evaluate import router
|
||||||
|
test_app = FastAPI()
|
||||||
|
test_app.include_router(router, prefix="/api/v1")
|
||||||
|
yield TestClient(test_app)
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sample_result():
|
||||||
|
return GenerateResult(
|
||||||
|
result_id="test-result-001",
|
||||||
|
input_type="text",
|
||||||
|
profile="A",
|
||||||
|
input=InputInfo(text="test question"),
|
||||||
|
extracted_key_questions=["q1", "q2"],
|
||||||
|
retrieval=RetrievalResult(per_sub_question=[], total_chunks_retrieved=10, retriever_time_ms=100),
|
||||||
|
filtered=FilteredResult(per_sub_question=[], total_chunks_filtered=5, filter_time_ms=100),
|
||||||
|
response=ResponseResult(final_answer="answer", sub_question_sources=[], generate_time_ms=100),
|
||||||
|
timing=TimingInfo(decomposer_time_ms=100, retriever_time_ms=100, filter_time_ms=100, generator_time_ms=100, total_time_ms=400),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def saved_result(client):
|
||||||
|
from app.services.test_storage_service import TestStorageService
|
||||||
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
result = _make_sample_result()
|
||||||
|
svc = TestStorageService(get_settings().test_results_dir, get_settings().test_evaluations_dir)
|
||||||
|
svc.save_result(result)
|
||||||
|
return result.result_id
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvaluateEndpoint:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_evaluate_returns_200(self, client, saved_result):
|
||||||
|
mock_scores = DimensionScores(dimension_1_準確性=35.0, dimension_2_完整性=22.0, dimension_3_清晰度=18.0, dimension_4_簡潔性=13.0)
|
||||||
|
mock_kq = KeyQuestionsEvalResult(
|
||||||
|
evaluations=[
|
||||||
|
KeyQuestionsEvalEntry(model_name="m1", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100),
|
||||||
|
KeyQuestionsEvalEntry(model_name="m2", scores=mock_scores, total_score=88, max_score=100, comments="ok", thinking_trace="", time_ms=100),
|
||||||
|
],
|
||||||
|
average_scores=mock_scores,
|
||||||
|
average_total=88.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"result_id": saved_result,
|
||||||
|
"evaluation_config": {
|
||||||
|
"key_questions_evaluators": [
|
||||||
|
{"model_name": "deepseek-v4-pro", "base_url": "https://api.deepseek.com", "api_key_env": "DP_API_KEY", "enable_thinking": True},
|
||||||
|
{"model_name": "qwen3-7b-max", "base_url": "https://dashscope.example.com/v1", "api_key_env": "DASHSCOPE_API_KEY", "enable_thinking": True},
|
||||||
|
],
|
||||||
|
"chunk_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True},
|
||||||
|
"response_evaluator": {"model_name": "test", "base_url": "https://test.example.com", "api_key_env": "LLM_API_KEY", "enable_thinking": True},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = client.post("/api/v1/test/evaluate", json=payload)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["status"] in ("completed", "partial")
|
||||||
|
assert "evaluation_id" in data
|
||||||
|
|
||||||
|
def test_missing_result_returns_404(self, client):
|
||||||
|
payload = {
|
||||||
|
"result_id": "no-such-id",
|
||||||
|
"evaluation_config": {
|
||||||
|
"key_questions_evaluators": [],
|
||||||
|
"chunk_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"},
|
||||||
|
"response_evaluator": {"model_name": "t", "base_url": "https://x.com", "api_key_env": "LLM_API_KEY"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp = client.post("/api/v1/test/evaluate", json=payload)
|
||||||
|
assert resp.status_code == 404
|
||||||
Loading…
Reference in New Issue