legco_ai_assistant/backend/app/test/test_phase2_ws_protocol.py

153 lines
5.1 KiB
Python

"""Phase 2 tests: DashScope WebSocket protocol — callback bridge and event formatting.
Covers:
- DashScopeCallback sync→async queue bridge
- Transcription text event formatting (partial)
- Transcription completed event formatting (final)
"""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
class TestDashScopeCallback:
def test_puts_events_on_queue(self):
"""on_event should push parsed JSON onto the asyncio queue."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
test_event = json.dumps({"type": "test", "data": "hello"})
callback.on_event(test_event)
# Give the loop a chance to process call_soon_threadsafe
loop.run_until_complete(asyncio.sleep(0.05))
assert not queue.empty()
event = queue.get_nowait()
assert event["type"] == "test"
assert event["data"] == "hello"
loop.close()
def test_handles_dict_event(self):
"""on_event should accept dict as well as string."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
callback.on_event({"type": "test_dict", "key": "value"})
loop.run_until_complete(asyncio.sleep(0.05))
assert not queue.empty()
event = queue.get_nowait()
assert event["type"] == "test_dict"
loop.close()
def test_handles_invalid_json_gracefully(self):
"""on_event should not crash on invalid JSON."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
# Should not raise
callback.on_event("not-valid-json{{{")
loop.close()
def test_on_open_and_close(self):
"""on_open and on_close should not crash."""
from app.routers.ws_asr import DashScopeCallback
queue: asyncio.Queue = asyncio.Queue()
loop = asyncio.new_event_loop()
callback = DashScopeCallback(queue, loop)
callback.on_open()
callback.on_close(1000, "normal")
loop.close()
class TestProxyFormatsTranscriptionTextEvent:
def test_partial_event_format(self):
"""Partial transcription event should format as ASRTranscriptEvent with is_final=False."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.text",
"stash": "你好",
"language": "yue",
}
accumulated = ""
result = format_transcription_event(event, accumulated)
assert result is not None
assert result["is_final"] is False
assert result["language"] == "yue"
assert result["delta"] == ""
assert "你好" in result["full_text"]
def test_partial_with_accumulated(self):
"""Partial event should combine accumulated + current stash."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.text",
"stash": "世界",
"language": "yue",
}
accumulated = "你好"
result = format_transcription_event(event, accumulated)
assert "你好" in result["full_text"]
assert "世界" in result["full_text"]
class TestProxyFormatsTranscriptionCompletedEvent:
def test_completed_event_format(self):
"""Completed event should format as ASRTranscriptEvent with is_final=True."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.completed",
"transcript": "你好世界",
"language": "yue",
}
accumulated = ""
result = format_transcription_event(event, accumulated)
assert result is not None
assert result["is_final"] is True
assert result["language"] == "yue"
assert "你好" in result["full_text"]
def test_completed_updates_accumulated(self):
"""Completed event should return updated accumulated text."""
from app.routers.ws_asr import format_transcription_event
event = {
"type": "conversation.item.input_audio_transcription.completed",
"transcript": "世界",
"language": "yue",
}
accumulated = "你好"
result = format_transcription_event(event, accumulated)
assert "你好" in result["full_text"]
assert "世界" in result["full_text"]
def test_unknown_event_type_returns_none(self):
"""Unknown event types should return None."""
from app.routers.ws_asr import format_transcription_event
event = {"type": "unknown.event"}
result = format_transcription_event(event, "")
assert result is None