343 lines
12 KiB
Python
343 lines
12 KiB
Python
"""RAG service for embedding, retrieval, and response generation."""
|
|
import uuid
|
|
from typing import TYPE_CHECKING, List, Tuple, Dict, Any, Optional
|
|
import logging
|
|
|
|
from app.core.config import Settings
|
|
from app.core.database import get_chroma_client
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.prompt_service import PromptService
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RAGService:
|
|
"""Service for document ingestion, retrieval, and response generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
chroma_client=None,
|
|
llm_client=None,
|
|
settings: Optional[Settings] = None,
|
|
prompt_service: "PromptService | None" = None,
|
|
):
|
|
self.chroma_client = chroma_client or get_chroma_client()
|
|
self.llm_client = llm_client
|
|
self.settings = settings
|
|
self._prompt_service = prompt_service
|
|
|
|
self._collection = None
|
|
|
|
@property
|
|
def collection(self):
|
|
if self._collection is None:
|
|
from app.core.database import get_or_create_collection, get_embedding_function_settings
|
|
embedding_fn = None
|
|
if self.settings is not None:
|
|
embedding_fn = get_embedding_function_settings(self.settings)
|
|
self._collection = get_or_create_collection(
|
|
self.chroma_client, "documents", embedding_function=embedding_fn
|
|
)
|
|
return self._collection
|
|
|
|
def ingest_document(
|
|
self,
|
|
file_path: str,
|
|
chunks: List[str],
|
|
metadata_list: List[Dict[str, Any]],
|
|
document_id: Optional[str] = None,
|
|
) -> str:
|
|
if not chunks:
|
|
return ""
|
|
|
|
document_id = document_id or str(uuid.uuid4())
|
|
ids = [f"{document_id}_{i}" for i in range(len(chunks))]
|
|
|
|
self.collection.add(
|
|
documents=chunks,
|
|
metadatas=metadata_list,
|
|
ids=ids,
|
|
)
|
|
|
|
return document_id
|
|
|
|
def retrieve_per_subquestion(
|
|
self,
|
|
sub_questions: List[str],
|
|
n_results: int = 10,
|
|
) -> List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]]:
|
|
"""Retrieve chunks for each sub-question independently.
|
|
|
|
Calls retrieve() once per sub-question to get chunks specifically
|
|
relevant to each decomposed question, rather than joining all
|
|
sub-questions into a single query string.
|
|
|
|
Args:
|
|
sub_questions: List of decomposed sub-questions from QueryDecomposer.
|
|
n_results: Number of chunks to retrieve per sub-question.
|
|
|
|
Returns:
|
|
List of (sub_question, chunks) tuples. Each chunks list contains
|
|
(text, metadata, distance) tuples in the standard retrieve() format.
|
|
Returns empty list if sub_questions is empty.
|
|
"""
|
|
if not sub_questions:
|
|
return []
|
|
|
|
results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = []
|
|
for sub_q in sub_questions:
|
|
chunks = self.retrieve([sub_q], n_results=n_results)
|
|
results.append((sub_q, chunks))
|
|
return results
|
|
|
|
def retrieve(
|
|
self,
|
|
query_keywords: List[str],
|
|
n_results: int = 10,
|
|
) -> List[Tuple[str, Dict[str, Any], float]]:
|
|
query_text = " ".join(query_keywords)
|
|
|
|
results = self.collection.query(
|
|
query_texts=[query_text],
|
|
n_results=n_results,
|
|
)
|
|
|
|
chunks = []
|
|
if results["documents"] and results["documents"][0]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
metadata = results["metadatas"][0][i] if results["metadatas"][0] else {}
|
|
distance = results["distances"][0][i] if results["distances"][0] else 0.0
|
|
chunks.append((doc, metadata, distance))
|
|
|
|
return chunks
|
|
|
|
async def generate_response(
|
|
self,
|
|
question: str,
|
|
chunks: List[str],
|
|
metadata_list: List[Dict[str, Any]],
|
|
) -> Tuple[str, str]:
|
|
"""Generate a RAG response and return it alongside the prompt used.
|
|
|
|
Args:
|
|
question: The user's question.
|
|
chunks: Retrieved chunk texts.
|
|
metadata_list: Metadata dicts corresponding to each chunk.
|
|
|
|
Returns:
|
|
A tuple of (answer, prompt). answer is the LLM-generated response
|
|
(or a fallback message). prompt is the rendered prompt string, or
|
|
``""`` when no prompt was built.
|
|
"""
|
|
if not chunks:
|
|
return "I could not find any relevant information to answer your question.", ""
|
|
|
|
if self.llm_client is None:
|
|
return "LLM client not configured.", ""
|
|
|
|
context_parts = []
|
|
for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)):
|
|
source = meta.get("filename", "unknown")
|
|
summary = meta.get("content_summary", "")
|
|
page_num = meta.get("page_number")
|
|
citation_label = f"{source}, page {page_num}" if page_num else source
|
|
context_parts.append(
|
|
f"[{citation_label}] Source: {source}\n"
|
|
f"Summary: {summary}\n"
|
|
f"Content: {chunk}\n"
|
|
)
|
|
|
|
context = "\n".join(context_parts)
|
|
|
|
if self._prompt_service is not None:
|
|
template = self._prompt_service.get_prompt_template("generate")
|
|
else:
|
|
template = (
|
|
f"Question: {'{question}'}\n\n"
|
|
f"Answer the question using ONLY these document chunks. "
|
|
f"Do not use any external knowledge. "
|
|
f"Format your answer as bullet points. "
|
|
f"Cite your sources inline using the exact bracket labels provided, "
|
|
f"e.g. [filename, page N]. Place the citation at the end of each relevant point.\n\n"
|
|
f"Document chunks:\n{'{context}'}\n\n"
|
|
f"Answer:"
|
|
)
|
|
|
|
# str.replace is safe even with stray curly braces in user text.
|
|
prompt = template.replace("{question}", question).replace("{context}", context)
|
|
|
|
result = await self.llm_client.complete(prompt=prompt, temperature=0.3, step_name="ResponseGeneration")
|
|
return result, prompt
|
|
|
|
async def generate_response_per_subquestion(
|
|
self,
|
|
sub_questions: List[str],
|
|
sub_chunks: List[List[str]],
|
|
sub_metadata: List[List[Dict[str, Any]]],
|
|
) -> Tuple[str, str, List[List[Dict[str, Any]]]]:
|
|
"""Generate sub-question-organized RAG response.
|
|
|
|
Builds context sections for each sub-question and asks the LLM to
|
|
answer each one using only its own document chunks. Returns the
|
|
full markdown answer plus sources organized by sub-question.
|
|
|
|
Args:
|
|
sub_questions: List of decomposed sub-questions.
|
|
sub_chunks: List of chunk text lists (one per sub-question).
|
|
sub_metadata: List of metadata dict lists (one per sub-question).
|
|
Must be same length as sub_chunks, with inner lists matching.
|
|
|
|
Returns:
|
|
Tuple of (answer, prompt, grouped_sources).
|
|
answer: Markdown string with ## Sub-question N: sections.
|
|
prompt: The rendered LLM prompt string.
|
|
grouped_sources: List of metadata dict lists (one per sub-question),
|
|
each metadata dict is a SourceMetadata-compatible dict.
|
|
"""
|
|
if not sub_questions:
|
|
return (
|
|
"I could not find any relevant information to answer your question.",
|
|
"",
|
|
[],
|
|
)
|
|
|
|
has_chunks = any(len(c) > 0 for c in sub_chunks)
|
|
if not has_chunks:
|
|
return (
|
|
"I could not find any relevant information to answer your question.",
|
|
"",
|
|
[],
|
|
)
|
|
|
|
if self.llm_client is None:
|
|
return ("LLM client not configured.", "", [])
|
|
|
|
context_parts = []
|
|
for idx, (sq, chunks, metas) in enumerate(
|
|
zip(sub_questions, sub_chunks, sub_metadata)
|
|
):
|
|
context_parts.append(
|
|
f'### Context for Sub-question {idx}: "{sq}"'
|
|
)
|
|
for chunk, meta in zip(chunks, metas):
|
|
source = meta.get("filename", "unknown")
|
|
summary = meta.get("content_summary", "")
|
|
page_num = meta.get("page_number")
|
|
citation_label = (
|
|
f"{source}, page {page_num}" if page_num else source
|
|
)
|
|
context_parts.append(
|
|
f"[{citation_label}] Source: {source}\n"
|
|
f"Summary: {summary}\n"
|
|
f"Content: {chunk}\n"
|
|
)
|
|
|
|
context_sections = "\n".join(context_parts)
|
|
|
|
if self._prompt_service is not None:
|
|
template = self._prompt_service.get_prompt_template(
|
|
"generate_per_subq"
|
|
)
|
|
else:
|
|
template = (
|
|
"You must answer each sub-question using ONLY the document "
|
|
"chunks provided for it.\n"
|
|
"Do not use any external knowledge.\n"
|
|
"Format your answer as markdown sections — one section per "
|
|
"sub-question.\n"
|
|
'Each section should start with "## Sub-question N: '
|
|
'<the question>"\n'
|
|
"Each section should contain 1-5 bullet points.\n"
|
|
"Cite your sources inline using bracket labels, "
|
|
"e.g. [filename, page N].\n"
|
|
"Place the citation at the end of each relevant bullet point."
|
|
"\n\n"
|
|
"{context_sections}\n\n"
|
|
"Answer:"
|
|
)
|
|
|
|
prompt = template.replace("{context_sections}", context_sections)
|
|
|
|
answer = await self.llm_client.complete(
|
|
prompt=prompt, temperature=0.3, step_name="ResponseGeneration"
|
|
)
|
|
|
|
grouped_sources: List[List[Dict[str, Any]]] = []
|
|
for metas in sub_metadata:
|
|
grouped_sources.append(list(metas))
|
|
|
|
return answer, prompt, grouped_sources
|
|
|
|
def list_documents(self) -> Tuple[List[Dict[str, Any]], int, int]:
|
|
from collections import defaultdict
|
|
|
|
all_data = self.collection.get(include=["metadatas"])
|
|
|
|
if not all_data["metadatas"]:
|
|
return [], 0, 0
|
|
|
|
docs = defaultdict(lambda: {"filename": "", "chunk_count": 0, "upload_date": ""})
|
|
|
|
for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]):
|
|
parts = chunk_id.rsplit("_", 1)
|
|
doc_id = parts[0] if len(parts) == 2 else chunk_id
|
|
|
|
docs[doc_id]["filename"] = meta.get("filename", "unknown")
|
|
docs[doc_id]["chunk_count"] += 1
|
|
docs[doc_id]["upload_date"] = meta.get("upload_date", "")
|
|
|
|
total_chunks = sum(d["chunk_count"] for d in docs.values())
|
|
doc_list = [
|
|
{
|
|
"document_id": doc_id,
|
|
"filename": info["filename"],
|
|
"chunk_count": info["chunk_count"],
|
|
"upload_date": info["upload_date"],
|
|
}
|
|
for doc_id, info in docs.items()
|
|
]
|
|
|
|
return doc_list, len(doc_list), total_chunks
|
|
|
|
def list_chunks(self, document_id: str) -> List[Dict[str, Any]]:
|
|
all_data = self.collection.get(include=["metadatas"])
|
|
|
|
chunks = []
|
|
for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]):
|
|
if chunk_id.startswith(f"{document_id}_"):
|
|
chunks.append({
|
|
"chunk_id": chunk_id,
|
|
"chunk_index": meta.get("chunk_index", 0),
|
|
"content_summary": meta.get("content_summary", ""),
|
|
"page_number": meta.get("page_number"),
|
|
"chunk_file_path": meta.get("chunk_file_path"),
|
|
})
|
|
|
|
chunks.sort(key=lambda x: x["chunk_index"])
|
|
return chunks
|
|
|
|
def delete_document(self, document_id: str) -> Tuple[bool, int]:
|
|
all_data = self.collection.get(include=["metadatas"])
|
|
|
|
ids_to_delete = [
|
|
chunk_id for chunk_id in all_data["ids"]
|
|
if chunk_id.startswith(f"{document_id}_")
|
|
]
|
|
|
|
if not ids_to_delete:
|
|
return False, 0
|
|
|
|
self.collection.delete(ids=ids_to_delete)
|
|
return True, len(ids_to_delete)
|
|
|
|
def delete_chunk(self, chunk_id: str) -> bool:
|
|
all_data = self.collection.get(include=["metadatas"])
|
|
|
|
if chunk_id not in all_data["ids"]:
|
|
return False
|
|
|
|
self.collection.delete(ids=[chunk_id])
|
|
return True
|