legco_ai_assistant/backend/app/services/chunk_evaluator.py

170 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, List, Optional, Set, Tuple
from app.models.testing import (
ChunkAccuracy,
EvaluatorConfig,
GroundTruthInfo,
SubQuestionChunkEval,
)
from app.services.llm_client import LLMClient
logger = logging.getLogger(__name__)
CHUNK_BATCH_SIZE = 10
CHUNK_MAX_RETRIES = 2
CHUNK_RETRY_DELAY = 2.0
_CHUNK_EVAL_SYSTEM = """你正在評估文檔塊與關鍵問題的相關性。
對於每個<chunk_N>,判斷其是否包含與<sub_question>相關的信息。
返回JSON{"relevant_chunk_indices": [0, 3, 7]}僅包含相關的塊索引0-based從本批次的第一個塊算起"""
def _split_into_batches(
chunks: List[Tuple[str, int, str, Dict[str, Any]]], batch_size: int = CHUNK_BATCH_SIZE
) -> List[List[Tuple[str, int, str, Dict[str, Any]]]]:
"""Split flat chunk list into batches of batch_size."""
batches = []
for i in range(0, len(chunks), batch_size):
batches.append(chunks[i : i + batch_size])
return batches
def _parse_relevance_response(raw: str) -> Optional[List[int]]:
"""Parse LLM response for chunk relevance indices."""
try:
data = json.loads(raw)
except json.JSONDecodeError:
return None
if not isinstance(data, dict) or "relevant_chunk_indices" not in data:
return None
indices = data["relevant_chunk_indices"]
if not isinstance(indices, list):
return None
return [int(i) for i in indices]
def _build_chunk_batch_prompt(
sub_question: str, batch: List[Tuple[str, int, str, Dict[str, Any]]]
) -> str:
"""Build XML-format prompt for chunk evaluation."""
parts = []
parts.append(_CHUNK_EVAL_SYSTEM)
parts.append("")
parts.append(f"<sub_question>")
parts.append(sub_question)
parts.append(f"</sub_question>")
parts.append("")
for idx, (doc_id, global_idx, text, meta) in enumerate(batch):
page = meta.get("page_number", "?")
parts.append(f'<chunk_{idx} doc="{doc_id}" page="{page}">')
parts.append(text)
parts.append(f"</chunk_{idx}>")
parts.append("")
return "\n".join(parts)
def _make_eval_client(config: EvaluatorConfig, model_idx: int) -> LLMClient:
api_key = os.environ.get(config.api_key_env, "")
client = LLMClient.__new__(LLMClient)
client.settings = type("_Settings", (), {"vllm_engine": False, "llm_enable_thinking": config.enable_thinking})()
client.model = config.model_name
client.enable_thinking = config.enable_thinking
client.logger = logging.getLogger(f"{__name__}.eval_{model_idx}")
import httpx
from openai import AsyncOpenAI
client._client = AsyncOpenAI(
base_url=config.base_url.rstrip("/"),
api_key=api_key,
timeout=120.0,
http_client=httpx.AsyncClient(headers={"Content-Type": "application/json"}),
)
client._langchain_model = None
return client
async def _evaluate_batch(
client: LLMClient, prompt: str, retries: int = CHUNK_MAX_RETRIES
) -> Optional[List[int]]:
for attempt in range(retries + 1):
try:
raw = await client.complete(prompt=prompt, temperature=0.1, step_name="ChunkEval")
result = _parse_relevance_response(raw)
if result is not None:
return result
except Exception as exc:
logger.warning("Chunk batch eval attempt %d failed: %s", attempt + 1, exc)
if attempt < retries:
await asyncio.sleep(CHUNK_RETRY_DELAY)
return None
async def _determine_ground_truth_chunks(
sub_question: str,
all_chunks: List[Tuple[str, int, str, Dict[str, Any]]],
config: EvaluatorConfig,
semaphore: asyncio.Semaphore,
model_idx: int = 0,
batch_size: int = CHUNK_BATCH_SIZE,
) -> Tuple[Set[Tuple[str, int]], int, int]:
"""Determine which chunks are relevant to a key question.
Returns (ground_truth_set, total_chunks, elapsed_ms).
"""
start = time.perf_counter()
batches = _split_into_batches(all_chunks, batch_size)
client = _make_eval_client(config, model_idx)
async def _eval_with_limit(batch):
async with semaphore:
prompt = _build_chunk_batch_prompt(sub_question, batch)
return await _evaluate_batch(client, prompt)
batch_results = await asyncio.gather(*[_eval_with_limit(b) for b in batches])
ground_truth: Set[Tuple[str, int]] = set()
for batch, result in zip(batches, batch_results):
if result is None:
continue
for batch_local_idx in result:
if 0 <= batch_local_idx < len(batch):
doc_id = batch[batch_local_idx][0]
chunk_global_idx = batch[batch_local_idx][1]
ground_truth.add((doc_id, chunk_global_idx))
elapsed_ms = int((time.perf_counter() - start) * 1000)
return ground_truth, len(all_chunks), elapsed_ms
def _calculate_accuracy(
pipeline_chunks: Set[Tuple[str, int]], ground_truth: Set[Tuple[str, int]]
) -> ChunkAccuracy:
"""Calculate precision, recall, F1 for chunk comparison."""
if not pipeline_chunks:
return ChunkAccuracy(precision=0.0, recall=0.0, f1=0.0, pipeline_chunks=0, relevant_in_pipeline=0)
tp = len(pipeline_chunks & ground_truth)
precision = tp / len(pipeline_chunks) if pipeline_chunks else 0.0
recall = tp / len(ground_truth) if ground_truth else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return ChunkAccuracy(
precision=round(precision, 4),
recall=round(recall, 4),
f1=round(f1, 4),
pipeline_chunks=len(pipeline_chunks),
relevant_in_pipeline=tp,
)