105 lines
3.1 KiB
Python
105 lines
3.1 KiB
Python
import io
|
|
import logging
|
|
|
|
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
|
|
|
from app.core.config import get_settings
|
|
from app.models.testing import GenerateTextRequest
|
|
from app.services.prompt_service import PromptService
|
|
from app.services.test_runner_service import TestRunnerService
|
|
from app.services.test_storage_service import TestStorageService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(tags=["test"])
|
|
|
|
|
|
def _get_prompt_service() -> PromptService:
|
|
settings = get_settings()
|
|
return PromptService(db_path=settings.prompts_db_path)
|
|
|
|
|
|
def _get_storage_service() -> TestStorageService:
|
|
settings = get_settings()
|
|
return TestStorageService(
|
|
results_dir=settings.test_results_dir,
|
|
evaluations_dir=settings.test_evaluations_dir,
|
|
)
|
|
|
|
|
|
@router.post("/test/generate/text")
|
|
async def generate_text(request: GenerateTextRequest):
|
|
settings = get_settings()
|
|
prompt_service = _get_prompt_service()
|
|
|
|
runner = TestRunnerService(settings)
|
|
result = await runner.run_text_test(
|
|
question=request.question,
|
|
profile=request.profile,
|
|
prompt_service=prompt_service,
|
|
label=request.label,
|
|
)
|
|
|
|
storage = _get_storage_service()
|
|
storage.save_result(result)
|
|
|
|
return result.model_dump()
|
|
|
|
|
|
@router.post("/test/generate/audio")
|
|
async def generate_audio(
|
|
audio_file: UploadFile = File(...),
|
|
profile: str = Form(...),
|
|
reference_transcript: str = Form(""),
|
|
label: str = Form(""),
|
|
language: str = Form("yue"),
|
|
):
|
|
if profile not in ("A", "B", "C"):
|
|
raise HTTPException(status_code=400, detail="profile must be A, B, or C")
|
|
|
|
settings = get_settings()
|
|
prompt_service = _get_prompt_service()
|
|
|
|
audio_bytes = await audio_file.read()
|
|
if not audio_bytes:
|
|
raise HTTPException(status_code=400, detail="Audio file is empty")
|
|
|
|
runner = TestRunnerService(settings)
|
|
result = await runner.run_audio_test(
|
|
audio_bytes=audio_bytes,
|
|
reference_transcript=reference_transcript,
|
|
profile=profile,
|
|
prompt_service=prompt_service,
|
|
language=language,
|
|
label=label,
|
|
audio_filename=audio_file.filename or "unknown",
|
|
)
|
|
|
|
storage = _get_storage_service()
|
|
storage.save_result(result)
|
|
|
|
return result.model_dump()
|
|
|
|
|
|
@router.get("/test/results")
|
|
async def list_results(limit: int = Query(50, ge=1, le=200), offset: int = Query(0, ge=0)):
|
|
storage = _get_storage_service()
|
|
return storage.list_results(limit=limit, offset=offset)
|
|
|
|
|
|
@router.get("/test/results/{result_id}")
|
|
async def get_result(result_id: str):
|
|
storage = _get_storage_service()
|
|
result = storage.load_result(result_id)
|
|
if result is None:
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
return result.model_dump()
|
|
|
|
|
|
@router.delete("/test/results/{result_id}")
|
|
async def delete_result(result_id: str):
|
|
storage = _get_storage_service()
|
|
deleted = storage.delete_result(result_id)
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Result not found")
|
|
return {"status": "deleted", "result_id": result_id}
|