legco_ai_assistant/backend/app/services/chunk_highlight_service.py

239 lines
9.2 KiB
Python

"""Chunk highlight service — batch LLM highlight computation and HTML rendering."""
import json
import logging
import time
from collections import defaultdict
from typing import Any
from app.models.highlight import (
ChunkHighlightTarget,
ChunkHighlights,
HighlightBatchResponse,
HighlightBatchResult,
RelevantSentence,
)
from app.services.highlight_cache import compute_cache_key
from app.utils.sentence_splitter import split_sentences
logger = logging.getLogger(__name__)
def render_highlight_html(
chunk_text: str,
sentences: list[str],
relevant_sentences: list[RelevantSentence],
metadata: dict[str, Any],
) -> str:
highlighted_indices = {rs.sentence_index for rs in relevant_sentences}
index_to_reason: dict[int, str] = {rs.sentence_index: rs.reason for rs in relevant_sentences}
filename = metadata.get("filename", "Unknown")
page_number = metadata.get("page_number")
chunk_file_path = metadata.get("chunk_file_path")
sub_question = metadata.get("sub_question", "")
chunk_index = metadata.get("chunk_index", 0)
parts: list[str] = []
parts.append("<!DOCTYPE html>")
parts.append("<html>")
parts.append("<head>")
parts.append('<meta charset="utf-8">')
parts.append("<style>")
parts.append("body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; margin: 24px; color: #1e293b; line-height: 1.7; }")
parts.append(".highlighted { background-color: #fef08a; padding: 2px 4px; border-radius: 3px; }")
parts.append(".reason { color: #6b7280; font-style: italic; font-size: 0.9em; margin-left: 12px; }")
parts.append(".header { margin-bottom: 16px; }")
parts.append(".sub-header { color: #475569; font-size: 0.95em; margin: 4px 0; }")
parts.append(".sentence { margin: 8px 0; }")
parts.append(".footer { margin-top: 24px; padding-top: 12px; border-top: 1px solid #e2e8f0; }")
parts.append(".footer a { color: #2563eb; text-decoration: none; }")
parts.append(".footer a:hover { text-decoration: underline; }")
parts.append("</style>")
parts.append("</head>")
parts.append("<body>")
parts.append('<div class="header">')
parts.append(f"<h2>{filename} — Chunk {chunk_index}</h2>")
if page_number is not None:
parts.append(f'<p class="sub-header">Page {page_number}</p>')
if sub_question:
parts.append(f'<p class="sub-header">Sub-question: {sub_question}</p>')
parts.append("</div>")
for i, sentence in enumerate(sentences):
if i in highlighted_indices:
reason = index_to_reason.get(i, "")
parts.append('<div class="sentence">')
parts.append(f'<span class="highlighted">{sentence}</span>')
if reason:
parts.append(f'<span class="reason">{reason}</span>')
parts.append("</div>")
else:
parts.append(f'<p class="sentence">{sentence}</p>')
if chunk_file_path:
parts.append('<div class="footer">')
parts.append(f'<a href="/api/v1/chunks/view?file={chunk_file_path}">View Original PDF →</a>')
parts.append("</div>")
parts.append("</body>")
parts.append("</html>")
return "\n".join(parts)
class ChunkHighlightService:
def __init__(self, rag_service, llm_client, highlight_cache, settings):
self._rag = rag_service
self._llm = llm_client
self._cache = highlight_cache
self._settings = settings
async def compute_highlights_batch(
self,
targets: list[ChunkHighlightTarget],
) -> HighlightBatchResponse:
if not targets:
return HighlightBatchResponse(status="completed", cached_count=0)
errors: list[str] = []
fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = []
for target in targets:
chunk_id = f"{target.document_id}_{target.chunk_index}"
try:
result = self._rag.collection.get(
ids=[chunk_id],
include=["documents", "metadatas"],
)
if not result.get("documents"):
logger.warning("No documents returned for chunk_id=%s, skipping.", chunk_id)
continue
chunk_text = result["documents"][0]
metadata = result["metadatas"][0] if result.get("metadatas") else {}
metadata["sub_question"] = target.sub_question_text
metadata["chunk_index"] = target.chunk_index
sentences = split_sentences(chunk_text)
fetched.append((target, sentences, metadata))
except Exception as exc:
msg = f"Failed to fetch chunk {chunk_id}: {exc}"
logger.error(msg)
errors.append(msg)
if not fetched:
return HighlightBatchResponse(
status="completed" if not errors else "partial",
cached_count=0,
errors=errors,
)
prompt = self._build_prompt(fetched)
highlight_start = time.perf_counter()
try:
llm_result: HighlightBatchResult = await self._llm.complete_structured(
prompt, HighlightBatchResult, step_name="HighlightBatch"
)
except Exception as exc:
logger.error("HighlightBatch LLM call failed: %s", exc)
return HighlightBatchResponse(
status="failed", cached_count=0, errors=[str(exc)]
)
highlight_time_ms = int((time.perf_counter() - highlight_start) * 1000)
cached_count = self._cache_results(fetched, llm_result)
highlight_response_json = llm_result.model_dump_json()
result_ids = {(r.document_id, r.chunk_index) for r in llm_result.results}
fetched_ids = {(t.document_id, t.chunk_index) for t, _, _ in fetched}
missing = fetched_ids - result_ids
if errors or missing:
for doc_id, chunk_idx in missing:
errors.append(f"No highlight result for {doc_id}_{chunk_idx}")
status = "partial" if (errors or missing) else "completed"
return HighlightBatchResponse(
status=status,
cached_count=cached_count,
errors=errors,
highlight_prompt=prompt,
highlight_response_json=highlight_response_json,
highlight_time_ms=highlight_time_ms,
)
def _build_prompt(
self,
fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]],
) -> str:
by_sub_q: dict[int, list[tuple[ChunkHighlightTarget, list[str]]]] = defaultdict(list)
for target, sentences, _meta in fetched:
by_sub_q[target.sub_question_index].append((target, sentences))
lines: list[str] = [
"For each sub-question below, identify which sentences in each cited chunk are relevant to answering that sub-question.",
"Return a HighlightBatchResult with a results list containing one ChunkHighlights per (document_id, chunk_index) pair.",
"Each ChunkHighlights should list the relevant sentence indices (0-based) with a brief reason (max 80 chars).",
"",
]
for sq_idx in sorted(by_sub_q.keys()):
items = by_sub_q[sq_idx]
sub_q_text = items[0][0].sub_question_text
lines.append(f"## Sub-question {sq_idx}: {sub_q_text}")
lines.append("")
for target, sentences in items:
lines.append(f"### Chunk: document_id={target.document_id}, chunk_index={target.chunk_index}")
for i, s in enumerate(sentences):
lines.append(f"[{i}] {s}")
lines.append("")
return "\n".join(lines)
def _cache_results(
self,
fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]],
llm_result: HighlightBatchResult,
) -> int:
lookup: dict[tuple[str, int], tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = {
(t.document_id, t.chunk_index): (t, s, m)
for t, s, m in fetched
}
cached_count = 0
for chunk_hl in llm_result.results:
key = (chunk_hl.document_id, chunk_hl.chunk_index)
entry = lookup.get(key)
if entry is None:
continue
target, sentences, metadata = entry
html = render_highlight_html(
chunk_text=" ".join(sentences),
sentences=sentences,
relevant_sentences=chunk_hl.relevant_sentences,
metadata=metadata,
)
cache_key = compute_cache_key(
target.document_id,
target.chunk_index,
target.sub_question_text,
)
self._cache.set_highlight(
cache_key=cache_key,
document_id=target.document_id,
chunk_index=target.chunk_index,
sub_question=target.sub_question_text,
relevant_sentences_json=json.dumps(
[rs.model_dump() for rs in chunk_hl.relevant_sentences],
default=str,
),
html_content=html,
)
cached_count += 1
return cached_count