148 lines
4.8 KiB
Python
148 lines
4.8 KiB
Python
"""Table extraction utilities for Package 8.
|
|
|
|
Provides vision-based and text-based table detection and markdown conversion
|
|
for LegCo documents. Uses the existing LLM model (vision-capable) for
|
|
table-to-markdown conversion.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_CACHE_DIR = Path(__file__).resolve().parent.parent.parent / ".cache" / "vision_tables"
|
|
|
|
|
|
async def extract_tables_vision(page_images: List[str], llm_client) -> List[str]:
|
|
"""Send page images to vision LLM, get back markdown tables.
|
|
|
|
Each page_image is a base64-encoded PNG string.
|
|
Uses the existing LLM model which supports vision input.
|
|
"""
|
|
results: List[str] = []
|
|
prompt = (
|
|
"Convert this page to Markdown. For any tables:\n"
|
|
"- Use proper markdown table syntax with |---|---| alignment\n"
|
|
"- Preserve all column headers and row labels\n"
|
|
"- Do not modify or translate the content\n"
|
|
"- If a table spans multiple pages, note it"
|
|
)
|
|
for idx, img_b64 in enumerate(page_images):
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": prompt},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {"url": f"data:image/png;base64,{img_b64}"},
|
|
},
|
|
],
|
|
}
|
|
]
|
|
try:
|
|
response = await llm_client._client.chat.completions.create(
|
|
model=llm_client.model,
|
|
messages=messages,
|
|
temperature=0.1,
|
|
)
|
|
content = response.choices[0].message.content or ""
|
|
if content.strip():
|
|
results.append(content.strip())
|
|
except Exception:
|
|
logger.warning("Vision table extraction failed for page image %d", idx, exc_info=True)
|
|
return results
|
|
|
|
|
|
_TABLE_HEURISTIC_RE = [
|
|
r"(?:\|[\s\-:]+\|)",
|
|
r"(?:\+[-=]+\+)",
|
|
r"(?:(?:\S+\s{2,}){3,}\n)",
|
|
]
|
|
|
|
_TABLE_REGION_PROMPT = (
|
|
"Convert this raw table text extracted from a PDF into a markdown table.\n"
|
|
"Preserve all data exactly. Detect column boundaries and alignment.\n\n"
|
|
"{table_text}"
|
|
)
|
|
|
|
|
|
async def extract_tables_text(text: str, llm_client) -> List[str]:
|
|
"""Detect table-like text regions, send to LLM for markdown conversion."""
|
|
import re
|
|
|
|
regions: List[str] = []
|
|
lines = text.split("\n")
|
|
current_region: List[str] = []
|
|
in_table = False
|
|
|
|
for line in lines:
|
|
is_table_line = any(re.search(pat, line) for pat in _TABLE_HEURISTIC_RE)
|
|
if is_table_line:
|
|
in_table = True
|
|
current_region.append(line)
|
|
elif in_table and line.strip():
|
|
current_region.append(line)
|
|
else:
|
|
if len(current_region) >= 3:
|
|
regions.append("\n".join(current_region))
|
|
current_region = []
|
|
in_table = False
|
|
|
|
if len(current_region) >= 3:
|
|
regions.append("\n".join(current_region))
|
|
|
|
if not regions:
|
|
return []
|
|
|
|
results: List[str] = []
|
|
for region in regions:
|
|
prompt = _TABLE_REGION_PROMPT.format(table_text=region)
|
|
try:
|
|
response = await llm_client.complete(prompt, temperature=0.1, step_name="TableExtraction")
|
|
if response.strip():
|
|
results.append(response.strip())
|
|
except Exception:
|
|
logger.warning("Text-based table extraction failed", exc_info=True)
|
|
return results
|
|
|
|
|
|
def inject_tables_into_answer(answer: str, tables_md: List[str]) -> str:
|
|
"""Replace raw table text regions in answer with markdown tables."""
|
|
if not tables_md:
|
|
return answer
|
|
result = answer
|
|
for table_md in tables_md:
|
|
lines = table_md.split("\n")
|
|
if not lines:
|
|
continue
|
|
header_line = lines[0]
|
|
if header_line.strip() in result:
|
|
result = result.replace(header_line.strip(), table_md)
|
|
return result
|
|
|
|
|
|
def cache_vision_result(page_hash: str) -> Optional[str]:
|
|
"""Simple disk cache: hash→markdown stored in .cache dir. Returns None on miss."""
|
|
cache_file = _CACHE_DIR / f"{page_hash}.md"
|
|
if cache_file.exists():
|
|
return cache_file.read_text(encoding="utf-8")
|
|
return None
|
|
|
|
|
|
def save_vision_result(page_hash: str, markdown: str) -> None:
|
|
"""Save a vision result to the disk cache."""
|
|
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
cache_file = _CACHE_DIR / f"{page_hash}.md"
|
|
cache_file.write_text(markdown, encoding="utf-8")
|
|
|
|
|
|
def compute_page_hash(page_image_b64: str) -> str:
|
|
"""Compute a hash for a page image for cache key purposes."""
|
|
return hashlib.sha256(page_image_b64.encode("utf-8")).hexdigest()[:16]
|