legco_ai_assistant/backend/app/routers/ws_asr.py

381 lines
14 KiB
Python

import json
import asyncio
import base64
import logging
import struct
import time
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from app.core.config import get_settings
from app.services.asr_client import float32_to_s16le, build_display_text, _to_traditional
from app.services.asr_providers import OpenRouterASRProvider
logger = logging.getLogger(__name__)
_stash_logger = logging.getLogger("ws_asr.stash")
_stash_logger.propagate = False
_stash_handler = logging.FileHandler("app/log/stash.log")
_stash_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s"))
_stash_logger.addHandler(_stash_handler)
_stash_logger.setLevel(logging.DEBUG)
router = APIRouter(tags=["asr"])
try:
from dashscope.audio.qwen_omni.omni_realtime import (
OmniRealtimeConversation,
OmniRealtimeCallback,
TranscriptionParams,
MultiModality,
)
except ImportError:
OmniRealtimeConversation = None
OmniRealtimeCallback = object
TranscriptionParams = None
MultiModality = None
class DashScopeCallback(OmniRealtimeCallback):
def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop):
super().__init__()
self._queue = event_queue
self._loop = loop
def on_open(self):
logger.info("dashscope-connection-opened")
def on_event(self, message):
try:
event = json.loads(message) if isinstance(message, str) else message
event_type = event.get("type", "") if isinstance(event, dict) else ""
logger.debug("dashscope-event-received type=%s", event_type)
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
except Exception as e:
logger.error("dashscope-callback-error error=%s", e)
def on_close(self, code, msg):
logger.info("dashscope-connection-closed code=%s msg=%s", code, msg)
def format_transcription_event(event: dict, accumulated: str) -> dict | None:
event_type = event.get("type", "")
if event_type == "conversation.item.input_audio_transcription.text":
text = event.get("text", "")
stash = event.get("stash", "")
return {
"delta": "",
"text": text,
"stash": stash,
"language": event.get("language", "yue"),
"is_final": False,
}
if event_type == "conversation.item.input_audio_transcription.completed":
transcript = event.get("transcript", "")
new_accumulated = build_display_text(accumulated, transcript) if transcript and transcript.strip() else accumulated
return {
"delta": "",
"full_text": _to_traditional(new_accumulated),
"language": event.get("language", "yue"),
"is_final": True,
}
return None
def pcm_to_wav(pcm_bytes: bytes, sample_rate: int = 16000, channels: int = 1, bits_per_sample: int = 16) -> bytes:
byte_rate = sample_rate * channels * bits_per_sample // 8
block_align = channels * bits_per_sample // 8
data_size = len(pcm_bytes)
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF",
36 + data_size,
b"WAVE",
b"fmt ",
16,
1, # PCM
channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b"data",
data_size,
)
return header + pcm_bytes
async def _ws_proxy_openrouter(client_ws: WebSocket, language: str = "yue"):
settings = get_settings()
session_start = time.monotonic()
provider = OpenRouterASRProvider(
api_key=settings.openrouter_api_key,
base_url=settings.llm_base_url,
model=settings.asr_openrouter_model,
)
logger.info(
"openrouter-ws-started model=%s url=%s language=%s",
settings.asr_openrouter_model,
provider._stt_url,
language,
)
accumulated_text = ""
audio_buffer = bytearray()
chunk_count = 0
last_flush = time.monotonic()
flush_lock = asyncio.Lock()
async def flush_chunk():
nonlocal audio_buffer, accumulated_text, chunk_count, last_flush
if not audio_buffer:
return
pcm_snapshot = bytes(audio_buffer)
audio_buffer.clear()
last_flush = time.monotonic()
chunk_count += 1
try:
wav_bytes = pcm_to_wav(pcm_snapshot)
logger.debug(
"openrouter-chunk-sending chunk=%d pcm_bytes=%d wav_bytes=%d",
chunk_count, len(pcm_snapshot), len(wav_bytes),
)
text = await provider.transcribe(wav_bytes, language)
if text.strip():
accumulated_text = build_display_text(accumulated_text, text)
await client_ws.send_json({
"delta": "",
"full_text": _to_traditional(accumulated_text),
"language": language,
"is_final": True,
})
logger.info(
"openrouter-chunk-completed chunk=%d text_len=%d total_len=%d",
chunk_count, len(text), len(accumulated_text),
)
except Exception as e:
logger.error(
"openrouter-chunk-failed chunk=%d pcm_bytes=%d error=%s",
chunk_count, len(pcm_snapshot), e,
)
async def chunk_timer():
while True:
await asyncio.sleep(3.0)
async with flush_lock:
if audio_buffer and (time.monotonic() - last_flush >= 3.0):
await flush_chunk()
timer_task = asyncio.create_task(chunk_timer())
try:
while True:
float32_bytes = await client_ws.receive_bytes()
s16_bytes = float32_to_s16le(float32_bytes)
audio_buffer.extend(s16_bytes)
except WebSocketDisconnect:
logger.info(
"openrouter-client-disconnected chunks=%d accumulated_len=%d",
chunk_count, len(accumulated_text),
)
finally:
timer_task.cancel()
try:
async with flush_lock:
await flush_chunk()
except Exception:
pass
await provider.close()
duration = time.monotonic() - session_start
logger.info(
"openrouter-ws-closed chunks=%d text_len=%d duration=%.1fs",
chunk_count, len(accumulated_text), duration,
)
async def _ws_proxy_dashscope(client_ws: WebSocket, loop: asyncio.AbstractEventLoop, language: str = "yue"):
event_queue: asyncio.Queue = asyncio.Queue()
callback = DashScopeCallback(event_queue, loop)
session_start = time.monotonic()
conversation = OmniRealtimeConversation(
model=get_settings().asr_realtime_model_name,
api_key=get_settings().dashscope_api_key,
url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime",
callback=callback,
)
await loop.run_in_executor(None, conversation.connect)
transcription_kwargs: dict = {
"sample_rate": 16000,
"input_audio_format": "pcm",
}
if language != "auto":
transcription_kwargs["language"] = language
transcription_params = TranscriptionParams(**transcription_kwargs)
conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
)
logger.info("dashscope-session-updated lang=%s", language)
accumulated_text = ""
prev_display = ""
current_item_id = ""
chunk_count = 0
stash_seq = 0
async def read_events():
nonlocal accumulated_text, prev_display, current_item_id, stash_seq
while True:
event = await event_queue.get()
result = format_transcription_event(event, accumulated_text)
if result is None:
continue
if result["is_final"]:
transcript = event.get("transcript", "")
if transcript and transcript.strip():
accumulated_text = build_display_text(accumulated_text, transcript)
prev_display = ""
result["delta"] = ""
result["full_text"] = _to_traditional(accumulated_text)
logger.info("dashscope-utterance-completed text_len=%d lang=%s", len(accumulated_text), result.get("language", "yue"))
else:
text = result.pop("text", "")
stash = result.pop("stash", "")
elapsed_ms = int((time.monotonic() - session_start) * 1000)
stash_seq += 1
_stash_logger.info(
"seq=%d elapsed_ms=%d stash_len=%d text_len=%d stash=%r text=%r lang=%s event=%s",
stash_seq,
elapsed_ms,
len(stash),
len(text),
stash,
text,
result.get("language", "?"),
json.dumps(event, ensure_ascii=False),
)
# New utterance: item_id changes, text resets to empty
item_id = event.get("item_id", "")
if item_id and item_id != current_item_id:
if prev_display:
prev_display = " " # prepend space for next utterance
current_item_id = item_id
# text is monotonically growing within one utterance
if text.strip():
new_delta = ""
if text != prev_display:
if prev_display and text.startswith(prev_display):
new_delta = text[len(prev_display):]
else:
new_delta = text
prev_display = text
result["delta"] = _to_traditional(new_delta) if new_delta else ""
result["full_text"] = ""
result["stash"] = _to_traditional(stash) if stash.strip() else ""
else:
continue
if result["delta"] or result["is_final"]:
await client_ws.send_json(result)
read_task = asyncio.create_task(read_events())
try:
while True:
float32_bytes = await client_ws.receive_bytes()
s16_bytes = float32_to_s16le(float32_bytes)
audio_b64 = base64.b64encode(s16_bytes).decode("ascii")
conversation.append_audio(audio_b64)
chunk_count += 1
logger.debug(
"audio-chunk-received size_bytes=%d sample_count=%d chunk_num=%d",
len(float32_bytes),
len(float32_bytes) // 4,
chunk_count,
)
except WebSocketDisconnect:
logger.warning(
"client-disconnected-mid-session chunks=%d accumulated_len=%d",
chunk_count,
len(accumulated_text),
)
finally:
read_task.cancel()
try:
conversation.close()
except Exception:
pass
duration = time.monotonic() - session_start
logger.info(
"dashscope-session-closed text_len=%d chunks=%d duration=%.1fs",
len(accumulated_text),
chunk_count,
duration,
)
@router.websocket("/ws/asr/{video_id}")
async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue", source: str = "upload"):
settings = get_settings()
client_host = websocket.client.host if websocket.client else "unknown"
if source == "system-audio" and not settings.system_audio_enabled:
await websocket.accept()
await websocket.send_json({"error": "System audio capture is disabled"})
await websocket.close(code=1008, reason="System audio disabled")
logger.warning("ws-rejected-system-audio-disabled video_id=%s client=%s", video_id, client_host)
return
if source == "mic" and not settings.mic_enabled:
await websocket.accept()
await websocket.send_json({"error": "Microphone capture is disabled"})
await websocket.close(code=1008, reason="Mic disabled")
logger.warning("ws-rejected-mic-disabled video_id=%s client=%s", video_id, client_host)
return
if settings.asr_provider == "openrouter":
if not settings.openrouter_api_key:
await websocket.accept()
await websocket.send_json({"error": "OPENROUTER_API_KEY is not configured"})
await websocket.close(code=1011, reason="OPENROUTER_API_KEY not set")
logger.warning("ws-rejected-no-openrouter-key video_id=%s client=%s", video_id, client_host)
return
else:
if not settings.dashscope_api_key:
await websocket.accept()
await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"})
await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set")
logger.warning("ws-rejected-no-apikey video_id=%s client=%s", video_id, client_host)
return
await websocket.accept()
loop = asyncio.get_event_loop()
logger.info(
"ws-connect video_id=%s lang=%s source=%s client=%s provider=%s",
video_id, language, source, client_host, settings.asr_provider,
)
try:
if settings.asr_provider == "openrouter":
await _ws_proxy_openrouter(websocket, language)
else:
await _ws_proxy_dashscope(websocket, loop, language)
except Exception as e:
logger.error("ws-asr-error video_id=%s error=%s", video_id, e)
try:
await websocket.send_json({"error": "ASR service unavailable", "detail": str(e)})
except Exception:
pass
finally:
logger.info("ws-disconnect video_id=%s", video_id)