239 lines
9.2 KiB
Python
239 lines
9.2 KiB
Python
"""Chunk highlight service — batch LLM highlight computation and HTML rendering."""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Any
|
|
|
|
from app.models.highlight import (
|
|
ChunkHighlightTarget,
|
|
ChunkHighlights,
|
|
HighlightBatchResponse,
|
|
HighlightBatchResult,
|
|
RelevantSentence,
|
|
)
|
|
from app.services.highlight_cache import compute_cache_key
|
|
from app.utils.sentence_splitter import split_sentences
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def render_highlight_html(
|
|
chunk_text: str,
|
|
sentences: list[str],
|
|
relevant_sentences: list[RelevantSentence],
|
|
metadata: dict[str, Any],
|
|
) -> str:
|
|
highlighted_indices = {rs.sentence_index for rs in relevant_sentences}
|
|
index_to_reason: dict[int, str] = {rs.sentence_index: rs.reason for rs in relevant_sentences}
|
|
|
|
filename = metadata.get("filename", "Unknown")
|
|
page_number = metadata.get("page_number")
|
|
chunk_file_path = metadata.get("chunk_file_path")
|
|
sub_question = metadata.get("sub_question", "")
|
|
chunk_index = metadata.get("chunk_index", 0)
|
|
|
|
parts: list[str] = []
|
|
parts.append("<!DOCTYPE html>")
|
|
parts.append("<html>")
|
|
parts.append("<head>")
|
|
parts.append('<meta charset="utf-8">')
|
|
parts.append("<style>")
|
|
parts.append("body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; margin: 24px; color: #1e293b; line-height: 1.7; }")
|
|
parts.append(".highlighted { background-color: #fef08a; padding: 2px 4px; border-radius: 3px; }")
|
|
parts.append(".reason { color: #6b7280; font-style: italic; font-size: 0.9em; margin-left: 12px; }")
|
|
parts.append(".header { margin-bottom: 16px; }")
|
|
parts.append(".sub-header { color: #475569; font-size: 0.95em; margin: 4px 0; }")
|
|
parts.append(".sentence { margin: 8px 0; }")
|
|
parts.append(".footer { margin-top: 24px; padding-top: 12px; border-top: 1px solid #e2e8f0; }")
|
|
parts.append(".footer a { color: #2563eb; text-decoration: none; }")
|
|
parts.append(".footer a:hover { text-decoration: underline; }")
|
|
parts.append("</style>")
|
|
parts.append("</head>")
|
|
parts.append("<body>")
|
|
|
|
parts.append('<div class="header">')
|
|
parts.append(f"<h2>{filename} — Chunk {chunk_index}</h2>")
|
|
if page_number is not None:
|
|
parts.append(f'<p class="sub-header">Page {page_number}</p>')
|
|
if sub_question:
|
|
parts.append(f'<p class="sub-header">Sub-question: {sub_question}</p>')
|
|
parts.append("</div>")
|
|
|
|
for i, sentence in enumerate(sentences):
|
|
if i in highlighted_indices:
|
|
reason = index_to_reason.get(i, "")
|
|
parts.append('<div class="sentence">')
|
|
parts.append(f'<span class="highlighted">{sentence}</span>')
|
|
if reason:
|
|
parts.append(f'<span class="reason">{reason}</span>')
|
|
parts.append("</div>")
|
|
else:
|
|
parts.append(f'<p class="sentence">{sentence}</p>')
|
|
|
|
if chunk_file_path:
|
|
parts.append('<div class="footer">')
|
|
parts.append(f'<a href="/api/v1/chunks/view?file={chunk_file_path}">View Original PDF →</a>')
|
|
parts.append("</div>")
|
|
|
|
parts.append("</body>")
|
|
parts.append("</html>")
|
|
|
|
return "\n".join(parts)
|
|
|
|
|
|
class ChunkHighlightService:
|
|
|
|
def __init__(self, rag_service, llm_client, highlight_cache, settings):
|
|
self._rag = rag_service
|
|
self._llm = llm_client
|
|
self._cache = highlight_cache
|
|
self._settings = settings
|
|
|
|
async def compute_highlights_batch(
|
|
self,
|
|
targets: list[ChunkHighlightTarget],
|
|
) -> HighlightBatchResponse:
|
|
if not targets:
|
|
return HighlightBatchResponse(status="completed", cached_count=0)
|
|
|
|
errors: list[str] = []
|
|
fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = []
|
|
|
|
for target in targets:
|
|
chunk_id = f"{target.document_id}_{target.chunk_index}"
|
|
try:
|
|
result = self._rag.collection.get(
|
|
ids=[chunk_id],
|
|
include=["documents", "metadatas"],
|
|
)
|
|
if not result.get("documents"):
|
|
logger.warning("No documents returned for chunk_id=%s, skipping.", chunk_id)
|
|
continue
|
|
chunk_text = result["documents"][0]
|
|
metadata = result["metadatas"][0] if result.get("metadatas") else {}
|
|
metadata["sub_question"] = target.sub_question_text
|
|
metadata["chunk_index"] = target.chunk_index
|
|
sentences = split_sentences(chunk_text)
|
|
fetched.append((target, sentences, metadata))
|
|
except Exception as exc:
|
|
msg = f"Failed to fetch chunk {chunk_id}: {exc}"
|
|
logger.error(msg)
|
|
errors.append(msg)
|
|
|
|
if not fetched:
|
|
return HighlightBatchResponse(
|
|
status="completed" if not errors else "partial",
|
|
cached_count=0,
|
|
errors=errors,
|
|
)
|
|
|
|
prompt = self._build_prompt(fetched)
|
|
|
|
highlight_start = time.perf_counter()
|
|
try:
|
|
llm_result: HighlightBatchResult = await self._llm.complete_structured(
|
|
prompt, HighlightBatchResult, step_name="HighlightBatch"
|
|
)
|
|
except Exception as exc:
|
|
logger.error("HighlightBatch LLM call failed: %s", exc)
|
|
return HighlightBatchResponse(
|
|
status="failed", cached_count=0, errors=[str(exc)]
|
|
)
|
|
highlight_time_ms = int((time.perf_counter() - highlight_start) * 1000)
|
|
|
|
cached_count = self._cache_results(fetched, llm_result)
|
|
highlight_response_json = llm_result.model_dump_json()
|
|
|
|
result_ids = {(r.document_id, r.chunk_index) for r in llm_result.results}
|
|
fetched_ids = {(t.document_id, t.chunk_index) for t, _, _ in fetched}
|
|
missing = fetched_ids - result_ids
|
|
|
|
if errors or missing:
|
|
for doc_id, chunk_idx in missing:
|
|
errors.append(f"No highlight result for {doc_id}_{chunk_idx}")
|
|
|
|
status = "partial" if (errors or missing) else "completed"
|
|
return HighlightBatchResponse(
|
|
status=status,
|
|
cached_count=cached_count,
|
|
errors=errors,
|
|
highlight_prompt=prompt,
|
|
highlight_response_json=highlight_response_json,
|
|
highlight_time_ms=highlight_time_ms,
|
|
)
|
|
|
|
def _build_prompt(
|
|
self,
|
|
fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]],
|
|
) -> str:
|
|
by_sub_q: dict[int, list[tuple[ChunkHighlightTarget, list[str]]]] = defaultdict(list)
|
|
for target, sentences, _meta in fetched:
|
|
by_sub_q[target.sub_question_index].append((target, sentences))
|
|
|
|
lines: list[str] = [
|
|
"For each sub-question below, identify which sentences in each cited chunk are relevant to answering that sub-question.",
|
|
"Return a HighlightBatchResult with a results list containing one ChunkHighlights per (document_id, chunk_index) pair.",
|
|
"Each ChunkHighlights should list the relevant sentence indices (0-based) with a brief reason (max 80 chars).",
|
|
"",
|
|
]
|
|
|
|
for sq_idx in sorted(by_sub_q.keys()):
|
|
items = by_sub_q[sq_idx]
|
|
sub_q_text = items[0][0].sub_question_text
|
|
lines.append(f"## Sub-question {sq_idx}: {sub_q_text}")
|
|
lines.append("")
|
|
for target, sentences in items:
|
|
lines.append(f"### Chunk: document_id={target.document_id}, chunk_index={target.chunk_index}")
|
|
for i, s in enumerate(sentences):
|
|
lines.append(f"[{i}] {s}")
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
def _cache_results(
|
|
self,
|
|
fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]],
|
|
llm_result: HighlightBatchResult,
|
|
) -> int:
|
|
lookup: dict[tuple[str, int], tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = {
|
|
(t.document_id, t.chunk_index): (t, s, m)
|
|
for t, s, m in fetched
|
|
}
|
|
|
|
cached_count = 0
|
|
for chunk_hl in llm_result.results:
|
|
key = (chunk_hl.document_id, chunk_hl.chunk_index)
|
|
entry = lookup.get(key)
|
|
if entry is None:
|
|
continue
|
|
|
|
target, sentences, metadata = entry
|
|
html = render_highlight_html(
|
|
chunk_text=" ".join(sentences),
|
|
sentences=sentences,
|
|
relevant_sentences=chunk_hl.relevant_sentences,
|
|
metadata=metadata,
|
|
)
|
|
|
|
cache_key = compute_cache_key(
|
|
target.document_id,
|
|
target.chunk_index,
|
|
target.sub_question_text,
|
|
)
|
|
self._cache.set_highlight(
|
|
cache_key=cache_key,
|
|
document_id=target.document_id,
|
|
chunk_index=target.chunk_index,
|
|
sub_question=target.sub_question_text,
|
|
relevant_sentences_json=json.dumps(
|
|
[rs.model_dump() for rs in chunk_hl.relevant_sentences],
|
|
default=str,
|
|
),
|
|
html_content=html,
|
|
)
|
|
cached_count += 1
|
|
|
|
return cached_count
|