legco_ai_assistant/backend/app/services/response_evaluator.py

120 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
import time
from typing import Any, Dict, List, Optional, Tuple
from app.models.testing import (
EvaluatorConfig,
SubQuestionResponseEval,
)
from app.services.llm_client import LLMClient
logger = logging.getLogger(__name__)
_RESPONSE_GEN_PROMPT = """使用以下文檔塊回答關鍵問題。僅使用提供的文檔塊信息,不要使用外部知識。在答案中引用來源。
關鍵問題:{key_question}
文檔塊:
{chunks}
回答:"""
_RESPONSE_COMPARE_PROMPT = """比較以下兩個回答的完整性和事實準確性。
關鍵問題:{key_question}
回答 A基準答案從相關塊生成
{ground_truth_response}
回答 B要評估的答案
{pipeline_response}
請評估回答 B 是否包含回答 A 中的所有關鍵信息。返回JSON格式
{{"completeness_score": 0.0-1.0, "factual_accuracy_score": 0.0-1.0, "comments": "簡要評語"}}"""
def _make_eval_client(config: EvaluatorConfig) -> LLMClient:
api_key = os.environ.get(config.api_key_env, "")
client = LLMClient.__new__(LLMClient)
client.settings = type("_Settings", (), {"vllm_engine": False, "llm_enable_thinking": config.enable_thinking})()
client.model = config.model_name
client.enable_thinking = config.enable_thinking
client.logger = logging.getLogger(f"{__name__}.resp_eval")
import httpx
from openai import AsyncOpenAI
client._client = AsyncOpenAI(
base_url=config.base_url.rstrip("/"),
api_key=api_key,
timeout=120.0,
http_client=httpx.AsyncClient(headers={"Content-Type": "application/json"}),
)
client._langchain_model = None
return client
async def evaluate_response(
key_question: str,
ground_truth_chunks: List[Tuple[str, Dict[str, Any]]],
pipeline_response: str,
evaluator_config: EvaluatorConfig,
) -> Optional[SubQuestionResponseEval]:
client = _make_eval_client(evaluator_config)
# Step 1: Generate ground truth response from relevant chunks
gen_start = time.perf_counter()
chunks_text = "\n\n".join(
f"[{meta.get('filename', 'unknown')}, page {meta.get('page_number', '?')}]\n{text}"
for text, meta in ground_truth_chunks
)
gen_prompt = _RESPONSE_GEN_PROMPT.format(key_question=key_question, chunks=chunks_text)
try:
ground_truth_response = await client.complete(
prompt=gen_prompt, temperature=0.3, step_name="ResponseGen-GroundTruth"
)
except Exception as exc:
logger.warning("Failed to generate ground truth response: %s", exc)
return None
gen_time_ms = int((time.perf_counter() - gen_start) * 1000)
# Step 2: Compare responses
comp_start = time.perf_counter()
comp_prompt = _RESPONSE_COMPARE_PROMPT.format(
key_question=key_question,
ground_truth_response=ground_truth_response,
pipeline_response=pipeline_response,
)
try:
raw = await client.complete(
prompt=comp_prompt, temperature=0.3, step_name="ResponseCompare"
)
data = json.loads(raw)
except Exception as exc:
logger.warning("Failed to compare responses: %s", exc)
return None
comp_time_ms = int((time.perf_counter() - comp_start) * 1000)
completeness = float(data.get("completeness_score", 0.0))
factual = float(data.get("factual_accuracy_score", 0.0))
comments = data.get("comments", "")
return SubQuestionResponseEval(
sub_question_index=0,
sub_question_text=key_question,
ground_truth_response=ground_truth_response,
pipeline_response_section=pipeline_response,
completeness_score=round(completeness, 4),
factual_accuracy_score=round(factual, 4),
comments=comments,
ground_truth_generation_time_ms=gen_time_ms,
comparison_time_ms=comp_time_ms,
)