From 666b603639161a9df5a80382a102806db1de8c63 Mon Sep 17 00:00:00 2001 From: Woody Date: Sun, 26 Apr 2026 23:28:06 +0800 Subject: [PATCH] feat(query): refactor pipeline for per-sub-question flow with progressive SSE Restructure _query_stream() to use per-sub-question retrieval, filtering, and generation. Add generative_subquestion SSE events for progressive frontend rendering. Add format_chunks_retrieved_per_subq() and format_chunks_filtered_per_subq() with XML wrappers. Add empty decomposition fallback using original question as single sub-q. Update history recording for grouped sources JSON (list-of-lists format). Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- backend/app/routers/query.py | 196 ++++++++++++++++++++++++++++------- 1 file changed, 157 insertions(+), 39 deletions(-) diff --git a/backend/app/routers/query.py b/backend/app/routers/query.py index 72e1dac..3374dd4 100644 --- a/backend/app/routers/query.py +++ b/backend/app/routers/query.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from app.core.config import get_settings -from app.models.query import QueryRequest +from app.models.query import QueryRequest, SubQuestionSources from app.models.common import SourceMetadata from app.services.history_service import HistoryService from app.services.llm_client import LLMClient @@ -43,6 +43,27 @@ def format_chunks_retrieved_xml(chunks: list) -> str: return "\n".join(parts) +def format_chunks_retrieved_per_subq(results: list) -> str: + """Format per-sub-question retrieved chunks as XML with sub_q wrappers.""" + if not results: + return "" + + parts = [] + for q_idx, (sub_question, chunks) in enumerate(results): + parts.append(f'') + for i, (text, meta, _dist) in enumerate(chunks, start=1): + lines = [f" "] + lines.append(f" Filename: {meta.get('filename', 'unknown')}") + page = meta.get("page_number") + if page is not None: + lines.append(f" Page: {page}") + lines.append(f" Content: {text}") + lines.append(f" ") + parts.append("\n".join(lines)) + parts.append("") + return "\n".join(parts) + + def format_chunks_filtered_xml(filtered: list) -> str: """Format filtered chunks as XML-tagged string with relevance scores. filtered = [(text, meta), ...] — score embedded in meta["relevance_score"] @@ -63,6 +84,37 @@ def format_chunks_filtered_xml(filtered: list) -> str: return "\n".join(parts) +def format_chunks_filtered_per_subq(results: list) -> str: + """Format per-sub-question filtered chunks as XML with sub_q wrappers. + + Args: + results: List of (sub_question, filtered_chunks) from filter_per_subquestion(). + Each filtered_chunks is [(text, meta), ...] with relevance_score in meta. + + Returns: + XML string with wrappers containing elements with Relevance scores. + """ + if not results: + return "" + + parts = [] + for q_idx, (sub_question, filtered_chunks) in enumerate(results): + parts.append(f'') + for i, (text, meta) in enumerate(filtered_chunks, start=1): + score = meta.get("relevance_score", "N/A") + lines = [f" "] + lines.append(f" Filename: {meta.get('filename', 'unknown')}") + page = meta.get("page_number") + if page is not None: + lines.append(f" Page: {page}") + lines.append(f" Relevance: {score}") + lines.append(f" Content: {text}") + lines.append(f" ") + parts.append("\n".join(lines)) + parts.append("") + return "\n".join(parts) + + async def _record_history(history_service, input_text, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, @@ -142,21 +194,32 @@ async def _query_stream(request: QueryRequest): decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000) logger.info("Extracted questions: %s", extracted_questions) + if not extracted_questions: + extracted_questions = [request.question] + yield _format_sse({ "phase": "decomposed", "extracted_questions": extracted_questions, }) - # Stage 2: Retrieve + # Stage 2: Retrieve (per sub-question) stage_start = time.perf_counter() - chunks = rag.retrieve(extracted_questions, n_results=settings.retrieval_n_results) + retrieval_results = rag.retrieve_per_subquestion( + extracted_questions, n_results=settings.retrieval_n_results + ) if extracted_questions else [] retriever_time_ms = int((time.perf_counter() - stage_start) * 1000) - chunks_retrieved_count = len(chunks) - chunks_retrieved = format_chunks_retrieved_xml(chunks) + + all_chunks_flat = [] + for _sub_q, chunks in retrieval_results: + for text, meta, _dist in chunks: + all_chunks_flat.append((text, meta, _dist)) + + chunks_retrieved_count = len(all_chunks_flat) + chunks_retrieved = format_chunks_retrieved_per_subq(retrieval_results) yield _format_sse({"phase": "retrieving"}) - if not chunks: + if not all_chunks_flat: _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, 0, 0, "", "", 0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER, @@ -168,25 +231,37 @@ async def _query_stream(request: QueryRequest): }) return - # Stage 3: Filter - chunks_for_filter = [(text, meta) for text, meta, _dist in chunks] + # Stage 3: Filter (per sub-question — single LLM call) + stage_start = time.perf_counter() + chunks_by_subq = [] + for _sub_q, chunks in retrieval_results: + chunks_by_subq.append([(text, meta) for text, meta, _dist in chunks]) + relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service) yield _format_sse({"phase": "filtering"}) - filter_result = await relevance_filter.filter( - request.question, chunks_for_filter, threshold=settings.relevance_threshold - ) - if isinstance(filter_result, tuple): - filtered, filter_prompt = filter_result + if extracted_questions and chunks_by_subq: + filter_result = await relevance_filter.filter_per_subquestion( + extracted_questions, chunks_by_subq, threshold=settings.relevance_threshold + ) else: - filtered, filter_prompt = filter_result, "" + filter_result = ([], "") + + if isinstance(filter_result, tuple): + filtered_by_subq, filter_prompt = filter_result + else: + filtered_by_subq, filter_prompt = filter_result, "" + + all_filtered_flat = [] + for _sub_q, filtered_chunks in filtered_by_subq: + all_filtered_flat.extend(filtered_chunks) filter_time_ms = int((time.perf_counter() - stage_start) * 1000) - chunks_filtered_count = len(filtered) - chunks_filtered = format_chunks_filtered_xml(filtered) + chunks_filtered_count = len(all_filtered_flat) + chunks_filtered = format_chunks_filtered_per_subq(filtered_by_subq) if filtered_by_subq else "" - if not filtered: + if not all_filtered_flat: _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, @@ -200,48 +275,91 @@ async def _query_stream(request: QueryRequest): }) return - # Stage 4: Generate + # Stage 4: Generate (per sub-question with progressive streaming) stage_start = time.perf_counter() - chunk_texts = [chunk for chunk, _meta in filtered] - chunk_metadata = [meta for _chunk, meta in filtered] + + sub_chunk_texts = [] + sub_chunk_metadata = [] + for _sub_q, filtered_chunks in filtered_by_subq: + texts = [chunk for chunk, _meta in filtered_chunks] + metas = [meta for _chunk, meta in filtered_chunks] + sub_chunk_texts.append(texts) + sub_chunk_metadata.append(metas) yield _format_sse({"phase": "generating"}) - gen_result = await rag.generate_response(request.question, chunk_texts, chunk_metadata) - if isinstance(gen_result, tuple): - answer, generate_prompt = gen_result + if extracted_questions and filtered_by_subq: + gen_result = await rag.generate_response_per_subquestion( + extracted_questions, + sub_chunk_texts, + sub_chunk_metadata, + ) else: - answer, generate_prompt = gen_result, "" + gen_result = ("", "", []) + + if isinstance(gen_result, tuple) and len(gen_result) == 3: + answer, generate_prompt, grouped_sources_meta = gen_result + else: + answer, generate_prompt = gen_result if isinstance(gen_result, tuple) else (gen_result, "") + grouped_sources_meta = [] + + sub_question_sources = [] + for idx, (sub_q_text, sources_meta) in enumerate( + zip(extracted_questions, grouped_sources_meta) + ): + sources = [ + SourceMetadata( + filename=meta.get("filename", "unknown"), + upload_date=meta.get("upload_date", ""), + content_summary=meta.get("content_summary", ""), + chunk_index=meta.get("chunk_index", 0), + page_number=meta.get("page_number"), + chunk_file_path=meta.get("chunk_file_path"), + ) + for meta in sources_meta + ] + sub_question_sources.append( + SubQuestionSources( + sub_question_index=idx, + sub_question_text=sub_q_text, + sources=sources, + ) + ) + yield _format_sse({ + "phase": "generating_subquestion", + "sub_question_index": idx, + "sub_question_text": sub_q_text, + }) generator_time_ms = int((time.perf_counter() - stage_start) * 1000) - logger.info("Answer generated: %d chars, %d sources", len(answer), len(filtered)) + logger.info( + "Answer generated: %d chars, %d sub-questions", + len(answer), len(extracted_questions), + ) total_time_ms = int((time.perf_counter() - overall_start) * 1000) - sources = [ - SourceMetadata( - filename=meta.get("filename", "unknown"), - upload_date=meta.get("upload_date", ""), - content_summary=meta.get("content_summary", ""), - chunk_index=meta.get("chunk_index", 0), - page_number=meta.get("page_number"), - chunk_file_path=meta.get("chunk_file_path"), - ) - for meta in chunk_metadata - ] + all_sources_flat = [] + for sq in sub_question_sources: + all_sources_flat.extend(sq.sources) + + sources_json = json.dumps([ + [s.model_dump() for s in sq.sources] + for sq in sub_question_sources + ]) _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, - answer, json.dumps([s.model_dump() for s in sources]), - total_time_ms) + answer, sources_json, total_time_ms) yield _format_sse({ "phase": "completed", "answer": answer, - "sources": [s.model_dump() for s in sources], + "sub_question_sources": [sq.model_dump() for sq in sub_question_sources], + "sources": [s.model_dump() for s in all_sources_flat], }) except HTTPException: