121 lines
3.7 KiB
Python
121 lines
3.7 KiB
Python
import json
|
|
import logging
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.core.config import get_settings
|
|
from app.models.query import QueryRequest
|
|
from app.models.common import SourceMetadata
|
|
from app.services.llm_client import LLMClient
|
|
from app.services.query_decomposer import QueryDecomposer
|
|
from app.services.relevance_filter import RelevanceFilter
|
|
from app.services.rag import RAGService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(tags=["query"])
|
|
|
|
NO_RESULTS_ANSWER = "I could not find any relevant information to answer your question."
|
|
|
|
|
|
def _format_sse(data: dict) -> str:
|
|
return f"data: {json.dumps(data)}\n\n"
|
|
|
|
|
|
async def _query_stream(request: QueryRequest):
|
|
settings = get_settings()
|
|
|
|
try:
|
|
llm_client = LLMClient(settings)
|
|
rag = RAGService(llm_client=llm_client, settings=settings)
|
|
|
|
logger.info("Query: %s", request.question)
|
|
|
|
decomposer = QueryDecomposer(llm_client)
|
|
extracted_questions = await decomposer.decompose(request.question)
|
|
logger.info("Extracted questions: %s", extracted_questions)
|
|
|
|
yield _format_sse({
|
|
"phase": "decomposed",
|
|
"extracted_questions": extracted_questions,
|
|
})
|
|
|
|
chunks = rag.retrieve(extracted_questions, n_results=settings.retrieval_n_results)
|
|
|
|
yield _format_sse({"phase": "retrieving"})
|
|
|
|
if not chunks:
|
|
yield _format_sse({
|
|
"phase": "completed",
|
|
"answer": NO_RESULTS_ANSWER,
|
|
"sources": [],
|
|
})
|
|
return
|
|
|
|
chunks_for_filter = [(text, meta) for text, meta, _dist in chunks]
|
|
relevance_filter = RelevanceFilter(llm_client)
|
|
|
|
yield _format_sse({"phase": "filtering"})
|
|
|
|
filtered = await relevance_filter.filter(
|
|
request.question, chunks_for_filter, threshold=settings.relevance_threshold
|
|
)
|
|
|
|
if not filtered:
|
|
yield _format_sse({
|
|
"phase": "completed",
|
|
"answer": NO_RESULTS_ANSWER,
|
|
"sources": [],
|
|
})
|
|
return
|
|
|
|
chunk_texts = [chunk for chunk, _meta in filtered]
|
|
chunk_metadata = [meta for _chunk, meta in filtered]
|
|
|
|
yield _format_sse({"phase": "generating"})
|
|
|
|
answer = await rag.generate_response(request.question, chunk_texts, chunk_metadata)
|
|
logger.info("Answer generated: %d chars, %d sources", len(answer), len(filtered))
|
|
|
|
sources = [
|
|
SourceMetadata(
|
|
filename=meta.get("filename", "unknown"),
|
|
upload_date=meta.get("upload_date", ""),
|
|
content_summary=meta.get("content_summary", ""),
|
|
chunk_index=meta.get("chunk_index", 0),
|
|
page_number=meta.get("page_number"),
|
|
chunk_file_path=meta.get("chunk_file_path"),
|
|
)
|
|
for meta in chunk_metadata
|
|
]
|
|
|
|
yield _format_sse({
|
|
"phase": "completed",
|
|
"answer": answer,
|
|
"sources": [s.model_dump() for s in sources],
|
|
})
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Query stream failed: %s", str(e))
|
|
yield _format_sse({
|
|
"phase": "error",
|
|
"message": f"Query failed: {str(e)}",
|
|
})
|
|
|
|
|
|
@router.post("/query")
|
|
async def query(request: QueryRequest):
|
|
if not request.question or not request.question.strip():
|
|
raise HTTPException(status_code=400, detail="Question is required")
|
|
|
|
return StreamingResponse(
|
|
_query_stream(request),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|