251 lines
9.4 KiB
Python
251 lines
9.4 KiB
Python
import json
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import time
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
from app.core.config import get_settings
|
|
from app.services.asr_client import float32_to_s16le, build_display_text, _to_traditional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_stash_logger = logging.getLogger("ws_asr.stash")
|
|
_stash_logger.propagate = False
|
|
_stash_handler = logging.FileHandler("app/log/stash.log")
|
|
_stash_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s"))
|
|
_stash_logger.addHandler(_stash_handler)
|
|
_stash_logger.setLevel(logging.DEBUG)
|
|
|
|
router = APIRouter(tags=["asr"])
|
|
|
|
try:
|
|
from dashscope.audio.qwen_omni.omni_realtime import (
|
|
OmniRealtimeConversation,
|
|
OmniRealtimeCallback,
|
|
TranscriptionParams,
|
|
MultiModality,
|
|
)
|
|
except ImportError:
|
|
OmniRealtimeConversation = None
|
|
OmniRealtimeCallback = object
|
|
TranscriptionParams = None
|
|
MultiModality = None
|
|
|
|
|
|
class DashScopeCallback(OmniRealtimeCallback):
|
|
def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop):
|
|
super().__init__()
|
|
self._queue = event_queue
|
|
self._loop = loop
|
|
|
|
def on_open(self):
|
|
logger.info("dashscope-connection-opened")
|
|
|
|
def on_event(self, message):
|
|
try:
|
|
event = json.loads(message) if isinstance(message, str) else message
|
|
event_type = event.get("type", "") if isinstance(event, dict) else ""
|
|
logger.debug("dashscope-event-received type=%s", event_type)
|
|
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
|
except Exception as e:
|
|
logger.error("dashscope-callback-error error=%s", e)
|
|
|
|
def on_close(self, code, msg):
|
|
logger.info("dashscope-connection-closed code=%s msg=%s", code, msg)
|
|
|
|
|
|
def format_transcription_event(event: dict, accumulated: str) -> dict | None:
|
|
event_type = event.get("type", "")
|
|
|
|
if event_type == "conversation.item.input_audio_transcription.text":
|
|
text = event.get("text", "")
|
|
stash = event.get("stash", "")
|
|
return {
|
|
"delta": "",
|
|
"text": text,
|
|
"stash": stash,
|
|
"language": event.get("language", "yue"),
|
|
"is_final": False,
|
|
}
|
|
|
|
if event_type == "conversation.item.input_audio_transcription.completed":
|
|
transcript = event.get("transcript", "")
|
|
new_accumulated = build_display_text(accumulated, transcript) if transcript and transcript.strip() else accumulated
|
|
return {
|
|
"delta": "",
|
|
"full_text": _to_traditional(new_accumulated),
|
|
"language": event.get("language", "yue"),
|
|
"is_final": True,
|
|
}
|
|
|
|
return None
|
|
|
|
|
|
async def _ws_proxy_dashscope(client_ws: WebSocket, loop: asyncio.AbstractEventLoop, language: str = "yue"):
|
|
event_queue: asyncio.Queue = asyncio.Queue()
|
|
callback = DashScopeCallback(event_queue, loop)
|
|
session_start = time.monotonic()
|
|
|
|
conversation = OmniRealtimeConversation(
|
|
model=get_settings().asr_realtime_model_name,
|
|
api_key=get_settings().dashscope_api_key,
|
|
url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime",
|
|
callback=callback,
|
|
)
|
|
|
|
await loop.run_in_executor(None, conversation.connect)
|
|
|
|
transcription_kwargs: dict = {
|
|
"sample_rate": 16000,
|
|
"input_audio_format": "pcm",
|
|
}
|
|
if language != "auto":
|
|
transcription_kwargs["language"] = language
|
|
|
|
transcription_params = TranscriptionParams(**transcription_kwargs)
|
|
conversation.update_session(
|
|
output_modalities=[MultiModality.TEXT],
|
|
enable_input_audio_transcription=True,
|
|
transcription_params=transcription_params,
|
|
)
|
|
logger.info("dashscope-session-updated lang=%s", language)
|
|
|
|
accumulated_text = ""
|
|
prev_display = ""
|
|
current_item_id = ""
|
|
chunk_count = 0
|
|
stash_seq = 0
|
|
|
|
async def read_events():
|
|
nonlocal accumulated_text, prev_display, current_item_id, stash_seq
|
|
while True:
|
|
event = await event_queue.get()
|
|
result = format_transcription_event(event, accumulated_text)
|
|
if result is None:
|
|
continue
|
|
if result["is_final"]:
|
|
transcript = event.get("transcript", "")
|
|
if transcript and transcript.strip():
|
|
accumulated_text = build_display_text(accumulated_text, transcript)
|
|
prev_display = ""
|
|
result["delta"] = ""
|
|
result["full_text"] = _to_traditional(accumulated_text)
|
|
logger.info("dashscope-utterance-completed text_len=%d lang=%s", len(accumulated_text), result.get("language", "yue"))
|
|
else:
|
|
text = result.pop("text", "")
|
|
stash = result.pop("stash", "")
|
|
elapsed_ms = int((time.monotonic() - session_start) * 1000)
|
|
stash_seq += 1
|
|
_stash_logger.info(
|
|
"seq=%d elapsed_ms=%d stash_len=%d text_len=%d stash=%r text=%r lang=%s event=%s",
|
|
stash_seq,
|
|
elapsed_ms,
|
|
len(stash),
|
|
len(text),
|
|
stash,
|
|
text,
|
|
result.get("language", "?"),
|
|
json.dumps(event, ensure_ascii=False),
|
|
)
|
|
# New utterance: item_id changes, text resets to empty
|
|
item_id = event.get("item_id", "")
|
|
if item_id and item_id != current_item_id:
|
|
if prev_display:
|
|
prev_display = " " # prepend space for next utterance
|
|
current_item_id = item_id
|
|
# text is monotonically growing within one utterance
|
|
if text.strip():
|
|
new_delta = ""
|
|
if text != prev_display:
|
|
if prev_display and text.startswith(prev_display):
|
|
new_delta = text[len(prev_display):]
|
|
else:
|
|
new_delta = text
|
|
prev_display = text
|
|
result["delta"] = _to_traditional(new_delta) if new_delta else ""
|
|
result["full_text"] = ""
|
|
result["stash"] = _to_traditional(stash) if stash.strip() else ""
|
|
else:
|
|
continue
|
|
if result["delta"] or result["is_final"]:
|
|
await client_ws.send_json(result)
|
|
|
|
read_task = asyncio.create_task(read_events())
|
|
|
|
try:
|
|
while True:
|
|
float32_bytes = await client_ws.receive_bytes()
|
|
s16_bytes = float32_to_s16le(float32_bytes)
|
|
audio_b64 = base64.b64encode(s16_bytes).decode("ascii")
|
|
conversation.append_audio(audio_b64)
|
|
chunk_count += 1
|
|
logger.debug(
|
|
"audio-chunk-received size_bytes=%d sample_count=%d chunk_num=%d",
|
|
len(float32_bytes),
|
|
len(float32_bytes) // 4,
|
|
chunk_count,
|
|
)
|
|
except WebSocketDisconnect:
|
|
logger.warning(
|
|
"client-disconnected-mid-session chunks=%d accumulated_len=%d",
|
|
chunk_count,
|
|
len(accumulated_text),
|
|
)
|
|
finally:
|
|
read_task.cancel()
|
|
try:
|
|
conversation.close()
|
|
except Exception:
|
|
pass
|
|
duration = time.monotonic() - session_start
|
|
logger.info(
|
|
"dashscope-session-closed text_len=%d chunks=%d duration=%.1fs",
|
|
len(accumulated_text),
|
|
chunk_count,
|
|
duration,
|
|
)
|
|
|
|
|
|
@router.websocket("/ws/asr/{video_id}")
|
|
async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue", source: str = "upload"):
|
|
settings = get_settings()
|
|
client_host = websocket.client.host if websocket.client else "unknown"
|
|
|
|
if not settings.dashscope_api_key:
|
|
await websocket.accept()
|
|
await websocket.send_json({"error": "DASHSCOPE_API_KEY is not configured"})
|
|
await websocket.close(code=1011, reason="DASHSCOPE_API_KEY not set")
|
|
logger.warning("ws-rejected-no-apikey video_id=%s client=%s", video_id, client_host)
|
|
return
|
|
|
|
if source == "system-audio" and not settings.system_audio_enabled:
|
|
await websocket.accept()
|
|
await websocket.send_json({"error": "System audio capture is disabled"})
|
|
await websocket.close(code=1008, reason="System audio disabled")
|
|
logger.warning("ws-rejected-system-audio-disabled video_id=%s client=%s", video_id, client_host)
|
|
return
|
|
|
|
if source == "mic" and not settings.mic_enabled:
|
|
await websocket.accept()
|
|
await websocket.send_json({"error": "Microphone capture is disabled"})
|
|
await websocket.close(code=1008, reason="Mic disabled")
|
|
logger.warning("ws-rejected-mic-disabled video_id=%s client=%s", video_id, client_host)
|
|
return
|
|
|
|
await websocket.accept()
|
|
loop = asyncio.get_event_loop()
|
|
logger.info("ws-connect video_id=%s lang=%s source=%s client=%s", video_id, language, source, client_host)
|
|
|
|
try:
|
|
await _ws_proxy_dashscope(websocket, loop, language)
|
|
except Exception as e:
|
|
logger.error("ws-asr-error video_id=%s error=%s", video_id, e)
|
|
try:
|
|
await websocket.send_json({"error": "ASR service unavailable", "detail": str(e)})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
logger.info("ws-disconnect video_id=%s", video_id)
|