legco_ai_assistant/backend/app/routers/test_generate.py

105 lines
3.1 KiB
Python

import io
import logging
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from app.core.config import get_settings
from app.models.testing import GenerateTextRequest
from app.services.prompt_service import PromptService
from app.services.test_runner_service import TestRunnerService
from app.services.test_storage_service import TestStorageService
logger = logging.getLogger(__name__)
router = APIRouter(tags=["test"])
def _get_prompt_service() -> PromptService:
settings = get_settings()
return PromptService(db_path=settings.prompts_db_path)
def _get_storage_service() -> TestStorageService:
settings = get_settings()
return TestStorageService(
results_dir=settings.test_results_dir,
evaluations_dir=settings.test_evaluations_dir,
)
@router.post("/test/generate/text")
async def generate_text(request: GenerateTextRequest):
settings = get_settings()
prompt_service = _get_prompt_service()
runner = TestRunnerService(settings)
result = await runner.run_text_test(
question=request.question,
profile=request.profile,
prompt_service=prompt_service,
label=request.label,
)
storage = _get_storage_service()
storage.save_result(result)
return result.model_dump()
@router.post("/test/generate/audio")
async def generate_audio(
audio_file: UploadFile = File(...),
profile: str = Form(...),
reference_transcript: str = Form(""),
label: str = Form(""),
language: str = Form("yue"),
):
if profile not in ("A", "B", "C"):
raise HTTPException(status_code=400, detail="profile must be A, B, or C")
settings = get_settings()
prompt_service = _get_prompt_service()
audio_bytes = await audio_file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail="Audio file is empty")
runner = TestRunnerService(settings)
result = await runner.run_audio_test(
audio_bytes=audio_bytes,
reference_transcript=reference_transcript,
profile=profile,
prompt_service=prompt_service,
language=language,
label=label,
audio_filename=audio_file.filename or "unknown",
)
storage = _get_storage_service()
storage.save_result(result)
return result.model_dump()
@router.get("/test/results")
async def list_results(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
storage = _get_storage_service()
return storage.list_results(limit=limit, offset=offset)
@router.get("/test/results/{result_id}")
async def get_result(result_id: str):
storage = _get_storage_service()
result = storage.load_result(result_id)
if result is None:
raise HTTPException(status_code=404, detail="Result not found")
return result.model_dump()
@router.delete("/test/results/{result_id}")
async def delete_result(result_id: str):
storage = _get_storage_service()
deleted = storage.delete_result(result_id)
if not deleted:
raise HTTPException(status_code=404, detail="Result not found")
return {"status": "deleted", "result_id": result_id}