legco_ai_assistant/backend/app/services/rag_pipeline.py

290 lines
10 KiB
Python

import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from app.models.common import SourceMetadata
from app.models.query import SubQuestionSources
from app.services.query_decomposer import QueryDecomposer
from app.services.relevance_filter import RelevanceFilter
from app.services.rag import RAGService
logger = logging.getLogger(__name__)
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
@dataclass
class PipelineSnapshot:
"""Complete snapshot at a given pipeline stage for accuracy testing capture."""
phase: str
question: str = ""
# Stage 1: Decompose
extracted_questions: List[str] = field(default_factory=list)
decompose_prompt: str = ""
decomposer_time_ms: int = 0
# Stage 2: Retrieve
retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field(
default_factory=list
)
chunks_retrieved_count: int = 0
retriever_time_ms: int = 0
# Stage 3: Filter
filter_prompt: str = ""
filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field(
default_factory=list
)
chunks_filtered_count: int = 0
filter_time_ms: int = 0
# Stage 4: Generate
generate_prompt: str = ""
answer: str = ""
sub_question_sources: List[SubQuestionSources] = field(default_factory=list)
generator_time_ms: int = 0
# Metadata
total_time_ms: int = 0
error_message: str = ""
class RAGPipeline:
"""Reusable RAG pipeline: decompose → retrieve → filter → generate.
Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling.
Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints
(collect snapshots into GenerateResult).
Usage:
pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...)
async for snap in pipeline.execute("question text"):
if snap.phase == "completed":
print(snap.answer)
"""
def __init__(
self,
*,
decomposer: QueryDecomposer,
rag: RAGService,
relevance_filter: RelevanceFilter,
retrieval_n_results: int = 10,
relevance_threshold: float = 7.0,
):
self._decomposer = decomposer
self._rag = rag
self._relevance_filter = relevance_filter
self._retrieval_n_results = retrieval_n_results
self._relevance_threshold = relevance_threshold
async def execute(
self,
question: str,
stop_after_decompose: bool = False,
) -> AsyncGenerator[PipelineSnapshot, None]:
"""Execute the full pipeline, yielding one snapshot per stage."""
overall_start = time.perf_counter()
# --- Stage 1: Decompose ---
stage_start = time.perf_counter()
decompose_result = await self._decomposer.decompose(question)
if isinstance(decompose_result, tuple):
extracted_questions, decompose_prompt = decompose_result
else:
extracted_questions, decompose_prompt = decompose_result, ""
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
if not extracted_questions:
extracted_questions = [question]
yield PipelineSnapshot(
phase="decomposed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
)
if stop_after_decompose:
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
total_time_ms=total_ms,
)
return
# --- Stage 2: Retrieve ---
stage_start = time.perf_counter()
retrieval_results = (
self._rag.retrieve_per_subquestion(
extracted_questions, n_results=self._retrieval_n_results
)
if extracted_questions
else []
)
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_retrieved_count = sum(
len(chunks) for _, chunks in retrieval_results
)
yield PipelineSnapshot(
phase="retrieving",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
)
if not any(chunks for _, chunks in retrieval_results):
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
chunks_retrieved_count=0,
retriever_time_ms=retriever_time_ms,
answer=NO_RESULTS_ANSWER,
total_time_ms=total_ms,
)
return
# --- Stage 3: Filter ---
stage_start = time.perf_counter()
chunks_by_subq = [
[(text, meta) for text, meta, _dist in chunks]
for _, chunks in retrieval_results
]
if extracted_questions and chunks_by_subq:
filter_result = await self._relevance_filter.filter_per_subquestion(
extracted_questions, chunks_by_subq, threshold=self._relevance_threshold
)
else:
filter_result = ([], "")
if isinstance(filter_result, tuple):
filtered_by_subq, filter_prompt = filter_result
else:
filtered_by_subq, filter_prompt = filter_result, ""
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_filtered_count = sum(
len(chunks) for _, chunks in filtered_by_subq
)
yield PipelineSnapshot(
phase="filtering",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
filter_prompt=filter_prompt,
filtered_by_subq=filtered_by_subq,
chunks_filtered_count=chunks_filtered_count,
filter_time_ms=filter_time_ms,
)
if not filtered_by_subq or not any(
chunks for _, chunks in filtered_by_subq
):
total_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decomposer_time_ms=decomposer_time_ms,
retriever_time_ms=retriever_time_ms,
filter_time_ms=filter_time_ms,
answer=NO_RESULTS_ANSWER,
total_time_ms=total_ms,
)
return
# --- Stage 4: Generate ---
stage_start = time.perf_counter()
sub_chunk_texts = []
sub_chunk_metadata = []
for _, filtered_chunks in filtered_by_subq:
sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks])
sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks])
if extracted_questions and filtered_by_subq:
gen_result = await self._rag.generate_response_per_subquestion(
extracted_questions, sub_chunk_texts, sub_chunk_metadata
)
else:
gen_result = ("", "", [])
if isinstance(gen_result, tuple) and len(gen_result) == 3:
answer, generate_prompt, grouped_sources_meta = gen_result
else:
answer, generate_prompt = (
gen_result if isinstance(gen_result, tuple) else (gen_result, "")
)
grouped_sources_meta = []
sub_question_sources = []
for idx, (sub_q_text, sources_meta) in enumerate(
zip(extracted_questions, grouped_sources_meta)
):
sources = [
SourceMetadata(
filename=meta.get("filename", "unknown"),
upload_date=meta.get("upload_date", ""),
content_summary=meta.get("content_summary", ""),
chunk_index=meta.get("chunk_index", 0),
page_number=meta.get("page_number"),
chunk_file_path=meta.get("chunk_file_path"),
document_id=meta.get("document_id"),
)
for meta in sources_meta
]
sub_question_sources.append(
SubQuestionSources(
sub_question_index=idx,
sub_question_text=sub_q_text,
sources=sources,
)
)
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
yield PipelineSnapshot(
phase="completed",
question=question,
extracted_questions=extracted_questions,
decompose_prompt=decompose_prompt,
decomposer_time_ms=decomposer_time_ms,
retrieval_results=retrieval_results,
chunks_retrieved_count=chunks_retrieved_count,
retriever_time_ms=retriever_time_ms,
filter_prompt=filter_prompt,
filtered_by_subq=filtered_by_subq,
chunks_filtered_count=chunks_filtered_count,
filter_time_ms=filter_time_ms,
generate_prompt=generate_prompt,
answer=answer,
sub_question_sources=sub_question_sources,
generator_time_ms=generator_time_ms,
total_time_ms=total_time_ms,
)