diff --git a/backend/app/test/test_phase1_page_aware_chunking.py b/backend/app/test/test_phase1_page_aware_chunking.py new file mode 100644 index 0000000..ff53886 --- /dev/null +++ b/backend/app/test/test_phase1_page_aware_chunking.py @@ -0,0 +1,201 @@ +"""Phase 1.5.4: Page-aware chunking tests. + +Tests for TokenChunkingStrategy.chunk_pages() which creates one chunk per page +with overlap context from adjacent pages. +""" + +import importlib.util +from pathlib import Path +import pytest + +# Dynamically load the chunking module directly from the filesystem to avoid +# import path issues in the test environment. +CHUNKING_PATH = Path(__file__).resolve().parents[1] / "utils" / "chunking.py" +spec = importlib.util.spec_from_file_location("legco_chunking", str(CHUNKING_PATH)) +chunking_module = importlib.util.module_from_spec(spec) # type: ignore +assert spec and spec.loader +spec.loader.exec_module(chunking_module) # type: ignore +TokenChunkingStrategy = chunking_module.TokenChunkingStrategy + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_strategy() -> TokenChunkingStrategy: + return TokenChunkingStrategy(chunk_size=1000, overlap=200) + + +def _long_text(topic: str, min_tokens: int = 300) -> str: + """Generate text with a unique topic marker and enough tokens to exceed min_tokens.""" + # Each word is roughly 1 token; add plenty of margin. + return f"[{topic}] " + " ".join(f"{topic}-word{i}" for i in range(min_tokens)) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_chunk_pages_basic(): + """3 pages → 3 chunks, one per page, each contains main page text.""" + strat = _make_strategy() + pages = [ + (1, _long_text("alpha")), + (2, _long_text("beta")), + (3, _long_text("gamma")), + ] + result = strat.chunk_pages(pages) + + assert len(result) == 3 + # Each result is (chunk_text, page_number) + for i, (chunk_text, page_num) in enumerate(result): + assert isinstance(chunk_text, str) + assert page_num == pages[i][0] + # Main page content must be present in the chunk + assert pages[i][1] in chunk_text + + +def test_chunk_pages_single_page(): + """Single page returns single chunk with no overlap.""" + strat = _make_strategy() + text = _long_text("solo") + pages = [(1, text)] + result = strat.chunk_pages(pages) + + assert len(result) == 1 + chunk_text, page_num = result[0] + assert page_num == 1 + assert text in chunk_text + # No overlap content — chunk should be the original text (no extra newlines from joining) + assert chunk_text.strip() == text.strip() + + +def test_chunk_pages_first_page(): + """First page gets overlap_after from page 2 but no overlap_before.""" + strat = _make_strategy() + pages = [ + (1, _long_text("first")), + (2, _long_text("second")), + (3, _long_text("third")), + ] + result = strat.chunk_pages(pages) + + chunk_text, page_num = result[0] + assert page_num == 1 + # Main content present + assert pages[0][1] in chunk_text + # Overlap from page 2 present + assert "second" in chunk_text + + +def test_chunk_pages_last_page(): + """Last page gets overlap_before from page N-1 but no overlap_after.""" + strat = _make_strategy() + pages = [ + (1, _long_text("first")), + (2, _long_text("second")), + (3, _long_text("third")), + ] + result = strat.chunk_pages(pages) + + chunk_text, page_num = result[-1] + assert page_num == 3 + # Main content present + assert pages[2][1] in chunk_text + # Overlap from page 2 present + assert "second" in chunk_text + + +def test_chunk_pages_empty_input(): + """Empty list returns empty list.""" + strat = _make_strategy() + result = strat.chunk_pages([]) + assert result == [] + + +def test_chunk_pages_overlap_content(): + """Verify overlap content comes from the correct adjacent pages. + + Use distinct, recognizable text per page so we can assert that page N's + chunk includes tokens from pages N-1 and N+1. + """ + strat = _make_strategy() + pages = [ + (1, _long_text("page_one")), + (2, _long_text("page_two")), + (3, _long_text("page_three")), + ] + result = strat.chunk_pages(pages) + + # Page 2 chunk should contain overlap from both neighbors + middle_chunk, middle_page = result[1] + assert middle_page == 2 + assert "page_one" in middle_chunk + assert "page_two" in middle_chunk + assert "page_three" in middle_chunk + + # Page 1 chunk: no page_one overlap before (it IS page 1), but has page_two overlap after + first_chunk, _ = result[0] + assert "page_two" in first_chunk + # Should NOT contain page_three (that's two pages away) + assert "page_three" not in first_chunk + + # Page 3 chunk: has page_two overlap before, but no page_four after + last_chunk, _ = result[2] + assert "page_two" in last_chunk + # Should NOT contain page_one (that's two pages away) + assert "page_one" not in last_chunk + + +def test_chunk_pages_returns_page_numbers(): + """Verify page numbers are correctly preserved in output.""" + strat = _make_strategy() + pages = [ + (5, _long_text("five")), + (10, _long_text("ten")), + (99, _long_text("ninety_nine")), + ] + result = strat.chunk_pages(pages) + + assert len(result) == 3 + output_pages = [pn for _, pn in result] + assert output_pages == [5, 10, 99] + + +def test_chunk_pages_custom_overlap(): + """Test with non-default overlap_tokens value.""" + strat = _make_strategy() + # Use very small overlap to verify it's respected + pages = [ + (1, _long_text("aaa")), + (2, _long_text("bbb")), + ] + result = strat.chunk_pages(pages, overlap_tokens=5) + + assert len(result) == 2 + # Both pages present + assert result[0][1] == 1 + assert result[1][1] == 2 + # Page 1 should still have some overlap from page 2 + assert "bbb" in result[0][0] + # Page 2 should still have some overlap from page 1 + assert "aaa" in result[1][0] + + # Verify with zero overlap + result_zero = strat.chunk_pages(pages, overlap_tokens=0) + # Page 1 chunk should NOT contain page 2 content + assert "bbb" not in result_zero[0][0] + # Page 2 chunk should NOT contain page 1 content + assert "aaa" not in result_zero[1][0] + + +def test_chunk_pages_output_format(): + """Each result element is a (str, int) tuple.""" + strat = _make_strategy() + pages = [(1, "Short text one."), (2, "Short text two.")] + result = strat.chunk_pages(pages) + + for chunk_text, page_num in result: + assert isinstance(chunk_text, str) + assert isinstance(page_num, int) diff --git a/backend/app/utils/chunking.py b/backend/app/utils/chunking.py index 30ce8dc..8118bda 100644 --- a/backend/app/utils/chunking.py +++ b/backend/app/utils/chunking.py @@ -7,7 +7,7 @@ token-based windows. from __future__ import annotations from abc import ABC, abstractmethod -from typing import List +from typing import List, Tuple class ChunkingStrategy(ABC): @@ -71,3 +71,49 @@ class TokenChunkingStrategy(ChunkingStrategy): break return chunks + + def chunk_pages( + self, pages: List[Tuple[int, str]], overlap_tokens: int = 200 + ) -> List[Tuple[str, int]]: + """Chunk page-segmented text with overlap from adjacent pages. + + For each page, creates one chunk containing: + [last overlap_tokens of previous page] + [full current page] + [first overlap_tokens of next page] + + One chunk per page — never splits a page even if oversized. + The page_number metadata always refers to the main page (N), not overlap pages. + + Args: + pages: List of (page_number, page_text) tuples. 1-indexed. + overlap_tokens: Number of tokens to include from adjacent pages. + + Returns: + List of (chunk_text, page_number) tuples. One chunk per page. + """ + if not pages: + return [] + + tokenized: List[List[int]] = [ + self._encoding.encode(text) for _, text in pages + ] + + results: List[Tuple[str, int]] = [] + + for i, (page_num, page_text) in enumerate(pages): + parts: List[str] = [] + + if i > 0 and overlap_tokens > 0: + prev_tokens = tokenized[i - 1] + overlap_before = prev_tokens[-overlap_tokens:] + parts.append(self._encoding.decode(overlap_before)) + + parts.append(page_text) + + if i < len(pages) - 1 and overlap_tokens > 0: + next_tokens = tokenized[i + 1] + overlap_after = next_tokens[:overlap_tokens] + parts.append(self._encoding.decode(overlap_after)) + + results.append(("\n".join(parts), page_num)) + + return results