116 lines
3.5 KiB
Python
116 lines
3.5 KiB
Python
"""RAG service for embedding, retrieval, and response generation."""
|
|
import uuid
|
|
from typing import List, Tuple, Dict, Any, Optional
|
|
import logging
|
|
|
|
from app.core.config import Settings
|
|
from app.core.database import get_chroma_client
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RAGService:
|
|
"""Service for document ingestion, retrieval, and response generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
chroma_client=None,
|
|
llm_client=None,
|
|
settings: Optional[Settings] = None,
|
|
):
|
|
self.chroma_client = chroma_client or get_chroma_client()
|
|
self.llm_client = llm_client
|
|
self.settings = settings
|
|
|
|
self._collection = None
|
|
|
|
@property
|
|
def collection(self):
|
|
if self._collection is None:
|
|
from app.core.database import get_or_create_collection, get_embedding_function_settings
|
|
embedding_fn = None
|
|
if self.settings is not None:
|
|
embedding_fn = get_embedding_function_settings(self.settings)
|
|
self._collection = get_or_create_collection(
|
|
self.chroma_client, "documents", embedding_function=embedding_fn
|
|
)
|
|
return self._collection
|
|
|
|
def ingest_document(
|
|
self,
|
|
file_path: str,
|
|
chunks: List[str],
|
|
metadata_list: List[Dict[str, Any]],
|
|
) -> str:
|
|
if not chunks:
|
|
return ""
|
|
|
|
document_id = str(uuid.uuid4())
|
|
ids = [f"{document_id}_{i}" for i in range(len(chunks))]
|
|
|
|
self.collection.add(
|
|
documents=chunks,
|
|
metadatas=metadata_list,
|
|
ids=ids,
|
|
)
|
|
|
|
return document_id
|
|
|
|
def retrieve(
|
|
self,
|
|
query_keywords: List[str],
|
|
n_results: int = 10,
|
|
) -> List[Tuple[str, Dict[str, Any], float]]:
|
|
query_text = " ".join(query_keywords)
|
|
|
|
results = self.collection.query(
|
|
query_texts=[query_text],
|
|
n_results=n_results,
|
|
)
|
|
|
|
chunks = []
|
|
if results["documents"] and results["documents"][0]:
|
|
for i, doc in enumerate(results["documents"][0]):
|
|
metadata = results["metadatas"][0][i] if results["metadatas"][0] else {}
|
|
distance = results["distances"][0][i] if results["distances"][0] else 0.0
|
|
chunks.append((doc, metadata, distance))
|
|
|
|
return chunks
|
|
|
|
async def generate_response(
|
|
self,
|
|
question: str,
|
|
chunks: List[str],
|
|
metadata_list: List[Dict[str, Any]],
|
|
) -> str:
|
|
if not chunks:
|
|
return "I could not find any relevant information to answer your question."
|
|
|
|
if self.llm_client is None:
|
|
return "LLM client not configured."
|
|
|
|
context_parts = []
|
|
for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)):
|
|
source = meta.get("filename", "unknown")
|
|
summary = meta.get("content_summary", "")
|
|
context_parts.append(
|
|
f"[{i + 1}] Source: {source}\n"
|
|
f"Summary: {summary}\n"
|
|
f"Content: {chunk}\n"
|
|
)
|
|
|
|
context = "\n".join(context_parts)
|
|
|
|
prompt = (
|
|
f"Question: {question}\n\n"
|
|
f"Answer the question using ONLY these document chunks. "
|
|
f"Do not use any external knowledge. "
|
|
f"Format your answer as bullet points. "
|
|
f"Cite the source number [N] for each point.\n\n"
|
|
f"Document chunks:\n{context}\n\n"
|
|
f"Answer:"
|
|
)
|
|
|
|
return await self.llm_client.complete(prompt=prompt, temperature=0.3)
|