174 lines
6.6 KiB
Python
174 lines
6.6 KiB
Python
"""Phase 9 tests: Text generation endpoint integration (Sub-Phase 9.1).
|
|
|
|
Covers:
|
|
- POST /api/v1/test/generate/text with valid request returns 200
|
|
- Response includes all pipeline stages (retrieval, filtered, response, timing)
|
|
- Invalid profile returns validation error
|
|
- Empty question returns validation error
|
|
- GET /api/v1/test/results lists saved results
|
|
- GET /api/v1/test/results/{id} retrieves specific result
|
|
- DELETE /api/v1/test/results/{id} deletes result
|
|
"""
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from app.routers.test_generate import router
|
|
|
|
|
|
@pytest.fixture
|
|
def client(tmp_path, monkeypatch):
|
|
results_dir = str(tmp_path / "test_results")
|
|
evals_dir = str(tmp_path / "test_evaluations")
|
|
prompts_path = str(tmp_path / "prompts.db")
|
|
history_path = str(tmp_path / "history.db")
|
|
|
|
monkeypatch.setenv("TEST_RESULTS_DIR", results_dir)
|
|
monkeypatch.setenv("TEST_EVALUATIONS_DIR", evals_dir)
|
|
monkeypatch.setenv("PROMPTS_DB_PATH", prompts_path)
|
|
monkeypatch.setenv("HISTORY_DB_PATH", history_path)
|
|
monkeypatch.setenv("LLM_API_KEY", "test-key")
|
|
monkeypatch.setenv("LLM_BASE_URL", "https://test.example.com/v1")
|
|
monkeypatch.setenv("LLM_MODEL_NAME", "test-model")
|
|
monkeypatch.setenv("DP_API_KEY", "test-dp-key")
|
|
monkeypatch.setenv("EMBEDDING_MODEL", "test-embedding")
|
|
|
|
from app.core.config import get_settings
|
|
get_settings.cache_clear()
|
|
|
|
from app.core.sqlite_db import _get_db, init_prompts_db, init_history_db, seed_default_profiles
|
|
conn = _get_db(prompts_path)
|
|
init_prompts_db(conn)
|
|
seed_default_profiles(conn)
|
|
conn.close()
|
|
|
|
hconn = _get_db(history_path)
|
|
init_history_db(hconn)
|
|
hconn.close()
|
|
|
|
test_app = FastAPI()
|
|
test_app.include_router(router, prefix="/api/v1")
|
|
yield TestClient(test_app)
|
|
|
|
get_settings.cache_clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_pipeline(monkeypatch):
|
|
"""Mock LLM and RAG to avoid real API calls."""
|
|
|
|
async def _mock_decompose(self, question):
|
|
return (["sub question 1", "sub question 2"], "mocked decompose prompt")
|
|
|
|
def _mock_retrieve(self, sub_questions, n_results=10):
|
|
return [
|
|
(sq, [("chunk text content", {"filename": "test.pdf", "upload_date": "2026-01-01",
|
|
"content_summary": "test chunk", "chunk_index": 0,
|
|
"page_number": 1, "document_id": "doc-1"}, 0.15)])
|
|
for sq in sub_questions
|
|
]
|
|
|
|
async def _mock_filter(self, sub_questions, chunks_by_subq, threshold=7.0):
|
|
result = [
|
|
(sq, [(text, {**meta, "relevance_score": 8.5}) for text, meta in chunks])
|
|
for sq, chunks in zip(sub_questions, chunks_by_subq)
|
|
]
|
|
return (result, "mocked filter prompt")
|
|
|
|
async def _mock_generate(self, sub_questions, sub_chunks, sub_metadata):
|
|
answer = "## Sub-question 0: sub question 1\n\n- Test answer with citation [test.pdf, page 1]\n\n## Sub-question 1: sub question 2\n\n- More answer content"
|
|
grouped_sources = [[meta_list[0]] for meta_list in sub_metadata] if sub_metadata else [[] for _ in sub_questions]
|
|
return (answer, "mocked generate prompt", grouped_sources)
|
|
|
|
monkeypatch.setattr("app.services.query_decomposer.QueryDecomposer.decompose", _mock_decompose)
|
|
monkeypatch.setattr("app.services.rag.RAGService.retrieve_per_subquestion", _mock_retrieve)
|
|
monkeypatch.setattr("app.services.relevance_filter.RelevanceFilter.filter_per_subquestion", _mock_filter)
|
|
monkeypatch.setattr("app.services.rag.RAGService.generate_response_per_subquestion", _mock_generate)
|
|
|
|
|
|
@pytest.mark.usefixtures("mock_pipeline")
|
|
class TestGenerateTextEndpoint:
|
|
def test_valid_request_returns_200(self, client):
|
|
resp = client.post("/api/v1/test/generate/text", json={
|
|
"question": "test question",
|
|
"profile": "A",
|
|
"label": "my label",
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["input_type"] == "text"
|
|
assert data["profile"] == "A"
|
|
assert data["label"] == "my label"
|
|
assert "result_id" in data
|
|
|
|
def test_result_contains_all_stages(self, client):
|
|
resp = client.post("/api/v1/test/generate/text", json={
|
|
"question": "test",
|
|
"profile": "B",
|
|
})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data["extracted_key_questions"]) == 2
|
|
assert data["retrieval"]["total_chunks_retrieved"] > 0
|
|
assert data["filtered"]["total_chunks_filtered"] > 0
|
|
assert len(data["response"]["final_answer"]) > 0
|
|
assert data["timing"]["total_time_ms"] >= 0
|
|
|
|
def test_invalid_profile_rejected(self, client):
|
|
resp = client.post("/api/v1/test/generate/text", json={
|
|
"question": "test",
|
|
"profile": "D",
|
|
})
|
|
assert resp.status_code == 422
|
|
|
|
def test_empty_question_rejected(self, client):
|
|
resp = client.post("/api/v1/test/generate/text", json={
|
|
"question": "",
|
|
"profile": "A",
|
|
})
|
|
assert resp.status_code == 422
|
|
|
|
def test_result_saved_and_retrievable(self, client):
|
|
resp = client.post("/api/v1/test/generate/text", json={
|
|
"question": "save test",
|
|
"profile": "A",
|
|
})
|
|
assert resp.status_code == 200
|
|
result_id = resp.json()["result_id"]
|
|
|
|
get_resp = client.get(f"/api/v1/test/results/{result_id}")
|
|
assert get_resp.status_code == 200
|
|
assert get_resp.json()["result_id"] == result_id
|
|
|
|
def test_list_results(self, client):
|
|
client.post("/api/v1/test/generate/text", json={
|
|
"question": "list test 1", "profile": "A",
|
|
})
|
|
client.post("/api/v1/test/generate/text", json={
|
|
"question": "list test 2", "profile": "B",
|
|
})
|
|
|
|
resp = client.get("/api/v1/test/results?limit=10")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert len(data) >= 2
|
|
|
|
def test_get_nonexistent_result(self, client):
|
|
resp = client.get("/api/v1/test/results/no-such-id")
|
|
assert resp.status_code == 404
|
|
|
|
def test_delete_result(self, client):
|
|
resp = client.post("/api/v1/test/generate/text", json={
|
|
"question": "delete test", "profile": "A",
|
|
})
|
|
result_id = resp.json()["result_id"]
|
|
|
|
del_resp = client.delete(f"/api/v1/test/results/{result_id}")
|
|
assert del_resp.status_code == 200
|
|
|
|
get_resp = client.get(f"/api/v1/test/results/{result_id}")
|
|
assert get_resp.status_code == 404
|