diff --git a/backend/app/routers/query.py b/backend/app/routers/query.py
index 72e1dac..3374dd4 100644
--- a/backend/app/routers/query.py
+++ b/backend/app/routers/query.py
@@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from app.core.config import get_settings
-from app.models.query import QueryRequest
+from app.models.query import QueryRequest, SubQuestionSources
from app.models.common import SourceMetadata
from app.services.history_service import HistoryService
from app.services.llm_client import LLMClient
@@ -43,6 +43,27 @@ def format_chunks_retrieved_xml(chunks: list) -> str:
return "\n".join(parts)
+def format_chunks_retrieved_per_subq(results: list) -> str:
+ """Format per-sub-question retrieved chunks as XML with sub_q wrappers."""
+ if not results:
+ return ""
+
+ parts = []
+ for q_idx, (sub_question, chunks) in enumerate(results):
+ parts.append(f'')
+ for i, (text, meta, _dist) in enumerate(chunks, start=1):
+ lines = [f" "]
+ lines.append(f" Filename: {meta.get('filename', 'unknown')}")
+ page = meta.get("page_number")
+ if page is not None:
+ lines.append(f" Page: {page}")
+ lines.append(f" Content: {text}")
+ lines.append(f" ")
+ parts.append("\n".join(lines))
+ parts.append("")
+ return "\n".join(parts)
+
+
def format_chunks_filtered_xml(filtered: list) -> str:
"""Format filtered chunks as XML-tagged string with relevance scores.
filtered = [(text, meta), ...] — score embedded in meta["relevance_score"]
@@ -63,6 +84,37 @@ def format_chunks_filtered_xml(filtered: list) -> str:
return "\n".join(parts)
+def format_chunks_filtered_per_subq(results: list) -> str:
+ """Format per-sub-question filtered chunks as XML with sub_q wrappers.
+
+ Args:
+ results: List of (sub_question, filtered_chunks) from filter_per_subquestion().
+ Each filtered_chunks is [(text, meta), ...] with relevance_score in meta.
+
+ Returns:
+ XML string with wrappers containing elements with Relevance scores.
+ """
+ if not results:
+ return ""
+
+ parts = []
+ for q_idx, (sub_question, filtered_chunks) in enumerate(results):
+ parts.append(f'')
+ for i, (text, meta) in enumerate(filtered_chunks, start=1):
+ score = meta.get("relevance_score", "N/A")
+ lines = [f" "]
+ lines.append(f" Filename: {meta.get('filename', 'unknown')}")
+ page = meta.get("page_number")
+ if page is not None:
+ lines.append(f" Page: {page}")
+ lines.append(f" Relevance: {score}")
+ lines.append(f" Content: {text}")
+ lines.append(f" ")
+ parts.append("\n".join(lines))
+ parts.append("")
+ return "\n".join(parts)
+
+
async def _record_history(history_service, input_text, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
@@ -142,21 +194,32 @@ async def _query_stream(request: QueryRequest):
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
logger.info("Extracted questions: %s", extracted_questions)
+ if not extracted_questions:
+ extracted_questions = [request.question]
+
yield _format_sse({
"phase": "decomposed",
"extracted_questions": extracted_questions,
})
- # Stage 2: Retrieve
+ # Stage 2: Retrieve (per sub-question)
stage_start = time.perf_counter()
- chunks = rag.retrieve(extracted_questions, n_results=settings.retrieval_n_results)
+ retrieval_results = rag.retrieve_per_subquestion(
+ extracted_questions, n_results=settings.retrieval_n_results
+ ) if extracted_questions else []
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
- chunks_retrieved_count = len(chunks)
- chunks_retrieved = format_chunks_retrieved_xml(chunks)
+
+ all_chunks_flat = []
+ for _sub_q, chunks in retrieval_results:
+ for text, meta, _dist in chunks:
+ all_chunks_flat.append((text, meta, _dist))
+
+ chunks_retrieved_count = len(all_chunks_flat)
+ chunks_retrieved = format_chunks_retrieved_per_subq(retrieval_results)
yield _format_sse({"phase": "retrieving"})
- if not chunks:
+ if not all_chunks_flat:
_schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, 0, 0, "", "",
0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER,
@@ -168,25 +231,37 @@ async def _query_stream(request: QueryRequest):
})
return
- # Stage 3: Filter
- chunks_for_filter = [(text, meta) for text, meta, _dist in chunks]
+ # Stage 3: Filter (per sub-question — single LLM call)
+ stage_start = time.perf_counter()
+ chunks_by_subq = []
+ for _sub_q, chunks in retrieval_results:
+ chunks_by_subq.append([(text, meta) for text, meta, _dist in chunks])
+
relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service)
yield _format_sse({"phase": "filtering"})
- filter_result = await relevance_filter.filter(
- request.question, chunks_for_filter, threshold=settings.relevance_threshold
- )
- if isinstance(filter_result, tuple):
- filtered, filter_prompt = filter_result
+ if extracted_questions and chunks_by_subq:
+ filter_result = await relevance_filter.filter_per_subquestion(
+ extracted_questions, chunks_by_subq, threshold=settings.relevance_threshold
+ )
else:
- filtered, filter_prompt = filter_result, ""
+ filter_result = ([], "")
+
+ if isinstance(filter_result, tuple):
+ filtered_by_subq, filter_prompt = filter_result
+ else:
+ filtered_by_subq, filter_prompt = filter_result, ""
+
+ all_filtered_flat = []
+ for _sub_q, filtered_chunks in filtered_by_subq:
+ all_filtered_flat.extend(filtered_chunks)
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
- chunks_filtered_count = len(filtered)
- chunks_filtered = format_chunks_filtered_xml(filtered)
+ chunks_filtered_count = len(all_filtered_flat)
+ chunks_filtered = format_chunks_filtered_per_subq(filtered_by_subq) if filtered_by_subq else ""
- if not filtered:
+ if not all_filtered_flat:
_schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
@@ -200,48 +275,91 @@ async def _query_stream(request: QueryRequest):
})
return
- # Stage 4: Generate
+ # Stage 4: Generate (per sub-question with progressive streaming)
stage_start = time.perf_counter()
- chunk_texts = [chunk for chunk, _meta in filtered]
- chunk_metadata = [meta for _chunk, meta in filtered]
+
+ sub_chunk_texts = []
+ sub_chunk_metadata = []
+ for _sub_q, filtered_chunks in filtered_by_subq:
+ texts = [chunk for chunk, _meta in filtered_chunks]
+ metas = [meta for _chunk, meta in filtered_chunks]
+ sub_chunk_texts.append(texts)
+ sub_chunk_metadata.append(metas)
yield _format_sse({"phase": "generating"})
- gen_result = await rag.generate_response(request.question, chunk_texts, chunk_metadata)
- if isinstance(gen_result, tuple):
- answer, generate_prompt = gen_result
+ if extracted_questions and filtered_by_subq:
+ gen_result = await rag.generate_response_per_subquestion(
+ extracted_questions,
+ sub_chunk_texts,
+ sub_chunk_metadata,
+ )
else:
- answer, generate_prompt = gen_result, ""
+ gen_result = ("", "", [])
+
+ if isinstance(gen_result, tuple) and len(gen_result) == 3:
+ answer, generate_prompt, grouped_sources_meta = gen_result
+ else:
+ answer, generate_prompt = gen_result if isinstance(gen_result, tuple) else (gen_result, "")
+ grouped_sources_meta = []
+
+ sub_question_sources = []
+ for idx, (sub_q_text, sources_meta) in enumerate(
+ zip(extracted_questions, grouped_sources_meta)
+ ):
+ sources = [
+ SourceMetadata(
+ filename=meta.get("filename", "unknown"),
+ upload_date=meta.get("upload_date", ""),
+ content_summary=meta.get("content_summary", ""),
+ chunk_index=meta.get("chunk_index", 0),
+ page_number=meta.get("page_number"),
+ chunk_file_path=meta.get("chunk_file_path"),
+ )
+ for meta in sources_meta
+ ]
+ sub_question_sources.append(
+ SubQuestionSources(
+ sub_question_index=idx,
+ sub_question_text=sub_q_text,
+ sources=sources,
+ )
+ )
+ yield _format_sse({
+ "phase": "generating_subquestion",
+ "sub_question_index": idx,
+ "sub_question_text": sub_q_text,
+ })
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
- logger.info("Answer generated: %d chars, %d sources", len(answer), len(filtered))
+ logger.info(
+ "Answer generated: %d chars, %d sub-questions",
+ len(answer), len(extracted_questions),
+ )
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
- sources = [
- SourceMetadata(
- filename=meta.get("filename", "unknown"),
- upload_date=meta.get("upload_date", ""),
- content_summary=meta.get("content_summary", ""),
- chunk_index=meta.get("chunk_index", 0),
- page_number=meta.get("page_number"),
- chunk_file_path=meta.get("chunk_file_path"),
- )
- for meta in chunk_metadata
- ]
+ all_sources_flat = []
+ for sq in sub_question_sources:
+ all_sources_flat.extend(sq.sources)
+
+ sources_json = json.dumps([
+ [s.model_dump() for s in sq.sources]
+ for sq in sub_question_sources
+ ])
_schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
filter_time_ms, chunks_filtered_count, chunks_filtered,
generate_prompt, generator_time_ms, active_profile,
- answer, json.dumps([s.model_dump() for s in sources]),
- total_time_ms)
+ answer, sources_json, total_time_ms)
yield _format_sse({
"phase": "completed",
"answer": answer,
- "sources": [s.model_dump() for s in sources],
+ "sub_question_sources": [sq.model_dump() for sq in sub_question_sources],
+ "sources": [s.model_dump() for s in all_sources_flat],
})
except HTTPException: