"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1). Covers: - POST /api/v1/test/generate/text with valid request returns 200 - Response includes all pipeline stages (retrieval, filtered, response, timing) - Invalid profile returns validation error - Empty question returns validation error - GET /api/v1/test/results lists saved results - GET /api/v1/test/results/{id} retrieves specific result - DELETE /api/v1/test/results/{id} deletes result """ import json from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from app.routers.test_generate import router @pytest.fixture def client(tmp_path, monkeypatch): results_dir = str(tmp_path / "test_results") evals_dir = str(tmp_path / "test_evaluations") prompts_path = str(tmp_path / "prompts.db") history_path = str(tmp_path / "history.db") monkeypatch.setenv("TEST_RESULTS_DIR", results_dir) monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir) monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path) monkeypatch.setenv("HISTORY_DB_PATH", history_path) monkeypatch.setenv("LLM_API_KEY", "test-key") monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1") monkeypatch.setenv("LLM_MODEL_NAME", "test-model") monkeypatch.setenv("DP_API_KEY", "test-dp-key") monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding") from app.core.config import get_settings get_settings.cache_clear() from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles conn = _get_db(prompts_path) init_prompts_db(conn) seed_default_profiles(conn) conn.close() hconn = _get_db(history_path) init_history_db(hconn) hconn.close() test_app = FastAPI() test_app.include_router(router, prefix="/api/v1") yield TestClient(test_app) get_settings.cache_clear() @pytest.fixture def mock_pipeline(monkeypatch): """Mock LLM and RAG to avoid real API calls.""" async def _mock_decompose(self, question): return (["sub question 1", "sub question 2"], "mocked decompose prompt") def _mock_retrieve(self, sub_questions, n_results=10): return [ (sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01", "content_summary": "test chunk", "chunk_index": 0, "page_number": 1, "document_id": "doc-1"}, 0.15)]) for sq in sub_questions ] async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0): result = [ (sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks]) for sq, chunks in zip(sub_questions, chunks_by_subq) ] return (result, "mocked filter prompt") async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata): answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content" grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions] return (answer, "mocked generate prompt", grouped_sources) monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose) monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve) monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter) monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate) @pytest.mark.usefixtures("mock_pipeline") class TestGenerateTextEndpoint: def test_valid_request_returns_200(self, client): resp = client.post("/api/v1/test/generate/text", json={ "question": "test question", "profile": "A", "label": "my label", }) assert resp.status_code == 200 data = resp.json() assert data["input_type"] == "text" assert data["profile"] == "A" assert data["label"] == "my label" assert "result_id" in data def test_result_contains_all_stages(self, client): resp = client.post("/api/v1/test/generate/text", json={ "question": "test", "profile": "B", }) assert resp.status_code == 200 data = resp.json() assert len(data["extracted_key_questions"]) == 2 assert data["retrieval"]["total_chunks_retrieved"] > 0 assert data["filtered"]["total_chunks_filtered"] > 0 assert len(data["response"]["final_answer"]) > 0 assert data["timing"]["total_time_ms"] >= 0 def test_invalid_profile_rejected(self, client): resp = client.post("/api/v1/test/generate/text", json={ "question": "test", "profile": "D", }) assert resp.status_code == 422 def test_empty_question_rejected(self, client): resp = client.post("/api/v1/test/generate/text", json={ "question": "", "profile": "A", }) assert resp.status_code == 422 def test_result_saved_and_retrievable(self, client): resp = client.post("/api/v1/test/generate/text", json={ "question": "save test", "profile": "A", }) assert resp.status_code == 200 result_id = resp.json()["result_id"] get_resp = client.get(f"/api/v1/test/results/{result_id}") assert get_resp.status_code == 200 assert get_resp.json()["result_id"] == result_id def test_list_results(self, client): client.post("/api/v1/test/generate/text", json={ "question": "list test 1", "profile": "A", }) client.post("/api/v1/test/generate/text", json={ "question": "list test 2", "profile": "B", }) resp = client.get("/api/v1/test/results?limit=10") assert resp.status_code == 200 data = resp.json() assert len(data) >= 2 def test_get_nonexistent_result(self, client): resp = client.get("/api/v1/test/results/no-such-id") assert resp.status_code == 404 def test_delete_result(self, client): resp = client.post("/api/v1/test/generate/text", json={ "question": "delete test", "profile": "A", }) result_id = resp.json()["result_id"] del_resp = client.delete(f"/api/v1/test/results/{result_id}") assert del_resp.status_code == 200 get_resp = client.get(f"/api/v1/test/results/{result_id}") assert get_resp.status_code == 404