import asyncio import json import logging import time from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from app.core.config import get_settings from app.models.query import QueryRequest, SubQuestionSources from app.models.common import SourceMetadata from app.services.history_service import HistoryService from app.services.llm_client import LLMClient from app.services.llm_client_dp import LLMClientDP from app.services.prompt_service import PromptService from app.services.query_decomposer import QueryDecomposer from app.services.relevance_filter import RelevanceFilter from app.services.rag import RAGService logger = logging.getLogger(__name__) router = APIRouter(tags=["query"]) NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question." def _format_sse(data: dict) -> str: return f"data: {json.dumps(data)}\n\n" def format_chunks_retrieved_xml(chunks: list) -> str: """Format retrieved chunks as XML-tagged string. chunks = [(text, metadata, distance), ...] from RAGService.retrieve() """ parts = [] for i, (text, meta, _dist) in enumerate(chunks, start=1): lines = [f""] lines.append(f"Filename: {meta.get('filename', 'unknown')}") page = meta.get("page_number") if page is not None: lines.append(f"Page: {page}") lines.append(f"Content: {text}") lines.append(f"") parts.append("\n".join(lines)) return "\n".join(parts) def format_chunks_retrieved_per_subq(results: list) -> str: """Format per-sub-question retrieved chunks as XML with sub_q wrappers.""" if not results: return "" parts = [] for q_idx, (sub_question, chunks) in enumerate(results): parts.append(f'') for i, (text, meta, _dist) in enumerate(chunks, start=1): lines = [f" "] lines.append(f" Filename: {meta.get('filename', 'unknown')}") page = meta.get("page_number") if page is not None: lines.append(f" Page: {page}") lines.append(f" Content: {text}") lines.append(f" ") parts.append("\n".join(lines)) parts.append("") return "\n".join(parts) def format_chunks_filtered_xml(filtered: list) -> str: """Format filtered chunks as XML-tagged string with relevance scores. filtered = [(text, meta), ...] — score embedded in meta["relevance_score"] """ parts = [] for i, (text, meta) in enumerate(filtered, start=1): lines = [f""] lines.append(f"Filename: {meta.get('filename', 'unknown')}") page = meta.get("page_number") if page is not None: lines.append(f"Page: {page}") score = meta.get("relevance_score") if score is not None: lines.append(f"Relevance: {score}") lines.append(f"Content: {text}") lines.append(f"") parts.append("\n".join(lines)) return "\n".join(parts) def format_chunks_filtered_per_subq(results: list) -> str: """Format per-sub-question filtered chunks as XML with sub_q wrappers. Args: results: List of (sub_question, filtered_chunks) from filter_per_subquestion(). Each filtered_chunks is [(text, meta), ...] with relevance_score in meta. Returns: XML string with wrappers containing elements with Relevance scores. """ if not results: return "" parts = [] for q_idx, (sub_question, filtered_chunks) in enumerate(results): parts.append(f'') for i, (text, meta) in enumerate(filtered_chunks, start=1): score = meta.get("relevance_score", "N/A") lines = [f" "] lines.append(f" Filename: {meta.get('filename', 'unknown')}") page = meta.get("page_number") if page is not None: lines.append(f" Page: {page}") lines.append(f" Relevance: {score}") lines.append(f" Content: {text}") lines.append(f" ") parts.append("\n".join(lines)) parts.append("") return "\n".join(parts) async def _record_history(history_service, input_text, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, profile_used, final_answer, sources, total_time_ms) -> int: """Record a query to history. Returns the history record ID.""" try: return history_service.record({ "input_text": input_text, "extracted_questions": json.dumps(extracted_questions) if isinstance(extracted_questions, list) else extracted_questions, "decompose_prompt": decompose_prompt, "decomposer_time_ms": decomposer_time_ms, "retriever_time_ms": retriever_time_ms, "chunks_retrieved": chunks_retrieved, "chunks_retrieved_count": chunks_retrieved_count, "filter_prompt": filter_prompt, "filter_time_ms": filter_time_ms, "chunks_filtered": chunks_filtered, "chunks_filtered_count": chunks_filtered_count, "generate_prompt": generate_prompt, "generator_time_ms": generator_time_ms, "total_time_ms": total_time_ms, "final_answer": final_answer, "sources": sources, "profile_used": profile_used, }) except Exception: logger.warning("History recording failed", exc_info=True) return -1 def _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, final_answer, sources_json, total_time_ms): """Fire-and-forget history recording. Never blocks the SSE stream.""" try: asyncio.create_task( _record_history( history_service, request.question, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, final_answer, sources_json, total_time_ms ) ) except Exception: logger.warning("Failed to schedule history recording", exc_info=True) async def _query_stream(request: QueryRequest): settings = get_settings() prompt_service = PromptService(db_path=settings.prompts_db_path) overall_start = time.perf_counter() try: history_service = HistoryService(db_path=settings.history_db_path) llm_client_dp = LLMClientDP(settings) llm_client = LLMClient(settings) rag = RAGService(llm_client=llm_client, settings=settings, prompt_service=prompt_service) active_profile = prompt_service.get_active_profile_name() logger.info("Query: %s. Active prompt profile: %s", request.question, active_profile) decomposer = QueryDecomposer(llm_client_dp, prompt_service=prompt_service) # Stage 1: Decompose stage_start = time.perf_counter() decompose_result = await decomposer.decompose(request.question) if isinstance(decompose_result, tuple): extracted_questions, decompose_prompt = decompose_result else: extracted_questions, decompose_prompt = decompose_result, "" decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000) logger.info("Extracted questions: %s", extracted_questions) if not extracted_questions: extracted_questions = [request.question] yield _format_sse({ "phase": "decomposed", "extracted_questions": extracted_questions, }) # Stage 2: Retrieve (per sub-question) stage_start = time.perf_counter() retrieval_results = rag.retrieve_per_subquestion( extracted_questions, n_results=settings.retrieval_n_results ) if extracted_questions else [] retriever_time_ms = int((time.perf_counter() - stage_start) * 1000) all_chunks_flat = [] for _sub_q, chunks in retrieval_results: for text, meta, _dist in chunks: all_chunks_flat.append((text, meta, _dist)) chunks_retrieved_count = len(all_chunks_flat) chunks_retrieved = format_chunks_retrieved_per_subq(retrieval_results) yield _format_sse({"phase": "retrieving"}) if not all_chunks_flat: _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, 0, 0, "", "", 0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER, "[]", int((time.perf_counter() - overall_start) * 1000)) yield _format_sse({ "phase": "completed", "answer": NO_RESULTS_ANSWER, "sources": [], }) return # Stage 3: Filter (per sub-question — single LLM call) stage_start = time.perf_counter() chunks_by_subq = [] for _sub_q, chunks in retrieval_results: chunks_by_subq.append([(text, meta) for text, meta, _dist in chunks]) relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service) yield _format_sse({"phase": "filtering"}) if extracted_questions and chunks_by_subq: filter_result = await relevance_filter.filter_per_subquestion( extracted_questions, chunks_by_subq, threshold=settings.relevance_threshold ) else: filter_result = ([], "") if isinstance(filter_result, tuple): filtered_by_subq, filter_prompt = filter_result else: filtered_by_subq, filter_prompt = filter_result, "" all_filtered_flat = [] for _sub_q, filtered_chunks in filtered_by_subq: all_filtered_flat.extend(filtered_chunks) filter_time_ms = int((time.perf_counter() - stage_start) * 1000) chunks_filtered_count = len(all_filtered_flat) chunks_filtered = format_chunks_filtered_per_subq(filtered_by_subq) if filtered_by_subq else "" if not all_filtered_flat: _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER, "[]", int((time.perf_counter() - overall_start) * 1000)) yield _format_sse({ "phase": "completed", "answer": NO_RESULTS_ANSWER, "sources": [], }) return # Stage 4: Generate (per sub-question with progressive streaming) stage_start = time.perf_counter() sub_chunk_texts = [] sub_chunk_metadata = [] for _sub_q, filtered_chunks in filtered_by_subq: texts = [chunk for chunk, _meta in filtered_chunks] metas = [meta for _chunk, meta in filtered_chunks] sub_chunk_texts.append(texts) sub_chunk_metadata.append(metas) yield _format_sse({"phase": "generating"}) if extracted_questions and filtered_by_subq: gen_result = await rag.generate_response_per_subquestion( extracted_questions, sub_chunk_texts, sub_chunk_metadata, ) else: gen_result = ("", "", []) if isinstance(gen_result, tuple) and len(gen_result) == 3: answer, generate_prompt, grouped_sources_meta = gen_result else: answer, generate_prompt = gen_result if isinstance(gen_result, tuple) else (gen_result, "") grouped_sources_meta = [] sub_question_sources = [] for idx, (sub_q_text, sources_meta) in enumerate( zip(extracted_questions, grouped_sources_meta) ): sources = [ SourceMetadata( filename=meta.get("filename", "unknown"), upload_date=meta.get("upload_date", ""), content_summary=meta.get("content_summary", ""), chunk_index=meta.get("chunk_index", 0), page_number=meta.get("page_number"), chunk_file_path=meta.get("chunk_file_path"), document_id=meta.get("document_id"), ) for meta in sources_meta ] sub_question_sources.append( SubQuestionSources( sub_question_index=idx, sub_question_text=sub_q_text, sources=sources, ) ) yield _format_sse({ "phase": "generating_subquestion", "sub_question_index": idx, "sub_question_text": sub_q_text, }) generator_time_ms = int((time.perf_counter() - stage_start) * 1000) logger.info( "Answer generated: %d chars, %d sub-questions", len(answer), len(extracted_questions), ) total_time_ms = int((time.perf_counter() - overall_start) * 1000) all_sources_flat = [] for sq in sub_question_sources: all_sources_flat.extend(sq.sources) sources_json = json.dumps([ [s.model_dump() for s in sq.sources] for sq in sub_question_sources ]) history_id = await _record_history( history_service, request.question, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, answer, sources_json, total_time_ms ) yield _format_sse({ "phase": "completed", "answer": answer, "sub_question_sources": [sq.model_dump() for sq in sub_question_sources], "sources": [s.model_dump() for s in all_sources_flat], "history_id": history_id, }) except HTTPException: raise except Exception as e: logger.error("Query stream failed: %s", str(e)) yield _format_sse({ "phase": "error", "message": f"Query failed: {str(e)}", }) @router.post("/query") async def query(request: QueryRequest): if not request.question or not request.question.strip(): raise HTTPException(status_code=400, detail="Question is required") return StreamingResponse( _query_stream(request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )