266 lines
9.4 KiB
Python
266 lines
9.4 KiB
Python
"""
|
|
Reference: Alibaba Cloud DashScope Real-Time ASR Proxy (FastAPI WebSocket).
|
|
|
|
Extracted and adapted from:
|
|
/mnt/c/Users/woody/Documents/projects/voice input/backend/app.py
|
|
|
|
Architecture:
|
|
Browser (Float32 PCM) → FastAPI WebSocket → DashScope Real-Time WebSocket
|
|
API key NEVER leaves the server. Backend proxies audio to Alibaba Cloud.
|
|
|
|
Key imports:
|
|
pip install dashscope>=0.4.0 openai>=1.52.0 zhconv
|
|
"""
|
|
import os
|
|
import json
|
|
import struct
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from dotenv import load_dotenv
|
|
import zhconv # Optional: simplified → traditional Chinese conversion
|
|
|
|
from dashscope.audio.qwen_omni.omni_realtime import (
|
|
OmniRealtimeConversation,
|
|
OmniRealtimeCallback,
|
|
TranscriptionParams,
|
|
MultiModality,
|
|
)
|
|
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger("asr-proxy")
|
|
|
|
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
|
|
|
|
# ─── Audio Conversion: Float32 PCM → S16_LE ─────────────────────────────────
|
|
|
|
def float32_to_s16le(float32_bytes: bytes) -> bytes:
|
|
"""Convert browser Float32 PCM to S16_LE required by DashScope.
|
|
|
|
Browser: Float32Array (values -1.0 to 1.0) → raw bytes
|
|
DashScope: S16_LE PCM 16kHz mono → base64 encoded
|
|
"""
|
|
num_samples = len(float32_bytes) // 4
|
|
floats = struct.unpack(f"<{num_samples}f", float32_bytes)
|
|
int16_samples = [max(-32768, min(32767, int(f * 32767.0))) for f in floats]
|
|
return struct.pack(f"<{num_samples}h", *int16_samples)
|
|
|
|
|
|
def _to_traditional(text: str) -> str:
|
|
"""Convert Simplified Chinese to Traditional (for Cantonese display)."""
|
|
if not text:
|
|
return text
|
|
return zhconv.convert(text, "zh-hant")
|
|
|
|
|
|
def build_display_text(accumulated: str, current: str) -> str:
|
|
"""Assemble multi-utterance display text."""
|
|
parts = [p for p in (accumulated, current) if p and p.strip()]
|
|
return " ".join(parts)
|
|
|
|
|
|
# ─── DashScope Callback Bridge (Sync → Async) ───────────────────────────────
|
|
|
|
class DashScopeCallback(OmniRealtimeCallback):
|
|
"""Bridges sync DashScope SDK callbacks to async WebSocket messages.
|
|
|
|
The DashScope SDK fires callbacks from a background thread.
|
|
We push events to an asyncio.Queue so the async task can read them.
|
|
"""
|
|
|
|
def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop):
|
|
super().__init__()
|
|
self._queue = event_queue
|
|
self._loop = loop
|
|
|
|
def on_open(self):
|
|
logger.info("DashScope realtime connection opened")
|
|
|
|
def on_event(self, message):
|
|
"""Called from SDK background thread."""
|
|
try:
|
|
event = json.loads(message) if isinstance(message, str) else message
|
|
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
|
except Exception as e:
|
|
logger.error(f"DashScope callback error: {e}")
|
|
|
|
def on_close(self, code, msg):
|
|
logger.info(f"DashScope realtime closed: code={code}, msg={msg}")
|
|
|
|
|
|
# ─── WebSocket Proxy Handler ────────────────────────────────────────────────
|
|
|
|
async def ws_proxy_dashscope(
|
|
client_ws: WebSocket,
|
|
loop: asyncio.AbstractEventLoop,
|
|
language: str = "yue", # "yue" for Cantonese, "zh" for Mandarin, "en" for English
|
|
):
|
|
"""Proxy browser audio to DashScope real-time ASR.
|
|
|
|
Flow:
|
|
1. Browser sends Float32 PCM bytes via WebSocket
|
|
2. Backend converts to S16_LE → base64
|
|
3. Append audio to DashScope conversation
|
|
4. DashScope SDK sends events → callback → asyncio.Queue
|
|
5. Read events from queue → format as JSON → send to browser
|
|
|
|
Protocol (backend → browser):
|
|
Partial: {"delta": "", "full_text": "...", "language": "yue", "is_final": false}
|
|
Final: {"delta": "", "full_text": "...", "language": "yue", "is_final": true}
|
|
"""
|
|
event_queue: asyncio.Queue = asyncio.Queue()
|
|
callback = DashScopeCallback(event_queue, loop)
|
|
|
|
# Initialize DashScope real-time conversation
|
|
conversation = OmniRealtimeConversation(
|
|
model="qwen3-asr-flash-realtime",
|
|
url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime",
|
|
callback=callback,
|
|
)
|
|
|
|
# connect() is synchronous — run in executor to avoid blocking
|
|
await loop.run_in_executor(None, conversation.connect)
|
|
logger.info("dashscope-session-connected")
|
|
|
|
# Configure session: text output + audio transcription
|
|
transcription_kwargs: dict = {
|
|
"sample_rate": 16000,
|
|
"input_audio_format": "pcm",
|
|
}
|
|
if language != "auto":
|
|
transcription_kwargs["language"] = language
|
|
|
|
transcription_params = TranscriptionParams(**transcription_kwargs)
|
|
conversation.update_session(
|
|
output_modalities=[MultiModality.TEXT],
|
|
enable_input_audio_transcription=True,
|
|
transcription_params=transcription_params,
|
|
)
|
|
logger.info("dashscope-session-updated lang=%s", language)
|
|
|
|
accumulated_text = ""
|
|
current_lang = language
|
|
|
|
async def read_events():
|
|
"""Async task: read DashScope events from queue → send to browser."""
|
|
nonlocal accumulated_text, current_lang
|
|
while True:
|
|
event = await event_queue.get()
|
|
event_type = event.get("type", "")
|
|
|
|
if event_type == "conversation.item.input_audio_transcription.text":
|
|
# Partial result (in-progress utterance)
|
|
stash = event.get("stash", "")
|
|
display = (
|
|
build_display_text(accumulated_text, stash)
|
|
if stash else accumulated_text
|
|
)
|
|
await client_ws.send_json({
|
|
"delta": "",
|
|
"full_text": _to_traditional(display),
|
|
"language": event.get("language", current_lang),
|
|
"is_final": False,
|
|
})
|
|
|
|
elif event_type == "conversation.item.input_audio_transcription.completed":
|
|
# Final utterance completed
|
|
transcript = event.get("transcript", "")
|
|
if transcript and transcript.strip():
|
|
accumulated_text = build_display_text(accumulated_text, transcript)
|
|
current_lang = event.get("language", current_lang)
|
|
logger.info(
|
|
"dashscope-utterance lang=%s text_len=%d accumulated_len=%d",
|
|
current_lang, len(transcript), len(accumulated_text),
|
|
)
|
|
await client_ws.send_json({
|
|
"delta": "",
|
|
"full_text": _to_traditional(accumulated_text),
|
|
"language": current_lang,
|
|
"is_final": True,
|
|
})
|
|
|
|
read_task = asyncio.create_task(read_events())
|
|
|
|
try:
|
|
# Main loop: receive Float32 PCM from browser → convert → send to DashScope
|
|
while True:
|
|
float32_bytes = await client_ws.receive_bytes()
|
|
s16_bytes = float32_to_s16le(float32_bytes)
|
|
audio_b64 = base64.b64encode(s16_bytes).decode("ascii")
|
|
conversation.append_audio(audio_b64)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
read_task.cancel()
|
|
conversation.close()
|
|
logger.info(
|
|
"dashscope-session-closed text_len=%d",
|
|
len(accumulated_text),
|
|
)
|
|
|
|
|
|
# ─── FastAPI App Setup ──────────────────────────────────────────────────────
|
|
|
|
app = FastAPI()
|
|
|
|
@app.websocket("/ws/asr/{video_id}")
|
|
async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue"):
|
|
"""WebSocket endpoint for real-time ASR.
|
|
|
|
Query params:
|
|
language: "yue" (Cantonese), "zh" (Mandarin), "en" (English), "auto"
|
|
"""
|
|
await websocket.accept()
|
|
loop = asyncio.get_event_loop()
|
|
logger.info("ws-connect video_id=%s lang=%s", video_id, language)
|
|
|
|
await ws_proxy_dashscope(websocket, loop, language)
|
|
|
|
logger.info("ws-disconnect video_id=%s", video_id)
|
|
|
|
|
|
# ─── Non-Streaming Fallback (POST) ──────────────────────────────────────────
|
|
|
|
from fastapi import UploadFile, File
|
|
from openai import OpenAI
|
|
|
|
# Sync client for non-streaming fallback
|
|
sync_client = OpenAI(
|
|
api_key=DASHSCOPE_API_KEY,
|
|
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
)
|
|
|
|
@app.post("/api/v1/asr/transcribe")
|
|
async def transcribe_file(file: UploadFile = File(...), language: str = "yue"):
|
|
"""Non-streaming fallback: transcribe an uploaded audio file."""
|
|
audio = await file.read()
|
|
|
|
# Build data URL: data:;base64,<base64_audio>
|
|
audio_b64 = base64.b64encode(audio).decode()
|
|
data_url = f"data:;base64,{audio_b64}"
|
|
|
|
resp = sync_client.chat.completions.create(
|
|
model="qwen3-asr-flash",
|
|
messages=[{
|
|
"role": "user",
|
|
"content": [{
|
|
"type": "input_audio",
|
|
"input_audio": {"data": data_url},
|
|
}],
|
|
}],
|
|
extra_body={
|
|
"asr_options": {
|
|
"language": language if language != "auto" else None,
|
|
}
|
|
},
|
|
)
|
|
|
|
result = resp.choices[0].message.content
|
|
return {
|
|
"text": _to_traditional(result),
|
|
"language": language,
|
|
}
|