"""RAG service for embedding, retrieval, and response generation.""" import uuid from typing import TYPE_CHECKING, List, Tuple, Dict, Any, Optional import logging from app.core.config import Settings from app.core.database import get_chroma_client if TYPE_CHECKING: from app.services.prompt_service import PromptService logger = logging.getLogger(__name__) class RAGService: """Service for document ingestion, retrieval, and response generation.""" def __init__( self, chroma_client=None, llm_client=None, settings: Optional[Settings] = None, prompt_service: "PromptService | None" = None, ): self.chroma_client = chroma_client or get_chroma_client() self.llm_client = llm_client self.settings = settings self._prompt_service = prompt_service self._collection = None @property def collection(self): if self._collection is None: from app.core.database import get_or_create_collection, get_embedding_function_settings embedding_fn = None if self.settings is not None: embedding_fn = get_embedding_function_settings(self.settings) self._collection = get_or_create_collection( self.chroma_client, "documents", embedding_function=embedding_fn ) return self._collection def ingest_document( self, file_path: str, chunks: List[str], metadata_list: List[Dict[str, Any]], document_id: Optional[str] = None, ) -> str: if not chunks: return "" document_id = document_id or str(uuid.uuid4()) ids = [f"{document_id}_{i}" for i in range(len(chunks))] self.collection.add( documents=chunks, metadatas=metadata_list, ids=ids, ) return document_id def retrieve_per_subquestion( self, sub_questions: List[str], n_results: int = 10, ) -> List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]]: """Retrieve chunks for each sub-question independently. Calls retrieve() once per sub-question to get chunks specifically relevant to each decomposed question, rather than joining all sub-questions into a single query string. Args: sub_questions: List of decomposed sub-questions from QueryDecomposer. n_results: Number of chunks to retrieve per sub-question. Returns: List of (sub_question, chunks) tuples. Each chunks list contains (text, metadata, distance) tuples in the standard retrieve() format. Returns empty list if sub_questions is empty. """ if not sub_questions: return [] results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = [] for sub_q in sub_questions: chunks = self.retrieve([sub_q], n_results=n_results) results.append((sub_q, chunks)) return results def retrieve( self, query_keywords: List[str], n_results: int = 10, ) -> List[Tuple[str, Dict[str, Any], float]]: query_text = " ".join(query_keywords) results = self.collection.query( query_texts=[query_text], n_results=n_results, ) chunks = [] if results["documents"] and results["documents"][0]: for i, doc in enumerate(results["documents"][0]): metadata = results["metadatas"][0][i] if results["metadatas"][0] else {} distance = results["distances"][0][i] if results["distances"][0] else 0.0 chunks.append((doc, metadata, distance)) return chunks async def generate_response( self, question: str, chunks: List[str], metadata_list: List[Dict[str, Any]], ) -> Tuple[str, str]: """Generate a RAG response and return it alongside the prompt used. Args: question: The user's question. chunks: Retrieved chunk texts. metadata_list: Metadata dicts corresponding to each chunk. Returns: A tuple of (answer, prompt). answer is the LLM-generated response (or a fallback message). prompt is the rendered prompt string, or ``""`` when no prompt was built. """ if not chunks: return "I could not find any relevant information to answer your question.", "" if self.llm_client is None: return "LLM client not configured.", "" context_parts = [] for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)): source = meta.get("filename", "unknown") summary = meta.get("content_summary", "") page_num = meta.get("page_number") citation_label = f"{source}, page {page_num}" if page_num else source context_parts.append( f"[{citation_label}] Source: {source}\n" f"Summary: {summary}\n" f"Content: {chunk}\n" ) context = "\n".join(context_parts) if self._prompt_service is not None: template = self._prompt_service.get_prompt_template("generate") else: template = ( f"Question: {'{question}'}\n\n" f"Answer the question using ONLY these document chunks. " f"Do not use any external knowledge. " f"Format your answer as bullet points. " f"Cite your sources inline using the exact bracket labels provided, " f"e.g. [filename, page N]. Place the citation at the end of each relevant point.\n\n" f"Document chunks:\n{'{context}'}\n\n" f"Answer:" ) # str.replace is safe even with stray curly braces in user text. prompt = template.replace("{question}", question).replace("{context}", context) result = await self.llm_client.complete(prompt=prompt, temperature=0.3, step_name="ResponseGeneration") return result, prompt async def generate_response_per_subquestion( self, sub_questions: List[str], sub_chunks: List[List[str]], sub_metadata: List[List[Dict[str, Any]]], ) -> Tuple[str, str, List[List[Dict[str, Any]]]]: """Generate sub-question-organized RAG response. Builds context sections for each sub-question and asks the LLM to answer each one using only its own document chunks. Returns the full markdown answer plus sources organized by sub-question. Args: sub_questions: List of decomposed sub-questions. sub_chunks: List of chunk text lists (one per sub-question). sub_metadata: List of metadata dict lists (one per sub-question). Must be same length as sub_chunks, with inner lists matching. Returns: Tuple of (answer, prompt, grouped_sources). answer: Markdown string with ## Sub-question N: sections. prompt: The rendered LLM prompt string. grouped_sources: List of metadata dict lists (one per sub-question), each metadata dict is a SourceMetadata-compatible dict. """ if not sub_questions: return ( "I could not find any relevant information to answer your question.", "", [], ) has_chunks = any(len(c) > 0 for c in sub_chunks) if not has_chunks: return ( "I could not find any relevant information to answer your question.", "", [], ) if self.llm_client is None: return ("LLM client not configured.", "", []) context_parts = [] for idx, (sq, chunks, metas) in enumerate( zip(sub_questions, sub_chunks, sub_metadata) ): context_parts.append( f'### Context for Sub-question {idx}: "{sq}"' ) for chunk, meta in zip(chunks, metas): source = meta.get("filename", "unknown") summary = meta.get("content_summary", "") page_num = meta.get("page_number") citation_label = ( f"{source}, page {page_num}" if page_num else source ) context_parts.append( f"[{citation_label}] Source: {source}\n" f"Summary: {summary}\n" f"Content: {chunk}\n" ) context_sections = "\n".join(context_parts) if self._prompt_service is not None: template = self._prompt_service.get_prompt_template( "generate_per_subq" ) else: template = ( "You must answer each sub-question using ONLY the document " "chunks provided for it.\n" "Do not use any external knowledge.\n" "Format your answer as markdown sections — one section per " "sub-question.\n" 'Each section should start with "## Sub-question N: ' '"\n' "Each section should contain 1-5 bullet points.\n" "Cite your sources inline using bracket labels, " "e.g. [filename, page N].\n" "Copy the exact bracket labels shown in the document chunks — " "do not modify filenames or add/remove extensions.\n" "Place the citation at the end of each relevant bullet point." "\n\n" "{context_sections}\n\n" "Answer:" ) prompt = template.replace("{context_sections}", context_sections) answer = await self.llm_client.complete( prompt=prompt, temperature=0.3, step_name="ResponseGeneration" ) grouped_sources: List[List[Dict[str, Any]]] = [] for metas in sub_metadata: grouped_sources.append(list(metas)) return answer, prompt, grouped_sources def list_documents(self) -> Tuple[List[Dict[str, Any]], int, int]: from collections import defaultdict all_data = self.collection.get(include=["metadatas"]) if not all_data["metadatas"]: return [], 0, 0 docs = defaultdict(lambda: {"filename": "", "chunk_count": 0, "upload_date": "", "chunking_strategy": "token"}) for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]): parts = chunk_id.rsplit("_", 1) doc_id = parts[0] if len(parts) == 2 else chunk_id docs[doc_id]["filename"] = meta.get("filename", "unknown") docs[doc_id]["chunk_count"] += 1 docs[doc_id]["upload_date"] = meta.get("upload_date", "") if meta.get("strategy_type") == "question": docs[doc_id]["chunking_strategy"] = "question" total_chunks = sum(d["chunk_count"] for d in docs.values()) doc_list = [ { "document_id": doc_id, "filename": info["filename"], "chunk_count": info["chunk_count"], "upload_date": info["upload_date"], "chunking_strategy": info["chunking_strategy"], } for doc_id, info in docs.items() ] return doc_list, len(doc_list), total_chunks def list_chunks(self, document_id: str) -> List[Dict[str, Any]]: all_data = self.collection.get(include=["metadatas"]) chunks = [] for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]): if chunk_id.startswith(f"{document_id}_"): chunks.append({ "chunk_id": chunk_id, "chunk_index": meta.get("chunk_index", 0), "content_summary": meta.get("content_summary", ""), "page_number": meta.get("page_number"), "chunk_file_path": meta.get("chunk_file_path"), "strategy_type": meta.get("strategy_type"), "question_index": meta.get("question_index"), "question_id": meta.get("question_id"), "question_text": meta.get("question_text"), "section_heading": meta.get("section_heading"), "answer_contains_table": meta.get("answer_contains_table"), "source_page_range": meta.get("source_page_range"), "parent_topic": meta.get("parent_topic"), }) chunks.sort(key=lambda x: x["chunk_index"]) return chunks def delete_document(self, document_id: str) -> Tuple[bool, int]: all_data = self.collection.get(include=["metadatas"]) ids_to_delete = [ chunk_id for chunk_id in all_data["ids"] if chunk_id.startswith(f"{document_id}_") ] if not ids_to_delete: return False, 0 self.collection.delete(ids=ids_to_delete) return True, len(ids_to_delete) def delete_chunk(self, chunk_id: str) -> bool: all_data = self.collection.get(include=["metadatas"]) if chunk_id not in all_data["ids"]: return False self.collection.delete(ids=[chunk_id]) return True