legco_ai_assistant/backend/app/test/test_phase2_ws_protocol.py

189 lines
6.3 KiB
Python

"""Phase 2 tests: DashScope WebSocket protocol — callback bridge and event formatting.
Covers:
- DashScopeCallback sync→async queue bridge
- Transcription text event formatting (partial)
- Transcription completed event formatting (final)
"""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
class TestDashScopeCallback:
def test_puts_events_on_queue(self):
"""on_event should push parsed JSON onto the asyncio queue."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
test_event = json.dumps({"type": "test", "data": "hello"})
callback.on_event(test_event)
# Give the loop a chance to process call_soon_threadsafe
loop.run_until_complete(asyncio.sleep(0.05))
assert not queue.empty()
event = queue.get_nowait()
assert event["type"] == "test"
assert event["data"] == "hello"
loop.close()
def test_handles_dict_event(self):
"""on_event should accept dict as well as string."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
callback.on_event({"type": "test_dict", "key": "value"})
loop.run_until_complete(asyncio.sleep(0.05))
assert not queue.empty()
event = queue.get_nowait()
assert event["type"] == "test_dict"
loop.close()
def test_handles_invalid_json_gracefully(self):
"""on_event should not crash on invalid JSON."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
# Should not raise
callback.on_event("not-valid-json{{{")
loop.close()
def test_on_open_and_close(self):
"""on_open and on_close should not crash."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
callback.on_open()
callback.on_close(1000, "normal")
loop.close()
class TestMergeStash:
def test_merge_empty_buffer_returns_stash(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("", "你好") == "你好"
def test_merge_overlapping_suffix(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("系多謝主席", "主席咁咧呢個") == "系多謝主席咁咧呢個"
def test_merge_overlapping_single_char(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("abcde", "efgh") == "abcdefgh"
def test_merge_no_overlap_appends_with_space(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("你好", "世界") == "你好 世界"
def test_merge_stash_subset_of_buffer(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("系多謝主席咁咧", "咧呢") == "系多謝主席咁咧呢"
def test_merge_empty_stash_preserves_buffer(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("你好", "") == "你好"
assert _merge_stash("", "") == ""
def test_merge_whitespace_only_stash_preserves_buffer(self):
from app.routers.ws_asr import _merge_stash
assert _merge_stash("你好", " ") == "你好"
class TestProxyFormatsTranscriptionTextEvent:
def test_partial_event_returns_text_and_stash_fields(self):
"""Partial event returns both text (stable prefix) and stash (trailing)."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.text",
"text": "多謝主席",
"stash": "席咁啊",
"language": "yue",
}
result = format_transcription_event(event, "")
assert result is not None
assert result["is_final"] is False
assert result["language"] == "yue"
assert result["delta"] == ""
assert result["text"] == "多謝主席"
assert result["stash"] == "席咁啊"
def test_partial_event_ignores_accumulated(self):
"""Partial event returns fields unchanged regardless of accumulated."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.text",
"text": "世界",
"stash": "界大同",
"language": "yue",
}
result = format_transcription_event(event, "你好")
assert result["text"] == "世界"
assert result["stash"] == "界大同"
class TestProxyFormatsTranscriptionCompletedEvent:
def test_completed_event_format(self):
"""Completed event should format as ASRTranscriptEvent with is_final=True."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.completed",
"transcript": "你好世界",
"language": "yue",
}
result = format_transcription_event(event, "")
assert result is not None
assert result["is_final"] is True
assert result["language"] == "yue"
assert "你好" in result["full_text"]
def test_completed_updates_accumulated(self):
"""Completed event appends transcript to accumulated text."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.completed",
"transcript": "世界",
"language": "yue",
}
result = format_transcription_event(event, "你好")
assert "你好" in result["full_text"]
assert "世界" in result["full_text"]
def test_unknown_event_type_returns_none(self):
"""Unknown event types should return None."""
from app.routers.ws_asr import format_transcription_event
result = format_transcription_event({"type": "unknown.event"}, "")
assert result is None