import logging import time from dataclasses import dataclass, field from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from app.models.common import SourceMetadata from app.models.query import SubQuestionSources from app.services.query_decomposer import QueryDecomposer from app.services.relevance_filter import RelevanceFilter from app.services.rag import RAGService logger = logging.getLogger(__name__) NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question." @dataclass class PipelineSnapshot: """Complete snapshot at a given pipeline stage for accuracy testing capture.""" phase: str question: str = "" # Stage 1: Decompose extracted_questions: List[str] = field(default_factory=list) decompose_prompt: str = "" decomposer_time_ms: int = 0 # Stage 2: Retrieve retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field( default_factory=list ) chunks_retrieved_count: int = 0 retriever_time_ms: int = 0 # Stage 3: Filter filter_prompt: str = "" filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field( default_factory=list ) chunks_filtered_count: int = 0 filter_time_ms: int = 0 # Stage 4: Generate generate_prompt: str = "" answer: str = "" sub_question_sources: List[SubQuestionSources] = field(default_factory=list) generator_time_ms: int = 0 # Metadata total_time_ms: int = 0 error_message: str = "" class RAGPipeline: """Reusable RAG pipeline: decompose → retrieve → filter → generate. Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling. Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints (collect snapshots into GenerateResult). Usage: pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...) async for snap in pipeline.execute("question text"): if snap.phase == "completed": print(snap.answer) """ def __init__( self, *, decomposer: QueryDecomposer, rag: RAGService, relevance_filter: RelevanceFilter, retrieval_n_results: int = 10, relevance_threshold: float = 7.0, ): self._decomposer = decomposer self._rag = rag self._relevance_filter = relevance_filter self._retrieval_n_results = retrieval_n_results self._relevance_threshold = relevance_threshold async def execute( self, question: str, stop_after_decompose: bool = False, ) -> AsyncGenerator[PipelineSnapshot, None]: """Execute the full pipeline, yielding one snapshot per stage.""" overall_start = time.perf_counter() # --- Stage 1: Decompose --- stage_start = time.perf_counter() decompose_result = await self._decomposer.decompose(question) if isinstance(decompose_result, tuple): extracted_questions, decompose_prompt = decompose_result else: extracted_questions, decompose_prompt = decompose_result, "" decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000) if not extracted_questions: extracted_questions = [question] yield PipelineSnapshot( phase="decomposed", question=question, extracted_questions=extracted_questions, decompose_prompt=decompose_prompt, decomposer_time_ms=decomposer_time_ms, ) if stop_after_decompose: total_ms = int((time.perf_counter() - overall_start) * 1000) yield PipelineSnapshot( phase="completed", question=question, extracted_questions=extracted_questions, decompose_prompt=decompose_prompt, decomposer_time_ms=decomposer_time_ms, total_time_ms=total_ms, ) return # --- Stage 2: Retrieve --- stage_start = time.perf_counter() retrieval_results = ( self._rag.retrieve_per_subquestion( extracted_questions, n_results=self._retrieval_n_results ) if extracted_questions else [] ) retriever_time_ms = int((time.perf_counter() - stage_start) * 1000) chunks_retrieved_count = sum( len(chunks) for _, chunks in retrieval_results ) yield PipelineSnapshot( phase="retrieving", question=question, extracted_questions=extracted_questions, decompose_prompt=decompose_prompt, decomposer_time_ms=decomposer_time_ms, retrieval_results=retrieval_results, chunks_retrieved_count=chunks_retrieved_count, retriever_time_ms=retriever_time_ms, ) if not any(chunks for _, chunks in retrieval_results): total_ms = int((time.perf_counter() - overall_start) * 1000) yield PipelineSnapshot( phase="completed", question=question, extracted_questions=extracted_questions, decompose_prompt=decompose_prompt, decomposer_time_ms=decomposer_time_ms, chunks_retrieved_count=0, retriever_time_ms=retriever_time_ms, answer=NO_RESULTS_ANSWER, total_time_ms=total_ms, ) return # --- Stage 3: Filter --- stage_start = time.perf_counter() chunks_by_subq = [ [(text, meta) for text, meta, _dist in chunks] for _, chunks in retrieval_results ] if extracted_questions and chunks_by_subq: filter_result = await self._relevance_filter.filter_per_subquestion( extracted_questions, chunks_by_subq, threshold=self._relevance_threshold ) else: filter_result = ([], "") if isinstance(filter_result, tuple): filtered_by_subq, filter_prompt = filter_result else: filtered_by_subq, filter_prompt = filter_result, "" filter_time_ms = int((time.perf_counter() - stage_start) * 1000) chunks_filtered_count = sum( len(chunks) for _, chunks in filtered_by_subq ) yield PipelineSnapshot( phase="filtering", question=question, extracted_questions=extracted_questions, decompose_prompt=decompose_prompt, decomposer_time_ms=decomposer_time_ms, retrieval_results=retrieval_results, chunks_retrieved_count=chunks_retrieved_count, retriever_time_ms=retriever_time_ms, filter_prompt=filter_prompt, filtered_by_subq=filtered_by_subq, chunks_filtered_count=chunks_filtered_count, filter_time_ms=filter_time_ms, ) if not filtered_by_subq or not any( chunks for _, chunks in filtered_by_subq ): total_ms = int((time.perf_counter() - overall_start) * 1000) yield PipelineSnapshot( phase="completed", question=question, extracted_questions=extracted_questions, decomposer_time_ms=decomposer_time_ms, retriever_time_ms=retriever_time_ms, filter_time_ms=filter_time_ms, answer=NO_RESULTS_ANSWER, total_time_ms=total_ms, ) return # --- Stage 4: Generate --- stage_start = time.perf_counter() sub_chunk_texts = [] sub_chunk_metadata = [] for _, filtered_chunks in filtered_by_subq: sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks]) sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks]) if extracted_questions and filtered_by_subq: gen_result = await self._rag.generate_response_per_subquestion( extracted_questions, sub_chunk_texts, sub_chunk_metadata ) else: gen_result = ("", "", []) if isinstance(gen_result, tuple) and len(gen_result) == 3: answer, generate_prompt, grouped_sources_meta = gen_result else: answer, generate_prompt = ( gen_result if isinstance(gen_result, tuple) else (gen_result, "") ) grouped_sources_meta = [] sub_question_sources = [] for idx, (sub_q_text, sources_meta) in enumerate( zip(extracted_questions, grouped_sources_meta) ): sources = [ SourceMetadata( filename=meta.get("filename", "unknown"), upload_date=meta.get("upload_date", ""), content_summary=meta.get("content_summary", ""), chunk_index=meta.get("chunk_index", 0), page_number=meta.get("page_number"), chunk_file_path=meta.get("chunk_file_path"), document_id=meta.get("document_id"), ) for meta in sources_meta ] sub_question_sources.append( SubQuestionSources( sub_question_index=idx, sub_question_text=sub_q_text, sources=sources, ) ) generator_time_ms = int((time.perf_counter() - stage_start) * 1000) total_time_ms = int((time.perf_counter() - overall_start) * 1000) yield PipelineSnapshot( phase="completed", question=question, extracted_questions=extracted_questions, decompose_prompt=decompose_prompt, decomposer_time_ms=decomposer_time_ms, retrieval_results=retrieval_results, chunks_retrieved_count=chunks_retrieved_count, retriever_time_ms=retriever_time_ms, filter_prompt=filter_prompt, filtered_by_subq=filtered_by_subq, chunks_filtered_count=chunks_filtered_count, filter_time_ms=filter_time_ms, generate_prompt=generate_prompt, answer=answer, sub_question_sources=sub_question_sources, generator_time_ms=generator_time_ms, total_time_ms=total_time_ms, )