201 lines
6.8 KiB
Python
201 lines
6.8 KiB
Python
"""Phase 2 tests: ASR client utilities and batch transcription.
|
|
|
|
Covers:
|
|
- float32_to_s16le() conversion correctness
|
|
- build_display_text() multi-utterance assembly
|
|
- _to_traditional() simplified→traditional Chinese conversion
|
|
- ASRClient.transcribe_full() batch transcription (mocked OpenAI client)
|
|
"""
|
|
import struct
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestFloat32ToS16Le:
|
|
def test_converts_silence(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
samples = struct.pack("<4f", 0.0, 0.0, 0.0, 0.0)
|
|
result = float32_to_s16le(samples)
|
|
expected = struct.pack("<4h", 0, 0, 0, 0)
|
|
assert result == expected
|
|
|
|
def test_converts_positive_peak(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
samples = struct.pack("<1f", 1.0)
|
|
result = float32_to_s16le(samples)
|
|
expected = struct.pack("<1h", 32767)
|
|
assert result == expected
|
|
|
|
def test_converts_negative_peak(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
samples = struct.pack("<1f", -1.0)
|
|
result = float32_to_s16le(samples)
|
|
expected = struct.pack("<1h", max(-32768, min(32767, int(-1.0 * 32767.0))))
|
|
assert result == expected
|
|
|
|
def test_clips_overflow(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
samples = struct.pack("<1f", 2.0)
|
|
result = float32_to_s16le(samples)
|
|
expected = struct.pack("<1h", 32767)
|
|
assert result == expected
|
|
|
|
def test_clips_underflow(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
samples = struct.pack("<1f", -2.0)
|
|
result = float32_to_s16le(samples)
|
|
expected = struct.pack("<1h", -32768)
|
|
assert result == expected
|
|
|
|
def test_multiple_samples(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
floats = [0.5, -0.5, 0.25, -0.25]
|
|
samples = struct.pack(f"<{len(floats)}f", *floats)
|
|
result = float32_to_s16le(samples)
|
|
expected_ints = [int(f * 32767) for f in floats]
|
|
expected = struct.pack(f"<{len(expected_ints)}h", *expected_ints)
|
|
assert result == expected
|
|
|
|
def test_empty_input(self):
|
|
from app.services.asr_client import float32_to_s16le
|
|
|
|
assert float32_to_s16le(b"") == b""
|
|
|
|
|
|
class TestBuildDisplayText:
|
|
def test_both_parts_present(self):
|
|
from app.services.asr_client import build_display_text
|
|
|
|
assert build_display_text("hello", "world") == "hello world"
|
|
|
|
def test_empty_accumulated(self):
|
|
from app.services.asr_client import build_display_text
|
|
|
|
assert build_display_text("", "world") == "world"
|
|
|
|
def test_empty_current(self):
|
|
from app.services.asr_client import build_display_text
|
|
|
|
assert build_display_text("hello", "") == "hello"
|
|
|
|
def test_both_empty(self):
|
|
from app.services.asr_client import build_display_text
|
|
|
|
assert build_display_text("", "") == ""
|
|
|
|
def test_whitespace_only_ignored(self):
|
|
from app.services.asr_client import build_display_text
|
|
|
|
assert build_display_text("hello", " ") == "hello"
|
|
|
|
|
|
class TestToTraditional:
|
|
def test_converts_simplified(self):
|
|
from app.services.asr_client import _to_traditional
|
|
|
|
result = _to_traditional("中国")
|
|
assert "國" in result
|
|
|
|
def test_empty_string(self):
|
|
from app.services.asr_client import _to_traditional
|
|
|
|
assert _to_traditional("") == ""
|
|
|
|
def test_already_traditional(self):
|
|
from app.services.asr_client import _to_traditional
|
|
|
|
text = "測試"
|
|
assert _to_traditional(text) == text
|
|
|
|
def test_mixed_text(self):
|
|
from app.services.asr_client import _to_traditional
|
|
|
|
result = _to_traditional("hello 中国 world")
|
|
assert "國" in result
|
|
assert "hello" in result
|
|
|
|
|
|
class TestTranscribeFull:
|
|
@pytest.mark.asyncio
|
|
async def test_returns_traditional_chinese_text(self, monkeypatch):
|
|
from app.services.asr_client import ASRClient
|
|
|
|
settings = MagicMock()
|
|
settings.asr_provider = "dashscope"
|
|
settings.dashscope_api_key = "sk-test-key"
|
|
settings.asr_model_name = "qwen3-asr-flash"
|
|
|
|
client = ASRClient(settings)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.choices = [MagicMock()]
|
|
mock_resp.choices[0].message.content = "测试结果"
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create.return_value = mock_resp
|
|
|
|
with patch("app.services.asr_providers.OpenAI", return_value=mock_openai_client):
|
|
result = await client.transcribe_full(b"fake-audio-bytes", language="yue")
|
|
|
|
assert result == "測試結果"
|
|
mock_openai_client.chat.completions.create.assert_called_once()
|
|
call_kwargs = mock_openai_client.chat.completions.create.call_args
|
|
assert call_kwargs.kwargs["model"] == "qwen3-asr-flash"
|
|
assert call_kwargs.kwargs["extra_body"]["asr_options"]["language"] == "yue"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_uses_correct_api_endpoint(self, monkeypatch):
|
|
from app.services.asr_client import ASRClient
|
|
|
|
settings = MagicMock()
|
|
settings.asr_provider = "dashscope"
|
|
settings.dashscope_api_key = "sk-test-key"
|
|
settings.asr_model_name = "qwen3-asr-flash"
|
|
|
|
client = ASRClient(settings)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.choices = [MagicMock()]
|
|
mock_resp.choices[0].message.content = "text"
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create.return_value = mock_resp
|
|
|
|
with patch("app.services.asr_providers.OpenAI", return_value=mock_openai_client) as mock_openai_cls:
|
|
await client.transcribe_full(b"audio", language="yue")
|
|
mock_openai_cls.assert_called_once_with(
|
|
api_key="sk-test-key",
|
|
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auto_language_omits_language_param(self, monkeypatch):
|
|
from app.services.asr_client import ASRClient
|
|
|
|
settings = MagicMock()
|
|
settings.asr_provider = "dashscope"
|
|
settings.dashscope_api_key = "sk-test-key"
|
|
settings.asr_model_name = "qwen3-asr-flash"
|
|
|
|
client = ASRClient(settings)
|
|
|
|
mock_resp = MagicMock()
|
|
mock_resp.choices = [MagicMock()]
|
|
mock_resp.choices[0].message.content = "text"
|
|
|
|
mock_openai_client = MagicMock()
|
|
mock_openai_client.chat.completions.create.return_value = mock_resp
|
|
|
|
with patch("app.services.asr_providers.OpenAI", return_value=mock_openai_client):
|
|
await client.transcribe_full(b"audio", language="auto")
|
|
|
|
call_kwargs = mock_openai_client.chat.completions.create.call_args
|
|
assert call_kwargs.kwargs.get("extra_body") is None
|