feat: Phase 2.3 ASR proxy + full transcript and 2.4 frontend hooks
- Backend: DashScope WebSocket proxy (/ws/asr/{video_id}), DashScopeCallback
sync-to-async bridge, ffmpeg audio extraction, POST /video/{id}/transcribe
- Frontend: useVideoASR hook (auto on play), useFullTranscript hook,
QueryInput partialText prop, VideoUploadResponse types, uploadVideo API
- Tests: 41 backend + 26 frontend = 67 new tests, all passing
This commit is contained in:
parent
9934749d2b
commit
a4e067822b
|
|
@ -6,8 +6,9 @@ import aiofiles
|
|||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from app.models.video import VideoUploadResponse
|
||||
from app.models.video import VideoUploadResponse, FullTranscriptResponse
|
||||
from app.services.video_service import VideoService
|
||||
from app.services.asr_client import ASRClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["video"])
|
||||
|
|
@ -75,3 +76,35 @@ async def serve_video(video_id: str):
|
|||
".mkv": "video/x-matroska",
|
||||
}
|
||||
return FileResponse(str(video_path), media_type=media_types.get(ext, "application/octet-stream"))
|
||||
|
||||
|
||||
@router.post("/video/{video_id}/transcribe", response_model=FullTranscriptResponse)
|
||||
async def transcribe_video(video_id: str, language: str = "yue"):
|
||||
from app.core.config import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.dashscope_api_key:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="DASHSCOPE_API_KEY is not configured. Set it in .env to enable transcription.",
|
||||
)
|
||||
|
||||
service = _get_video_service()
|
||||
wav_path = await service.extract_audio(video_id)
|
||||
|
||||
try:
|
||||
audio_bytes = wav_path.read_bytes()
|
||||
asr = ASRClient(settings)
|
||||
text = asr.transcribe_full(audio_bytes, language=language)
|
||||
except Exception as e:
|
||||
logger.error("Transcription failed for video_id=%s: %s", video_id, e)
|
||||
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
||||
finally:
|
||||
if wav_path.exists():
|
||||
wav_path.unlink(missing_ok=True)
|
||||
|
||||
return FullTranscriptResponse(
|
||||
text=text,
|
||||
language=language,
|
||||
duration_seconds=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,156 @@
|
|||
import json
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.services.asr_client import float32_to_s16le, build_display_text, _to_traditional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["asr"])
|
||||
|
||||
try:
|
||||
from dashscope.audio.qwen_omni.omni_realtime import (
|
||||
OmniRealtimeConversation,
|
||||
OmniRealtimeCallback,
|
||||
TranscriptionParams,
|
||||
MultiModality,
|
||||
)
|
||||
except ImportError:
|
||||
OmniRealtimeConversation = None
|
||||
OmniRealtimeCallback = object
|
||||
TranscriptionParams = None
|
||||
MultiModality = None
|
||||
|
||||
|
||||
class DashScopeCallback(OmniRealtimeCallback):
|
||||
def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop):
|
||||
super().__init__()
|
||||
self._queue = event_queue
|
||||
self._loop = loop
|
||||
|
||||
def on_open(self):
|
||||
logger.info("DashScope realtime connection opened")
|
||||
|
||||
def on_event(self, message):
|
||||
try:
|
||||
event = json.loads(message) if isinstance(message, str) else message
|
||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||
except Exception as e:
|
||||
logger.error("DashScope callback error: %s", e)
|
||||
|
||||
def on_close(self, code, msg):
|
||||
logger.info("DashScope realtime closed: code=%s msg=%s", code, msg)
|
||||
|
||||
|
||||
def format_transcription_event(event: dict, accumulated: str) -> dict | None:
|
||||
event_type = event.get("type", "")
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.text":
|
||||
stash = event.get("stash", "")
|
||||
display = build_display_text(accumulated, stash) if stash else accumulated
|
||||
return {
|
||||
"delta": "",
|
||||
"full_text": _to_traditional(display),
|
||||
"language": event.get("language", "yue"),
|
||||
"is_final": False,
|
||||
}
|
||||
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
transcript = event.get("transcript", "")
|
||||
new_accumulated = build_display_text(accumulated, transcript) if transcript and transcript.strip() else accumulated
|
||||
return {
|
||||
"delta": "",
|
||||
"full_text": _to_traditional(new_accumulated),
|
||||
"language": event.get("language", "yue"),
|
||||
"is_final": True,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _ws_proxy_dashscope(client_ws: WebSocket, loop: asyncio.AbstractEventLoop, language: str = "yue"):
|
||||
event_queue: asyncio.Queue = asyncio.Queue()
|
||||
callback = DashScopeCallback(event_queue, loop)
|
||||
|
||||
conversation = OmniRealtimeConversation(
|
||||
model=get_settings().asr_realtime_model_name,
|
||||
url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime",
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
await loop.run_in_executor(None, conversation.connect)
|
||||
|
||||
transcription_kwargs: dict = {
|
||||
"sample_rate": 16000,
|
||||
"input_audio_format": "pcm",
|
||||
}
|
||||
if language != "auto":
|
||||
transcription_kwargs["language"] = language
|
||||
|
||||
transcription_params = TranscriptionParams(**transcription_kwargs)
|
||||
conversation.update_session(
|
||||
output_modalities=[MultiModality.TEXT],
|
||||
enable_input_audio_transcription=True,
|
||||
transcription_params=transcription_params,
|
||||
)
|
||||
logger.info("dashscope-session-updated lang=%s", language)
|
||||
|
||||
accumulated_text = ""
|
||||
|
||||
async def read_events():
|
||||
nonlocal accumulated_text
|
||||
while True:
|
||||
event = await event_queue.get()
|
||||
result = format_transcription_event(event, accumulated_text)
|
||||
if result is not None:
|
||||
if result["is_final"]:
|
||||
event_type = event.get("type", "")
|
||||
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||
transcript = event.get("transcript", "")
|
||||
if transcript and transcript.strip():
|
||||
accumulated_text = build_display_text(accumulated_text, transcript)
|
||||
result["full_text"] = _to_traditional(accumulated_text)
|
||||
await client_ws.send_json(result)
|
||||
|
||||
read_task = asyncio.create_task(read_events())
|
||||
|
||||
try:
|
||||
while True:
|
||||
float32_bytes = await client_ws.receive_bytes()
|
||||
s16_bytes = float32_to_s16le(float32_bytes)
|
||||
audio_b64 = base64.b64encode(s16_bytes).decode("ascii")
|
||||
conversation.append_audio(audio_b64)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
read_task.cancel()
|
||||
try:
|
||||
conversation.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("dashscope-session-closed text_len=%d", len(accumulated_text))
|
||||
|
||||
|
||||
@router.websocket("/ws/asr/{video_id}")
|
||||
async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue"):
|
||||
settings = get_settings()
|
||||
|
||||
if not settings.dashscope_api_key:
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"})
|
||||
await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
loop = asyncio.get_event_loop()
|
||||
logger.info("ws-connect video_id=%s lang=%s", video_id, language)
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_bytes()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
await _ws_proxy_dashscope(websocket, loop, language)
|
||||
except Exception as e:
|
||||
logger.error("ws-asr error: %s", e)
|
||||
finally:
|
||||
logger.info("ws-disconnect video_id=%s", video_id)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import struct
|
|||
import base64
|
||||
import logging
|
||||
|
||||
import zhconv
|
||||
from openai import OpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -17,9 +20,40 @@ def build_display_text(accumulated: str, current: str) -> str:
|
|||
return " ".join(parts)
|
||||
|
||||
|
||||
def _to_traditional(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
return zhconv.convert(text, "zh-hant")
|
||||
|
||||
|
||||
class ASRClient:
|
||||
def __init__(self, settings):
|
||||
self.settings = settings
|
||||
|
||||
async def transcribe_full(self, audio_bytes: bytes, language: str = "yue") -> str:
|
||||
raise NotImplementedError
|
||||
def transcribe_full(self, audio_bytes: bytes, language: str = "yue") -> str:
|
||||
audio_b64 = base64.b64encode(audio_bytes).decode()
|
||||
data_url = f"data:;base64,{audio_b64}"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self.settings.dashscope_api_key,
|
||||
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model=self.settings.asr_model_name,
|
||||
messages=[{ # type: ignore[list-item]
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": data_url},
|
||||
}],
|
||||
}],
|
||||
extra_body={
|
||||
"asr_options": {
|
||||
"language": language if language != "auto" else None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
result = resp.choices[0].message.content or ""
|
||||
return _to_traditional(result)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
import asyncio
|
||||
from pathlib import Path
|
||||
from fastapi import HTTPException
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VideoService:
|
||||
def __init__(self, upload_dir: str, max_size_mb: int, supported_formats: list[str]):
|
||||
|
|
@ -33,3 +38,28 @@ class VideoService:
|
|||
def delete_video(self, video_id: str) -> None:
|
||||
for p in self.upload_dir.glob(f"{video_id}.*"):
|
||||
p.unlink()
|
||||
|
||||
async def extract_audio(self, video_id: str) -> Path:
|
||||
video_path = self.get_video_path(video_id)
|
||||
output_path = self.upload_dir / f"{video_id}_audio.wav"
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg", "-i", str(video_path),
|
||||
"-vn", "-acodec", "pcm_s16le",
|
||||
"-ar", "16000", "-ac", "1",
|
||||
"-f", "wav", str(output_path),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
if output_path.exists():
|
||||
output_path.unlink(missing_ok=True)
|
||||
logger.error("ffmpeg failed for video_id=%s: %s", video_id, stderr.decode(errors="replace"))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Audio extraction failed: {stderr.decode(errors='replace')[:200]}",
|
||||
)
|
||||
|
||||
return output_path
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ def mock_asr_client(monkeypatch):
|
|||
async def transcribe(self, audio_bytes): # type: ignore
|
||||
return ""
|
||||
|
||||
def transcribe_full(self, audio_bytes, language="yue"): # type: ignore
|
||||
return "mock full transcript"
|
||||
|
||||
return _Mock()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,194 @@
|
|||
"""Phase 2 tests: ASR transcription client.
|
||||
"""Phase 2 tests: ASR client utilities and batch transcription.
|
||||
|
||||
Covers:
|
||||
- Integration with Qwen/Qwen3-ASR-1.7B
|
||||
- File upload vs audio content input
|
||||
- Error handling for transcription failures
|
||||
- Mocked responses in test mode
|
||||
- float32_to_s16le() conversion correctness
|
||||
- build_display_text() multi-utterance assembly
|
||||
- _to_traditional() simplified→traditional Chinese conversion
|
||||
- ASRClient.transcribe_full() batch transcription (mocked OpenAI client)
|
||||
"""
|
||||
import struct
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestASRClient:
|
||||
"""ASR client tests (all external calls mocked)."""
|
||||
class TestFloat32ToS16Le:
|
||||
def test_converts_silence(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
def test_asr_transcribe_audio(self, mock_asr_client):
|
||||
"""Should return transcript from mocked ASR."""
|
||||
pass # TODO: implement
|
||||
samples = struct.pack("<4f", 0.0, 0.0, 0.0, 0.0)
|
||||
result = float32_to_s16le(samples)
|
||||
expected = struct.pack("<4h", 0, 0, 0, 0)
|
||||
assert result == expected
|
||||
|
||||
def test_asr_file_upload_mode(self):
|
||||
"""Should support file path input."""
|
||||
pass # TODO: implement
|
||||
def test_converts_positive_peak(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
def test_asr_audio_content_mode(self):
|
||||
"""Should support raw audio bytes input."""
|
||||
pass # TODO: implement
|
||||
samples = struct.pack("<1f", 1.0)
|
||||
result = float32_to_s16le(samples)
|
||||
expected = struct.pack("<1h", 32767)
|
||||
assert result == expected
|
||||
|
||||
def test_converts_negative_peak(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
samples = struct.pack("<1f", -1.0)
|
||||
result = float32_to_s16le(samples)
|
||||
expected = struct.pack("<1h", max(-32768, min(32767, int(-1.0 * 32767.0))))
|
||||
assert result == expected
|
||||
|
||||
def test_clips_overflow(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
samples = struct.pack("<1f", 2.0)
|
||||
result = float32_to_s16le(samples)
|
||||
expected = struct.pack("<1h", 32767)
|
||||
assert result == expected
|
||||
|
||||
def test_clips_underflow(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
samples = struct.pack("<1f", -2.0)
|
||||
result = float32_to_s16le(samples)
|
||||
expected = struct.pack("<1h", -32768)
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_samples(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
floats = [0.5, -0.5, 0.25, -0.25]
|
||||
samples = struct.pack(f"<{len(floats)}f", *floats)
|
||||
result = float32_to_s16le(samples)
|
||||
expected_ints = [int(f * 32767) for f in floats]
|
||||
expected = struct.pack(f"<{len(expected_ints)}h", *expected_ints)
|
||||
assert result == expected
|
||||
|
||||
def test_empty_input(self):
|
||||
from app.services.asr_client import float32_to_s16le
|
||||
|
||||
assert float32_to_s16le(b"") == b""
|
||||
|
||||
|
||||
class TestBuildDisplayText:
|
||||
def test_both_parts_present(self):
|
||||
from app.services.asr_client import build_display_text
|
||||
|
||||
assert build_display_text("hello", "world") == "hello world"
|
||||
|
||||
def test_empty_accumulated(self):
|
||||
from app.services.asr_client import build_display_text
|
||||
|
||||
assert build_display_text("", "world") == "world"
|
||||
|
||||
def test_empty_current(self):
|
||||
from app.services.asr_client import build_display_text
|
||||
|
||||
assert build_display_text("hello", "") == "hello"
|
||||
|
||||
def test_both_empty(self):
|
||||
from app.services.asr_client import build_display_text
|
||||
|
||||
assert build_display_text("", "") == ""
|
||||
|
||||
def test_whitespace_only_ignored(self):
|
||||
from app.services.asr_client import build_display_text
|
||||
|
||||
assert build_display_text("hello", " ") == "hello"
|
||||
|
||||
|
||||
class TestToTraditional:
|
||||
def test_converts_simplified(self):
|
||||
from app.services.asr_client import _to_traditional
|
||||
|
||||
result = _to_traditional("中国")
|
||||
assert "國" in result
|
||||
|
||||
def test_empty_string(self):
|
||||
from app.services.asr_client import _to_traditional
|
||||
|
||||
assert _to_traditional("") == ""
|
||||
|
||||
def test_already_traditional(self):
|
||||
from app.services.asr_client import _to_traditional
|
||||
|
||||
text = "測試"
|
||||
assert _to_traditional(text) == text
|
||||
|
||||
def test_mixed_text(self):
|
||||
from app.services.asr_client import _to_traditional
|
||||
|
||||
result = _to_traditional("hello 中国 world")
|
||||
assert "國" in result
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
class TestTranscribeFull:
|
||||
def test_returns_traditional_chinese_text(self, monkeypatch):
|
||||
from app.services.asr_client import ASRClient
|
||||
|
||||
settings = MagicMock()
|
||||
settings.dashscope_api_key = "sk-test-key"
|
||||
settings.asr_model_name = "qwen3-asr-flash"
|
||||
|
||||
client = ASRClient(settings)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "测试结果"
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.return_value = mock_resp
|
||||
|
||||
with patch("app.services.asr_client.OpenAI", return_value=mock_openai_client):
|
||||
result = client.transcribe_full(b"fake-audio-bytes", language="yue")
|
||||
|
||||
assert result == "測試結果"
|
||||
mock_openai_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "qwen3-asr-flash"
|
||||
assert call_kwargs.kwargs["extra_body"]["asr_options"]["language"] == "yue"
|
||||
|
||||
def test_uses_correct_api_endpoint(self, monkeypatch):
|
||||
from app.services.asr_client import ASRClient
|
||||
|
||||
settings = MagicMock()
|
||||
settings.dashscope_api_key = "sk-test-key"
|
||||
settings.asr_model_name = "qwen3-asr-flash"
|
||||
|
||||
client = ASRClient(settings)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "text"
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.return_value = mock_resp
|
||||
|
||||
with patch("app.services.asr_client.OpenAI", return_value=mock_openai_client) as mock_openai_cls:
|
||||
client.transcribe_full(b"audio", language="yue")
|
||||
mock_openai_cls.assert_called_once_with(
|
||||
api_key="sk-test-key",
|
||||
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
|
||||
def test_auto_language_omits_language_param(self, monkeypatch):
|
||||
from app.services.asr_client import ASRClient
|
||||
|
||||
settings = MagicMock()
|
||||
settings.dashscope_api_key = "sk-test-key"
|
||||
settings.asr_model_name = "qwen3-asr-flash"
|
||||
|
||||
client = ASRClient(settings)
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "text"
|
||||
|
||||
mock_openai_client = MagicMock()
|
||||
mock_openai_client.chat.completions.create.return_value = mock_resp
|
||||
|
||||
with patch("app.services.asr_client.OpenAI", return_value=mock_openai_client):
|
||||
client.transcribe_full(b"audio", language="auto")
|
||||
|
||||
call_kwargs = mock_openai_client.chat.completions.create.call_args
|
||||
assert call_kwargs.kwargs["extra_body"]["asr_options"]["language"] is None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
"""Phase 2 tests: Full transcript endpoint (POST /api/v1/video/{video_id}/transcribe).
|
||||
|
||||
Covers:
|
||||
- Successful transcription after video upload
|
||||
- 404 for missing video
|
||||
- ffmpeg audio extraction (mocked subprocess)
|
||||
- Missing API key error handling
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.routers.video import router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def video_client(tmp_path, monkeypatch):
|
||||
upload_dir = tmp_path / "test_uploads"
|
||||
upload_dir.mkdir()
|
||||
monkeypatch.setenv("VIDEO_UPLOAD_DIR", str(upload_dir))
|
||||
monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "50")
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test-key")
|
||||
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
return TestClient(app), upload_dir
|
||||
|
||||
|
||||
def _upload_video(client, filename="test.mp4", content=b"\x00" * 1024):
|
||||
"""Helper to upload a video and return the video_id."""
|
||||
resp = client.post(
|
||||
"/api/v1/video/upload",
|
||||
files={"file": (filename, content, "video/mp4")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return resp.json()["video_id"]
|
||||
|
||||
|
||||
class TestTranscribeSuccess:
|
||||
@patch("app.routers.video.VideoService.extract_audio")
|
||||
@patch("app.services.asr_client.OpenAI")
|
||||
def test_transcribe_returns_response(self, mock_openai_cls, mock_extract, video_client):
|
||||
"""POST transcribe should return FullTranscriptResponse."""
|
||||
client, upload_dir = video_client
|
||||
video_id = _upload_video(client)
|
||||
|
||||
# Mock extract_audio to return a fake WAV path
|
||||
fake_wav = upload_dir / "extracted.wav"
|
||||
fake_wav.write_bytes(b"RIFF" + b"\x00" * 100)
|
||||
mock_extract.return_value = fake_wav
|
||||
|
||||
# Mock OpenAI client
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "测试转录结果"
|
||||
|
||||
mock_openai_instance = MagicMock()
|
||||
mock_openai_instance.chat.completions.create.return_value = mock_resp
|
||||
mock_openai_cls.return_value = mock_openai_instance
|
||||
|
||||
resp = client.post(f"/api/v1/video/{video_id}/transcribe")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "text" in data
|
||||
assert "language" in data
|
||||
assert data["language"] == "yue"
|
||||
# Text should be traditional Chinese
|
||||
assert "測" in data["text"] or "試" in data["text"]
|
||||
|
||||
@patch("app.routers.video.VideoService.extract_audio")
|
||||
@patch("app.services.asr_client.OpenAI")
|
||||
def test_transcribe_custom_language(self, mock_openai_cls, mock_extract, video_client):
|
||||
"""POST transcribe with language param should pass it through."""
|
||||
client, upload_dir = video_client
|
||||
video_id = _upload_video(client)
|
||||
|
||||
fake_wav = upload_dir / "extracted.wav"
|
||||
fake_wav.write_bytes(b"RIFF" + b"\x00" * 100)
|
||||
mock_extract.return_value = fake_wav
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "hello world"
|
||||
|
||||
mock_openai_instance = MagicMock()
|
||||
mock_openai_instance.chat.completions.create.return_value = mock_resp
|
||||
mock_openai_cls.return_value = mock_openai_instance
|
||||
|
||||
resp = client.post(f"/api/v1/video/{video_id}/transcribe?language=en")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["language"] == "en"
|
||||
|
||||
|
||||
class TestTranscribeMissingVideo:
|
||||
def test_404_for_unknown_video(self, video_client):
|
||||
"""POST transcribe for non-existent video should return 404."""
|
||||
client, _ = video_client
|
||||
resp = client.post("/api/v1/video/nonexistent-video-id/transcribe")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestTranscribeExtractsAudio:
|
||||
@patch("app.services.video_service.asyncio.create_subprocess_exec")
|
||||
async def test_extract_audio_calls_ffmpeg(self, mock_subprocess, tmp_path):
|
||||
"""extract_audio should call ffmpeg with correct arguments."""
|
||||
from app.services.video_service import VideoService
|
||||
|
||||
# Setup: create a fake video file
|
||||
upload_dir = tmp_path / "uploads"
|
||||
upload_dir.mkdir()
|
||||
video_file = upload_dir / "test-video.mp4"
|
||||
video_file.write_bytes(b"fake-video-content")
|
||||
|
||||
service = VideoService(
|
||||
upload_dir=str(upload_dir),
|
||||
max_size_mb=300,
|
||||
supported_formats=[".mp4"],
|
||||
)
|
||||
|
||||
# Mock the subprocess
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.communicate.return_value = (b"ffmpeg output", b"")
|
||||
mock_subprocess.return_value = mock_proc
|
||||
|
||||
result = await service.extract_audio("test-video")
|
||||
assert result is not None
|
||||
# Verify ffmpeg was called
|
||||
mock_subprocess.assert_called_once()
|
||||
call_args = mock_subprocess.call_args[0]
|
||||
assert call_args[0] == "ffmpeg"
|
||||
assert "-i" in call_args
|
||||
|
||||
@patch("app.services.video_service.asyncio.create_subprocess_exec")
|
||||
async def test_extract_audio_fails_gracefully(self, mock_subprocess, tmp_path):
|
||||
"""extract_audio should raise on ffmpeg failure."""
|
||||
from app.services.video_service import VideoService
|
||||
|
||||
upload_dir = tmp_path / "uploads"
|
||||
upload_dir.mkdir()
|
||||
video_file = upload_dir / "test-fail.mp4"
|
||||
video_file.write_bytes(b"bad-content")
|
||||
|
||||
service = VideoService(
|
||||
upload_dir=str(upload_dir),
|
||||
max_size_mb=300,
|
||||
supported_formats=[".mp4"],
|
||||
)
|
||||
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.communicate.return_value = (b"", b"Error: Invalid data")
|
||||
mock_subprocess.return_value = mock_proc
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await service.extract_audio("test-fail")
|
||||
|
||||
|
||||
class TestTranscribeMissingApiKey:
|
||||
def test_missing_api_key_returns_500(self, monkeypatch, tmp_path):
|
||||
"""Empty DASHSCOPE_API_KEY should return 500 with descriptive message."""
|
||||
upload_dir = tmp_path / "uploads"
|
||||
upload_dir.mkdir()
|
||||
monkeypatch.setenv("VIDEO_UPLOAD_DIR", str(upload_dir))
|
||||
monkeypatch.setenv("MAX_VIDEO_SIZE_MB", "50")
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
from app.core.config import get_settings
|
||||
get_settings.cache_clear()
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
client = TestClient(app)
|
||||
|
||||
# Upload a video first
|
||||
resp = client.post(
|
||||
"/api/v1/video/upload",
|
||||
files={"file": ("test.mp4", b"\x00" * 512, "video/mp4")},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
video_id = resp.json()["video_id"]
|
||||
|
||||
# Try to transcribe
|
||||
resp = client.post(f"/api/v1/video/{video_id}/transcribe")
|
||||
assert resp.status_code == 500
|
||||
assert "DASHSCOPE_API_KEY" in resp.json()["detail"] or "API key" in resp.json()["detail"]
|
||||
|
|
@ -1,28 +1,73 @@
|
|||
"""Phase 2 tests: WebSocket ASR streaming.
|
||||
"""Phase 2 tests: WebSocket ASR endpoint.
|
||||
|
||||
Covers:
|
||||
- /ws/asr/{video_id} connection lifecycle
|
||||
- Real-time audio chunk streaming
|
||||
- Transcript accumulation
|
||||
- Connection cleanup on disconnect
|
||||
- WebSocket connection accepted
|
||||
- Language parameter defaults and customization
|
||||
- Client disconnect handled cleanly
|
||||
- Missing API key returns error
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestWebSocketASR:
|
||||
"""WebSocket ASR streaming tests."""
|
||||
@pytest.fixture
|
||||
def ws_app(monkeypatch):
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "sk-test-key")
|
||||
from app.core.config import get_settings
|
||||
from app.routers.ws_asr import router
|
||||
get_settings.cache_clear()
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_connection_established(self):
|
||||
"""Should accept WebSocket connection with valid video_id."""
|
||||
pass # TODO: implement
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_audio_chunk_streaming(self):
|
||||
"""Should process audio chunks and return transcripts."""
|
||||
pass # TODO: implement
|
||||
@pytest.fixture
|
||||
def ws_client(ws_app):
|
||||
return TestClient(ws_app)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_disconnect_cleanup(self):
|
||||
"""Should clean up resources on client disconnect."""
|
||||
pass # TODO: implement
|
||||
|
||||
class TestWSEndpointAcceptsConnection:
|
||||
def test_connect_success(self, ws_client):
|
||||
with ws_client.websocket_connect("/ws/asr/test-video-123") as ws:
|
||||
pass
|
||||
|
||||
def test_connect_with_video_id(self, ws_client):
|
||||
with ws_client.websocket_connect("/ws/asr/my-video-id") as ws:
|
||||
pass
|
||||
|
||||
|
||||
class TestWSEndpointLanguageParam:
|
||||
def test_default_language_is_yue(self, ws_client):
|
||||
with ws_client.websocket_connect("/ws/asr/test-vid") as ws:
|
||||
pass
|
||||
|
||||
def test_custom_language_param(self, ws_client):
|
||||
with ws_client.websocket_connect("/ws/asr/test-vid?language=zh") as ws:
|
||||
pass
|
||||
|
||||
def test_english_language_param(self, ws_client):
|
||||
with ws_client.websocket_connect("/ws/asr/test-vid?language=en") as ws:
|
||||
pass
|
||||
|
||||
|
||||
class TestWSEndpointHandlesDisconnect:
|
||||
def test_clean_disconnect(self, ws_client):
|
||||
with ws_client.websocket_connect("/ws/asr/test-vid") as ws:
|
||||
pass
|
||||
|
||||
|
||||
class TestWSEndpointMissingApiKey:
|
||||
def test_missing_api_key_sends_error(self, monkeypatch):
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "")
|
||||
from app.core.config import get_settings
|
||||
from app.routers.ws_asr import router
|
||||
get_settings.cache_clear()
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
with client.websocket_connect("/ws/asr/test-vid") as ws:
|
||||
msg = ws.receive_json()
|
||||
assert "error" in msg or "detail" in msg
|
||||
|
|
|
|||
|
|
@ -0,0 +1,152 @@
|
|||
"""Phase 2 tests: DashScope WebSocket protocol — callback bridge and event formatting.
|
||||
|
||||
Covers:
|
||||
- DashScopeCallback sync→async queue bridge
|
||||
- Transcription text event formatting (partial)
|
||||
- Transcription completed event formatting (final)
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDashScopeCallback:
|
||||
def test_puts_events_on_queue(self):
|
||||
"""on_event should push parsed JSON onto the asyncio queue."""
|
||||
from app.routers.ws_asr import DashScopeCallback
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.new_event_loop()
|
||||
callback = DashScopeCallback(queue, loop)
|
||||
|
||||
test_event = json.dumps({"type": "test", "data": "hello"})
|
||||
callback.on_event(test_event)
|
||||
|
||||
# Give the loop a chance to process call_soon_threadsafe
|
||||
loop.run_until_complete(asyncio.sleep(0.05))
|
||||
|
||||
assert not queue.empty()
|
||||
event = queue.get_nowait()
|
||||
assert event["type"] == "test"
|
||||
assert event["data"] == "hello"
|
||||
loop.close()
|
||||
|
||||
def test_handles_dict_event(self):
|
||||
"""on_event should accept dict as well as string."""
|
||||
from app.routers.ws_asr import DashScopeCallback
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.new_event_loop()
|
||||
callback = DashScopeCallback(queue, loop)
|
||||
|
||||
callback.on_event({"type": "test_dict", "key": "value"})
|
||||
|
||||
loop.run_until_complete(asyncio.sleep(0.05))
|
||||
|
||||
assert not queue.empty()
|
||||
event = queue.get_nowait()
|
||||
assert event["type"] == "test_dict"
|
||||
loop.close()
|
||||
|
||||
def test_handles_invalid_json_gracefully(self):
|
||||
"""on_event should not crash on invalid JSON."""
|
||||
from app.routers.ws_asr import DashScopeCallback
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.new_event_loop()
|
||||
callback = DashScopeCallback(queue, loop)
|
||||
|
||||
# Should not raise
|
||||
callback.on_event("not-valid-json{{{")
|
||||
loop.close()
|
||||
|
||||
def test_on_open_and_close(self):
|
||||
"""on_open and on_close should not crash."""
|
||||
from app.routers.ws_asr import DashScopeCallback
|
||||
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.new_event_loop()
|
||||
callback = DashScopeCallback(queue, loop)
|
||||
|
||||
callback.on_open()
|
||||
callback.on_close(1000, "normal")
|
||||
loop.close()
|
||||
|
||||
|
||||
class TestProxyFormatsTranscriptionTextEvent:
|
||||
def test_partial_event_format(self):
|
||||
"""Partial transcription event should format as ASRTranscriptEvent with is_final=False."""
|
||||
from app.routers.ws_asr import format_transcription_event
|
||||
|
||||
event = {
|
||||
"type": "conversation.item.input_audio_transcription.text",
|
||||
"stash": "你好",
|
||||
"language": "yue",
|
||||
}
|
||||
accumulated = ""
|
||||
|
||||
result = format_transcription_event(event, accumulated)
|
||||
assert result is not None
|
||||
assert result["is_final"] is False
|
||||
assert result["language"] == "yue"
|
||||
assert result["delta"] == ""
|
||||
assert "你好" in result["full_text"]
|
||||
|
||||
def test_partial_with_accumulated(self):
|
||||
"""Partial event should combine accumulated + current stash."""
|
||||
from app.routers.ws_asr import format_transcription_event
|
||||
|
||||
event = {
|
||||
"type": "conversation.item.input_audio_transcription.text",
|
||||
"stash": "世界",
|
||||
"language": "yue",
|
||||
}
|
||||
accumulated = "你好"
|
||||
|
||||
result = format_transcription_event(event, accumulated)
|
||||
assert "你好" in result["full_text"]
|
||||
assert "世界" in result["full_text"]
|
||||
|
||||
|
||||
class TestProxyFormatsTranscriptionCompletedEvent:
|
||||
def test_completed_event_format(self):
|
||||
"""Completed event should format as ASRTranscriptEvent with is_final=True."""
|
||||
from app.routers.ws_asr import format_transcription_event
|
||||
|
||||
event = {
|
||||
"type": "conversation.item.input_audio_transcription.completed",
|
||||
"transcript": "你好世界",
|
||||
"language": "yue",
|
||||
}
|
||||
accumulated = ""
|
||||
|
||||
result = format_transcription_event(event, accumulated)
|
||||
assert result is not None
|
||||
assert result["is_final"] is True
|
||||
assert result["language"] == "yue"
|
||||
assert "你好" in result["full_text"]
|
||||
|
||||
def test_completed_updates_accumulated(self):
|
||||
"""Completed event should return updated accumulated text."""
|
||||
from app.routers.ws_asr import format_transcription_event
|
||||
|
||||
event = {
|
||||
"type": "conversation.item.input_audio_transcription.completed",
|
||||
"transcript": "世界",
|
||||
"language": "yue",
|
||||
}
|
||||
accumulated = "你好"
|
||||
|
||||
result = format_transcription_event(event, accumulated)
|
||||
assert "你好" in result["full_text"]
|
||||
assert "世界" in result["full_text"]
|
||||
|
||||
def test_unknown_event_type_returns_none(self):
|
||||
"""Unknown event types should return None."""
|
||||
from app.routers.ws_asr import format_transcription_event
|
||||
|
||||
event = {"type": "unknown.event"}
|
||||
result = format_transcription_event(event, "")
|
||||
assert result is None
|
||||
|
|
@ -3,11 +3,16 @@ import React, { useState, type FormEvent, type KeyboardEvent } from 'react'
|
|||
export interface QueryInputProps {
|
||||
onSubmit: (question: string) => void
|
||||
isLoading: boolean
|
||||
partialText?: string
|
||||
}
|
||||
|
||||
export const QueryInput: React.FC<QueryInputProps> = ({ onSubmit, isLoading }) => {
|
||||
export const QueryInput: React.FC<QueryInputProps> = ({ onSubmit, isLoading, partialText }) => {
|
||||
const [question, setQuestion] = useState<string>('')
|
||||
const [submittedQuestion, setSubmittedQuestion] = useState<string | null>(null)
|
||||
const [hasUserInput, setHasUserInput] = useState(false)
|
||||
|
||||
const displayValue = hasUserInput ? question : (partialText ?? question)
|
||||
const showPartialStyle = !hasUserInput && !!partialText
|
||||
|
||||
const handleSubmit = (e: FormEvent): void => {
|
||||
e.preventDefault()
|
||||
|
|
@ -16,6 +21,7 @@ export const QueryInput: React.FC<QueryInputProps> = ({ onSubmit, isLoading }) =
|
|||
onSubmit(trimmed)
|
||||
setSubmittedQuestion(trimmed)
|
||||
setQuestion('')
|
||||
setHasUserInput(false)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -28,6 +34,7 @@ export const QueryInput: React.FC<QueryInputProps> = ({ onSubmit, isLoading }) =
|
|||
|
||||
const handleChange = (e: React.ChangeEvent<HTMLTextAreaElement>): void => {
|
||||
setQuestion(e.target.value)
|
||||
setHasUserInput(true)
|
||||
if (e.target.value.trim() !== '') {
|
||||
setSubmittedQuestion(null)
|
||||
}
|
||||
|
|
@ -35,16 +42,21 @@ export const QueryInput: React.FC<QueryInputProps> = ({ onSubmit, isLoading }) =
|
|||
|
||||
const isDisabled = isLoading || question.trim() === ''
|
||||
|
||||
const textareaClassName = [
|
||||
'w-full rounded border border-gray-300 px-3 py-2 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent disabled:bg-gray-100 disabled:cursor-not-allowed',
|
||||
showPartialStyle ? 'text-gray-400 italic' : '',
|
||||
].filter(Boolean).join(' ')
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="space-y-3">
|
||||
<textarea
|
||||
value={question}
|
||||
value={displayValue}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Ask a question about your documents..."
|
||||
disabled={isLoading}
|
||||
rows={3}
|
||||
className="w-full rounded border border-gray-300 px-3 py-2 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent disabled:bg-gray-100 disabled:cursor-not-allowed"
|
||||
className={textareaClassName}
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
|
|
|
|||
|
|
@ -0,0 +1,35 @@
|
|||
import { useState, useCallback } from 'react'
|
||||
|
||||
interface UseFullTranscriptOptions {
|
||||
videoId: string
|
||||
}
|
||||
|
||||
export function useFullTranscript({ videoId }: UseFullTranscriptOptions) {
|
||||
const [fullTranscript, setFullTranscript] = useState('')
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
|
||||
const requestFullTranscript = useCallback(async () => {
|
||||
setIsLoading(true)
|
||||
setError(null)
|
||||
try {
|
||||
const resp = await fetch(`/api/v1/video/${videoId}/transcribe`, {
|
||||
method: 'POST',
|
||||
})
|
||||
if (!resp.ok) {
|
||||
throw new Error(`Server returned ${resp.status}`)
|
||||
}
|
||||
const data = await resp.json()
|
||||
setFullTranscript(data.text)
|
||||
return data.text
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : 'Transcription failed'
|
||||
setError(msg)
|
||||
return null
|
||||
} finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [videoId])
|
||||
|
||||
return { fullTranscript, isLoading, error, requestFullTranscript }
|
||||
}
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
import { useState, useRef, useCallback, useEffect } from 'react'
|
||||
import type { ASRMessage, ASRStatus } from '../types'
|
||||
|
||||
interface UseVideoASROptions {
|
||||
videoId: string
|
||||
videoElement: HTMLVideoElement | null
|
||||
language?: string
|
||||
onFinalTranscript?: (text: string) => void
|
||||
}
|
||||
|
||||
export function useVideoASR({
|
||||
videoId,
|
||||
videoElement,
|
||||
language = 'yue',
|
||||
onFinalTranscript,
|
||||
}: UseVideoASROptions) {
|
||||
const [transcript, setTranscript] = useState('')
|
||||
const [partialTranscript, setPartialTranscript] = useState('')
|
||||
const [status, setStatus] = useState<ASRStatus>('idle')
|
||||
const [isStreaming, setIsStreaming] = useState(false)
|
||||
|
||||
const wsRef = useRef<WebSocket | null>(null)
|
||||
const audioContextRef = useRef<AudioContext | null>(null)
|
||||
const processorRef = useRef<ScriptProcessorNode | null>(null)
|
||||
const sourceRef = useRef<MediaElementAudioSourceNode | null>(null)
|
||||
const isStreamingRef = useRef(false)
|
||||
|
||||
const getWSURL = useCallback(() => {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'
|
||||
const host = window.location.host
|
||||
const langParam = language !== 'auto' ? `?language=${language}` : ''
|
||||
return `${protocol}//${host}/ws/asr/${videoId}${langParam}`
|
||||
}, [videoId, language])
|
||||
|
||||
const startStreaming = useCallback(() => {
|
||||
if (!videoElement) return
|
||||
try {
|
||||
setStatus('connecting')
|
||||
|
||||
const audioContext = new AudioContext({ sampleRate: 16000 })
|
||||
audioContextRef.current = audioContext
|
||||
|
||||
const source = audioContext.createMediaElementSource(videoElement)
|
||||
sourceRef.current = source
|
||||
|
||||
const processor = audioContext.createScriptProcessor(4096, 1, 1)
|
||||
processorRef.current = processor
|
||||
|
||||
const ws = new WebSocket(getWSURL())
|
||||
wsRef.current = ws
|
||||
|
||||
ws.onopen = () => {
|
||||
isStreamingRef.current = true
|
||||
setIsStreaming(true)
|
||||
setStatus('streaming')
|
||||
}
|
||||
|
||||
ws.onmessage = (e) => {
|
||||
const msg: ASRMessage = JSON.parse(e.data)
|
||||
setTranscript(msg.full_text)
|
||||
setPartialTranscript(msg.is_final ? '' : msg.full_text)
|
||||
if (msg.is_final && msg.full_text.trim()) {
|
||||
onFinalTranscript?.(msg.full_text)
|
||||
}
|
||||
}
|
||||
|
||||
ws.onerror = () => setStatus('error')
|
||||
ws.onclose = () => {
|
||||
isStreamingRef.current = false
|
||||
setIsStreaming(false)
|
||||
setStatus('disconnected')
|
||||
}
|
||||
|
||||
processor.onaudioprocess = (e) => {
|
||||
if (!isStreamingRef.current) return
|
||||
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return
|
||||
const float32Data = e.inputBuffer.getChannelData(0)
|
||||
wsRef.current.send(float32Data.buffer)
|
||||
}
|
||||
|
||||
source.connect(processor)
|
||||
processor.connect(audioContext.destination)
|
||||
} catch {
|
||||
setStatus('error')
|
||||
}
|
||||
}, [videoElement, getWSURL, onFinalTranscript])
|
||||
|
||||
const stopStreaming = useCallback(() => {
|
||||
isStreamingRef.current = false
|
||||
setIsStreaming(false)
|
||||
processorRef.current?.disconnect()
|
||||
processorRef.current = null
|
||||
sourceRef.current?.disconnect()
|
||||
sourceRef.current = null
|
||||
wsRef.current?.close()
|
||||
wsRef.current = null
|
||||
audioContextRef.current?.close()
|
||||
audioContextRef.current = null
|
||||
setStatus('idle')
|
||||
setPartialTranscript('')
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
isStreamingRef.current = false
|
||||
processorRef.current?.disconnect()
|
||||
sourceRef.current?.disconnect()
|
||||
wsRef.current?.close()
|
||||
audioContextRef.current?.close()
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (!videoElement) return
|
||||
const onPlay = () => startStreaming()
|
||||
const onPause = () => stopStreaming()
|
||||
const onEnded = () => stopStreaming()
|
||||
videoElement.addEventListener('play', onPlay)
|
||||
videoElement.addEventListener('pause', onPause)
|
||||
videoElement.addEventListener('ended', onEnded)
|
||||
return () => {
|
||||
videoElement.removeEventListener('play', onPlay)
|
||||
videoElement.removeEventListener('pause', onPause)
|
||||
videoElement.removeEventListener('ended', onEnded)
|
||||
}
|
||||
}, [videoElement, startStreaming, stopStreaming])
|
||||
|
||||
return {
|
||||
transcript,
|
||||
partialTranscript,
|
||||
isStreaming,
|
||||
status,
|
||||
startStreaming,
|
||||
stopStreaming,
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import axios from 'axios'
|
||||
import type { QueryRequest, QueryResponse, QueryStreamEvent, IngestResponse, DocumentListResponse, ChunkInfo, DeleteResponse, PromptProfileListResponse, PromptSetResponse, PromptUpdateRequest, PromptBatchUpdateRequest, PromptActivateResponse, PromptStatusResponse, ProfileExportData, ProfileImportResponse, QueryHistoryList, QueryHistoryDetail, HistoryStats, HistoryDeleteResponse } from '../types'
|
||||
import type { QueryRequest, QueryResponse, QueryStreamEvent, IngestResponse, DocumentListResponse, ChunkInfo, DeleteResponse, PromptProfileListResponse, PromptSetResponse, PromptUpdateRequest, PromptBatchUpdateRequest, PromptActivateResponse, PromptStatusResponse, ProfileExportData, ProfileImportResponse, QueryHistoryList, QueryHistoryDetail, HistoryStats, HistoryDeleteResponse, FullTranscriptResponse, VideoUploadResponse } from '../types'
|
||||
|
||||
const BASE_URL: string = import.meta.env.VITE_API_BASE_URL ?? 'http://localhost:8000/api/v1'
|
||||
|
||||
|
|
@ -153,3 +153,24 @@ export const getHistoryStats = async (): Promise<HistoryStats> => {
|
|||
const resp = await apiClient.get<HistoryStats>('/history/stats')
|
||||
return resp.data
|
||||
}
|
||||
|
||||
export const uploadVideo = async (file: File, onProgress?: (pct: number) => void): Promise<VideoUploadResponse> => {
|
||||
const form = new FormData()
|
||||
form.append('file', file)
|
||||
const resp = await apiClient.post<VideoUploadResponse>('/video/upload', form, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
onUploadProgress: onProgress
|
||||
? (e) => {
|
||||
if (e.total) onProgress(Math.round((e.loaded * 100) / e.total))
|
||||
}
|
||||
: undefined,
|
||||
})
|
||||
return resp.data
|
||||
}
|
||||
|
||||
export const getVideoUrl = (videoId: string): string => `${BASE_URL}/video/${videoId}`
|
||||
|
||||
export const requestFullTranscript = async (videoId: string): Promise<FullTranscriptResponse> => {
|
||||
const resp = await apiClient.post<FullTranscriptResponse>(`/video/${videoId}/transcribe`)
|
||||
return resp.data
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import React from 'react'
|
||||
import { QueryClient, QueryClientProvider, useMutation, useQuery, useQueryClient } from '@tanstack/react-query'
|
||||
import { queryDocument, queryDocumentStream, ingestDocument, listDocuments, listChunks, deleteDocument, deleteChunk, listPromptProfiles, getPromptProfile, activatePromptProfile, updatePrompt, updateAllPrompts, resetPrompts, exportProfile, importProfile, listQueryHistory, getQueryHistoryDetail, deleteQueryHistory, clearQueryHistory, getHistoryStats } from './api'
|
||||
import type { QueryRequest, QueryResponse, QueryStreamEvent, SourceMetadata, SubQuestionSources, IngestResponse, DocumentListResponse, ChunkInfo, DeleteResponse, PromptProfileListResponse, PromptSetResponse, PromptUpdateRequest, PromptBatchUpdateRequest, PromptActivateResponse, PromptStatusResponse, ProfileExportData, ProfileImportResponse, QueryHistoryList, QueryHistoryDetail, HistoryStats, HistoryDeleteResponse } from '../types'
|
||||
import { queryDocument, queryDocumentStream, ingestDocument, listDocuments, listChunks, deleteDocument, deleteChunk, listPromptProfiles, getPromptProfile, activatePromptProfile, updatePrompt, updateAllPrompts, resetPrompts, exportProfile, importProfile, listQueryHistory, getQueryHistoryDetail, deleteQueryHistory, clearQueryHistory, getHistoryStats, uploadVideo } from './api'
|
||||
import type { QueryRequest, QueryResponse, QueryStreamEvent, SourceMetadata, SubQuestionSources, IngestResponse, DocumentListResponse, ChunkInfo, DeleteResponse, PromptProfileListResponse, PromptSetResponse, PromptUpdateRequest, PromptBatchUpdateRequest, PromptActivateResponse, PromptStatusResponse, ProfileExportData, ProfileImportResponse, QueryHistoryList, QueryHistoryDetail, HistoryStats, HistoryDeleteResponse, VideoUploadResponse } from '../types'
|
||||
import { useState, useCallback, useRef } from 'react'
|
||||
|
||||
export const queryClient = new QueryClient()
|
||||
|
|
@ -268,3 +268,9 @@ export const useHistoryStats = () => {
|
|||
export const AppQueryProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => {
|
||||
return <QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
}
|
||||
|
||||
export const useVideoUpload = () => {
|
||||
return useMutation<VideoUploadResponse, Error, { file: File; onProgress?: (pct: number) => void }>({
|
||||
mutationFn: ({ file, onProgress }) => uploadVideo(file, onProgress),
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
/**
|
||||
* Phase 2.4 tests: QueryInput with partialText prop integration.
|
||||
*
|
||||
* Covers:
|
||||
* - partialText rendered as grey italic when no user input
|
||||
* - User typing overrides partialText with black text
|
||||
* - Submit clears input, partialText reappears
|
||||
* - Existing QueryInput behavior preserved
|
||||
*/
|
||||
import React from 'react'
|
||||
import { render, screen, fireEvent } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { QueryInput } from '../components/QueryInput'
|
||||
|
||||
describe('QueryInput with partialText (Phase 2.4)', () => {
|
||||
const mockOnSubmit = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
mockOnSubmit.mockClear()
|
||||
})
|
||||
|
||||
it('test_renders_partialText_when_no_user_input', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="This is ASR partial text"
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
// When partialText is provided and no user input, textarea value shows partial text
|
||||
expect(textarea).toHaveValue('This is ASR partial text')
|
||||
// The textarea should have grey italic styling
|
||||
expect(textarea).toHaveClass('text-gray-400')
|
||||
expect(textarea).toHaveClass('italic')
|
||||
})
|
||||
|
||||
it('test_partialText_not_shown_without_prop', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
expect(textarea).toHaveValue('')
|
||||
// No grey italic class
|
||||
expect(textarea).not.toHaveClass('text-gray-400')
|
||||
expect(textarea).not.toHaveClass('italic')
|
||||
})
|
||||
|
||||
it('test_user_input_overrides_partialText', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="ASR text here"
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
|
||||
// Initially shows partial text with grey italic
|
||||
expect(textarea).toHaveValue('ASR text here')
|
||||
expect(textarea).toHaveClass('text-gray-400')
|
||||
|
||||
// User types — should show their text in normal style
|
||||
fireEvent.change(textarea, { target: { value: 'My real question' } })
|
||||
expect(textarea).toHaveValue('My real question')
|
||||
// Grey italic should be removed when user types
|
||||
expect(textarea).not.toHaveClass('text-gray-400')
|
||||
expect(textarea).not.toHaveClass('italic')
|
||||
})
|
||||
|
||||
it('test_submit_clears_input_and_partialText_reappears', () => {
|
||||
const { rerender } = render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="ASR streaming text"
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
const button = screen.getByRole('button', { name: /submit/i })
|
||||
|
||||
// User types their own text
|
||||
fireEvent.change(textarea, { target: { value: 'My question' } })
|
||||
expect(textarea).not.toHaveClass('text-gray-400')
|
||||
|
||||
// Submit
|
||||
fireEvent.click(button)
|
||||
expect(mockOnSubmit).toHaveBeenCalledWith('My question')
|
||||
|
||||
// After submit, textarea is cleared, partialText should reappear
|
||||
expect(textarea).toHaveValue('ASR streaming text')
|
||||
expect(textarea).toHaveClass('text-gray-400')
|
||||
expect(textarea).toHaveClass('italic')
|
||||
})
|
||||
|
||||
it('test_empty_partialText_does_not_show', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText=""
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
expect(textarea).toHaveValue('')
|
||||
// Should not have grey italic class for empty partialText
|
||||
expect(textarea).not.toHaveClass('text-gray-400')
|
||||
})
|
||||
|
||||
it('test_partialText_updates_dynamically', () => {
|
||||
const { rerender } = render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="First partial"
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
expect(textarea).toHaveValue('First partial')
|
||||
|
||||
// Update partialText from ASR stream
|
||||
rerender(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="First partial and more"
|
||||
/>
|
||||
)
|
||||
expect(textarea).toHaveValue('First partial and more')
|
||||
expect(textarea).toHaveClass('text-gray-400')
|
||||
})
|
||||
|
||||
it('test_submit_with_partialText_still_works', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="Some ASR text"
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
const button = screen.getByRole('button', { name: /submit/i })
|
||||
|
||||
// partialText is shown but user hasn't typed, button should be disabled
|
||||
// (partial text is not a real user input — it's just display)
|
||||
expect(button).toBeDisabled()
|
||||
})
|
||||
|
||||
it('test_user_can_submit_after_typing_over_partialText', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
partialText="ASR partial text"
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
const button = screen.getByRole('button', { name: /submit/i })
|
||||
|
||||
// User types over the partial text
|
||||
fireEvent.change(textarea, { target: { value: 'My actual question' } })
|
||||
expect(button).not.toBeDisabled()
|
||||
|
||||
// Submit works
|
||||
fireEvent.click(button)
|
||||
expect(mockOnSubmit).toHaveBeenCalledWith('My actual question')
|
||||
})
|
||||
|
||||
it('test_existing_behavior_preserved_without_partialText', () => {
|
||||
render(
|
||||
<QueryInput
|
||||
onSubmit={mockOnSubmit}
|
||||
isLoading={false}
|
||||
/>
|
||||
)
|
||||
|
||||
const textarea = screen.getByPlaceholderText('Ask a question about your documents...')
|
||||
const button = screen.getByRole('button', { name: /submit/i })
|
||||
|
||||
// Existing behavior: empty → disabled
|
||||
expect(button).toBeDisabled()
|
||||
|
||||
// Existing behavior: type and submit
|
||||
fireEvent.change(textarea, { target: { value: 'Test question' } })
|
||||
expect(button).not.toBeDisabled()
|
||||
fireEvent.click(button)
|
||||
expect(mockOnSubmit).toHaveBeenCalledWith('Test question')
|
||||
expect(textarea).toHaveValue('')
|
||||
|
||||
// Existing behavior: submitted question displayed
|
||||
expect(screen.getByTestId('submitted-question')).toHaveTextContent('Test question')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
/**
|
||||
* Phase 2.4 tests: useFullTranscript hook.
|
||||
*
|
||||
* Tests batch transcription hook with mocked fetch.
|
||||
* Covers: initial state, success, error, loading state.
|
||||
*/
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { useFullTranscript } from '../hooks/useFullTranscript'
|
||||
|
||||
describe('useFullTranscript', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
it('test_initial_state', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'test-video-id' })
|
||||
)
|
||||
|
||||
expect(result.current.fullTranscript).toBe('')
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
expect(result.current.error).toBeNull()
|
||||
})
|
||||
|
||||
it('test_requestFullTranscript_success', async () => {
|
||||
const mockResponse = {
|
||||
text: 'This is the full transcript of the video.',
|
||||
language: 'yue',
|
||||
duration_seconds: 120.5,
|
||||
}
|
||||
|
||||
vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => mockResponse,
|
||||
} as Response)
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'test-video-id' })
|
||||
)
|
||||
|
||||
let transcriptResult: string | null = null
|
||||
await act(async () => {
|
||||
transcriptResult = await result.current.requestFullTranscript()
|
||||
})
|
||||
|
||||
expect(result.current.fullTranscript).toBe('This is the full transcript of the video.')
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
expect(result.current.error).toBeNull()
|
||||
expect(transcriptResult).toBe('This is the full transcript of the video.')
|
||||
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
'/api/v1/video/test-video-id/transcribe',
|
||||
{ method: 'POST' }
|
||||
)
|
||||
})
|
||||
|
||||
it('test_requestFullTranscript_error', async () => {
|
||||
vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: async () => ({ detail: 'Internal server error' }),
|
||||
} as Response)
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'test-video-id' })
|
||||
)
|
||||
|
||||
let transcriptResult: string | null = null
|
||||
await act(async () => {
|
||||
transcriptResult = await result.current.requestFullTranscript()
|
||||
})
|
||||
|
||||
expect(result.current.fullTranscript).toBe('')
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
expect(result.current.error).toBe('Server returned 500')
|
||||
expect(transcriptResult).toBeNull()
|
||||
})
|
||||
|
||||
it('test_requestFullTranscript_network_error', async () => {
|
||||
vi.spyOn(globalThis, 'fetch').mockRejectedValueOnce(new Error('Network error'))
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'test-video-id' })
|
||||
)
|
||||
|
||||
let transcriptResult: string | null = null
|
||||
await act(async () => {
|
||||
transcriptResult = await result.current.requestFullTranscript()
|
||||
})
|
||||
|
||||
expect(result.current.fullTranscript).toBe('')
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
expect(result.current.error).toBe('Network error')
|
||||
expect(transcriptResult).toBeNull()
|
||||
})
|
||||
|
||||
it('test_requestFullTranscript_loading_state', async () => {
|
||||
let resolvePromise: (value: any) => void
|
||||
const fetchPromise = new Promise((resolve) => {
|
||||
resolvePromise = resolve
|
||||
})
|
||||
|
||||
vi.spyOn(globalThis, 'fetch').mockReturnValueOnce(fetchPromise as Promise<Response>)
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'test-video-id' })
|
||||
)
|
||||
|
||||
// Start the request
|
||||
act(() => {
|
||||
result.current.requestFullTranscript()
|
||||
})
|
||||
|
||||
// Should be loading
|
||||
expect(result.current.isLoading).toBe(true)
|
||||
expect(result.current.error).toBeNull()
|
||||
|
||||
// Resolve the fetch
|
||||
await act(async () => {
|
||||
resolvePromise!({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({
|
||||
text: 'Resolved transcript',
|
||||
language: 'yue',
|
||||
duration_seconds: null,
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
expect(result.current.isLoading).toBe(false)
|
||||
expect(result.current.fullTranscript).toBe('Resolved transcript')
|
||||
})
|
||||
|
||||
it('test_requestFullTranscript_uses_videoId_in_url', async () => {
|
||||
vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({
|
||||
text: 'Transcript',
|
||||
language: 'yue',
|
||||
duration_seconds: null,
|
||||
}),
|
||||
} as Response)
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'my-custom-video-123' })
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.requestFullTranscript()
|
||||
})
|
||||
|
||||
expect(globalThis.fetch).toHaveBeenCalledWith(
|
||||
'/api/v1/video/my-custom-video-123/transcribe',
|
||||
{ method: 'POST' }
|
||||
)
|
||||
})
|
||||
|
||||
it('test_requestFullTranscript_clears_previous_error_on_new_request', async () => {
|
||||
// First request fails
|
||||
vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: async () => ({ detail: 'Error' }),
|
||||
} as Response)
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useFullTranscript({ videoId: 'test-video-id' })
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.requestFullTranscript()
|
||||
})
|
||||
expect(result.current.error).toBe('Server returned 500')
|
||||
|
||||
// Second request succeeds — error should be cleared
|
||||
vi.spyOn(globalThis, 'fetch').mockResolvedValueOnce({
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({
|
||||
text: 'Success transcript',
|
||||
language: 'yue',
|
||||
duration_seconds: null,
|
||||
}),
|
||||
} as Response)
|
||||
|
||||
await act(async () => {
|
||||
await result.current.requestFullTranscript()
|
||||
})
|
||||
|
||||
expect(result.current.error).toBeNull()
|
||||
expect(result.current.fullTranscript).toBe('Success transcript')
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
/**
|
||||
* Phase 2.4 tests: useVideoASR hook state management.
|
||||
*
|
||||
* WebAudio (AudioContext, ScriptProcessorNode) and WebSocket are NOT available
|
||||
* in jsdom, so these tests verify state management, return shape, and cleanup
|
||||
* logic only. Full audio capture is covered by acceptance tests.
|
||||
*/
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import { useVideoASR } from '../hooks/useVideoASR'
|
||||
import type { ASRStatus } from '../types'
|
||||
|
||||
// Mock AudioContext, WebSocket, and video element APIs that don't exist in jsdom
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('useVideoASR', () => {
|
||||
it('test_initial_state', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.current.transcript).toBe('')
|
||||
expect(result.current.partialTranscript).toBe('')
|
||||
expect(result.current.isStreaming).toBe(false)
|
||||
expect(result.current.status).toBe<ASRStatus>('idle')
|
||||
})
|
||||
|
||||
it('test_returns_startStreaming_function', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
expect(typeof result.current.startStreaming).toBe('function')
|
||||
})
|
||||
|
||||
it('test_returns_stopStreaming_function', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
expect(typeof result.current.stopStreaming).toBe('function')
|
||||
})
|
||||
|
||||
it('test_stopStreaming_resets_state', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
// Call stopStreaming — should reset status and partial transcript
|
||||
act(() => {
|
||||
result.current.stopStreaming()
|
||||
})
|
||||
|
||||
expect(result.current.status).toBe<ASRStatus>('idle')
|
||||
expect(result.current.isStreaming).toBe(false)
|
||||
expect(result.current.partialTranscript).toBe('')
|
||||
})
|
||||
|
||||
it('test_startStreaming_without_video_element_does_not_throw', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
// Should not throw when no video element
|
||||
expect(() => {
|
||||
act(() => {
|
||||
result.current.startStreaming()
|
||||
})
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('test_startStreaming_with_no_video_sets_error_status', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
act(() => {
|
||||
result.current.startStreaming()
|
||||
})
|
||||
|
||||
// Without a video element, status should remain idle or go to error
|
||||
// (implementation may vary — key is no crash)
|
||||
expect(['idle', 'error']).toContain(result.current.status)
|
||||
})
|
||||
|
||||
it('test_cleanup_on_unmount', () => {
|
||||
const { result, unmount } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
})
|
||||
)
|
||||
|
||||
// Unmount should not throw
|
||||
expect(() => {
|
||||
unmount()
|
||||
}).not.toThrow()
|
||||
})
|
||||
|
||||
it('test_accepts_language_option', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
language: 'en',
|
||||
})
|
||||
)
|
||||
|
||||
// Hook should initialize fine with custom language
|
||||
expect(result.current.status).toBe<ASRStatus>('idle')
|
||||
})
|
||||
|
||||
it('test_accepts_onFinalTranscript_callback', () => {
|
||||
const onFinal = vi.fn()
|
||||
const { result } = renderHook(() =>
|
||||
useVideoASR({
|
||||
videoId: 'test-video-id',
|
||||
videoElement: null,
|
||||
onFinalTranscript: onFinal,
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.current.status).toBe<ASRStatus>('idle')
|
||||
})
|
||||
|
||||
it('test_status_type_covers_all_states', () => {
|
||||
const validStatuses: ASRStatus[] = ['idle', 'connecting', 'streaming', 'disconnected', 'error']
|
||||
// Just a compile-time assertion that our type covers all states
|
||||
expect(validStatuses).toHaveLength(5)
|
||||
expect(validStatuses).toContain('idle')
|
||||
expect(validStatuses).toContain('connecting')
|
||||
expect(validStatuses).toContain('streaming')
|
||||
expect(validStatuses).toContain('disconnected')
|
||||
expect(validStatuses).toContain('error')
|
||||
})
|
||||
})
|
||||
|
|
@ -170,3 +170,27 @@ export interface HistoryDeleteResponse {
|
|||
deleted_id?: number
|
||||
deleted_count?: number
|
||||
}
|
||||
|
||||
// Phase 2.4 — Video / ASR types
|
||||
|
||||
export interface ASRMessage {
|
||||
delta: string
|
||||
full_text: string
|
||||
language: string
|
||||
is_final: boolean
|
||||
}
|
||||
|
||||
export type ASRStatus = 'idle' | 'connecting' | 'streaming' | 'disconnected' | 'error'
|
||||
|
||||
export interface FullTranscriptResponse {
|
||||
text: string
|
||||
language: string
|
||||
duration_seconds: number | null
|
||||
}
|
||||
|
||||
export interface VideoUploadResponse {
|
||||
video_id: string
|
||||
filename: string
|
||||
size_bytes: number
|
||||
url: string
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue