legco_ai_assistant/backend/app/services/rag.py

191 lines
6.3 KiB
Python

"""RAG service for embedding, retrieval, and response generation."""
import uuid
from typing import List, Tuple, Dict, Any, Optional
import logging
from app.core.config import Settings
from app.core.database import get_chroma_client
logger = logging.getLogger(__name__)
class RAGService:
"""Service for document ingestion, retrieval, and response generation."""
def __init__(
self,
chroma_client=None,
llm_client=None,
settings: Optional[Settings] = None,
):
self.chroma_client = chroma_client or get_chroma_client()
self.llm_client = llm_client
self.settings = settings
self._collection = None
@property
def collection(self):
if self._collection is None:
from app.core.database import get_or_create_collection, get_embedding_function_settings
embedding_fn = None
if self.settings is not None:
embedding_fn = get_embedding_function_settings(self.settings)
self._collection = get_or_create_collection(
self.chroma_client, "documents", embedding_function=embedding_fn
)
return self._collection
def ingest_document(
self,
file_path: str,
chunks: List[str],
metadata_list: List[Dict[str, Any]],
document_id: Optional[str] = None,
) -> str:
if not chunks:
return ""
document_id = document_id or str(uuid.uuid4())
ids = [f"{document_id}_{i}" for i in range(len(chunks))]
self.collection.add(
documents=chunks,
metadatas=metadata_list,
ids=ids,
)
return document_id
def retrieve(
self,
query_keywords: List[str],
n_results: int = 10,
) -> List[Tuple[str, Dict[str, Any], float]]:
query_text = " ".join(query_keywords)
results = self.collection.query(
query_texts=[query_text],
n_results=n_results,
)
chunks = []
if results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
metadata = results["metadatas"][0][i] if results["metadatas"][0] else {}
distance = results["distances"][0][i] if results["distances"][0] else 0.0
chunks.append((doc, metadata, distance))
return chunks
async def generate_response(
self,
question: str,
chunks: List[str],
metadata_list: List[Dict[str, Any]],
) -> str:
if not chunks:
return "I could not find any relevant information to answer your question."
if self.llm_client is None:
return "LLM client not configured."
context_parts = []
for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)):
source = meta.get("filename", "unknown")
summary = meta.get("content_summary", "")
page_num = meta.get("page_number")
citation_label = f"{source}, page {page_num}" if page_num else source
context_parts.append(
f"[{citation_label}] Source: {source}\n"
f"Summary: {summary}\n"
f"Content: {chunk}\n"
)
context = "\n".join(context_parts)
prompt = (
f"Question: {question}\n\n"
f"Answer the question using ONLY these document chunks. "
f"Do not use any external knowledge. "
f"Format your answer as bullet points. "
f"Cite your sources inline using the exact bracket labels provided, "
f"e.g. [filename, page N]. Place the citation at the end of each relevant point.\n\n"
f"Document chunks:\n{context}\n\n"
f"Answer:"
)
return await self.llm_client.complete(prompt=prompt, temperature=0.3, step_name="ResponseGeneration")
def list_documents(self) -> Tuple[List[Dict[str, Any]], int, int]:
from collections import defaultdict
all_data = self.collection.get(include=["metadatas"])
if not all_data["metadatas"]:
return [], 0, 0
docs = defaultdict(lambda: {"filename": "", "chunk_count": 0, "upload_date": ""})
for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]):
parts = chunk_id.rsplit("_", 1)
doc_id = parts[0] if len(parts) == 2 else chunk_id
docs[doc_id]["filename"] = meta.get("filename", "unknown")
docs[doc_id]["chunk_count"] += 1
docs[doc_id]["upload_date"] = meta.get("upload_date", "")
total_chunks = sum(d["chunk_count"] for d in docs.values())
doc_list = [
{
"document_id": doc_id,
"filename": info["filename"],
"chunk_count": info["chunk_count"],
"upload_date": info["upload_date"],
}
for doc_id, info in docs.items()
]
return doc_list, len(doc_list), total_chunks
def list_chunks(self, document_id: str) -> List[Dict[str, Any]]:
all_data = self.collection.get(include=["metadatas"])
chunks = []
for chunk_id, meta in zip(all_data["ids"], all_data["metadatas"]):
if chunk_id.startswith(f"{document_id}_"):
chunks.append({
"chunk_id": chunk_id,
"chunk_index": meta.get("chunk_index", 0),
"content_summary": meta.get("content_summary", ""),
"page_number": meta.get("page_number"),
"chunk_file_path": meta.get("chunk_file_path"),
})
chunks.sort(key=lambda x: x["chunk_index"])
return chunks
def delete_document(self, document_id: str) -> Tuple[bool, int]:
all_data = self.collection.get(include=["metadatas"])
ids_to_delete = [
chunk_id for chunk_id in all_data["ids"]
if chunk_id.startswith(f"{document_id}_")
]
if not ids_to_delete:
return False, 0
self.collection.delete(ids=ids_to_delete)
return True, len(ids_to_delete)
def delete_chunk(self, chunk_id: str) -> bool:
all_data = self.collection.get(include=["metadatas"])
if chunk_id not in all_data["ids"]:
return False
self.collection.delete(ids=[chunk_id])
return True