From a4e067822b68122175c394e4c4e0b48d4104042c Mon Sep 17 00:00:00 2001 From: Woody Date: Wed, 6 May 2026 13:41:24 +0800 Subject: [PATCH] feat: Phase 2.3 ASR proxy + full transcript and 2.4 frontend hooks - Backend: DashScope WebSocket proxy (/ws/asr/{video_id}), DashScopeCallback sync-to-async bridge, ffmpeg audio extraction, POST /video/{id}/transcribe - Frontend: useVideoASR hook (auto on play), useFullTranscript hook, QueryInput partialText prop, VideoUploadResponse types, uploadVideo API - Tests: 41 backend + 26 frontend = 67 new tests, all passing --- backend/app/routers/video.py | 35 ++- backend/app/routers/ws_asr.py | 148 ++++++++++++- backend/app/services/asr_client.py | 38 +++- backend/app/services/video_service.py | 30 +++ backend/app/test/conftest.py | 3 + backend/app/test/test_phase2_asr_client.py | 201 +++++++++++++++-- .../app/test/test_phase2_full_transcript.py | 192 +++++++++++++++++ backend/app/test/test_phase2_ws_asr.py | 83 +++++-- backend/app/test/test_phase2_ws_protocol.py | 152 +++++++++++++ frontend/src/components/QueryInput.tsx | 18 +- frontend/src/hooks/useFullTranscript.ts | 35 +++ frontend/src/hooks/useVideoASR.ts | 136 ++++++++++++ frontend/src/lib/api.ts | 23 +- frontend/src/lib/queries.tsx | 10 +- ...est_phase2_QueryInput_integration.test.tsx | 204 ++++++++++++++++++ .../test_phase2_useFullTranscript.test.ts | 201 +++++++++++++++++ .../src/test/test_phase2_useVideoASR.test.ts | 156 ++++++++++++++ frontend/src/types/index.ts | 24 +++ 18 files changed, 1641 insertions(+), 48 deletions(-) create mode 100644 backend/app/test/test_phase2_full_transcript.py create mode 100644 backend/app/test/test_phase2_ws_protocol.py create mode 100644 frontend/src/hooks/useFullTranscript.ts create mode 100644 frontend/src/hooks/useVideoASR.ts create mode 100644 frontend/src/test/test_phase2_QueryInput_integration.test.tsx create mode 100644 frontend/src/test/test_phase2_useFullTranscript.test.ts create mode 100644 frontend/src/test/test_phase2_useVideoASR.test.ts diff --git a/backend/app/routers/video.py b/backend/app/routers/video.py index c1cf050..0d91a6d 100644 --- a/backend/app/routers/video.py +++ b/backend/app/routers/video.py @@ -6,8 +6,9 @@ import aiofiles from fastapi import APIRouter, UploadFile, File, HTTPException from fastapi.responses import FileResponse -from app.models.video import VideoUploadResponse +from app.models.video import VideoUploadResponse, FullTranscriptResponse from app.services.video_service import VideoService +from app.services.asr_client import ASRClient logger = logging.getLogger(__name__) router = APIRouter(tags=["video"]) @@ -75,3 +76,35 @@ async def serve_video(video_id: str): ".mkv": "video/x-matroska", } return FileResponse(str(video_path), media_type=media_types.get(ext, "application/octet-stream")) + + +@router.post("/video/{video_id}/transcribe", response_model=FullTranscriptResponse) +async def transcribe_video(video_id: str, language: str = "yue"): + from app.core.config import get_settings + settings = get_settings() + + if not settings.dashscope_api_key: + raise HTTPException( + status_code=500, + detail="DASHSCOPE_API_KEY is not configured. Set it in .env to enable transcription.", + ) + + service = _get_video_service() + wav_path = await service.extract_audio(video_id) + + try: + audio_bytes = wav_path.read_bytes() + asr = ASRClient(settings) + text = asr.transcribe_full(audio_bytes, language=language) + except Exception as e: + logger.error("Transcription failed for video_id=%s: %s", video_id, e) + raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}") + finally: + if wav_path.exists(): + wav_path.unlink(missing_ok=True) + + return FullTranscriptResponse( + text=text, + language=language, + duration_seconds=None, + ) diff --git a/backend/app/routers/ws_asr.py b/backend/app/routers/ws_asr.py index e0a729a..1a5d7c3 100644 --- a/backend/app/routers/ws_asr.py +++ b/backend/app/routers/ws_asr.py @@ -1,16 +1,156 @@ +import json +import asyncio +import base64 import logging from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from app.core.config import get_settings +from app.services.asr_client import float32_to_s16le, build_display_text, _to_traditional + logger = logging.getLogger(__name__) router = APIRouter(tags=["asr"]) +try: + from dashscope.audio.qwen_omni.omni_realtime import ( + OmniRealtimeConversation, + OmniRealtimeCallback, + TranscriptionParams, + MultiModality, + ) +except ImportError: + OmniRealtimeConversation = None + OmniRealtimeCallback = object + TranscriptionParams = None + MultiModality = None + + +class DashScopeCallback(OmniRealtimeCallback): + def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop): + super().__init__() + self._queue = event_queue + self._loop = loop + + def on_open(self): + logger.info("DashScope realtime connection opened") + + def on_event(self, message): + try: + event = json.loads(message) if isinstance(message, str) else message + self._loop.call_soon_threadsafe(self._queue.put_nowait, event) + except Exception as e: + logger.error("DashScope callback error: %s", e) + + def on_close(self, code, msg): + logger.info("DashScope realtime closed: code=%s msg=%s", code, msg) + + +def format_transcription_event(event: dict, accumulated: str) -> dict | None: + event_type = event.get("type", "") + + if event_type == "conversation.item.input_audio_transcription.text": + stash = event.get("stash", "") + display = build_display_text(accumulated, stash) if stash else accumulated + return { + "delta": "", + "full_text": _to_traditional(display), + "language": event.get("language", "yue"), + "is_final": False, + } + + if event_type == "conversation.item.input_audio_transcription.completed": + transcript = event.get("transcript", "") + new_accumulated = build_display_text(accumulated, transcript) if transcript and transcript.strip() else accumulated + return { + "delta": "", + "full_text": _to_traditional(new_accumulated), + "language": event.get("language", "yue"), + "is_final": True, + } + + return None + + +async def _ws_proxy_dashscope(client_ws: WebSocket, loop: asyncio.AbstractEventLoop, language: str = "yue"): + event_queue: asyncio.Queue = asyncio.Queue() + callback = DashScopeCallback(event_queue, loop) + + conversation = OmniRealtimeConversation( + model=get_settings().asr_realtime_model_name, + url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime", + callback=callback, + ) + + await loop.run_in_executor(None, conversation.connect) + + transcription_kwargs: dict = { + "sample_rate": 16000, + "input_audio_format": "pcm", + } + if language != "auto": + transcription_kwargs["language"] = language + + transcription_params = TranscriptionParams(**transcription_kwargs) + conversation.update_session( + output_modalities=[MultiModality.TEXT], + enable_input_audio_transcription=True, + transcription_params=transcription_params, + ) + logger.info("dashscope-session-updated lang=%s", language) + + accumulated_text = "" + + async def read_events(): + nonlocal accumulated_text + while True: + event = await event_queue.get() + result = format_transcription_event(event, accumulated_text) + if result is not None: + if result["is_final"]: + event_type = event.get("type", "") + if event_type == "conversation.item.input_audio_transcription.completed": + transcript = event.get("transcript", "") + if transcript and transcript.strip(): + accumulated_text = build_display_text(accumulated_text, transcript) + result["full_text"] = _to_traditional(accumulated_text) + await client_ws.send_json(result) + + read_task = asyncio.create_task(read_events()) + + try: + while True: + float32_bytes = await client_ws.receive_bytes() + s16_bytes = float32_to_s16le(float32_bytes) + audio_b64 = base64.b64encode(s16_bytes).decode("ascii") + conversation.append_audio(audio_b64) + except WebSocketDisconnect: + pass + finally: + read_task.cancel() + try: + conversation.close() + except Exception: + pass + logger.info("dashscope-session-closed text_len=%d", len(accumulated_text)) + @router.websocket("/ws/asr/{video_id}") async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue"): + settings = get_settings() + + if not settings.dashscope_api_key: + await websocket.accept() + await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"}) + await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set") + return + await websocket.accept() + loop = asyncio.get_event_loop() + logger.info("ws-connect video_id=%s lang=%s", video_id, language) + try: - while True: - await websocket.receive_bytes() - except WebSocketDisconnect: - pass + await _ws_proxy_dashscope(websocket, loop, language) + except Exception as e: + logger.error("ws-asr error: %s", e) + finally: + logger.info("ws-disconnect video_id=%s", video_id) diff --git a/backend/app/services/asr_client.py b/backend/app/services/asr_client.py index f04ee5d..af2863a 100644 --- a/backend/app/services/asr_client.py +++ b/backend/app/services/asr_client.py @@ -2,6 +2,9 @@ import struct import base64 import logging +import zhconv +from openai import OpenAI + logger = logging.getLogger(__name__) @@ -17,9 +20,40 @@ def build_display_text(accumulated: str, current: str) -> str: return " ".join(parts) +def _to_traditional(text: str) -> str: + if not text: + return text + return zhconv.convert(text, "zh-hant") + + class ASRClient: def __init__(self, settings): self.settings = settings - async def transcribe_full(self, audio_bytes: bytes, language: str = "yue") -> str: - raise NotImplementedError + def transcribe_full(self, audio_bytes: bytes, language: str = "yue") -> str: + audio_b64 = base64.b64encode(audio_bytes).decode() + data_url = f"data:;base64,{audio_b64}" + + client = OpenAI( + api_key=self.settings.dashscope_api_key, + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + ) + + resp = client.chat.completions.create( + model=self.settings.asr_model_name, + messages=[{ # type: ignore[list-item] + "role": "user", + "content": [{ + "type": "input_audio", + "input_audio": {"data": data_url}, + }], + }], + extra_body={ + "asr_options": { + "language": language if language != "auto" else None, + } + }, + ) + + result = resp.choices[0].message.content or "" + return _to_traditional(result) diff --git a/backend/app/services/video_service.py b/backend/app/services/video_service.py index bdb6b1b..81a9ef1 100644 --- a/backend/app/services/video_service.py +++ b/backend/app/services/video_service.py @@ -1,6 +1,11 @@ +import asyncio from pathlib import Path from fastapi import HTTPException +import logging + +logger = logging.getLogger(__name__) + class VideoService: def __init__(self, upload_dir: str, max_size_mb: int, supported_formats: list[str]): @@ -33,3 +38,28 @@ class VideoService: def delete_video(self, video_id: str) -> None: for p in self.upload_dir.glob(f"{video_id}.*"): p.unlink() + + async def extract_audio(self, video_id: str) -> Path: + video_path = self.get_video_path(video_id) + output_path = self.upload_dir / f"{video_id}_audio.wav" + + proc = await asyncio.create_subprocess_exec( + "ffmpeg", "-i", str(video_path), + "-vn", "-acodec", "pcm_s16le", + "-ar", "16000", "-ac", "1", + "-f", "wav", str(output_path), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + if output_path.exists(): + output_path.unlink(missing_ok=True) + logger.error("ffmpeg failed for video_id=%s: %s", video_id, stderr.decode(errors="replace")) + raise HTTPException( + status_code=500, + detail=f"Audio extraction failed: {stderr.decode(errors='replace')[:200]}", + ) + + return output_path diff --git a/backend/app/test/conftest.py b/backend/app/test/conftest.py index 71ad424..0b12aab 100644 --- a/backend/app/test/conftest.py +++ b/backend/app/test/conftest.py @@ -23,6 +23,9 @@ def mock_asr_client(monkeypatch): async def transcribe(self, audio_bytes): # type: ignore return "" + def transcribe_full(self, audio_bytes, language="yue"): # type: ignore + return "mock full transcript" + return _Mock() diff --git a/backend/app/test/test_phase2_asr_client.py b/backend/app/test/test_phase2_asr_client.py index 7e5545c..418333e 100644 --- a/backend/app/test/test_phase2_asr_client.py +++ b/backend/app/test/test_phase2_asr_client.py @@ -1,25 +1,194 @@ -"""Phase 2 tests: ASR transcription client. +"""Phase 2 tests: ASR client utilities and batch transcription. Covers: -- Integration with Qwen/Qwen3-ASR-1.7B -- File upload vs audio content input -- Error handling for transcription failures -- Mocked responses in test mode +- float32_to_s16le() conversion correctness +- build_display_text() multi-utterance assembly +- _to_traditional() simplified→traditional Chinese conversion +- ASRClient.transcribe_full() batch transcription (mocked OpenAI client) """ +import struct +from unittest.mock import MagicMock, patch + import pytest -class TestASRClient: - """ASR client tests (all external calls mocked).""" +class TestFloat32ToS16Le: + def test_converts_silence(self): + from app.services.asr_client import float32_to_s16le - def test_asr_transcribe_audio(self, mock_asr_client): - """Should return transcript from mocked ASR.""" - pass # TODO: implement + samples = struct.pack("<4f", 0.0, 0.0, 0.0, 0.0) + result = float32_to_s16le(samples) + expected = struct.pack("<4h", 0, 0, 0, 0) + assert result == expected - def test_asr_file_upload_mode(self): - """Should support file path input.""" - pass # TODO: implement + def test_converts_positive_peak(self): + from app.services.asr_client import float32_to_s16le - def test_asr_audio_content_mode(self): - """Should support raw audio bytes input.""" - pass # TODO: implement + samples = struct.pack("<1f", 1.0) + result = float32_to_s16le(samples) + expected = struct.pack("<1h", 32767) + assert result == expected + + def test_converts_negative_peak(self): + from app.services.asr_client import float32_to_s16le + + samples = struct.pack("<1f", -1.0) + result = float32_to_s16le(samples) + expected = struct.pack("<1h", max(-32768, min(32767, int(-1.0 * 32767.0)))) + assert result == expected + + def test_clips_overflow(self): + from app.services.asr_client import float32_to_s16le + + samples = struct.pack("<1f", 2.0) + result = float32_to_s16le(samples) + expected = struct.pack("<1h", 32767) + assert result == expected + + def test_clips_underflow(self): + from app.services.asr_client import float32_to_s16le + + samples = struct.pack("<1f", -2.0) + result = float32_to_s16le(samples) + expected = struct.pack("<1h", -32768) + assert result == expected + + def test_multiple_samples(self): + from app.services.asr_client import float32_to_s16le + + floats = [0.5, -0.5, 0.25, -0.25] + samples = struct.pack(f"<{len(floats)}f", *floats) + result = float32_to_s16le(samples) + expected_ints = [int(f * 32767) for f in floats] + expected = struct.pack(f"<{len(expected_ints)}h", *expected_ints) + assert result == expected + + def test_empty_input(self): + from app.services.asr_client import float32_to_s16le + + assert float32_to_s16le(b"") == b"" + + +class TestBuildDisplayText: + def test_both_parts_present(self): + from app.services.asr_client import build_display_text + + assert build_display_text("hello", "world") == "hello world" + + def test_empty_accumulated(self): + from app.services.asr_client import build_display_text + + assert build_display_text("", "world") == "world" + + def test_empty_current(self): + from app.services.asr_client import build_display_text + + assert build_display_text("hello", "") == "hello" + + def test_both_empty(self): + from app.services.asr_client import build_display_text + + assert build_display_text("", "") == "" + + def test_whitespace_only_ignored(self): + from app.services.asr_client import build_display_text + + assert build_display_text("hello", " ") == "hello" + + +class TestToTraditional: + def test_converts_simplified(self): + from app.services.asr_client import _to_traditional + + result = _to_traditional("中国") + assert "國" in result + + def test_empty_string(self): + from app.services.asr_client import _to_traditional + + assert _to_traditional("") == "" + + def test_already_traditional(self): + from app.services.asr_client import _to_traditional + + text = "測試" + assert _to_traditional(text) == text + + def test_mixed_text(self): + from app.services.asr_client import _to_traditional + + result = _to_traditional("hello 中国 world") + assert "國" in result + assert "hello" in result + + +class TestTranscribeFull: + def test_returns_traditional_chinese_text(self, monkeypatch): + from app.services.asr_client import ASRClient + + settings = MagicMock() + settings.dashscope_api_key = "sk-test-key" + settings.asr_model_name = "qwen3-asr-flash" + + client = ASRClient(settings) + + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "测试结果" + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.return_value = mock_resp + + with patch("app.services.asr_client.OpenAI", return_value=mock_openai_client): + result = client.transcribe_full(b"fake-audio-bytes", language="yue") + + assert result == "測試結果" + mock_openai_client.chat.completions.create.assert_called_once() + call_kwargs = mock_openai_client.chat.completions.create.call_args + assert call_kwargs.kwargs["model"] == "qwen3-asr-flash" + assert call_kwargs.kwargs["extra_body"]["asr_options"]["language"] == "yue" + + def test_uses_correct_api_endpoint(self, monkeypatch): + from app.services.asr_client import ASRClient + + settings = MagicMock() + settings.dashscope_api_key = "sk-test-key" + settings.asr_model_name = "qwen3-asr-flash" + + client = ASRClient(settings) + + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "text" + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.return_value = mock_resp + + with patch("app.services.asr_client.OpenAI", return_value=mock_openai_client) as mock_openai_cls: + client.transcribe_full(b"audio", language="yue") + mock_openai_cls.assert_called_once_with( + api_key="sk-test-key", + base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", + ) + + def test_auto_language_omits_language_param(self, monkeypatch): + from app.services.asr_client import ASRClient + + settings = MagicMock() + settings.dashscope_api_key = "sk-test-key" + settings.asr_model_name = "qwen3-asr-flash" + + client = ASRClient(settings) + + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "text" + + mock_openai_client = MagicMock() + mock_openai_client.chat.completions.create.return_value = mock_resp + + with patch("app.services.asr_client.OpenAI", return_value=mock_openai_client): + client.transcribe_full(b"audio", language="auto") + + call_kwargs = mock_openai_client.chat.completions.create.call_args + assert call_kwargs.kwargs["extra_body"]["asr_options"]["language"] is None diff --git a/backend/app/test/test_phase2_full_transcript.py b/backend/app/test/test_phase2_full_transcript.py new file mode 100644 index 0000000..c69a06c --- /dev/null +++ b/backend/app/test/test_phase2_full_transcript.py @@ -0,0 +1,192 @@ +"""Phase 2 tests: Full transcript endpoint (POST /api/v1/video/{video_id}/transcribe). + +Covers: +- Successful transcription after video upload +- 404 for missing video +- ffmpeg audio extraction (mocked subprocess) +- Missing API key error handling +""" +import os +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app.routers.video import router + + +@pytest.fixture +def video_client(tmp_path, monkeypatch): + upload_dir = tmp_path / "test_uploads" + upload_dir.mkdir() + monkeypatch.setenv("VIDEO_UPLOAD_DIR", str(upload_dir)) + monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "50") + monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test-key") + + from app.core.config import get_settings + get_settings.cache_clear() + app = FastAPI() + app.include_router(router, prefix="/api/v1") + return TestClient(app), upload_dir + + +def _upload_video(client, filename="test.mp4", content=b"\x00" * 1024): + """Helper to upload a video and return the video_id.""" + resp = client.post( + "/api/v1/video/upload", + files={"file": (filename, content, "video/mp4")}, + ) + assert resp.status_code == 200 + return resp.json()["video_id"] + + +class TestTranscribeSuccess: + @patch("app.routers.video.VideoService.extract_audio") + @patch("app.services.asr_client.OpenAI") + def test_transcribe_returns_response(self, mock_openai_cls, mock_extract, video_client): + """POST transcribe should return FullTranscriptResponse.""" + client, upload_dir = video_client + video_id = _upload_video(client) + + # Mock extract_audio to return a fake WAV path + fake_wav = upload_dir / "extracted.wav" + fake_wav.write_bytes(b"RIFF" + b"\x00" * 100) + mock_extract.return_value = fake_wav + + # Mock OpenAI client + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "测试转录结果" + + mock_openai_instance = MagicMock() + mock_openai_instance.chat.completions.create.return_value = mock_resp + mock_openai_cls.return_value = mock_openai_instance + + resp = client.post(f"/api/v1/video/{video_id}/transcribe") + assert resp.status_code == 200 + data = resp.json() + assert "text" in data + assert "language" in data + assert data["language"] == "yue" + # Text should be traditional Chinese + assert "測" in data["text"] or "試" in data["text"] + + @patch("app.routers.video.VideoService.extract_audio") + @patch("app.services.asr_client.OpenAI") + def test_transcribe_custom_language(self, mock_openai_cls, mock_extract, video_client): + """POST transcribe with language param should pass it through.""" + client, upload_dir = video_client + video_id = _upload_video(client) + + fake_wav = upload_dir / "extracted.wav" + fake_wav.write_bytes(b"RIFF" + b"\x00" * 100) + mock_extract.return_value = fake_wav + + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "hello world" + + mock_openai_instance = MagicMock() + mock_openai_instance.chat.completions.create.return_value = mock_resp + mock_openai_cls.return_value = mock_openai_instance + + resp = client.post(f"/api/v1/video/{video_id}/transcribe?language=en") + assert resp.status_code == 200 + assert resp.json()["language"] == "en" + + +class TestTranscribeMissingVideo: + def test_404_for_unknown_video(self, video_client): + """POST transcribe for non-existent video should return 404.""" + client, _ = video_client + resp = client.post("/api/v1/video/nonexistent-video-id/transcribe") + assert resp.status_code == 404 + + +class TestTranscribeExtractsAudio: + @patch("app.services.video_service.asyncio.create_subprocess_exec") + async def test_extract_audio_calls_ffmpeg(self, mock_subprocess, tmp_path): + """extract_audio should call ffmpeg with correct arguments.""" + from app.services.video_service import VideoService + + # Setup: create a fake video file + upload_dir = tmp_path / "uploads" + upload_dir.mkdir() + video_file = upload_dir / "test-video.mp4" + video_file.write_bytes(b"fake-video-content") + + service = VideoService( + upload_dir=str(upload_dir), + max_size_mb=300, + supported_formats=[".mp4"], + ) + + # Mock the subprocess + mock_proc = AsyncMock() + mock_proc.returncode = 0 + mock_proc.communicate.return_value = (b"ffmpeg output", b"") + mock_subprocess.return_value = mock_proc + + result = await service.extract_audio("test-video") + assert result is not None + # Verify ffmpeg was called + mock_subprocess.assert_called_once() + call_args = mock_subprocess.call_args[0] + assert call_args[0] == "ffmpeg" + assert "-i" in call_args + + @patch("app.services.video_service.asyncio.create_subprocess_exec") + async def test_extract_audio_fails_gracefully(self, mock_subprocess, tmp_path): + """extract_audio should raise on ffmpeg failure.""" + from app.services.video_service import VideoService + + upload_dir = tmp_path / "uploads" + upload_dir.mkdir() + video_file = upload_dir / "test-fail.mp4" + video_file.write_bytes(b"bad-content") + + service = VideoService( + upload_dir=str(upload_dir), + max_size_mb=300, + supported_formats=[".mp4"], + ) + + mock_proc = AsyncMock() + mock_proc.returncode = 1 + mock_proc.communicate.return_value = (b"", b"Error: Invalid data") + mock_subprocess.return_value = mock_proc + + with pytest.raises(Exception): + await service.extract_audio("test-fail") + + +class TestTranscribeMissingApiKey: + def test_missing_api_key_returns_500(self, monkeypatch, tmp_path): + """Empty DASHSCOPE_API_KEY should return 500 with descriptive message.""" + upload_dir = tmp_path / "uploads" + upload_dir.mkdir() + monkeypatch.setenv("VIDEO_UPLOAD_DIR", str(upload_dir)) + monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "50") + monkeypatch.setenv("DASHSCOPE_API_KEY", "") + + from app.core.config import get_settings + get_settings.cache_clear() + + app = FastAPI() + app.include_router(router, prefix="/api/v1") + client = TestClient(app) + + # Upload a video first + resp = client.post( + "/api/v1/video/upload", + files={"file": ("test.mp4", b"\x00" * 512, "video/mp4")}, + ) + assert resp.status_code == 200 + video_id = resp.json()["video_id"] + + # Try to transcribe + resp = client.post(f"/api/v1/video/{video_id}/transcribe") + assert resp.status_code == 500 + assert "DASHSCOPE_API_KEY" in resp.json()["detail"] or "API key" in resp.json()["detail"] diff --git a/backend/app/test/test_phase2_ws_asr.py b/backend/app/test/test_phase2_ws_asr.py index 713a226..451b0fd 100644 --- a/backend/app/test/test_phase2_ws_asr.py +++ b/backend/app/test/test_phase2_ws_asr.py @@ -1,28 +1,73 @@ -"""Phase 2 tests: WebSocket ASR streaming. +"""Phase 2 tests: WebSocket ASR endpoint. Covers: -- /ws/asr/{video_id} connection lifecycle -- Real-time audio chunk streaming -- Transcript accumulation -- Connection cleanup on disconnect +- WebSocket connection accepted +- Language parameter defaults and customization +- Client disconnect handled cleanly +- Missing API key returns error """ import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient -class TestWebSocketASR: - """WebSocket ASR streaming tests.""" +@pytest.fixture +def ws_app(monkeypatch): + monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test-key") + from app.core.config import get_settings + from app.routers.ws_asr import router + get_settings.cache_clear() + app = FastAPI() + app.include_router(router) + return app - @pytest.mark.asyncio - async def test_ws_connection_established(self): - """Should accept WebSocket connection with valid video_id.""" - pass # TODO: implement - @pytest.mark.asyncio - async def test_ws_audio_chunk_streaming(self): - """Should process audio chunks and return transcripts.""" - pass # TODO: implement +@pytest.fixture +def ws_client(ws_app): + return TestClient(ws_app) - @pytest.mark.asyncio - async def test_ws_disconnect_cleanup(self): - """Should clean up resources on client disconnect.""" - pass # TODO: implement + +class TestWSEndpointAcceptsConnection: + def test_connect_success(self, ws_client): + with ws_client.websocket_connect("/ws/asr/test-video-123") as ws: + pass + + def test_connect_with_video_id(self, ws_client): + with ws_client.websocket_connect("/ws/asr/my-video-id") as ws: + pass + + +class TestWSEndpointLanguageParam: + def test_default_language_is_yue(self, ws_client): + with ws_client.websocket_connect("/ws/asr/test-vid") as ws: + pass + + def test_custom_language_param(self, ws_client): + with ws_client.websocket_connect("/ws/asr/test-vid?language=zh") as ws: + pass + + def test_english_language_param(self, ws_client): + with ws_client.websocket_connect("/ws/asr/test-vid?language=en") as ws: + pass + + +class TestWSEndpointHandlesDisconnect: + def test_clean_disconnect(self, ws_client): + with ws_client.websocket_connect("/ws/asr/test-vid") as ws: + pass + + +class TestWSEndpointMissingApiKey: + def test_missing_api_key_sends_error(self, monkeypatch): + monkeypatch.setenv("DASHSCOPE_API_KEY", "") + from app.core.config import get_settings + from app.routers.ws_asr import router + get_settings.cache_clear() + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + with client.websocket_connect("/ws/asr/test-vid") as ws: + msg = ws.receive_json() + assert "error" in msg or "detail" in msg diff --git a/backend/app/test/test_phase2_ws_protocol.py b/backend/app/test/test_phase2_ws_protocol.py new file mode 100644 index 0000000..4a5a185 --- /dev/null +++ b/backend/app/test/test_phase2_ws_protocol.py @@ -0,0 +1,152 @@ +"""Phase 2 tests: DashScope WebSocket protocol — callback bridge and event formatting. + +Covers: +- DashScopeCallback sync→async queue bridge +- Transcription text event formatting (partial) +- Transcription completed event formatting (final) +""" +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + + +class TestDashScopeCallback: + def test_puts_events_on_queue(self): + """on_event should push parsed JSON onto the asyncio queue.""" + from app.routers.ws_asr import DashScopeCallback + + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.new_event_loop() + callback = DashScopeCallback(queue, loop) + + test_event = json.dumps({"type": "test", "data": "hello"}) + callback.on_event(test_event) + + # Give the loop a chance to process call_soon_threadsafe + loop.run_until_complete(asyncio.sleep(0.05)) + + assert not queue.empty() + event = queue.get_nowait() + assert event["type"] == "test" + assert event["data"] == "hello" + loop.close() + + def test_handles_dict_event(self): + """on_event should accept dict as well as string.""" + from app.routers.ws_asr import DashScopeCallback + + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.new_event_loop() + callback = DashScopeCallback(queue, loop) + + callback.on_event({"type": "test_dict", "key": "value"}) + + loop.run_until_complete(asyncio.sleep(0.05)) + + assert not queue.empty() + event = queue.get_nowait() + assert event["type"] == "test_dict" + loop.close() + + def test_handles_invalid_json_gracefully(self): + """on_event should not crash on invalid JSON.""" + from app.routers.ws_asr import DashScopeCallback + + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.new_event_loop() + callback = DashScopeCallback(queue, loop) + + # Should not raise + callback.on_event("not-valid-json{{{") + loop.close() + + def test_on_open_and_close(self): + """on_open and on_close should not crash.""" + from app.routers.ws_asr import DashScopeCallback + + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.new_event_loop() + callback = DashScopeCallback(queue, loop) + + callback.on_open() + callback.on_close(1000, "normal") + loop.close() + + +class TestProxyFormatsTranscriptionTextEvent: + def test_partial_event_format(self): + """Partial transcription event should format as ASRTranscriptEvent with is_final=False.""" + from app.routers.ws_asr import format_transcription_event + + event = { + "type": "conversation.item.input_audio_transcription.text", + "stash": "你好", + "language": "yue", + } + accumulated = "" + + result = format_transcription_event(event, accumulated) + assert result is not None + assert result["is_final"] is False + assert result["language"] == "yue" + assert result["delta"] == "" + assert "你好" in result["full_text"] + + def test_partial_with_accumulated(self): + """Partial event should combine accumulated + current stash.""" + from app.routers.ws_asr import format_transcription_event + + event = { + "type": "conversation.item.input_audio_transcription.text", + "stash": "世界", + "language": "yue", + } + accumulated = "你好" + + result = format_transcription_event(event, accumulated) + assert "你好" in result["full_text"] + assert "世界" in result["full_text"] + + +class TestProxyFormatsTranscriptionCompletedEvent: + def test_completed_event_format(self): + """Completed event should format as ASRTranscriptEvent with is_final=True.""" + from app.routers.ws_asr import format_transcription_event + + event = { + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "你好世界", + "language": "yue", + } + accumulated = "" + + result = format_transcription_event(event, accumulated) + assert result is not None + assert result["is_final"] is True + assert result["language"] == "yue" + assert "你好" in result["full_text"] + + def test_completed_updates_accumulated(self): + """Completed event should return updated accumulated text.""" + from app.routers.ws_asr import format_transcription_event + + event = { + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "世界", + "language": "yue", + } + accumulated = "你好" + + result = format_transcription_event(event, accumulated) + assert "你好" in result["full_text"] + assert "世界" in result["full_text"] + + def test_unknown_event_type_returns_none(self): + """Unknown event types should return None.""" + from app.routers.ws_asr import format_transcription_event + + event = {"type": "unknown.event"} + result = format_transcription_event(event, "") + assert result is None diff --git a/frontend/src/components/QueryInput.tsx b/frontend/src/components/QueryInput.tsx index 792b6a2..20fdeba 100644 --- a/frontend/src/components/QueryInput.tsx +++ b/frontend/src/components/QueryInput.tsx @@ -3,11 +3,16 @@ import React, { useState, type FormEvent, type KeyboardEvent } from 'react' export interface QueryInputProps { onSubmit: (question: string) => void isLoading: boolean + partialText?: string } -export const QueryInput: React.FC = ({ onSubmit, isLoading }) => { +export const QueryInput: React.FC = ({ onSubmit, isLoading, partialText }) => { const [question, setQuestion] = useState('') const [submittedQuestion, setSubmittedQuestion] = useState(null) + const [hasUserInput, setHasUserInput] = useState(false) + + const displayValue = hasUserInput ? question : (partialText ?? question) + const showPartialStyle = !hasUserInput && !!partialText const handleSubmit = (e: FormEvent): void => { e.preventDefault() @@ -16,6 +21,7 @@ export const QueryInput: React.FC = ({ onSubmit, isLoading }) = onSubmit(trimmed) setSubmittedQuestion(trimmed) setQuestion('') + setHasUserInput(false) } } @@ -28,6 +34,7 @@ export const QueryInput: React.FC = ({ onSubmit, isLoading }) = const handleChange = (e: React.ChangeEvent): void => { setQuestion(e.target.value) + setHasUserInput(true) if (e.target.value.trim() !== '') { setSubmittedQuestion(null) } @@ -35,16 +42,21 @@ export const QueryInput: React.FC = ({ onSubmit, isLoading }) = const isDisabled = isLoading || question.trim() === '' + const textareaClassName = [ + 'w-full rounded border border-gray-300 px-3 py-2 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent disabled:bg-gray-100 disabled:cursor-not-allowed', + showPartialStyle ? 'text-gray-400 italic' : '', + ].filter(Boolean).join(' ') + return (