290 lines
10 KiB
Python
290 lines
10 KiB
Python
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|
|
|
from app.models.common import SourceMetadata
|
|
from app.models.query import SubQuestionSources
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
from app.services.rag import RAGService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
|
|
|
|
|
|
@dataclass
|
|
class PipelineSnapshot:
|
|
"""Complete snapshot at a given pipeline stage for accuracy testing capture."""
|
|
|
|
phase: str
|
|
question: str = ""
|
|
|
|
# Stage 1: Decompose
|
|
extracted_questions: List[str] = field(default_factory=list)
|
|
decompose_prompt: str = ""
|
|
decomposer_time_ms: int = 0
|
|
|
|
# Stage 2: Retrieve
|
|
retrieval_results: List[Tuple[str, List[Tuple[str, Dict[str, Any], float]]]] = field(
|
|
default_factory=list
|
|
)
|
|
chunks_retrieved_count: int = 0
|
|
retriever_time_ms: int = 0
|
|
|
|
# Stage 3: Filter
|
|
filter_prompt: str = ""
|
|
filtered_by_subq: List[Tuple[str, List[Tuple[str, Dict[str, Any]]]]] = field(
|
|
default_factory=list
|
|
)
|
|
chunks_filtered_count: int = 0
|
|
filter_time_ms: int = 0
|
|
|
|
# Stage 4: Generate
|
|
generate_prompt: str = ""
|
|
answer: str = ""
|
|
sub_question_sources: List[SubQuestionSources] = field(default_factory=list)
|
|
generator_time_ms: int = 0
|
|
|
|
# Metadata
|
|
total_time_ms: int = 0
|
|
error_message: str = ""
|
|
|
|
|
|
class RAGPipeline:
|
|
"""Reusable RAG pipeline: decompose → retrieve → filter → generate.
|
|
|
|
Yields PipelineSnapshot at each stage boundary. No SSE or HTTP coupling.
|
|
Use in streaming endpoints (wrap snapshots as SSE) or capture endpoints
|
|
(collect snapshots into GenerateResult).
|
|
|
|
Usage:
|
|
pipeline = RAGPipeline(decomposer=..., rag=..., relevance_filter=..., settings=...)
|
|
async for snap in pipeline.execute("question text"):
|
|
if snap.phase == "completed":
|
|
print(snap.answer)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
decomposer: QueryDecomposer,
|
|
rag: RAGService,
|
|
relevance_filter: RelevanceFilter,
|
|
retrieval_n_results: int = 10,
|
|
relevance_threshold: float = 7.0,
|
|
):
|
|
self._decomposer = decomposer
|
|
self._rag = rag
|
|
self._relevance_filter = relevance_filter
|
|
self._retrieval_n_results = retrieval_n_results
|
|
self._relevance_threshold = relevance_threshold
|
|
|
|
async def execute(
|
|
self,
|
|
question: str,
|
|
stop_after_decompose: bool = False,
|
|
) -> AsyncGenerator[PipelineSnapshot, None]:
|
|
"""Execute the full pipeline, yielding one snapshot per stage."""
|
|
overall_start = time.perf_counter()
|
|
|
|
# --- Stage 1: Decompose ---
|
|
stage_start = time.perf_counter()
|
|
decompose_result = await self._decomposer.decompose(question)
|
|
if isinstance(decompose_result, tuple):
|
|
extracted_questions, decompose_prompt = decompose_result
|
|
else:
|
|
extracted_questions, decompose_prompt = decompose_result, ""
|
|
|
|
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
|
|
if not extracted_questions:
|
|
extracted_questions = [question]
|
|
|
|
yield PipelineSnapshot(
|
|
phase="decomposed",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decompose_prompt=decompose_prompt,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
)
|
|
|
|
if stop_after_decompose:
|
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
|
yield PipelineSnapshot(
|
|
phase="completed",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decompose_prompt=decompose_prompt,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
total_time_ms=total_ms,
|
|
)
|
|
return
|
|
|
|
# --- Stage 2: Retrieve ---
|
|
stage_start = time.perf_counter()
|
|
retrieval_results = (
|
|
self._rag.retrieve_per_subquestion(
|
|
extracted_questions, n_results=self._retrieval_n_results
|
|
)
|
|
if extracted_questions
|
|
else []
|
|
)
|
|
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
|
|
chunks_retrieved_count = sum(
|
|
len(chunks) for _, chunks in retrieval_results
|
|
)
|
|
|
|
yield PipelineSnapshot(
|
|
phase="retrieving",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decompose_prompt=decompose_prompt,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
retrieval_results=retrieval_results,
|
|
chunks_retrieved_count=chunks_retrieved_count,
|
|
retriever_time_ms=retriever_time_ms,
|
|
)
|
|
|
|
if not any(chunks for _, chunks in retrieval_results):
|
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
|
yield PipelineSnapshot(
|
|
phase="completed",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decompose_prompt=decompose_prompt,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
chunks_retrieved_count=0,
|
|
retriever_time_ms=retriever_time_ms,
|
|
answer=NO_RESULTS_ANSWER,
|
|
total_time_ms=total_ms,
|
|
)
|
|
return
|
|
|
|
# --- Stage 3: Filter ---
|
|
stage_start = time.perf_counter()
|
|
chunks_by_subq = [
|
|
[(text, meta) for text, meta, _dist in chunks]
|
|
for _, chunks in retrieval_results
|
|
]
|
|
|
|
if extracted_questions and chunks_by_subq:
|
|
filter_result = await self._relevance_filter.filter_per_subquestion(
|
|
extracted_questions, chunks_by_subq, threshold=self._relevance_threshold
|
|
)
|
|
else:
|
|
filter_result = ([], "")
|
|
|
|
if isinstance(filter_result, tuple):
|
|
filtered_by_subq, filter_prompt = filter_result
|
|
else:
|
|
filtered_by_subq, filter_prompt = filter_result, ""
|
|
|
|
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
chunks_filtered_count = sum(
|
|
len(chunks) for _, chunks in filtered_by_subq
|
|
)
|
|
|
|
yield PipelineSnapshot(
|
|
phase="filtering",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decompose_prompt=decompose_prompt,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
retrieval_results=retrieval_results,
|
|
chunks_retrieved_count=chunks_retrieved_count,
|
|
retriever_time_ms=retriever_time_ms,
|
|
filter_prompt=filter_prompt,
|
|
filtered_by_subq=filtered_by_subq,
|
|
chunks_filtered_count=chunks_filtered_count,
|
|
filter_time_ms=filter_time_ms,
|
|
)
|
|
|
|
if not filtered_by_subq or not any(
|
|
chunks for _, chunks in filtered_by_subq
|
|
):
|
|
total_ms = int((time.perf_counter() - overall_start) * 1000)
|
|
yield PipelineSnapshot(
|
|
phase="completed",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
retriever_time_ms=retriever_time_ms,
|
|
filter_time_ms=filter_time_ms,
|
|
answer=NO_RESULTS_ANSWER,
|
|
total_time_ms=total_ms,
|
|
)
|
|
return
|
|
|
|
# --- Stage 4: Generate ---
|
|
stage_start = time.perf_counter()
|
|
sub_chunk_texts = []
|
|
sub_chunk_metadata = []
|
|
for _, filtered_chunks in filtered_by_subq:
|
|
sub_chunk_texts.append([chunk for chunk, _meta in filtered_chunks])
|
|
sub_chunk_metadata.append([meta for _chunk, meta in filtered_chunks])
|
|
|
|
if extracted_questions and filtered_by_subq:
|
|
gen_result = await self._rag.generate_response_per_subquestion(
|
|
extracted_questions, sub_chunk_texts, sub_chunk_metadata
|
|
)
|
|
else:
|
|
gen_result = ("", "", [])
|
|
|
|
if isinstance(gen_result, tuple) and len(gen_result) == 3:
|
|
answer, generate_prompt, grouped_sources_meta = gen_result
|
|
else:
|
|
answer, generate_prompt = (
|
|
gen_result if isinstance(gen_result, tuple) else (gen_result, "")
|
|
)
|
|
grouped_sources_meta = []
|
|
|
|
sub_question_sources = []
|
|
for idx, (sub_q_text, sources_meta) in enumerate(
|
|
zip(extracted_questions, grouped_sources_meta)
|
|
):
|
|
sources = [
|
|
SourceMetadata(
|
|
filename=meta.get("filename", "unknown"),
|
|
upload_date=meta.get("upload_date", ""),
|
|
content_summary=meta.get("content_summary", ""),
|
|
chunk_index=meta.get("chunk_index", 0),
|
|
page_number=meta.get("page_number"),
|
|
chunk_file_path=meta.get("chunk_file_path"),
|
|
document_id=meta.get("document_id"),
|
|
)
|
|
for meta in sources_meta
|
|
]
|
|
sub_question_sources.append(
|
|
SubQuestionSources(
|
|
sub_question_index=idx,
|
|
sub_question_text=sub_q_text,
|
|
sources=sources,
|
|
)
|
|
)
|
|
|
|
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
|
|
|
|
yield PipelineSnapshot(
|
|
phase="completed",
|
|
question=question,
|
|
extracted_questions=extracted_questions,
|
|
decompose_prompt=decompose_prompt,
|
|
decomposer_time_ms=decomposer_time_ms,
|
|
retrieval_results=retrieval_results,
|
|
chunks_retrieved_count=chunks_retrieved_count,
|
|
retriever_time_ms=retriever_time_ms,
|
|
filter_prompt=filter_prompt,
|
|
filtered_by_subq=filtered_by_subq,
|
|
chunks_filtered_count=chunks_filtered_count,
|
|
filter_time_ms=filter_time_ms,
|
|
generate_prompt=generate_prompt,
|
|
answer=answer,
|
|
sub_question_sources=sub_question_sources,
|
|
generator_time_ms=generator_time_ms,
|
|
total_time_ms=total_time_ms,
|
|
)
|