"""Chunk highlight service — batch LLM highlight computation and HTML rendering.""" import json import logging from collections import defaultdict from typing import Any from app.models.highlight import ( ChunkHighlightTarget, ChunkHighlights, HighlightBatchResponse, HighlightBatchResult, RelevantSentence, ) from app.services.highlight_cache import compute_cache_key from app.utils.sentence_splitter import split_sentences logger = logging.getLogger(__name__) def render_highlight_html( chunk_text: str, sentences: list[str], relevant_sentences: list[RelevantSentence], metadata: dict[str, Any], ) -> str: highlighted_indices = {rs.sentence_index for rs in relevant_sentences} index_to_reason: dict[int, str] = {rs.sentence_index: rs.reason for rs in relevant_sentences} filename = metadata.get("filename", "Unknown") page_number = metadata.get("page_number") chunk_file_path = metadata.get("chunk_file_path") sub_question = metadata.get("sub_question", "") chunk_index = metadata.get("chunk_index", 0) parts: list[str] = [] parts.append("") parts.append("") parts.append("
") parts.append('') parts.append("") parts.append("") parts.append("") parts.append('Page {page_number}
') if sub_question: parts.append(f'Sub-question: {sub_question}
') parts.append("{sentence}
') if chunk_file_path: parts.append('") parts.append("") parts.append("") return "\n".join(parts) class ChunkHighlightService: def __init__(self, rag_service, llm_client, highlight_cache, settings): self._rag = rag_service self._llm = llm_client self._cache = highlight_cache self._settings = settings async def compute_highlights_batch( self, targets: list[ChunkHighlightTarget], ) -> HighlightBatchResponse: if not targets: return HighlightBatchResponse(status="completed", cached_count=0) errors: list[str] = [] fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = [] for target in targets: chunk_id = f"{target.document_id}_{target.chunk_index}" try: result = self._rag.collection.get( ids=[chunk_id], include=["documents", "metadatas"], ) if not result.get("documents"): logger.warning("No documents returned for chunk_id=%s, skipping.", chunk_id) continue chunk_text = result["documents"][0] metadata = result["metadatas"][0] if result.get("metadatas") else {} metadata["sub_question"] = target.sub_question_text metadata["chunk_index"] = target.chunk_index sentences = split_sentences(chunk_text) fetched.append((target, sentences, metadata)) except Exception as exc: msg = f"Failed to fetch chunk {chunk_id}: {exc}" logger.error(msg) errors.append(msg) if not fetched: return HighlightBatchResponse( status="completed" if not errors else "partial", cached_count=0, errors=errors, ) prompt = self._build_prompt(fetched) try: llm_result: HighlightBatchResult = await self._llm.complete_structured( prompt, HighlightBatchResult, step_name="HighlightBatch" ) except Exception as exc: logger.error("HighlightBatch LLM call failed: %s", exc) return HighlightBatchResponse( status="failed", cached_count=0, errors=[str(exc)] ) cached_count = self._cache_results(fetched, llm_result) result_ids = {(r.document_id, r.chunk_index) for r in llm_result.results} fetched_ids = {(t.document_id, t.chunk_index) for t, _, _ in fetched} missing = fetched_ids - result_ids if errors or missing: for doc_id, chunk_idx in missing: errors.append(f"No highlight result for {doc_id}_{chunk_idx}") status = "partial" if (errors or missing) else "completed" return HighlightBatchResponse( status=status, cached_count=cached_count, errors=errors, ) def _build_prompt( self, fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]], ) -> str: by_sub_q: dict[int, list[tuple[ChunkHighlightTarget, list[str]]]] = defaultdict(list) for target, sentences, _meta in fetched: by_sub_q[target.sub_question_index].append((target, sentences)) lines: list[str] = [ "For each sub-question below, identify which sentences in each cited chunk are relevant to answering that sub-question.", "Return a HighlightBatchResult with a results list containing one ChunkHighlights per (document_id, chunk_index) pair.", "Each ChunkHighlights should list the relevant sentence indices (0-based) with a brief reason (max 80 chars).", "", ] for sq_idx in sorted(by_sub_q.keys()): items = by_sub_q[sq_idx] sub_q_text = items[0][0].sub_question_text lines.append(f"## Sub-question {sq_idx}: {sub_q_text}") lines.append("") for target, sentences in items: lines.append(f"### Chunk: document_id={target.document_id}, chunk_index={target.chunk_index}") for i, s in enumerate(sentences): lines.append(f"[{i}] {s}") lines.append("") return "\n".join(lines) def _cache_results( self, fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]], llm_result: HighlightBatchResult, ) -> int: lookup: dict[tuple[str, int], tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = { (t.document_id, t.chunk_index): (t, s, m) for t, s, m in fetched } cached_count = 0 for chunk_hl in llm_result.results: key = (chunk_hl.document_id, chunk_hl.chunk_index) entry = lookup.get(key) if entry is None: continue target, sentences, metadata = entry html = render_highlight_html( chunk_text=" ".join(sentences), sentences=sentences, relevant_sentences=chunk_hl.relevant_sentences, metadata=metadata, ) cache_key = compute_cache_key( target.document_id, target.chunk_index, target.sub_question_text, ) self._cache.set_highlight( cache_key=cache_key, document_id=target.document_id, chunk_index=target.chunk_index, sub_question=target.sub_question_text, relevant_sentences_json=json.dumps( [rs.model_dump() for rs in chunk_hl.relevant_sentences], default=str, ), html_content=html, ) cached_count += 1 return cached_count