139 lines
4.8 KiB
Python
139 lines
4.8 KiB
Python
"""Query decomposer service.
|
|
|
|
This module provides a lightweight QueryDecomposer that delegates the
|
|
decomposition of a natural language question into simplified sub-questions
|
|
to an LLM client. Prompt templates are fetched from PromptService when
|
|
available; otherwise, a built-in default is used.
|
|
|
|
Uses LangChain structured output via LLMClient.complete_structured()
|
|
for guaranteed valid JSON, with a legacy json.loads() fallback path.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import TYPE_CHECKING, List, Tuple
|
|
|
|
if TYPE_CHECKING:
|
|
from app.services.prompt_service import PromptService
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Fallback template used when prompt_service is not provided (tests, standalone).
|
|
_BUILTIN_DECOMPOSE_TEMPLATE = (
|
|
"Given this question: '{question}'\n\n"
|
|
"Break it down into 2-5 simplified sub-questions that would help "
|
|
"search for relevant information. Each sub-question should be short "
|
|
"and focused on one aspect. Return as a JSON array of strings."
|
|
)
|
|
|
|
|
|
def _extract_json_from_markdown(response: str) -> str:
|
|
if not isinstance(response, str):
|
|
return str(response)
|
|
pattern = r"```(?:json)?\s*\n?(.*?)\n?```"
|
|
match = re.search(pattern, response, re.DOTALL)
|
|
if match:
|
|
return match.group(1).strip()
|
|
return response.strip()
|
|
|
|
|
|
def _parse_legacy_json(response: str) -> List[str]:
|
|
extracted = _extract_json_from_markdown(response)
|
|
logger.info("Legacy JSON parse: extracted text (first 300 chars): %s", extracted[:300] if extracted else "(empty)")
|
|
try:
|
|
data = json.loads(extracted)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Legacy JSON parse: json.loads failed on extracted text")
|
|
return []
|
|
|
|
if not isinstance(data, list):
|
|
return []
|
|
|
|
if len(data) == 0:
|
|
return []
|
|
if all(isinstance(item, str) for item in data):
|
|
return data
|
|
return [str(item) for item in data]
|
|
|
|
|
|
class QueryDecomposer:
|
|
"""Decompose a natural language question into simplified sub-questions.
|
|
|
|
The class expects an LLM client that exposes ``async complete(prompt: str) -> str``
|
|
and ``async complete_structured(prompt, pydantic_model) -> BaseModel``,
|
|
and an optional ``PromptService`` for templated prompts. When ``prompt_service`` is
|
|
``None``, a built-in default template is used.
|
|
"""
|
|
|
|
def __init__(self, llm_client, prompt_service: "PromptService | None" = None) -> None:
|
|
self.llm_client = llm_client
|
|
self._prompt_service = prompt_service
|
|
|
|
async def decompose(self, question: str) -> Tuple[List[str], str]:
|
|
"""Return a list of sub-questions and the prompt used for decomposition.
|
|
|
|
Uses LangChain structured output as the primary path (guaranteed valid JSON).
|
|
Falls back to legacy json.loads() parsing if structured output fails.
|
|
|
|
Args:
|
|
question: The natural language question to decompose.
|
|
|
|
Returns:
|
|
A tuple of (sub-questions, prompt). sub-questions is a list of
|
|
strings; prompt is the rendered prompt string. If both structured
|
|
and legacy paths fail, sub-questions will be an empty list.
|
|
"""
|
|
|
|
if question is None or question.strip() == "":
|
|
return [], ""
|
|
|
|
if self._prompt_service is not None:
|
|
template = self._prompt_service.get_prompt_template("decompose")
|
|
logger.info("Decompose prompt template (first 200 chars): %s", template[:200] if template else "(empty)")
|
|
else:
|
|
template = _BUILTIN_DECOMPOSE_TEMPLATE
|
|
|
|
prompt = template.replace("{question}", question)
|
|
|
|
from app.models.decompose import SubQuestions
|
|
|
|
try:
|
|
result = await self.llm_client.complete_structured(
|
|
prompt=prompt,
|
|
pydantic_model=SubQuestions,
|
|
step_name="QueryDecomposer",
|
|
)
|
|
return result.questions, prompt
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Structured decomposition failed: %s. Falling back to legacy parse.",
|
|
exc,
|
|
)
|
|
|
|
try:
|
|
response = await self.llm_client.complete(prompt, step_name="QueryDecomposer")
|
|
except Exception as exc:
|
|
logger.warning("Legacy LLM decomposition also failed: %s", exc)
|
|
return [], prompt
|
|
|
|
if not isinstance(response, str):
|
|
response = str(response)
|
|
|
|
questions = _parse_legacy_json(response)
|
|
|
|
if not questions:
|
|
logger.warning(
|
|
"Legacy decompose JSON parse failed. Raw response (first 500 chars): %s",
|
|
response[:500],
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Legacy decompose succeeded after structured output failure. "
|
|
"Consider investigating why structured output failed."
|
|
)
|
|
|
|
return questions, prompt
|