"""RAG service for embedding, retrieval, and response generation.""" import uuid from typing import TYPE_CHECKING, List, Tuple, Dict, Any, Optional import logging from app.core.config import Settings from app.core.database import get_chroma_client if TYPE_CHECKING: from app.services.prompt_service import PromptService logger = logging.getLogger(__name__) class RAGService: """Service for document ingestion, retrieval, and response generation.""" def __init__( self, chroma_client=None, llm_client=None, settings: Optional[Settings] = None, prompt_service: "PromptService | None" = None, ): self.chroma_client = chroma_client or get_chroma_client() self.llm_client = llm_client self.settings = settings self._prompt_service = prompt_service self._collection = None @property def collection(self): if self._collection is None: from app.core.database import get_or_create_collection, get_embedding_function_settings embedding_fn = None if self.settings is not None: embedding_fn = get_embedding_function_settings(self.settings) self._collection = get_or_create_collection( self.chroma_client, "documents", embedding_function=embedding_fn ) return self._collection def ingest_document( self, file_path: str, chunks: List[str], metadata_list: List[Dict[str, Any]], document_id: Optional[str] = None, ) -> str: if not chunks: return "" document_id = document_id or str(uuid.uuid4()) ids = [f"{document_id}_{i}" for i in range(len(chunks))] self.collection.add( documents=chunks, metadatas=metadata_list, ids=ids, ) return document_id def retrieve( self, query_keywords: List[str], n_results: int = 10, ) -> List[Tuple[str, Dict[str, Any], float]]: query_text = " ".join(query_keywords) results = self.collection.query( query_texts=[query_text], n_results=n_results, ) chunks = [] if results["documents"] and results["documents"][0]: for i, doc in enumerate(results["documents"][0]): metadata = results["metadatas"][0][i] if results["metadatas"][0] else {} distance = results["distances"][0][i] if results["distances"][0] else 0.0 chunks.append((doc, metadata, distance)) return chunks async def generate_response( self, question: str, chunks: List[str], metadata_list: List[Dict[str, Any]], ) -> Tuple[str, str]: """Generate a RAG response and return it alongside the prompt used. Args: question: The user's question. chunks: Retrieved chunk texts. metadata_list: Metadata dicts corresponding to each chunk. Returns: A tuple of (answer, prompt). answer is the LLM-generated response (or a fallback message). prompt is the rendered prompt string, or ``""`` when no prompt was built. """ if not chunks: return "I could not find any relevant information to answer your question.", "" if self.llm_client is None: return "LLM client not configured.", "" context_parts = [] for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)): source = meta.get("filename", "unknown") summary = meta.get("content_summary", "") page_num = meta.get("page_number") citation_label = f"{source}, page {page_num}" if page_num else source context_parts.append( f"[{citation_label}] Source: {source}\n" f"Summary: {summary}\n" f"Content: {chunk}\n" ) context = "\n".join(context_parts) if self._prompt_service is not None: template = self._prompt_service.get_prompt_template("generate") else: template = ( f"Question: {'{question}'}\n\n" f"Answer the question using ONLY these document chunks. " f"Do not use any external knowledge. " f"Format your answer as bullet points. " f"Cite your sources inline using the exact bracket labels provided, " f"e.g. [filename, page N]. Place the citation at the end of each relevant point.\n\n" f"Document chunks:\n{'{context}'}\n\n" f"Answer:" ) # str.replace is safe even with stray curly braces in user text. prompt = template.replace("{question}", question).replace("{context}", context) result = await self.llm_client.complete(prompt=prompt, temperature=0.3, step_name="ResponseGeneration") return result, prompt def list_documents(self) -> Tuple[List[Dict[str, Any]], int, int]: from collections import defaultdict all_data = self.collection.get(include=["metadatas"]) if not all_data["metadatas"]: return [], 0, 0 docs = defaultdict(lambda: {"filename": "", "chunk_count": 0, "upload_date": ""}) for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]): parts = chunk_id.rsplit("_", 1) doc_id = parts[0] if len(parts) == 2 else chunk_id docs[doc_id]["filename"] = meta.get("filename", "unknown") docs[doc_id]["chunk_count"] += 1 docs[doc_id]["upload_date"] = meta.get("upload_date", "") total_chunks = sum(d["chunk_count"] for d in docs.values()) doc_list = [ { "document_id": doc_id, "filename": info["filename"], "chunk_count": info["chunk_count"], "upload_date": info["upload_date"], } for doc_id, info in docs.items() ] return doc_list, len(doc_list), total_chunks def list_chunks(self, document_id: str) -> List[Dict[str, Any]]: all_data = self.collection.get(include=["metadatas"]) chunks = [] for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]): if chunk_id.startswith(f"{document_id}_"): chunks.append({ "chunk_id": chunk_id, "chunk_index": meta.get("chunk_index", 0), "content_summary": meta.get("content_summary", ""), "page_number": meta.get("page_number"), "chunk_file_path": meta.get("chunk_file_path"), }) chunks.sort(key=lambda x: x["chunk_index"]) return chunks def delete_document(self, document_id: str) -> Tuple[bool, int]: all_data = self.collection.get(include=["metadatas"]) ids_to_delete = [ chunk_id for chunk_id in all_data["ids"] if chunk_id.startswith(f"{document_id}_") ] if not ids_to_delete: return False, 0 self.collection.delete(ids=ids_to_delete) return True, len(ids_to_delete) def delete_chunk(self, chunk_id: str) -> bool: all_data = self.collection.get(include=["metadatas"]) if chunk_id not in all_data["ids"]: return False self.collection.delete(ids=[chunk_id]) return True