170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import time
|
||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||
|
||
from app.models.testing import (
|
||
ChunkAccuracy,
|
||
EvaluatorConfig,
|
||
GroundTruthInfo,
|
||
SubQuestionChunkEval,
|
||
)
|
||
from app.services.llm_client import LLMClient
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
CHUNK_BATCH_SIZE = 10
|
||
CHUNK_MAX_RETRIES = 2
|
||
CHUNK_RETRY_DELAY = 2.0
|
||
|
||
_CHUNK_EVAL_SYSTEM = """你正在評估文檔塊與關鍵問題的相關性。
|
||
對於每個<chunk_N>,判斷其是否包含與<sub_question>相關的信息。
|
||
返回JSON:{"relevant_chunk_indices": [0, 3, 7]}(僅包含相關的塊索引,0-based,從本批次的第一個塊算起)"""
|
||
|
||
|
||
def _split_into_batches(
|
||
chunks: List[Tuple[str, int, str, Dict[str, Any]]], batch_size: int = CHUNK_BATCH_SIZE
|
||
) -> List[List[Tuple[str, int, str, Dict[str, Any]]]]:
|
||
"""Split flat chunk list into batches of batch_size."""
|
||
batches = []
|
||
for i in range(0, len(chunks), batch_size):
|
||
batches.append(chunks[i : i + batch_size])
|
||
return batches
|
||
|
||
|
||
def _parse_relevance_response(raw: str) -> Optional[List[int]]:
|
||
"""Parse LLM response for chunk relevance indices."""
|
||
try:
|
||
data = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
if not isinstance(data, dict) or "relevant_chunk_indices" not in data:
|
||
return None
|
||
indices = data["relevant_chunk_indices"]
|
||
if not isinstance(indices, list):
|
||
return None
|
||
return [int(i) for i in indices]
|
||
|
||
|
||
def _build_chunk_batch_prompt(
|
||
sub_question: str, batch: List[Tuple[str, int, str, Dict[str, Any]]]
|
||
) -> str:
|
||
"""Build XML-format prompt for chunk evaluation."""
|
||
parts = []
|
||
parts.append(_CHUNK_EVAL_SYSTEM)
|
||
parts.append("")
|
||
parts.append(f"<sub_question>")
|
||
parts.append(sub_question)
|
||
parts.append(f"</sub_question>")
|
||
parts.append("")
|
||
|
||
for idx, (doc_id, global_idx, text, meta) in enumerate(batch):
|
||
page = meta.get("page_number", "?")
|
||
parts.append(f'<chunk_{idx} doc="{doc_id}" page="{page}">')
|
||
parts.append(text)
|
||
parts.append(f"</chunk_{idx}>")
|
||
parts.append("")
|
||
|
||
return "\n".join(parts)
|
||
|
||
|
||
def _make_eval_client(config: EvaluatorConfig, model_idx: int) -> LLMClient:
|
||
api_key = os.environ.get(config.api_key_env, "")
|
||
|
||
client = LLMClient.__new__(LLMClient)
|
||
client.settings = type("_Settings", (), {"vllm_engine": False, "llm_enable_thinking": config.enable_thinking})()
|
||
client.model = config.model_name
|
||
client.enable_thinking = config.enable_thinking
|
||
client.logger = logging.getLogger(f"{__name__}.eval_{model_idx}")
|
||
|
||
import httpx
|
||
from openai import AsyncOpenAI
|
||
|
||
client._client = AsyncOpenAI(
|
||
base_url=config.base_url.rstrip("/"),
|
||
api_key=api_key,
|
||
timeout=120.0,
|
||
http_client=httpx.AsyncClient(headers={"Content-Type": "application/json"}),
|
||
)
|
||
client._langchain_model = None
|
||
return client
|
||
|
||
|
||
async def _evaluate_batch(
|
||
client: LLMClient, prompt: str, retries: int = CHUNK_MAX_RETRIES
|
||
) -> Optional[List[int]]:
|
||
for attempt in range(retries + 1):
|
||
try:
|
||
raw = await client.complete(prompt=prompt, temperature=0.1, step_name="ChunkEval")
|
||
result = _parse_relevance_response(raw)
|
||
if result is not None:
|
||
return result
|
||
except Exception as exc:
|
||
logger.warning("Chunk batch eval attempt %d failed: %s", attempt + 1, exc)
|
||
|
||
if attempt < retries:
|
||
await asyncio.sleep(CHUNK_RETRY_DELAY)
|
||
|
||
return None
|
||
|
||
|
||
async def _determine_ground_truth_chunks(
|
||
sub_question: str,
|
||
all_chunks: List[Tuple[str, int, str, Dict[str, Any]]],
|
||
config: EvaluatorConfig,
|
||
semaphore: asyncio.Semaphore,
|
||
model_idx: int = 0,
|
||
batch_size: int = CHUNK_BATCH_SIZE,
|
||
) -> Tuple[Set[Tuple[str, int]], int, int]:
|
||
"""Determine which chunks are relevant to a key question.
|
||
|
||
Returns (ground_truth_set, total_chunks, elapsed_ms).
|
||
"""
|
||
start = time.perf_counter()
|
||
batches = _split_into_batches(all_chunks, batch_size)
|
||
|
||
client = _make_eval_client(config, model_idx)
|
||
|
||
async def _eval_with_limit(batch):
|
||
async with semaphore:
|
||
prompt = _build_chunk_batch_prompt(sub_question, batch)
|
||
return await _evaluate_batch(client, prompt)
|
||
|
||
batch_results = await asyncio.gather(*[_eval_with_limit(b) for b in batches])
|
||
|
||
ground_truth: Set[Tuple[str, int]] = set()
|
||
for batch, result in zip(batches, batch_results):
|
||
if result is None:
|
||
continue
|
||
for batch_local_idx in result:
|
||
if 0 <= batch_local_idx < len(batch):
|
||
doc_id = batch[batch_local_idx][0]
|
||
chunk_global_idx = batch[batch_local_idx][1]
|
||
ground_truth.add((doc_id, chunk_global_idx))
|
||
|
||
elapsed_ms = int((time.perf_counter() - start) * 1000)
|
||
return ground_truth, len(all_chunks), elapsed_ms
|
||
|
||
|
||
def _calculate_accuracy(
|
||
pipeline_chunks: Set[Tuple[str, int]], ground_truth: Set[Tuple[str, int]]
|
||
) -> ChunkAccuracy:
|
||
"""Calculate precision, recall, F1 for chunk comparison."""
|
||
if not pipeline_chunks:
|
||
return ChunkAccuracy(precision=0.0, recall=0.0, f1=0.0, pipeline_chunks=0, relevant_in_pipeline=0)
|
||
|
||
tp = len(pipeline_chunks & ground_truth)
|
||
precision = tp / len(pipeline_chunks) if pipeline_chunks else 0.0
|
||
recall = tp / len(ground_truth) if ground_truth else 0.0
|
||
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
||
|
||
return ChunkAccuracy(
|
||
precision=round(precision, 4),
|
||
recall=round(recall, 4),
|
||
f1=round(f1, 4),
|
||
pipeline_chunks=len(pipeline_chunks),
|
||
relevant_in_pipeline=tp,
|
||
)
|