diff --git a/backend/app/routers/ws_asr.py b/backend/app/routers/ws_asr.py index 39025cd..c3a7f92 100644 --- a/backend/app/routers/ws_asr.py +++ b/backend/app/routers/ws_asr.py @@ -2,12 +2,14 @@ import json import asyncio import base64 import logging +import struct import time from fastapi import APIRouter, WebSocket, WebSocketDisconnect from app.core.config import get_settings from app.services.asr_client import float32_to_s16le, build_display_text, _to_traditional +from app.services.asr_providers import OpenRouterASRProvider logger = logging.getLogger(__name__) @@ -83,6 +85,120 @@ def format_transcription_event(event: dict, accumulated: str) -> dict | None: return None +def pcm_to_wav(pcm_bytes: bytes, sample_rate: int = 16000, channels: int = 1, bits_per_sample: int = 16) -> bytes: + byte_rate = sample_rate * channels * bits_per_sample // 8 + block_align = channels * bits_per_sample // 8 + data_size = len(pcm_bytes) + header = struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", + 36 + data_size, + b"WAVE", + b"fmt ", + 16, + 1, # PCM + channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b"data", + data_size, + ) + return header + pcm_bytes + + +async def _ws_proxy_openrouter(client_ws: WebSocket, language: str = "yue"): + settings = get_settings() + session_start = time.monotonic() + + provider = OpenRouterASRProvider( + api_key=settings.openrouter_api_key, + base_url=settings.llm_base_url, + model=settings.asr_openrouter_model, + ) + logger.info( + "openrouter-ws-started model=%s url=%s language=%s", + settings.asr_openrouter_model, + provider._stt_url, + language, + ) + + accumulated_text = "" + audio_buffer = bytearray() + chunk_count = 0 + last_flush = time.monotonic() + flush_lock = asyncio.Lock() + + async def flush_chunk(): + nonlocal audio_buffer, accumulated_text, chunk_count, last_flush + if not audio_buffer: + return + + pcm_snapshot = bytes(audio_buffer) + audio_buffer.clear() + last_flush = time.monotonic() + chunk_count += 1 + + try: + wav_bytes = pcm_to_wav(pcm_snapshot) + logger.debug( + "openrouter-chunk-sending chunk=%d pcm_bytes=%d wav_bytes=%d", + chunk_count, len(pcm_snapshot), len(wav_bytes), + ) + text = await provider.transcribe(wav_bytes, language) + if text.strip(): + accumulated_text = build_display_text(accumulated_text, text) + await client_ws.send_json({ + "delta": "", + "full_text": _to_traditional(accumulated_text), + "language": language, + "is_final": True, + }) + logger.info( + "openrouter-chunk-completed chunk=%d text_len=%d total_len=%d", + chunk_count, len(text), len(accumulated_text), + ) + except Exception as e: + logger.error( + "openrouter-chunk-failed chunk=%d pcm_bytes=%d error=%s", + chunk_count, len(pcm_snapshot), e, + ) + + async def chunk_timer(): + while True: + await asyncio.sleep(3.0) + async with flush_lock: + if audio_buffer and (time.monotonic() - last_flush >= 3.0): + await flush_chunk() + + timer_task = asyncio.create_task(chunk_timer()) + + try: + while True: + float32_bytes = await client_ws.receive_bytes() + s16_bytes = float32_to_s16le(float32_bytes) + audio_buffer.extend(s16_bytes) + except WebSocketDisconnect: + logger.info( + "openrouter-client-disconnected chunks=%d accumulated_len=%d", + chunk_count, len(accumulated_text), + ) + finally: + timer_task.cancel() + try: + async with flush_lock: + await flush_chunk() + except Exception: + pass + await provider.close() + duration = time.monotonic() - session_start + logger.info( + "openrouter-ws-closed chunks=%d text_len=%d duration=%.1fs", + chunk_count, len(accumulated_text), duration, + ) + + async def _ws_proxy_dashscope(client_ws: WebSocket, loop: asyncio.AbstractEventLoop, language: str = "yue"): event_queue: asyncio.Queue = asyncio.Queue() callback = DashScopeCallback(event_queue, loop) @@ -213,13 +329,6 @@ async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = " settings = get_settings() client_host = websocket.client.host if websocket.client else "unknown" - if not settings.dashscope_api_key: - await websocket.accept() - await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"}) - await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set") - logger.warning("ws-rejected-no-apikey video_id=%s client=%s", video_id, client_host) - return - if source == "system-audio" and not settings.system_audio_enabled: await websocket.accept() await websocket.send_json({"error": "System audio capture is disabled"}) @@ -234,12 +343,33 @@ async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = " logger.warning("ws-rejected-mic-disabled video_id=%s client=%s", video_id, client_host) return + if settings.asr_provider == "openrouter": + if not settings.openrouter_api_key: + await websocket.accept() + await websocket.send_json({"error": "OPENROUTER_API_KEY is not configured"}) + await websocket.close(code=1011, reason="OPENROUTER_API_KEY not set") + logger.warning("ws-rejected-no-openrouter-key video_id=%s client=%s", video_id, client_host) + return + else: + if not settings.dashscope_api_key: + await websocket.accept() + await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"}) + await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set") + logger.warning("ws-rejected-no-apikey video_id=%s client=%s", video_id, client_host) + return + await websocket.accept() loop = asyncio.get_event_loop() - logger.info("ws-connect video_id=%s lang=%s source=%s client=%s", video_id, language, source, client_host) + logger.info( + "ws-connect video_id=%s lang=%s source=%s client=%s provider=%s", + video_id, language, source, client_host, settings.asr_provider, + ) try: - await _ws_proxy_dashscope(websocket, loop, language) + if settings.asr_provider == "openrouter": + await _ws_proxy_openrouter(websocket, language) + else: + await _ws_proxy_dashscope(websocket, loop, language) except Exception as e: logger.error("ws-asr-error video_id=%s error=%s", video_id, e) try: