270 lines
10 KiB
Python
270 lines
10 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.config import get_settings
|
|
from app.models.query import QueryRequest
|
|
from app.models.common import SourceMetadata
|
|
from app.services.history_service import HistoryService
|
|
from app.services.llm_client import LLMClient
|
|
from app.services.prompt_service import PromptService
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
from app.services.rag import RAGService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(tags=["query"])
|
|
|
|
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
|
|
|
|
|
|
def _format_sse(data: dict) -> str:
|
|
return f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
def format_chunks_retrieved_xml(chunks: list) -> str:
|
|
"""Format retrieved chunks as XML-tagged string.
|
|
chunks = [(text, metadata, distance), ...] from RAGService.retrieve()
|
|
"""
|
|
parts = []
|
|
for i, (text, meta, _dist) in enumerate(chunks, start=1):
|
|
lines = [f"<chunk_{i}>"]
|
|
lines.append(f"Filename: {meta.get('filename', 'unknown')}")
|
|
page = meta.get("page_number")
|
|
if page is not None:
|
|
lines.append(f"Page: {page}")
|
|
lines.append(f"Content: {text}")
|
|
lines.append(f"</chunk_{i}>")
|
|
parts.append("\n".join(lines))
|
|
return "\n".join(parts)
|
|
|
|
|
|
def format_chunks_filtered_xml(filtered: list) -> str:
|
|
"""Format filtered chunks as XML-tagged string with relevance scores.
|
|
filtered = [(text, meta), ...] — score embedded in meta["relevance_score"]
|
|
"""
|
|
parts = []
|
|
for i, (text, meta) in enumerate(filtered, start=1):
|
|
lines = [f"<chunk_{i}>"]
|
|
lines.append(f"Filename: {meta.get('filename', 'unknown')}")
|
|
page = meta.get("page_number")
|
|
if page is not None:
|
|
lines.append(f"Page: {page}")
|
|
score = meta.get("relevance_score")
|
|
if score is not None:
|
|
lines.append(f"Relevance: {score}")
|
|
lines.append(f"Content: {text}")
|
|
lines.append(f"</chunk_{i}>")
|
|
parts.append("\n".join(lines))
|
|
return "\n".join(parts)
|
|
|
|
|
|
async def _record_history(history_service, input_text, extracted_questions,
|
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
|
filter_time_ms, chunks_filtered_count, chunks_filtered,
|
|
generate_prompt, generator_time_ms, profile_used,
|
|
final_answer, sources, total_time_ms):
|
|
"""Record a query to history. Runs as a fire-and-forget task."""
|
|
try:
|
|
history_service.record({
|
|
"input_text": input_text,
|
|
"extracted_questions": json.dumps(extracted_questions) if isinstance(extracted_questions, list) else extracted_questions,
|
|
"decompose_prompt": decompose_prompt,
|
|
"decomposer_time_ms": decomposer_time_ms,
|
|
"retriever_time_ms": retriever_time_ms,
|
|
"chunks_retrieved": chunks_retrieved,
|
|
"chunks_retrieved_count": chunks_retrieved_count,
|
|
"filter_prompt": filter_prompt,
|
|
"filter_time_ms": filter_time_ms,
|
|
"chunks_filtered": chunks_filtered,
|
|
"chunks_filtered_count": chunks_filtered_count,
|
|
"generate_prompt": generate_prompt,
|
|
"generator_time_ms": generator_time_ms,
|
|
"total_time_ms": total_time_ms,
|
|
"final_answer": final_answer,
|
|
"sources": sources,
|
|
"profile_used": profile_used,
|
|
})
|
|
except Exception:
|
|
logger.warning("History recording failed", exc_info=True)
|
|
|
|
|
|
def _schedule_history(history_service, request, extracted_questions,
|
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
|
filter_time_ms, chunks_filtered_count, chunks_filtered,
|
|
generate_prompt, generator_time_ms, active_profile,
|
|
final_answer, sources_json, total_time_ms):
|
|
"""Fire-and-forget history recording. Never blocks the SSE stream."""
|
|
try:
|
|
asyncio.create_task(
|
|
_record_history(
|
|
history_service, request.question, extracted_questions,
|
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
|
filter_time_ms, chunks_filtered_count, chunks_filtered,
|
|
generate_prompt, generator_time_ms, active_profile,
|
|
final_answer, sources_json, total_time_ms
|
|
)
|
|
)
|
|
except Exception:
|
|
logger.warning("Failed to schedule history recording", exc_info=True)
|
|
|
|
|
|
async def _query_stream(request: QueryRequest):
|
|
settings = get_settings()
|
|
prompt_service = PromptService(db_path=settings.prompts_db_path)
|
|
overall_start = time.perf_counter()
|
|
|
|
try:
|
|
history_service = HistoryService(db_path=settings.history_db_path)
|
|
|
|
llm_client = LLMClient(settings)
|
|
rag = RAGService(llm_client=llm_client, settings=settings, prompt_service=prompt_service)
|
|
active_profile = prompt_service.get_active_profile_name()
|
|
|
|
logger.info("Query: %s. Active prompt profile: %s", request.question, active_profile)
|
|
|
|
decomposer = QueryDecomposer(llm_client, prompt_service=prompt_service)
|
|
|
|
# Stage 1: Decompose
|
|
stage_start = time.perf_counter()
|
|
decompose_result = await decomposer.decompose(request.question)
|
|
if isinstance(decompose_result, tuple):
|
|
extracted_questions, decompose_prompt = decompose_result
|
|
else:
|
|
extracted_questions, decompose_prompt = decompose_result, ""
|
|
decomposer_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
logger.info("Extracted questions: %s", extracted_questions)
|
|
|
|
yield _format_sse({
|
|
"phase": "decomposed",
|
|
"extracted_questions": extracted_questions,
|
|
})
|
|
|
|
# Stage 2: Retrieve
|
|
stage_start = time.perf_counter()
|
|
chunks = rag.retrieve(extracted_questions, n_results=settings.retrieval_n_results)
|
|
retriever_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
chunks_retrieved_count = len(chunks)
|
|
chunks_retrieved = format_chunks_retrieved_xml(chunks)
|
|
|
|
yield _format_sse({"phase": "retrieving"})
|
|
|
|
if not chunks:
|
|
_schedule_history(history_service, request, extracted_questions,
|
|
decompose_prompt, decomposer_time_ms, 0, 0, "", "",
|
|
0, 0, "", "", 0, active_profile, NO_RESULTS_ANSWER,
|
|
"[]", int((time.perf_counter() - overall_start) * 1000))
|
|
yield _format_sse({
|
|
"phase": "completed",
|
|
"answer": NO_RESULTS_ANSWER,
|
|
"sources": [],
|
|
})
|
|
return
|
|
|
|
# Stage 3: Filter
|
|
chunks_for_filter = [(text, meta) for text, meta, _dist in chunks]
|
|
relevance_filter = RelevanceFilter(llm_client, prompt_service=prompt_service)
|
|
|
|
yield _format_sse({"phase": "filtering"})
|
|
|
|
filter_result = await relevance_filter.filter(
|
|
request.question, chunks_for_filter, threshold=settings.relevance_threshold
|
|
)
|
|
if isinstance(filter_result, tuple):
|
|
filtered, filter_prompt = filter_result
|
|
else:
|
|
filtered, filter_prompt = filter_result, ""
|
|
|
|
filter_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
chunks_filtered_count = len(filtered)
|
|
chunks_filtered = format_chunks_filtered_xml(filtered)
|
|
|
|
if not filtered:
|
|
_schedule_history(history_service, request, extracted_questions,
|
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
|
filter_time_ms, 0, "", "", 0, active_profile,
|
|
NO_RESULTS_ANSWER, "[]",
|
|
int((time.perf_counter() - overall_start) * 1000))
|
|
yield _format_sse({
|
|
"phase": "completed",
|
|
"answer": NO_RESULTS_ANSWER,
|
|
"sources": [],
|
|
})
|
|
return
|
|
|
|
# Stage 4: Generate
|
|
stage_start = time.perf_counter()
|
|
chunk_texts = [chunk for chunk, _meta in filtered]
|
|
chunk_metadata = [meta for _chunk, meta in filtered]
|
|
|
|
yield _format_sse({"phase": "generating"})
|
|
|
|
gen_result = await rag.generate_response(request.question, chunk_texts, chunk_metadata)
|
|
if isinstance(gen_result, tuple):
|
|
answer, generate_prompt = gen_result
|
|
else:
|
|
answer, generate_prompt = gen_result, ""
|
|
|
|
generator_time_ms = int((time.perf_counter() - stage_start) * 1000)
|
|
logger.info("Answer generated: %d chars, %d sources", len(answer), len(filtered))
|
|
|
|
total_time_ms = int((time.perf_counter() - overall_start) * 1000)
|
|
|
|
sources = [
|
|
SourceMetadata(
|
|
filename=meta.get("filename", "unknown"),
|
|
upload_date=meta.get("upload_date", ""),
|
|
content_summary=meta.get("content_summary", ""),
|
|
chunk_index=meta.get("chunk_index", 0),
|
|
page_number=meta.get("page_number"),
|
|
chunk_file_path=meta.get("chunk_file_path"),
|
|
)
|
|
for meta in chunk_metadata
|
|
]
|
|
|
|
_schedule_history(history_service, request, extracted_questions,
|
|
decompose_prompt, decomposer_time_ms, retriever_time_ms,
|
|
chunks_retrieved_count, chunks_retrieved, filter_prompt,
|
|
filter_time_ms, chunks_filtered_count, chunks_filtered,
|
|
generate_prompt, generator_time_ms, active_profile,
|
|
answer, json.dumps([s.model_dump() for s in sources]),
|
|
total_time_ms)
|
|
|
|
yield _format_sse({
|
|
"phase": "completed",
|
|
"answer": answer,
|
|
"sources": [s.model_dump() for s in sources],
|
|
})
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Query stream failed: %s", str(e))
|
|
yield _format_sse({
|
|
"phase": "error",
|
|
"message": f"Query failed: {str(e)}",
|
|
})
|
|
|
|
|
|
@router.post("/query")
|
|
async def query(request: QueryRequest):
|
|
if not request.question or not request.question.strip():
|
|
raise HTTPException(status_code=400, detail="Question is required")
|
|
|
|
return StreamingResponse(
|
|
_query_stream(request),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|