feat(query): refactor pipeline for per-sub-question flow with progressive SSE
Restructure _query_stream() to use per-sub-question retrieval, filtering, and generation. Add generative_subquestion SSE events for progressive frontend rendering. Add format_chunks_retrieved_per_subq() and format_chunks_filtered_per_subq() with <sub_q> XML wrappers. Add empty decomposition fallback using original question as single sub-q. Update history recording for grouped sources JSON (list-of-lists format). Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
57a130dc96
commit
666b603639
|
|
@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
from app.models.query import QueryRequest
|
from app.models.query import QueryRequest, SubQuestionSources
|
||||||
from app.models.common import SourceMetadata
|
from app.models.common import SourceMetadata
|
||||||
from app.services.history_service import HistoryService
|
from app.services.history_service import HistoryService
|
||||||
from app.services.llm_client import LLMClient
|
from app.services.llm_client import LLMClient
|
||||||
|
|
@ -43,6 +43,27 @@ def format_chunks_retrieved_xml(chunks: list) -> str:
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def format_chunks_retrieved_per_subq(results: list) -> str:
|
||||||
|
"""Format per-sub-question retrieved chunks as XML with sub_q wrappers."""
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for q_idx, (sub_question, chunks) in enumerate(results):
|
||||||
|
parts.append(f'<sub_q idx="{q_idx}" question="{sub_question}">')
|
||||||
|
for i, (text, meta, _dist) in enumerate(chunks, start=1):
|
||||||
|
lines = [f" <chunk_{i}>"]
|
||||||
|
lines.append(f" Filename: {meta.get('filename', 'unknown')}")
|
||||||
|
page = meta.get("page_number")
|
||||||
|
if page is not None:
|
||||||
|
lines.append(f" Page: {page}")
|
||||||
|
lines.append(f" Content: {text}")
|
||||||
|
lines.append(f" </chunk_{i}>")
|
||||||
|
parts.append("\n".join(lines))
|
||||||
|
parts.append("</sub_q>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def format_chunks_filtered_xml(filtered: list) -> str:
|
def format_chunks_filtered_xml(filtered: list) -> str:
|
||||||
"""Format filtered chunks as XML-tagged string with relevance scores.
|
"""Format filtered chunks as XML-tagged string with relevance scores.
|
||||||
filtered = [(text, meta), ...] — score embedded in meta["relevance_score"]
|
filtered = [(text, meta), ...] — score embedded in meta["relevance_score"]
|
||||||
|
|
@ -63,6 +84,37 @@ def format_chunks_filtered_xml(filtered: list) -> str:
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def format_chunks_filtered_per_subq(results: list) -> str:
|
||||||
|
"""Format per-sub-question filtered chunks as XML with sub_q wrappers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of (sub_question, filtered_chunks) from filter_per_subquestion().
|
||||||
|
Each filtered_chunks is [(text, meta), ...] with relevance_score in meta.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML string with <sub_q> wrappers containing <chunk_N> elements with Relevance scores.
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for q_idx, (sub_question, filtered_chunks) in enumerate(results):
|
||||||
|
parts.append(f'<sub_q idx="{q_idx}" question="{sub_question}">')
|
||||||
|
for i, (text, meta) in enumerate(filtered_chunks, start=1):
|
||||||
|
score = meta.get("relevance_score", "N/A")
|
||||||
|
lines = [f" <chunk_{i}>"]
|
||||||
|
lines.append(f" Filename: {meta.get('filename', 'unknown')}")
|
||||||
|
page = meta.get("page_number")
|
||||||
|
if page is not None:
|
||||||
|
lines.append(f" Page: {page}")
|
||||||
|
lines.append(f" Relevance: {score}")
|
||||||
|
lines.append(f" Content: {text}")
|
||||||
|
lines.append(f" </chunk_{i}>")
|
||||||
|
parts.append("\n".join(lines))
|
||||||
|
parts.append("</sub_q>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
async def _record_history(history_service, input_text, extracted_questions,
|
async def _record_history(history_service, input_text, extracted_questions,
|
||||||
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
||||||
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
||||||
|
|
@ -142,21 +194,32 @@ async def _query_stream(request: QueryRequest):
|
||||||
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
logger.info("Extracted questions: %s", extracted_questions)
|
logger.info("Extracted questions: %s", extracted_questions)
|
||||||
|
|
||||||
|
if not extracted_questions:
|
||||||
|
extracted_questions = [request.question]
|
||||||
|
|
||||||
yield _format_sse({
|
yield _format_sse({
|
||||||
"phase": "decomposed",
|
"phase": "decomposed",
|
||||||
"extracted_questions": extracted_questions,
|
"extracted_questions": extracted_questions,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Stage 2: Retrieve
|
# Stage 2: Retrieve (per sub-question)
|
||||||
stage_start = time.perf_counter()
|
stage_start = time.perf_counter()
|
||||||
chunks = rag.retrieve(extracted_questions, n_results=settings.retrieval_n_results)
|
retrieval_results = rag.retrieve_per_subquestion(
|
||||||
|
extracted_questions, n_results=settings.retrieval_n_results
|
||||||
|
) if extracted_questions else []
|
||||||
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
chunks_retrieved_count = len(chunks)
|
|
||||||
chunks_retrieved = format_chunks_retrieved_xml(chunks)
|
all_chunks_flat = []
|
||||||
|
for _sub_q, chunks in retrieval_results:
|
||||||
|
for text, meta, _dist in chunks:
|
||||||
|
all_chunks_flat.append((text, meta, _dist))
|
||||||
|
|
||||||
|
chunks_retrieved_count = len(all_chunks_flat)
|
||||||
|
chunks_retrieved = format_chunks_retrieved_per_subq(retrieval_results)
|
||||||
|
|
||||||
yield _format_sse({"phase": "retrieving"})
|
yield _format_sse({"phase": "retrieving"})
|
||||||
|
|
||||||
if not chunks:
|
if not all_chunks_flat:
|
||||||
_schedule_history(history_service, request, extracted_questions,
|
_schedule_history(history_service, request, extracted_questions,
|
||||||
decompose_prompt, decomposer_time_ms, 0, 0, "", "",
|
decompose_prompt, decomposer_time_ms, 0, 0, "", "",
|
||||||
0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER,
|
0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER,
|
||||||
|
|
@ -168,25 +231,37 @@ async def _query_stream(request: QueryRequest):
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
# Stage 3: Filter
|
# Stage 3: Filter (per sub-question — single LLM call)
|
||||||
chunks_for_filter = [(text, meta) for text, meta, _dist in chunks]
|
stage_start = time.perf_counter()
|
||||||
|
chunks_by_subq = []
|
||||||
|
for _sub_q, chunks in retrieval_results:
|
||||||
|
chunks_by_subq.append([(text, meta) for text, meta, _dist in chunks])
|
||||||
|
|
||||||
relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service)
|
relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service)
|
||||||
|
|
||||||
yield _format_sse({"phase": "filtering"})
|
yield _format_sse({"phase": "filtering"})
|
||||||
|
|
||||||
filter_result = await relevance_filter.filter(
|
if extracted_questions and chunks_by_subq:
|
||||||
request.question, chunks_for_filter, threshold=settings.relevance_threshold
|
filter_result = await relevance_filter.filter_per_subquestion(
|
||||||
|
extracted_questions, chunks_by_subq, threshold=settings.relevance_threshold
|
||||||
)
|
)
|
||||||
if isinstance(filter_result, tuple):
|
|
||||||
filtered, filter_prompt = filter_result
|
|
||||||
else:
|
else:
|
||||||
filtered, filter_prompt = filter_result, ""
|
filter_result = ([], "")
|
||||||
|
|
||||||
|
if isinstance(filter_result, tuple):
|
||||||
|
filtered_by_subq, filter_prompt = filter_result
|
||||||
|
else:
|
||||||
|
filtered_by_subq, filter_prompt = filter_result, ""
|
||||||
|
|
||||||
|
all_filtered_flat = []
|
||||||
|
for _sub_q, filtered_chunks in filtered_by_subq:
|
||||||
|
all_filtered_flat.extend(filtered_chunks)
|
||||||
|
|
||||||
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
chunks_filtered_count = len(filtered)
|
chunks_filtered_count = len(all_filtered_flat)
|
||||||
chunks_filtered = format_chunks_filtered_xml(filtered)
|
chunks_filtered = format_chunks_filtered_per_subq(filtered_by_subq) if filtered_by_subq else ""
|
||||||
|
|
||||||
if not filtered:
|
if not all_filtered_flat:
|
||||||
_schedule_history(history_service, request, extracted_questions,
|
_schedule_history(history_service, request, extracted_questions,
|
||||||
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
||||||
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
||||||
|
|
@ -200,24 +275,38 @@ async def _query_stream(request: QueryRequest):
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
|
||||||
# Stage 4: Generate
|
# Stage 4: Generate (per sub-question with progressive streaming)
|
||||||
stage_start = time.perf_counter()
|
stage_start = time.perf_counter()
|
||||||
chunk_texts = [chunk for chunk, _meta in filtered]
|
|
||||||
chunk_metadata = [meta for _chunk, meta in filtered]
|
sub_chunk_texts = []
|
||||||
|
sub_chunk_metadata = []
|
||||||
|
for _sub_q, filtered_chunks in filtered_by_subq:
|
||||||
|
texts = [chunk for chunk, _meta in filtered_chunks]
|
||||||
|
metas = [meta for _chunk, meta in filtered_chunks]
|
||||||
|
sub_chunk_texts.append(texts)
|
||||||
|
sub_chunk_metadata.append(metas)
|
||||||
|
|
||||||
yield _format_sse({"phase": "generating"})
|
yield _format_sse({"phase": "generating"})
|
||||||
|
|
||||||
gen_result = await rag.generate_response(request.question, chunk_texts, chunk_metadata)
|
if extracted_questions and filtered_by_subq:
|
||||||
if isinstance(gen_result, tuple):
|
gen_result = await rag.generate_response_per_subquestion(
|
||||||
answer, generate_prompt = gen_result
|
extracted_questions,
|
||||||
|
sub_chunk_texts,
|
||||||
|
sub_chunk_metadata,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
answer, generate_prompt = gen_result, ""
|
gen_result = ("", "", [])
|
||||||
|
|
||||||
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
if isinstance(gen_result, tuple) and len(gen_result) == 3:
|
||||||
logger.info("Answer generated: %d chars, %d sources", len(answer), len(filtered))
|
answer, generate_prompt, grouped_sources_meta = gen_result
|
||||||
|
else:
|
||||||
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
|
answer, generate_prompt = gen_result if isinstance(gen_result, tuple) else (gen_result, "")
|
||||||
|
grouped_sources_meta = []
|
||||||
|
|
||||||
|
sub_question_sources = []
|
||||||
|
for idx, (sub_q_text, sources_meta) in enumerate(
|
||||||
|
zip(extracted_questions, grouped_sources_meta)
|
||||||
|
):
|
||||||
sources = [
|
sources = [
|
||||||
SourceMetadata(
|
SourceMetadata(
|
||||||
filename=meta.get("filename", "unknown"),
|
filename=meta.get("filename", "unknown"),
|
||||||
|
|
@ -227,21 +316,50 @@ async def _query_stream(request: QueryRequest):
|
||||||
page_number=meta.get("page_number"),
|
page_number=meta.get("page_number"),
|
||||||
chunk_file_path=meta.get("chunk_file_path"),
|
chunk_file_path=meta.get("chunk_file_path"),
|
||||||
)
|
)
|
||||||
for meta in chunk_metadata
|
for meta in sources_meta
|
||||||
]
|
]
|
||||||
|
sub_question_sources.append(
|
||||||
|
SubQuestionSources(
|
||||||
|
sub_question_index=idx,
|
||||||
|
sub_question_text=sub_q_text,
|
||||||
|
sources=sources,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield _format_sse({
|
||||||
|
"phase": "generating_subquestion",
|
||||||
|
"sub_question_index": idx,
|
||||||
|
"sub_question_text": sub_q_text,
|
||||||
|
})
|
||||||
|
|
||||||
|
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
||||||
|
logger.info(
|
||||||
|
"Answer generated: %d chars, %d sub-questions",
|
||||||
|
len(answer), len(extracted_questions),
|
||||||
|
)
|
||||||
|
|
||||||
|
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
|
||||||
|
|
||||||
|
all_sources_flat = []
|
||||||
|
for sq in sub_question_sources:
|
||||||
|
all_sources_flat.extend(sq.sources)
|
||||||
|
|
||||||
|
sources_json = json.dumps([
|
||||||
|
[s.model_dump() for s in sq.sources]
|
||||||
|
for sq in sub_question_sources
|
||||||
|
])
|
||||||
|
|
||||||
_schedule_history(history_service, request, extracted_questions,
|
_schedule_history(history_service, request, extracted_questions,
|
||||||
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
||||||
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
||||||
filter_time_ms, chunks_filtered_count, chunks_filtered,
|
filter_time_ms, chunks_filtered_count, chunks_filtered,
|
||||||
generate_prompt, generator_time_ms, active_profile,
|
generate_prompt, generator_time_ms, active_profile,
|
||||||
answer, json.dumps([s.model_dump() for s in sources]),
|
answer, sources_json, total_time_ms)
|
||||||
total_time_ms)
|
|
||||||
|
|
||||||
yield _format_sse({
|
yield _format_sse({
|
||||||
"phase": "completed",
|
"phase": "completed",
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"sources": [s.model_dump() for s in sources],
|
"sub_question_sources": [sq.model_dump() for sq in sub_question_sources],
|
||||||
|
"sources": [s.model_dump() for s in all_sources_flat],
|
||||||
})
|
})
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue