import asyncio
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from app.core.config import get_settings
from app.models.query import QueryRequest, SubQuestionSources
from app.models.common import SourceMetadata
from app.services.history_service import HistoryService
from app.services.llm_client import LLMClient
from app.services.llm_client_dp import LLMClientDP
from app.services.prompt_service import PromptService
from app.services.query_decomposer import QueryDecomposer
from app.services.relevance_filter import RelevanceFilter
from app.services.rag import RAGService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["query"])
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
def _format_sse(data: dict) -> str:
return f"data: {json.dumps(data)}\n\n"
def format_chunks_retrieved_xml(chunks: list) -> str:
"""Format retrieved chunks as XML-tagged string.
chunks = [(text, metadata, distance), ...] from RAGService.retrieve()
"""
parts = []
for i, (text, meta, _dist) in enumerate(chunks, start=1):
lines = [f""]
lines.append(f"Filename: {meta.get('filename', 'unknown')}")
page = meta.get("page_number")
if page is not None:
lines.append(f"Page: {page}")
lines.append(f"Content: {text}")
lines.append(f"")
parts.append("\n".join(lines))
return "\n".join(parts)
def format_chunks_retrieved_per_subq(results: list) -> str:
"""Format per-sub-question retrieved chunks as XML with sub_q wrappers."""
if not results:
return ""
parts = []
for q_idx, (sub_question, chunks) in enumerate(results):
parts.append(f'')
for i, (text, meta, _dist) in enumerate(chunks, start=1):
lines = [f" "]
lines.append(f" Filename: {meta.get('filename', 'unknown')}")
page = meta.get("page_number")
if page is not None:
lines.append(f" Page: {page}")
lines.append(f" Content: {text}")
lines.append(f" ")
parts.append("\n".join(lines))
parts.append("")
return "\n".join(parts)
def format_chunks_filtered_xml(filtered: list) -> str:
"""Format filtered chunks as XML-tagged string with relevance scores.
filtered = [(text, meta), ...] — score embedded in meta["relevance_score"]
"""
parts = []
for i, (text, meta) in enumerate(filtered, start=1):
lines = [f""]
lines.append(f"Filename: {meta.get('filename', 'unknown')}")
page = meta.get("page_number")
if page is not None:
lines.append(f"Page: {page}")
score = meta.get("relevance_score")
if score is not None:
lines.append(f"Relevance: {score}")
lines.append(f"Content: {text}")
lines.append(f"")
parts.append("\n".join(lines))
return "\n".join(parts)
def format_chunks_filtered_per_subq(results: list) -> str:
"""Format per-sub-question filtered chunks as XML with sub_q wrappers.
Args:
results: List of (sub_question, filtered_chunks) from filter_per_subquestion().
Each filtered_chunks is [(text, meta), ...] with relevance_score in meta.
Returns:
XML string with wrappers containing elements with Relevance scores.
"""
if not results:
return ""
parts = []
for q_idx, (sub_question, filtered_chunks) in enumerate(results):
parts.append(f'')
for i, (text, meta) in enumerate(filtered_chunks, start=1):
score = meta.get("relevance_score", "N/A")
lines = [f" "]
lines.append(f" Filename: {meta.get('filename', 'unknown')}")
page = meta.get("page_number")
if page is not None:
lines.append(f" Page: {page}")
lines.append(f" Relevance: {score}")
lines.append(f" Content: {text}")
lines.append(f" ")
parts.append("\n".join(lines))
parts.append("")
return "\n".join(parts)
async def _record_history(history_service, input_text, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
filter_time_ms, chunks_filtered_count, chunks_filtered,
generate_prompt, generator_time_ms, profile_used,
final_answer, sources, total_time_ms) -> int:
"""Record a query to history. Returns the history record ID."""
try:
return history_service.record({
"input_text": input_text,
"extracted_questions": json.dumps(extracted_questions) if isinstance(extracted_questions, list) else extracted_questions,
"decompose_prompt": decompose_prompt,
"decomposer_time_ms": decomposer_time_ms,
"retriever_time_ms": retriever_time_ms,
"chunks_retrieved": chunks_retrieved,
"chunks_retrieved_count": chunks_retrieved_count,
"filter_prompt": filter_prompt,
"filter_time_ms": filter_time_ms,
"chunks_filtered": chunks_filtered,
"chunks_filtered_count": chunks_filtered_count,
"generate_prompt": generate_prompt,
"generator_time_ms": generator_time_ms,
"total_time_ms": total_time_ms,
"final_answer": final_answer,
"sources": sources,
"profile_used": profile_used,
})
except Exception:
logger.warning("History recording failed", exc_info=True)
return -1
def _schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
filter_time_ms, chunks_filtered_count, chunks_filtered,
generate_prompt, generator_time_ms, active_profile,
final_answer, sources_json, total_time_ms):
"""Fire-and-forget history recording. Never blocks the SSE stream."""
try:
asyncio.create_task(
_record_history(
history_service, request.question, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
filter_time_ms, chunks_filtered_count, chunks_filtered,
generate_prompt, generator_time_ms, active_profile,
final_answer, sources_json, total_time_ms
)
)
except Exception:
logger.warning("Failed to schedule history recording", exc_info=True)
async def _query_stream(request: QueryRequest):
settings = get_settings()
prompt_service = PromptService(db_path=settings.prompts_db_path)
overall_start = time.perf_counter()
try:
history_service = HistoryService(db_path=settings.history_db_path)
llm_client_dp = LLMClientDP(settings)
llm_client = LLMClient(settings)
rag = RAGService(llm_client=llm_client, settings=settings, prompt_service=prompt_service)
active_profile = prompt_service.get_active_profile_name()
logger.info("Query: %s. Active prompt profile: %s", request.question, active_profile)
decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service)
# Stage 1: Decompose
stage_start = time.perf_counter()
decompose_result = await decomposer.decompose(request.question)
if isinstance(decompose_result, tuple):
extracted_questions, decompose_prompt = decompose_result
else:
extracted_questions, decompose_prompt = decompose_result, ""
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
logger.info("Extracted questions: %s", extracted_questions)
if not extracted_questions:
extracted_questions = [request.question]
yield _format_sse({
"phase": "decomposed",
"extracted_questions": extracted_questions,
})
# Half-question mode: stop after decomposition
if request.stop_after_decompose:
_schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, 0, 0, "", "",
0, 0, "", "", 0, active_profile, "",
"[]", int((time.perf_counter() - overall_start) * 1000))
yield _format_sse({
"phase": "completed",
"answer": "",
"half_question": True,
"extracted_questions": extracted_questions,
"sources": [],
})
return
# Stage 2: Retrieve (per sub-question)
stage_start = time.perf_counter()
retrieval_results = rag.retrieve_per_subquestion(
extracted_questions, n_results=settings.retrieval_n_results
) if extracted_questions else []
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
all_chunks_flat = []
for _sub_q, chunks in retrieval_results:
for text, meta, _dist in chunks:
all_chunks_flat.append((text, meta, _dist))
chunks_retrieved_count = len(all_chunks_flat)
chunks_retrieved = format_chunks_retrieved_per_subq(retrieval_results)
yield _format_sse({"phase": "retrieving"})
if not all_chunks_flat:
_schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, 0, 0, "", "",
0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER,
"[]", int((time.perf_counter() - overall_start) * 1000))
yield _format_sse({
"phase": "completed",
"answer": NO_RESULTS_ANSWER,
"sources": [],
})
return
# Stage 3: Filter (per sub-question — single LLM call)
stage_start = time.perf_counter()
chunks_by_subq = []
for _sub_q, chunks in retrieval_results:
chunks_by_subq.append([(text, meta) for text, meta, _dist in chunks])
relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service)
yield _format_sse({"phase": "filtering"})
if extracted_questions and chunks_by_subq:
filter_result = await relevance_filter.filter_per_subquestion(
extracted_questions, chunks_by_subq, threshold=settings.relevance_threshold
)
else:
filter_result = ([], "")
if isinstance(filter_result, tuple):
filtered_by_subq, filter_prompt = filter_result
else:
filtered_by_subq, filter_prompt = filter_result, ""
all_filtered_flat = []
for _sub_q, filtered_chunks in filtered_by_subq:
all_filtered_flat.extend(filtered_chunks)
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
chunks_filtered_count = len(all_filtered_flat)
chunks_filtered = format_chunks_filtered_per_subq(filtered_by_subq) if filtered_by_subq else ""
if not all_filtered_flat:
_schedule_history(history_service, request, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
filter_time_ms, 0, "", "", 0, active_profile,
NO_RESULTS_ANSWER, "[]",
int((time.perf_counter() - overall_start) * 1000))
yield _format_sse({
"phase": "completed",
"answer": NO_RESULTS_ANSWER,
"sources": [],
})
return
# Stage 4: Generate (per sub-question with progressive streaming)
stage_start = time.perf_counter()
sub_chunk_texts = []
sub_chunk_metadata = []
for _sub_q, filtered_chunks in filtered_by_subq:
texts = [chunk for chunk, _meta in filtered_chunks]
metas = [meta for _chunk, meta in filtered_chunks]
sub_chunk_texts.append(texts)
sub_chunk_metadata.append(metas)
yield _format_sse({"phase": "generating"})
if extracted_questions and filtered_by_subq:
gen_result = await rag.generate_response_per_subquestion(
extracted_questions,
sub_chunk_texts,
sub_chunk_metadata,
)
else:
gen_result = ("", "", [])
if isinstance(gen_result, tuple) and len(gen_result) == 3:
answer, generate_prompt, grouped_sources_meta = gen_result
else:
answer, generate_prompt = gen_result if isinstance(gen_result, tuple) else (gen_result, "")
grouped_sources_meta = []
sub_question_sources = []
for idx, (sub_q_text, sources_meta) in enumerate(
zip(extracted_questions, grouped_sources_meta)
):
sources = [
SourceMetadata(
filename=meta.get("filename", "unknown"),
upload_date=meta.get("upload_date", ""),
content_summary=meta.get("content_summary", ""),
chunk_index=meta.get("chunk_index", 0),
page_number=meta.get("page_number"),
chunk_file_path=meta.get("chunk_file_path"),
document_id=meta.get("document_id"),
)
for meta in sources_meta
]
sub_question_sources.append(
SubQuestionSources(
sub_question_index=idx,
sub_question_text=sub_q_text,
sources=sources,
)
)
yield _format_sse({
"phase": "generating_subquestion",
"sub_question_index": idx,
"sub_question_text": sub_q_text,
})
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
logger.info(
"Answer generated: %d chars, %d sub-questions",
len(answer), len(extracted_questions),
)
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
all_sources_flat = []
for sq in sub_question_sources:
all_sources_flat.extend(sq.sources)
sources_json = json.dumps([
[s.model_dump() for s in sq.sources]
for sq in sub_question_sources
])
history_id = await _record_history(
history_service, request.question, extracted_questions,
decompose_prompt, decomposer_time_ms, retriever_time_ms,
chunks_retrieved_count, chunks_retrieved, filter_prompt,
filter_time_ms, chunks_filtered_count, chunks_filtered,
generate_prompt, generator_time_ms, active_profile,
answer, sources_json, total_time_ms
)
yield _format_sse({
"phase": "completed",
"answer": answer,
"sub_question_sources": [sq.model_dump() for sq in sub_question_sources],
"sources": [s.model_dump() for s in all_sources_flat],
"history_id": history_id,
})
except HTTPException:
raise
except Exception as e:
logger.error("Query stream failed: %s", str(e))
yield _format_sse({
"phase": "error",
"message": f"Query failed: {str(e)}",
})
@router.post("/query")
async def query(request: QueryRequest):
if not request.question or not request.question.strip():
raise HTTPException(status_code=400, detail="Question is required")
return StreamingResponse(
_query_stream(request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)