"""Chunk highlight service — batch LLM highlight computation and HTML rendering.""" import json import logging import re import time from collections import defaultdict from typing import Any from app.models.highlight import ( ChunkHighlightTarget, ChunkHighlights, HighlightBatchResponse, HighlightBatchResult, RelevantSentence, ) from app.services.highlight_cache import compute_cache_key from app.utils.sentence_splitter import split_sentences logger = logging.getLogger(__name__) def render_highlight_html( chunk_text: str, sentences: list[str], relevant_sentences: list[RelevantSentence], metadata: dict[str, Any], ) -> str: highlighted_indices = {rs.sentence_index for rs in relevant_sentences} index_to_reason: dict[int, str] = {rs.sentence_index: rs.reason for rs in relevant_sentences} filename = metadata.get("filename", "Unknown") page_number = metadata.get("page_number") chunk_file_path = metadata.get("chunk_file_path") sub_question = metadata.get("sub_question", "") chunk_index = metadata.get("chunk_index", 0) parts: list[str] = [] parts.append("") parts.append("") parts.append("") parts.append('') parts.append("") parts.append("") parts.append("") parts.append('
') parts.append(f"

{filename} — Chunk {chunk_index}

") if page_number is not None: parts.append(f'

Page {page_number}

') if sub_question: parts.append(f'

Sub-question: {sub_question}

') parts.append("
") for i, sentence in enumerate(sentences): if i in highlighted_indices: reason = index_to_reason.get(i, "") parts.append('
') parts.append(f'{sentence}') if reason: parts.append(f'{reason}') parts.append("
") else: parts.append(f'

{sentence}

') if chunk_file_path: parts.append('") parts.append("") parts.append("") return "\n".join(parts) class ChunkHighlightService: def __init__(self, rag_service, llm_client, highlight_cache, settings): self._rag = rag_service self._llm = llm_client self._cache = highlight_cache self._settings = settings async def compute_highlights_batch( self, targets: list[ChunkHighlightTarget], ) -> HighlightBatchResponse: if not targets: return HighlightBatchResponse(status="completed", cached_count=0) logger.info( "Highlight batch: %d targets received, fetching from ChromaDB...", len(targets), ) errors: list[str] = [] fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = [] for target in targets: chunk_id = f"{target.document_id}_{target.chunk_index}" try: result = self._rag.collection.get( ids=[chunk_id], include=["documents", "metadatas"], ) if not result.get("documents"): logger.warning("No documents returned for chunk_id=%s, skipping.", chunk_id) continue chunk_text = result["documents"][0] metadata = result["metadatas"][0] if result.get("metadatas") else {} metadata["sub_question"] = target.sub_question_text metadata["chunk_index"] = target.chunk_index sentences = split_sentences(chunk_text) fetched.append((target, sentences, metadata)) except Exception as exc: msg = f"Failed to fetch chunk {chunk_id}: {exc}" logger.error(msg) errors.append(msg) if not fetched: logger.warning( "Highlight batch: no chunks fetched from %d targets (errors=%d)", len(targets), len(errors), ) return HighlightBatchResponse( status="completed" if not errors else "partial", cached_count=0, errors=errors, ) prompt = self._build_prompt(fetched) logger.info( "Highlight batch: %d/%d targets fetched, calling LLM (prompt len=%d)...", len(fetched), len(targets), len(prompt), ) highlight_start = time.perf_counter() try: llm_result: HighlightBatchResult = await self._llm.complete_structured( prompt, HighlightBatchResult, step_name="HighlightBatch" ) highlight_time_ms = int((time.perf_counter() - highlight_start) * 1000) logger.info( "Highlight batch: LLM structured succeeded in %dms — %d results", highlight_time_ms, len(llm_result.results) if llm_result else 0, ) except Exception as structured_exc: logger.warning( "HighlightBatch structured output failed: %s. " "Falling back to plain complete() with JSON instructions.", structured_exc, ) try: fallback_prompt = ( prompt + "\n\nRespond ONLY with a valid JSON object matching the HighlightBatchResult schema. " "Do NOT include markdown code fences, extra commentary, or any text outside the JSON." ) raw_response = await self._llm.complete( fallback_prompt, temperature=0.0, step_name="HighlightBatch-Fallback" ) # Strip any markdown fences the model may have emitted match = re.search( r"```(?:json)?\s*\n?(.*?)\n?```", raw_response, re.DOTALL ) if match: raw_response = match.group(1).strip() llm_result = HighlightBatchResult.model_validate_json(raw_response) highlight_time_ms = int((time.perf_counter() - highlight_start) * 1000) logger.info( "HighlightBatch fallback complete() succeeded in %dms — %d results", highlight_time_ms, len(llm_result.results) if llm_result else 0, ) except Exception as fallback_exc: logger.error( "HighlightBatch fallback also failed: %s", fallback_exc ) return HighlightBatchResponse( status="failed", cached_count=0, errors=[str(structured_exc), str(fallback_exc)], ) cached_count = self._cache_results(fetched, llm_result) highlight_response_json = llm_result.model_dump_json() result_ids = {(r.document_id, r.chunk_index) for r in llm_result.results} fetched_ids = {(t.document_id, t.chunk_index) for t, _, _ in fetched} missing = fetched_ids - result_ids if errors or missing: for doc_id, chunk_idx in missing: errors.append(f"No highlight result for {doc_id}_{chunk_idx}") status = "partial" if (errors or missing) else "completed" logger.info( "Highlight batch: done — status=%s cached=%d/%d errors=%d missing=%d time=%dms", status, cached_count, len(fetched), len(errors), len(missing), highlight_time_ms, ) return HighlightBatchResponse( status=status, cached_count=cached_count, errors=errors, highlight_prompt=prompt, highlight_response_json=highlight_response_json, highlight_time_ms=highlight_time_ms, ) def _build_prompt( self, fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]], ) -> str: by_sub_q: dict[int, list[tuple[ChunkHighlightTarget, list[str]]]] = defaultdict(list) for target, sentences, _meta in fetched: by_sub_q[target.sub_question_index].append((target, sentences)) lines: list[str] = [ "For each sub-question below, identify which sentences in each cited chunk are relevant to answering that sub-question.", "Return a HighlightBatchResult with a results list containing one ChunkHighlights per (document_id, chunk_index) pair.", "Each ChunkHighlights should list the relevant sentence indices (0-based) with a brief reason (max 80 chars).", "", ] for sq_idx in sorted(by_sub_q.keys()): items = by_sub_q[sq_idx] sub_q_text = items[0][0].sub_question_text lines.append(f"## Sub-question {sq_idx}: {sub_q_text}") lines.append("") for target, sentences in items: lines.append(f"### Chunk: document_id={target.document_id}, chunk_index={target.chunk_index}") for i, s in enumerate(sentences): lines.append(f"[{i}] {s}") lines.append("") return "\n".join(lines) def _cache_results( self, fetched: list[tuple[ChunkHighlightTarget, list[str], dict[str, Any]]], llm_result: HighlightBatchResult, ) -> int: lookup: dict[tuple[str, int], tuple[ChunkHighlightTarget, list[str], dict[str, Any]]] = { (t.document_id, t.chunk_index): (t, s, m) for t, s, m in fetched } cached_count = 0 for chunk_hl in llm_result.results: key = (chunk_hl.document_id, chunk_hl.chunk_index) entry = lookup.get(key) if entry is None: continue target, sentences, metadata = entry html = render_highlight_html( chunk_text=" ".join(sentences), sentences=sentences, relevant_sentences=chunk_hl.relevant_sentences, metadata=metadata, ) cache_key = compute_cache_key( target.document_id, target.chunk_index, target.sub_question_text, ) self._cache.set_highlight( cache_key=cache_key, document_id=target.document_id, chunk_index=target.chunk_index, sub_question=target.sub_question_text, relevant_sentences_json=json.dumps( [rs.model_dump() for rs in chunk_hl.relevant_sentences], default=str, ), html_content=html, ) cached_count += 1 return cached_count