legco_ai_assistant/backend/app/routers/ws_asr.py

204 lines
7.1 KiB
Python

import json
import asyncio
import base64
import logging
import time
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from app.core.config import get_settings
from app.services.asr_client import float32_to_s16le, build_display_text, _to_traditional
logger = logging.getLogger(__name__)
router = APIRouter(tags=["asr"])
try:
from dashscope.audio.qwen_omni.omni_realtime import (
OmniRealtimeConversation,
OmniRealtimeCallback,
TranscriptionParams,
MultiModality,
)
except ImportError:
OmniRealtimeConversation = None
OmniRealtimeCallback = object
TranscriptionParams = None
MultiModality = None
class DashScopeCallback(OmniRealtimeCallback):
def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop):
super().__init__()
self._queue = event_queue
self._loop = loop
def on_open(self):
logger.info("dashscope-connection-opened")
def on_event(self, message):
try:
event = json.loads(message) if isinstance(message, str) else message
event_type = event.get("type", "") if isinstance(event, dict) else ""
logger.debug("dashscope-event-received type=%s", event_type)
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
except Exception as e:
logger.error("dashscope-callback-error error=%s", e)
def on_close(self, code, msg):
logger.info("dashscope-connection-closed code=%s msg=%s", code, msg)
def _merge_stash(partial_buffer: str, new_stash: str) -> str:
if not new_stash.strip():
return partial_buffer
if not partial_buffer:
return new_stash
for i in range(min(len(partial_buffer), len(new_stash)), 0, -1):
if partial_buffer[-i:] == new_stash[:i]:
return partial_buffer + new_stash[i:]
return partial_buffer + " " + new_stash
def format_transcription_event(event: dict, accumulated: str) -> dict | None:
event_type = event.get("type", "")
if event_type == "conversation.item.input_audio_transcription.text":
stash = event.get("stash", "")
return {
"delta": "",
"stash": stash,
"language": event.get("language", "yue"),
"is_final": False,
}
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = event.get("transcript", "")
new_accumulated = build_display_text(accumulated, transcript) if transcript and transcript.strip() else accumulated
return {
"delta": "",
"full_text": _to_traditional(new_accumulated),
"language": event.get("language", "yue"),
"is_final": True,
}
return None
async def _ws_proxy_dashscope(client_ws: WebSocket, loop: asyncio.AbstractEventLoop, language: str = "yue"):
event_queue: asyncio.Queue = asyncio.Queue()
callback = DashScopeCallback(event_queue, loop)
session_start = time.monotonic()
conversation = OmniRealtimeConversation(
model=get_settings().asr_realtime_model_name,
api_key=get_settings().dashscope_api_key,
url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime",
callback=callback,
)
await loop.run_in_executor(None, conversation.connect)
transcription_kwargs: dict = {
"sample_rate": 16000,
"input_audio_format": "pcm",
}
if language != "auto":
transcription_kwargs["language"] = language
transcription_params = TranscriptionParams(**transcription_kwargs)
conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
)
logger.info("dashscope-session-updated lang=%s", language)
accumulated_text = ""
partial_buffer = ""
chunk_count = 0
async def read_events():
nonlocal accumulated_text, partial_buffer
while True:
event = await event_queue.get()
result = format_transcription_event(event, accumulated_text)
if result is None:
continue
if result["is_final"]:
transcript = event.get("transcript", "")
if transcript and transcript.strip():
accumulated_text = build_display_text(accumulated_text, transcript)
partial_buffer = ""
result["full_text"] = _to_traditional(accumulated_text)
logger.info("dashscope-utterance-completed text_len=%d lang=%s", len(accumulated_text), result.get("language", "yue"))
else:
stash = result.pop("stash", "")
if stash.strip():
partial_buffer = _merge_stash(partial_buffer, stash)
display = build_display_text(accumulated_text, partial_buffer)
result["full_text"] = _to_traditional(display)
await client_ws.send_json(result)
read_task = asyncio.create_task(read_events())
try:
while True:
float32_bytes = await client_ws.receive_bytes()
s16_bytes = float32_to_s16le(float32_bytes)
audio_b64 = base64.b64encode(s16_bytes).decode("ascii")
conversation.append_audio(audio_b64)
chunk_count += 1
logger.debug(
"audio-chunk-received size_bytes=%d sample_count=%d chunk_num=%d",
len(float32_bytes),
len(float32_bytes) // 4,
chunk_count,
)
except WebSocketDisconnect:
logger.warning(
"client-disconnected-mid-session chunks=%d accumulated_len=%d",
chunk_count,
len(accumulated_text),
)
finally:
read_task.cancel()
try:
conversation.close()
except Exception:
pass
duration = time.monotonic() - session_start
logger.info(
"dashscope-session-closed text_len=%d chunks=%d duration=%.1fs",
len(accumulated_text),
chunk_count,
duration,
)
@router.websocket("/ws/asr/{video_id}")
async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue"):
settings = get_settings()
client_host = websocket.client.host if websocket.client else "unknown"
if not settings.dashscope_api_key:
await websocket.accept()
await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"})
await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set")
logger.warning("ws-rejected-no-apikey video_id=%s client=%s", video_id, client_host)
return
await websocket.accept()
loop = asyncio.get_event_loop()
logger.info("ws-connect video_id=%s lang=%s client=%s", video_id, language, client_host)
try:
await _ws_proxy_dashscope(websocket, loop, language)
except Exception as e:
logger.error("ws-asr-error video_id=%s error=%s", video_id, e)
try:
await websocket.send_json({"error": "ASR service unavailable", "detail": str(e)})
except Exception:
pass
finally:
logger.info("ws-disconnect video_id=%s", video_id)