legco_ai_assistant/backend/app/test/test_phase9_generate_text.py

174 lines
6.6 KiB
Python

"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1).
Covers:
- POST /api/v1/test/generate/text with valid request returns 200
- Response includes all pipeline stages (retrieval, filtered, response, timing)
- Invalid profile returns validation error
- Empty question returns validation error
- GET /api/v1/test/results lists saved results
- GET /api/v1/test/results/{id} retrieves specific result
- DELETE /api/v1/test/results/{id} deletes result
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.routers.test_generate import router
@pytest.fixture
def client(tmp_path, monkeypatch):
results_dir = str(tmp_path / "test_results")
evals_dir = str(tmp_path / "test_evaluations")
prompts_path = str(tmp_path / "prompts.db")
history_path = str(tmp_path / "history.db")
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
monkeypatch.setenv("LLM_API_KEY", "test-key")
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
from app.core.config import get_settings
get_settings.cache_clear()
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
conn = _get_db(prompts_path)
init_prompts_db(conn)
seed_default_profiles(conn)
conn.close()
hconn = _get_db(history_path)
init_history_db(hconn)
hconn.close()
test_app = FastAPI()
test_app.include_router(router, prefix="/api/v1")
yield TestClient(test_app)
get_settings.cache_clear()
@pytest.fixture
def mock_pipeline(monkeypatch):
"""Mock LLM and RAG to avoid real API calls."""
async def _mock_decompose(self, question):
return (["sub question 1", "sub question 2"], "mocked decompose prompt")
def _mock_retrieve(self, sub_questions, n_results=10):
return [
(sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01",
"content_summary": "test chunk", "chunk_index": 0,
"page_number": 1, "document_id": "doc-1"}, 0.15)])
for sq in sub_questions
]
async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0):
result = [
(sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks])
for sq, chunks in zip(sub_questions, chunks_by_subq)
]
return (result, "mocked filter prompt")
async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata):
answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content"
grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions]
return (answer, "mocked generate prompt", grouped_sources)
monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose)
monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve)
monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter)
monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate)
@pytest.mark.usefixtures("mock_pipeline")
class TestGenerateTextEndpoint:
def test_valid_request_returns_200(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test question",
"profile": "A",
"label": "my label",
})
assert resp.status_code == 200
data = resp.json()
assert data["input_type"] == "text"
assert data["profile"] == "A"
assert data["label"] == "my label"
assert "result_id" in data
def test_result_contains_all_stages(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test",
"profile": "B",
})
assert resp.status_code == 200
data = resp.json()
assert len(data["extracted_key_questions"]) == 2
assert data["retrieval"]["total_chunks_retrieved"] > 0
assert data["filtered"]["total_chunks_filtered"] > 0
assert len(data["response"]["final_answer"]) > 0
assert data["timing"]["total_time_ms"] >= 0
def test_invalid_profile_rejected(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "test",
"profile": "D",
})
assert resp.status_code == 422
def test_empty_question_rejected(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "",
"profile": "A",
})
assert resp.status_code == 422
def test_result_saved_and_retrievable(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "save test",
"profile": "A",
})
assert resp.status_code == 200
result_id = resp.json()["result_id"]
get_resp = client.get(f"/api/v1/test/results/{result_id}")
assert get_resp.status_code == 200
assert get_resp.json()["result_id"] == result_id
def test_list_results(self, client):
client.post("/api/v1/test/generate/text", json={
"question": "list test 1", "profile": "A",
})
client.post("/api/v1/test/generate/text", json={
"question": "list test 2", "profile": "B",
})
resp = client.get("/api/v1/test/results?limit=10")
assert resp.status_code == 200
data = resp.json()
assert len(data) >= 2
def test_get_nonexistent_result(self, client):
resp = client.get("/api/v1/test/results/no-such-id")
assert resp.status_code == 404
def test_delete_result(self, client):
resp = client.post("/api/v1/test/generate/text", json={
"question": "delete test", "profile": "A",
})
result_id = resp.json()["result_id"]
del_resp = client.delete(f"/api/v1/test/results/{result_id}")
assert del_resp.status_code == 200
get_resp = client.get(f"/api/v1/test/results/{result_id}")
assert get_resp.status_code == 404