legco_ai_assistant/backend/app/services/rag.py

215 lines
7.4 KiB
Python

"""RAG service for embedding, retrieval, and response generation."""
import uuid
from typing import TYPE_CHECKING, List, Tuple, Dict, Any, Optional
import logging
from app.core.config import Settings
from app.core.database import get_chroma_client
if TYPE_CHECKING:
from app.services.prompt_service import PromptService
logger = logging.getLogger(__name__)
class RAGService:
"""Service for document ingestion, retrieval, and response generation."""
def __init__(
self,
chroma_client=None,
llm_client=None,
settings: Optional[Settings] = None,
prompt_service: "PromptService | None" = None,
):
self.chroma_client = chroma_client or get_chroma_client()
self.llm_client = llm_client
self.settings = settings
self._prompt_service = prompt_service
self._collection = None
@property
def collection(self):
if self._collection is None:
from app.core.database import get_or_create_collection, get_embedding_function_settings
embedding_fn = None
if self.settings is not None:
embedding_fn = get_embedding_function_settings(self.settings)
self._collection = get_or_create_collection(
self.chroma_client, "documents", embedding_function=embedding_fn
)
return self._collection
def ingest_document(
self,
file_path: str,
chunks: List[str],
metadata_list: List[Dict[str, Any]],
document_id: Optional[str] = None,
) -> str:
if not chunks:
return ""
document_id = document_id or str(uuid.uuid4())
ids = [f"{document_id}_{i}" for i in range(len(chunks))]
self.collection.add(
documents=chunks,
metadatas=metadata_list,
ids=ids,
)
return document_id
def retrieve(
self,
query_keywords: List[str],
n_results: int = 10,
) -> List[Tuple[str, Dict[str, Any], float]]:
query_text = " ".join(query_keywords)
results = self.collection.query(
query_texts=[query_text],
n_results=n_results,
)
chunks = []
if results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
metadata = results["metadatas"][0][i] if results["metadatas"][0] else {}
distance = results["distances"][0][i] if results["distances"][0] else 0.0
chunks.append((doc, metadata, distance))
return chunks
async def generate_response(
self,
question: str,
chunks: List[str],
metadata_list: List[Dict[str, Any]],
) -> Tuple[str, str]:
"""Generate a RAG response and return it alongside the prompt used.
Args:
question: The user's question.
chunks: Retrieved chunk texts.
metadata_list: Metadata dicts corresponding to each chunk.
Returns:
A tuple of (answer, prompt). answer is the LLM-generated response
(or a fallback message). prompt is the rendered prompt string, or
``""`` when no prompt was built.
"""
if not chunks:
return "I could not find any relevant information to answer your question.", ""
if self.llm_client is None:
return "LLM client not configured.", ""
context_parts = []
for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)):
source = meta.get("filename", "unknown")
summary = meta.get("content_summary", "")
page_num = meta.get("page_number")
citation_label = f"{source}, page {page_num}" if page_num else source
context_parts.append(
f"[{citation_label}] Source: {source}\n"
f"Summary: {summary}\n"
f"Content: {chunk}\n"
)
context = "\n".join(context_parts)
if self._prompt_service is not None:
template = self._prompt_service.get_prompt_template("generate")
else:
template = (
f"Question: {'{question}'}\n\n"
f"Answer the question using ONLY these document chunks. "
f"Do not use any external knowledge. "
f"Format your answer as bullet points. "
f"Cite your sources inline using the exact bracket labels provided, "
f"e.g. [filename, page N]. Place the citation at the end of each relevant point.\n\n"
f"Document chunks:\n{'{context}'}\n\n"
f"Answer:"
)
# str.replace is safe even with stray curly braces in user text.
prompt = template.replace("{question}", question).replace("{context}", context)
result = await self.llm_client.complete(prompt=prompt, temperature=0.3, step_name="ResponseGeneration")
return result, prompt
def list_documents(self) -> Tuple[List[Dict[str, Any]], int, int]:
from collections import defaultdict
all_data = self.collection.get(include=["metadatas"])
if not all_data["metadatas"]:
return [], 0, 0
docs = defaultdict(lambda: {"filename": "", "chunk_count": 0, "upload_date": ""})
for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]):
parts = chunk_id.rsplit("_", 1)
doc_id = parts[0] if len(parts) == 2 else chunk_id
docs[doc_id]["filename"] = meta.get("filename", "unknown")
docs[doc_id]["chunk_count"] += 1
docs[doc_id]["upload_date"] = meta.get("upload_date", "")
total_chunks = sum(d["chunk_count"] for d in docs.values())
doc_list = [
{
"document_id": doc_id,
"filename": info["filename"],
"chunk_count": info["chunk_count"],
"upload_date": info["upload_date"],
}
for doc_id, info in docs.items()
]
return doc_list, len(doc_list), total_chunks
def list_chunks(self, document_id: str) -> List[Dict[str, Any]]:
all_data = self.collection.get(include=["metadatas"])
chunks = []
for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]):
if chunk_id.startswith(f"{document_id}_"):
chunks.append({
"chunk_id": chunk_id,
"chunk_index": meta.get("chunk_index", 0),
"content_summary": meta.get("content_summary", ""),
"page_number": meta.get("page_number"),
"chunk_file_path": meta.get("chunk_file_path"),
})
chunks.sort(key=lambda x: x["chunk_index"])
return chunks
def delete_document(self, document_id: str) -> Tuple[bool, int]:
all_data = self.collection.get(include=["metadatas"])
ids_to_delete = [
chunk_id for chunk_id in all_data["ids"]
if chunk_id.startswith(f"{document_id}_")
]
if not ids_to_delete:
return False, 0
self.collection.delete(ids=ids_to_delete)
return True, len(ids_to_delete)
def delete_chunk(self, chunk_id: str) -> bool:
all_data = self.collection.get(include=["metadatas"])
if chunk_id not in all_data["ids"]:
return False
self.collection.delete(ids=[chunk_id])
return True