legco_ai_assistant/backend/app/utils/qa_chunking.py

362 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Q&A-pair chunking utilities for Package 8.
Provides section detection (LLM + regex), text preprocessing,
and chunk building for LegCo documents with Q&A structure.
"""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@dataclass
class Section:
"""A detected section within a LegCo document."""
type: str # "qa" | "narrative" | "speaking_notes" | "table" | "toc" | "heading_only"
heading: str = ""
qa_id: Optional[str] = None
question: Optional[str] = None
answer: Optional[str] = None
content: str = ""
start_page: int = 1
end_page: int = 1
has_table: bool = False
parent_topic: str = ""
_FOOTER_RE = re.compile(r"^[A-Z]-\d+\s*$", re.MULTILINE)
_FOOTER_DATE_RE = re.compile(r"^[A-Z]-\d+\s*\n\d{4}-\d{2}-\d{2}$", re.MULTILINE)
_HEADER_LETTER_RE = re.compile(r"^(\([A-Z]\))\s*$", re.MULTILINE)
_FULLWIDTH_COLON_RE = re.compile("[]")
def preprocess_text(pages: List[Tuple[int, str]]) -> str:
"""Concatenate pages, strip footers/headers, normalize colons, insert [PAGE_BREAK: N] markers."""
parts: List[str] = []
for idx, (page_num, page_text) in enumerate(pages):
text = _FOOTER_DATE_RE.sub("", page_text)
text = _FOOTER_RE.sub("", text)
if idx > 0:
text = _HEADER_LETTER_RE.sub("", text)
text = _FULLWIDTH_COLON_RE.sub(":", text)
parts.append(f"[PAGE_BREAK: {page_num}]\n{text}")
return "\n".join(parts)
_STRUCTURE_PROMPT_TEMPLATE = """You are analyzing a Hong Kong Legislative Council document.
The text has page markers like [PAGE_BREAK: N] showing where pages begin.
For each distinct section in this document, identify:
1. The section type:
- "qa": a question-and-answer pair (問/答 or Q1/Q2 format)
- "narrative": policy text, explanatory paragraphs, section content with bullets
- "speaking_notes": briefing points (發言要點) with bullet markers
- "table": standalone data tables (not embedded in answers)
- "toc": table of contents
- "heading_only": a section heading with no following content
2. For "qa" sections:
- The question text (exact)
- The answer text (exact, including tables, bullet lists, and [內部參考] content)
- The question ID if present (e.g. "A1", "Q3")
- The start page and end page
3. For all sections:
- The section heading (e.g. "(A) 排水系統", "(1) 住戶的安置補償")
- The start page and end page
- Whether the section contains tables
Return JSON:
{{
"sections": [
{{
"type": "qa",
"heading": "(A) 排水系統",
"qa_id": "A1",
"question": "...",
"answer": "...",
"start_page": 2,
"end_page": 3,
"has_table": true,
"parent_topic": "排水系統"
}},
{{
"type": "narrative",
"heading": "(1) 住戶的安置補償",
"content": "...",
"start_page": 2,
"end_page": 5,
"has_table": false
}}
]
}}
DOCUMENT TEXT:
{document_text}"""
def build_structure_detection_prompt(text: str) -> str:
"""Construct the LLM prompt for section classification."""
return _STRUCTURE_PROMPT_TEMPLATE.format(document_text=text)
_MARKDOWN_FENCE_RE = re.compile(r"```(?:json)?\s*\n?(.*?)\n?```", re.DOTALL)
def parse_llm_structure_response(response_text: str) -> List[Section]:
"""Parse the JSON returned by the LLM. Handle markdown code fences.
Raises ValueError if response is not valid JSON.
"""
cleaned = response_text.strip()
fence_match = _MARKDOWN_FENCE_RE.search(cleaned)
if fence_match:
cleaned = fence_match.group(1).strip()
try:
data = json.loads(cleaned)
except json.JSONDecodeError as exc:
raise ValueError(f"Invalid JSON from LLM structure detection: {exc}") from exc
sections_raw = data.get("sections", [])
sections: List[Section] = []
for raw in sections_raw:
sections.append(Section(
type=raw.get("type", "narrative"),
heading=raw.get("heading", ""),
qa_id=raw.get("qa_id"),
question=raw.get("question"),
answer=raw.get("answer"),
content=raw.get("content", ""),
start_page=raw.get("start_page", 1),
end_page=raw.get("end_page", 1),
has_table=raw.get("has_table", False),
parent_topic=raw.get("parent_topic", ""),
))
return sections
_CN_QA_RE = re.compile(
r"\s*([A-Z]\d+)\s*[:]\s*(.*?)\s*"
r"(?:\n\s*答\s*\1\s*[:]\s*(.*?)\s*)"
r"(?=\n\s*(?:問\s*[A-Z]\d+\s*[:]|$))",
re.DOTALL,
)
def split_chinese_qa(text: str) -> List[Section]:
"""Regex fast-pass for 問/答 format. Returns empty list if no matches found."""
sections: List[Section] = []
for m in _CN_QA_RE.finditer(text):
qa_id = m.group(1)
question = m.group(2).strip()
answer = (m.group(3) or "").strip()
sections.append(Section(
type="qa",
qa_id=qa_id,
question=question,
answer=answer,
))
return sections
_EN_QA_RE = re.compile(
r"^(Q\d+)\s+(.*?)\s*$\n((?:(?!^Q\d+).+(?:\n|$))*)",
re.MULTILINE,
)
def split_english_qa(text: str) -> List[Section]:
"""Regex fast-pass for Q-number format. Returns empty list if no matches found."""
sections: List[Section] = []
for m in _EN_QA_RE.finditer(text):
qa_id = m.group(1)
question = m.group(2).strip()
answer = m.group(3).strip()
sections.append(Section(
type="qa",
qa_id=qa_id,
question=question,
answer=answer,
))
return sections
def _estimate_tokens(text: str) -> int:
"""Rough token estimate: ~1.3 tokens per CJK char, ~1 token per 4 chars for Latin."""
cjk_count = 0
latin_len = 0
for ch in text:
if "\u4e00" <= ch <= "\u9fff":
cjk_count += 1
else:
latin_len += 1
return int(cjk_count * 1.3 + latin_len / 4)
def _split_oversized_qa(
question: str, answer: str, page: int, heading: str,
qa_id: Optional[str], question_index: int, has_table: bool,
parent_topic: str, start_page: int, end_page: int,
max_tokens: int,
) -> List[Tuple[str, int, dict]]:
"""Recursively split an oversized Q&A answer with question prepended to each sub-chunk."""
# Try paragraph boundaries first
parts = answer.split("\n\n")
if len(parts) <= 1:
parts = answer.split("\n")
# Group parts into sub-chunks that fit within max_tokens
sub_chunks: List[str] = []
current = ""
for part in parts:
candidate = (current + "\n\n" + part) if current else part
if _estimate_tokens(f"Question: {question}\n\nAnswer (part 1/N): {candidate}") > max_tokens and current:
sub_chunks.append(current)
current = part
else:
current = candidate
if current:
sub_chunks.append(current)
total = len(sub_chunks)
results: List[Tuple[str, int, dict]] = []
for i, sub in enumerate(sub_chunks):
chunk_text = f"Question: {question}\n\nAnswer (part {i + 1}/{total}): {sub}"
meta = {
"strategy_type": "question",
"section_type": "qa",
"question_index": question_index,
"question_id": qa_id,
"question_text": question,
"section_heading": heading,
"answer_contains_table": has_table,
"source_page_range": [start_page, end_page],
"parent_topic": parent_topic,
}
results.append((chunk_text, page, meta))
return results
def build_chunks_from_sections(
sections: List[Section], max_tokens: int = 3000,
) -> List[Tuple[str, int, dict]]:
"""Build chunk texts + page refs + metadata from sections.
Returns List[(chunk_text, page_number, metadata_dict)].
"""
chunks: List[Tuple[str, int, dict]] = []
qa_index = 0
for section in sections:
if section.type in ("toc", "heading_only"):
continue
if section.type == "qa":
question_text = section.question or ""
answer_text = section.answer or ""
chunk_text = f"Question: {question_text}\n\nAnswer: {answer_text}"
if section.heading:
chunk_text = f"[{section.heading}]\n{chunk_text}"
page = section.start_page
meta: Dict = {
"strategy_type": "question",
"section_type": "qa",
"question_index": qa_index,
"question_id": section.qa_id,
"question_text": question_text,
"section_heading": section.heading,
"answer_contains_table": section.has_table,
"source_page_range": [section.start_page, section.end_page],
"parent_topic": section.parent_topic,
}
if _estimate_tokens(chunk_text) > max_tokens:
chunks.extend(_split_oversized_qa(
question=question_text,
answer=answer_text,
page=page,
heading=section.heading,
qa_id=section.qa_id,
question_index=qa_index,
has_table=section.has_table,
parent_topic=section.parent_topic,
start_page=section.start_page,
end_page=section.end_page,
max_tokens=max_tokens,
))
else:
chunks.append((chunk_text, page, meta))
qa_index += 1
elif section.type == "narrative":
content = section.content
if not content.strip():
continue
prefix = f"[{section.heading}]\n" if section.heading else ""
chunk_text = f"{prefix}{content}"
meta = {
"strategy_type": "question",
"section_type": "narrative",
"section_heading": section.heading,
"source_page_range": [section.start_page, section.end_page],
}
if _estimate_tokens(chunk_text) <= max_tokens:
chunks.append((chunk_text, section.start_page, meta))
else:
paragraphs = content.split("\n\n")
current = ""
for para in paragraphs:
candidate = (current + "\n\n" + para) if current else para
full = f"{prefix}{candidate}"
if _estimate_tokens(full) > max_tokens and current:
chunks.append((f"{prefix}{current}", section.start_page, dict(meta)))
current = para
else:
current = candidate
if current:
chunks.append((f"{prefix}{current}", section.start_page, dict(meta)))
elif section.type == "speaking_notes":
content = section.content
if not content.strip():
continue
bullets = re.split(r"(?=⚫)", content)
bullets = [b.strip() for b in bullets if b.strip()]
if not bullets:
bullets = [content]
prefix = f"[{section.heading}]\n" if section.heading else ""
for bullet in bullets:
chunk_text = f"{prefix}{bullet}"
meta = {
"strategy_type": "question",
"section_type": "speaking_notes",
"section_heading": section.heading,
"source_page_range": [section.start_page, section.end_page],
}
chunks.append((chunk_text, section.start_page, meta))
elif section.type == "table":
content = section.content
if not content.strip():
continue
chunk_text = f"[{section.heading}]\n{content}" if section.heading else content
meta = {
"strategy_type": "question",
"section_type": "table",
"section_heading": section.heading,
"answer_contains_table": True,
"source_page_range": [section.start_page, section.end_page],
}
chunks.append((chunk_text, section.start_page, meta))
return chunks