refactor(backend): update query decomposer, relevance filter, and RAG service
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
38f4c70762
commit
f4d78b0b77
|
|
@ -2,19 +2,22 @@
|
||||||
|
|
||||||
This module provides a lightweight QueryDecomposer that delegates the
|
This module provides a lightweight QueryDecomposer that delegates the
|
||||||
translation of a natural language question into a list of keyword search
|
translation of a natural language question into a list of keyword search
|
||||||
terms to an LLM client. The interface is intentionally minimal to support
|
terms to an LLM client.
|
||||||
test-driven development for Phase 1.3.
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class QueryDecomposer:
|
class QueryDecomposer:
|
||||||
"""Decompose a natural language question into a list of keywords.
|
"""Decompose a natural language question into a list of keywords.
|
||||||
|
|
||||||
The class expects an object that exposes a ``complete(prompt: str) -> str``
|
The class expects an object that exposes an ``async complete(prompt: str) -> str``
|
||||||
method (an LLM client). The ``decompose`` method builds a prompt, asks the
|
method (an LLM client). The ``decompose`` method builds a prompt, asks the
|
||||||
LLM to return a JSON array of strings, and parses that JSON into a Python
|
LLM to return a JSON array of strings, and parses that JSON into a Python
|
||||||
list of strings. Edge cases are handled gracefully.
|
list of strings. Edge cases are handled gracefully.
|
||||||
|
|
@ -23,7 +26,7 @@ class QueryDecomposer:
|
||||||
def __init__(self, llm_client) -> None:
|
def __init__(self, llm_client) -> None:
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
|
|
||||||
def decompose(self, question: str) -> List[str]:
|
async def decompose(self, question: str) -> List[str]:
|
||||||
"""Return a list of keywords extracted for the given question.
|
"""Return a list of keywords extracted for the given question.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -40,9 +43,9 @@ class QueryDecomposer:
|
||||||
prompt = f"Given question: '{question}', extract key search keywords as JSON array"
|
prompt = f"Given question: '{question}', extract key search keywords as JSON array"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.llm_client.complete(prompt)
|
response = await self.llm_client.complete(prompt)
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
# If the LLM call fails for any reason, defensively return no keywords
|
logger.warning("LLM decomposition failed: %s", exc)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not isinstance(response, str):
|
if not isinstance(response, str):
|
||||||
|
|
@ -51,15 +54,13 @@ class QueryDecomposer:
|
||||||
try:
|
try:
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# Invalid JSON – treat as no keywords
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# If all items are strings, return as-is. Otherwise, coerce to strings.
|
|
||||||
if len(data) == 0:
|
if len(data) == 0:
|
||||||
return []
|
return []
|
||||||
if all(isinstance(item, str) for item in data):
|
if all(isinstance(item, str) for item in data):
|
||||||
return data # type: ignore[return-value]
|
return data
|
||||||
return [str(item) for item in data]
|
return [str(item) for item in data]
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,15 @@
|
||||||
"""RAG service for embedding, retrieval, and response generation."""
|
"""RAG service for embedding, retrieval, and response generation."""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Tuple, Dict, Any, Optional
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
|
import logging
|
||||||
import httpx
|
|
||||||
|
|
||||||
from app.core.config import Settings
|
from app.core.config import Settings
|
||||||
from app.core.database import get_chroma_client
|
from app.core.database import get_chroma_client
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RAGService:
|
class RAGService:
|
||||||
"""Service for document ingestion, retrieval, and response generation."""
|
"""Service for document ingestion, retrieval, and response generation."""
|
||||||
|
|
||||||
|
|
@ -25,10 +27,14 @@ class RAGService:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def collection(self):
|
def collection(self):
|
||||||
"""Lazy-load the ChromaDB collection."""
|
|
||||||
if self._collection is None:
|
if self._collection is None:
|
||||||
from app.core.database import get_or_create_collection
|
from app.core.database import get_or_create_collection, get_embedding_function_settings
|
||||||
self._collection = get_or_create_collection(self.chroma_client, "documents")
|
embedding_fn = None
|
||||||
|
if self.settings is not None:
|
||||||
|
embedding_fn = get_embedding_function_settings(self.settings)
|
||||||
|
self._collection = get_or_create_collection(
|
||||||
|
self.chroma_client, "documents", embedding_function=embedding_fn
|
||||||
|
)
|
||||||
return self._collection
|
return self._collection
|
||||||
|
|
||||||
def ingest_document(
|
def ingest_document(
|
||||||
|
|
@ -37,16 +43,6 @@ class RAGService:
|
||||||
chunks: List[str],
|
chunks: List[str],
|
||||||
metadata_list: List[Dict[str, Any]],
|
metadata_list: List[Dict[str, Any]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Ingest document chunks into ChromaDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the source file.
|
|
||||||
chunks: List of text chunks.
|
|
||||||
metadata_list: List of metadata dicts matching chunk count.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Document ID (UUID) for the ingestion batch.
|
|
||||||
"""
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
@ -66,15 +62,6 @@ class RAGService:
|
||||||
query_keywords: List[str],
|
query_keywords: List[str],
|
||||||
n_results: int = 10,
|
n_results: int = 10,
|
||||||
) -> List[Tuple[str, Dict[str, Any], float]]:
|
) -> List[Tuple[str, Dict[str, Any], float]]:
|
||||||
"""Retrieve relevant chunks from ChromaDB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_keywords: List of keywords from query decomposition.
|
|
||||||
n_results: Maximum number of results to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of (chunk_text, metadata, distance) tuples.
|
|
||||||
"""
|
|
||||||
query_text = " ".join(query_keywords)
|
query_text = " ".join(query_keywords)
|
||||||
|
|
||||||
results = self.collection.query(
|
results = self.collection.query(
|
||||||
|
|
@ -91,22 +78,12 @@ class RAGService:
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def generate_response(
|
async def generate_response(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
chunks: List[str],
|
chunks: List[str],
|
||||||
metadata_list: List[Dict[str, Any]],
|
metadata_list: List[Dict[str, Any]],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a bullet-point response using only provided chunks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question: The user's question.
|
|
||||||
chunks: List of relevant document chunks.
|
|
||||||
metadata_list: List of metadata for each chunk.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Bullet-point formatted answer string.
|
|
||||||
"""
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return "I could not find any relevant information to answer your question."
|
return "I could not find any relevant information to answer your question."
|
||||||
|
|
||||||
|
|
@ -135,4 +112,4 @@ class RAGService:
|
||||||
f"Answer:"
|
f"Answer:"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.llm_client.complete(prompt=prompt, temperature=0.3)
|
return await self.llm_client.complete(prompt=prompt, temperature=0.3)
|
||||||
|
|
|
||||||
|
|
@ -1,66 +1,58 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RelevanceFilter:
|
class RelevanceFilter:
|
||||||
"""RelevanceFilter batches chunk texts to an LLM and selects those with
|
"""RelevanceFilter batches chunk texts to an LLM and selects those with
|
||||||
relevance scores above a threshold.
|
relevance scores above a threshold.
|
||||||
|
|
||||||
The constructor expects an llm_client-like object with a `complete(prompt: str, temperature: float = 0.7) -> str` method.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_client):
|
def __init__(self, llm_client):
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
|
|
||||||
def _build_prompt(self, question: str, chunks: List[Tuple[str, Dict]]) -> str:
|
def _build_prompt(self, question: str, chunks: List[Tuple[str, Dict]]) -> str:
|
||||||
"""Build the single prompt that asks the LLM to score all chunks.
|
|
||||||
|
|
||||||
The prompt format is designed to be simple and deterministic for tests:
|
|
||||||
- Include the question
|
|
||||||
- List the chunk texts in order
|
|
||||||
- Ask for a JSON array of scores corresponding to each chunk
|
|
||||||
"""
|
|
||||||
texts = [chunk_text for (chunk_text, _meta) in chunks]
|
texts = [chunk_text for (chunk_text, _meta) in chunks]
|
||||||
# Keep the prompt readable and deterministic
|
lines = []
|
||||||
|
for idx, t in enumerate(texts, start=1):
|
||||||
|
lines.append(f"Chunk {idx}: {t}")
|
||||||
|
chunks_formatted = "\n".join(lines)
|
||||||
prompt = (
|
prompt = (
|
||||||
f"Given question '{question}' and these document chunks, rate each 0-10 for relevance. "
|
f"Given question '{question}' and these document chunks, rate each 0-10 for relevance. "
|
||||||
f"Return JSON array of scores. Chunks: {texts}"
|
f"Return JSON array of scores.\n{chunks_formatted}\n"
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def filter(
|
async def filter(
|
||||||
self, question: str, chunks: List[Tuple[str, Dict]], threshold: float = 7.0
|
self, question: str, chunks: List[Tuple[str, Dict]], threshold: float = 7.0
|
||||||
) -> List[Tuple[str, Dict]]:
|
) -> List[Tuple[str, Dict]]:
|
||||||
"""Return only chunks whose relevance score exceeds the threshold.
|
|
||||||
|
|
||||||
- Chunks are sent to the LLM in a single batch call.
|
|
||||||
- Expects the LLM to respond with a JSON array of numbers, in the same
|
|
||||||
order as the provided chunks.
|
|
||||||
- If input is empty, returns an empty list.
|
|
||||||
- If the LLM response cannot be parsed or the length mismatches, returns an empty list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
prompt = self._build_prompt(question, chunks)
|
prompt = self._build_prompt(question, chunks)
|
||||||
response = self.llm_client.complete(prompt, temperature=0.0)
|
try:
|
||||||
|
response = await self.llm_client.complete(prompt, temperature=0.0)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("RelevanceFilter LLM call failed: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
scores: List[float] = []
|
scores: List[float] = []
|
||||||
try:
|
try:
|
||||||
parsed = json.loads(response)
|
parsed = json.loads(response)
|
||||||
if not isinstance(parsed, list):
|
if not isinstance(parsed, list):
|
||||||
return []
|
return []
|
||||||
# Ensure all values are numeric
|
|
||||||
for v in parsed:
|
for v in parsed:
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
scores.append(float(v))
|
scores.append(float(v))
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
# Gracefully handle invalid JSON or unexpected formats
|
logger.error("RelevanceFilter JSON parse failed: %s", exc)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if len(scores) != len(chunks):
|
if len(scores) != len(chunks):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue