legco_ai_assistant/backend/app/services/rag.py

116 lines
3.5 KiB
Python

"""RAG service for embedding, retrieval, and response generation."""
import uuid
from typing import List, Tuple, Dict, Any, Optional
import logging
from app.core.config import Settings
from app.core.database import get_chroma_client
logger = logging.getLogger(__name__)
class RAGService:
"""Service for document ingestion, retrieval, and response generation."""
def __init__(
self,
chroma_client=None,
llm_client=None,
settings: Optional[Settings] = None,
):
self.chroma_client = chroma_client or get_chroma_client()
self.llm_client = llm_client
self.settings = settings
self._collection = None
@property
def collection(self):
if self._collection is None:
from app.core.database import get_or_create_collection, get_embedding_function_settings
embedding_fn = None
if self.settings is not None:
embedding_fn = get_embedding_function_settings(self.settings)
self._collection = get_or_create_collection(
self.chroma_client, "documents", embedding_function=embedding_fn
)
return self._collection
def ingest_document(
self,
file_path: str,
chunks: List[str],
metadata_list: List[Dict[str, Any]],
) -> str:
if not chunks:
return ""
document_id = str(uuid.uuid4())
ids = [f"{document_id}_{i}" for i in range(len(chunks))]
self.collection.add(
documents=chunks,
metadatas=metadata_list,
ids=ids,
)
return document_id
def retrieve(
self,
query_keywords: List[str],
n_results: int = 10,
) -> List[Tuple[str, Dict[str, Any], float]]:
query_text = " ".join(query_keywords)
results = self.collection.query(
query_texts=[query_text],
n_results=n_results,
)
chunks = []
if results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
metadata = results["metadatas"][0][i] if results["metadatas"][0] else {}
distance = results["distances"][0][i] if results["distances"][0] else 0.0
chunks.append((doc, metadata, distance))
return chunks
async def generate_response(
self,
question: str,
chunks: List[str],
metadata_list: List[Dict[str, Any]],
) -> str:
if not chunks:
return "I could not find any relevant information to answer your question."
if self.llm_client is None:
return "LLM client not configured."
context_parts = []
for i, (chunk, meta) in enumerate(zip(chunks, metadata_list)):
source = meta.get("filename", "unknown")
summary = meta.get("content_summary", "")
context_parts.append(
f"[{i + 1}] Source: {source}\n"
f"Summary: {summary}\n"
f"Content: {chunk}\n"
)
context = "\n".join(context_parts)
prompt = (
f"Question: {question}\n\n"
f"Answer the question using ONLY these document chunks. "
f"Do not use any external knowledge. "
f"Format your answer as bullet points. "
f"Cite the source number [N] for each point.\n\n"
f"Document chunks:\n{context}\n\n"
f"Answer:"
)
return await self.llm_client.complete(prompt=prompt, temperature=0.3, step_name="ResponseGeneration")