import asyncio import json import logging import time from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from app.core.config import get_settings from app.models.query import QueryRequest from app.models.common import SourceMetadata from app.services.history_service import HistoryService from app.services.llm_client import LLMClient from app.services.prompt_service import PromptService from app.services.query_decomposer import QueryDecomposer from app.services.relevance_filter import RelevanceFilter from app.services.rag import RAGService logger = logging.getLogger(__name__) router = APIRouter(tags=["query"]) NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question." def _format_sse(data: dict) -> str: return f"data: {json.dumps(data)}\n\n" def format_chunks_retrieved_xml(chunks: list) -> str: """Format retrieved chunks as XML-tagged string. chunks = [(text, metadata, distance), ...] from RAGService.retrieve() """ parts = [] for i, (text, meta, _dist) in enumerate(chunks, start=1): lines = [f""] lines.append(f"Filename: {meta.get('filename', 'unknown')}") page = meta.get("page_number") if page is not None: lines.append(f"Page: {page}") lines.append(f"Content: {text}") lines.append(f"") parts.append("\n".join(lines)) return "\n".join(parts) def format_chunks_filtered_xml(filtered: list) -> str: """Format filtered chunks as XML-tagged string with relevance scores. filtered = [(text, meta), ...] — score embedded in meta["relevance_score"] """ parts = [] for i, (text, meta) in enumerate(filtered, start=1): lines = [f""] lines.append(f"Filename: {meta.get('filename', 'unknown')}") page = meta.get("page_number") if page is not None: lines.append(f"Page: {page}") score = meta.get("relevance_score") if score is not None: lines.append(f"Relevance: {score}") lines.append(f"Content: {text}") lines.append(f"") parts.append("\n".join(lines)) return "\n".join(parts) async def _record_history(history_service, input_text, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, profile_used, final_answer, sources, total_time_ms): """Record a query to history. Runs as a fire-and-forget task.""" try: history_service.record({ "input_text": input_text, "extracted_questions": json.dumps(extracted_questions) if isinstance(extracted_questions, list) else extracted_questions, "decompose_prompt": decompose_prompt, "decomposer_time_ms": decomposer_time_ms, "retriever_time_ms": retriever_time_ms, "chunks_retrieved": chunks_retrieved, "chunks_retrieved_count": chunks_retrieved_count, "filter_prompt": filter_prompt, "filter_time_ms": filter_time_ms, "chunks_filtered": chunks_filtered, "chunks_filtered_count": chunks_filtered_count, "generate_prompt": generate_prompt, "generator_time_ms": generator_time_ms, "total_time_ms": total_time_ms, "final_answer": final_answer, "sources": sources, "profile_used": profile_used, }) except Exception: logger.warning("History recording failed", exc_info=True) def _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, final_answer, sources_json, total_time_ms): """Fire-and-forget history recording. Never blocks the SSE stream.""" try: asyncio.create_task( _record_history( history_service, request.question, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, final_answer, sources_json, total_time_ms ) ) except Exception: logger.warning("Failed to schedule history recording", exc_info=True) async def _query_stream(request: QueryRequest): settings = get_settings() prompt_service = PromptService(db_path=settings.prompts_db_path) overall_start = time.perf_counter() try: history_service = HistoryService(db_path=settings.history_db_path) llm_client = LLMClient(settings) rag = RAGService(llm_client=llm_client, settings=settings, prompt_service=prompt_service) active_profile = prompt_service.get_active_profile_name() logger.info("Query: %s. Active prompt profile: %s", request.question, active_profile) decomposer = QueryDecomposer(llm_client, prompt_service=prompt_service) # Stage 1: Decompose stage_start = time.perf_counter() decompose_result = await decomposer.decompose(request.question) if isinstance(decompose_result, tuple): extracted_questions, decompose_prompt = decompose_result else: extracted_questions, decompose_prompt = decompose_result, "" decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000) logger.info("Extracted questions: %s", extracted_questions) yield _format_sse({ "phase": "decomposed", "extracted_questions": extracted_questions, }) # Stage 2: Retrieve stage_start = time.perf_counter() chunks = rag.retrieve(extracted_questions, n_results=settings.retrieval_n_results) retriever_time_ms = int((time.perf_counter() - stage_start) * 1000) chunks_retrieved_count = len(chunks) chunks_retrieved = format_chunks_retrieved_xml(chunks) yield _format_sse({"phase": "retrieving"}) if not chunks: _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, 0, 0, "", "", 0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER, "[]", int((time.perf_counter() - overall_start) * 1000)) yield _format_sse({ "phase": "completed", "answer": NO_RESULTS_ANSWER, "sources": [], }) return # Stage 3: Filter chunks_for_filter = [(text, meta) for text, meta, _dist in chunks] relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service) yield _format_sse({"phase": "filtering"}) filter_result = await relevance_filter.filter( request.question, chunks_for_filter, threshold=settings.relevance_threshold ) if isinstance(filter_result, tuple): filtered, filter_prompt = filter_result else: filtered, filter_prompt = filter_result, "" filter_time_ms = int((time.perf_counter() - stage_start) * 1000) chunks_filtered_count = len(filtered) chunks_filtered = format_chunks_filtered_xml(filtered) if not filtered: _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER, "[]", int((time.perf_counter() - overall_start) * 1000)) yield _format_sse({ "phase": "completed", "answer": NO_RESULTS_ANSWER, "sources": [], }) return # Stage 4: Generate stage_start = time.perf_counter() chunk_texts = [chunk for chunk, _meta in filtered] chunk_metadata = [meta for _chunk, meta in filtered] yield _format_sse({"phase": "generating"}) gen_result = await rag.generate_response(request.question, chunk_texts, chunk_metadata) if isinstance(gen_result, tuple): answer, generate_prompt = gen_result else: answer, generate_prompt = gen_result, "" generator_time_ms = int((time.perf_counter() - stage_start) * 1000) logger.info("Answer generated: %d chars, %d sources", len(answer), len(filtered)) total_time_ms = int((time.perf_counter() - overall_start) * 1000) sources = [ SourceMetadata( filename=meta.get("filename", "unknown"), upload_date=meta.get("upload_date", ""), content_summary=meta.get("content_summary", ""), chunk_index=meta.get("chunk_index", 0), page_number=meta.get("page_number"), chunk_file_path=meta.get("chunk_file_path"), ) for meta in chunk_metadata ] _schedule_history(history_service, request, extracted_questions, decompose_prompt, decomposer_time_ms, retriever_time_ms, chunks_retrieved_count, chunks_retrieved, filter_prompt, filter_time_ms, chunks_filtered_count, chunks_filtered, generate_prompt, generator_time_ms, active_profile, answer, json.dumps([s.model_dump() for s in sources]), total_time_ms) yield _format_sse({ "phase": "completed", "answer": answer, "sources": [s.model_dump() for s in sources], }) except HTTPException: raise except Exception as e: logger.error("Query stream failed: %s", str(e)) yield _format_sse({ "phase": "error", "message": f"Query failed: {str(e)}", }) @router.post("/query") async def query(request: QueryRequest): if not request.question or not request.question.strip(): raise HTTPException(status_code=400, detail="Question is required") return StreamingResponse( _query_stream(request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, )