legco_ai_assistant/.examples/alibaba_asr_backend.py

266 lines
9.4 KiB
Python

"""
Reference: Alibaba Cloud DashScope Real-Time ASR Proxy (FastAPI WebSocket).
Extracted and adapted from:
/mnt/c/Users/woody/Documents/projects/voice input/backend/app.py
Architecture:
Browser (Float32 PCM) → FastAPI WebSocket → DashScope Real-Time WebSocket
API key NEVER leaves the server. Backend proxies audio to Alibaba Cloud.
Key imports:
pip install dashscope>=0.4.0 openai>=1.52.0 zhconv
"""
import os
import json
import struct
import asyncio
import base64
import logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv
import zhconv # Optional: simplified → traditional Chinese conversion
from dashscope.audio.qwen_omni.omni_realtime import (
OmniRealtimeConversation,
OmniRealtimeCallback,
TranscriptionParams,
MultiModality,
)
load_dotenv()
logger = logging.getLogger("asr-proxy")
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
# ─── Audio Conversion: Float32 PCM → S16_LE ─────────────────────────────────
def float32_to_s16le(float32_bytes: bytes) -> bytes:
"""Convert browser Float32 PCM to S16_LE required by DashScope.
Browser: Float32Array (values -1.0 to 1.0) → raw bytes
DashScope: S16_LE PCM 16kHz mono → base64 encoded
"""
num_samples = len(float32_bytes) // 4
floats = struct.unpack(f"<{num_samples}f", float32_bytes)
int16_samples = [max(-32768, min(32767, int(f * 32767.0))) for f in floats]
return struct.pack(f"<{num_samples}h", *int16_samples)
def _to_traditional(text: str) -> str:
"""Convert Simplified Chinese to Traditional (for Cantonese display)."""
if not text:
return text
return zhconv.convert(text, "zh-hant")
def build_display_text(accumulated: str, current: str) -> str:
"""Assemble multi-utterance display text."""
parts = [p for p in (accumulated, current) if p and p.strip()]
return " ".join(parts)
# ─── DashScope Callback Bridge (Sync → Async) ───────────────────────────────
class DashScopeCallback(OmniRealtimeCallback):
"""Bridges sync DashScope SDK callbacks to async WebSocket messages.
The DashScope SDK fires callbacks from a background thread.
We push events to an asyncio.Queue so the async task can read them.
"""
def __init__(self, event_queue: asyncio.Queue, loop: asyncio.AbstractEventLoop):
super().__init__()
self._queue = event_queue
self._loop = loop
def on_open(self):
logger.info("DashScope realtime connection opened")
def on_event(self, message):
"""Called from SDK background thread."""
try:
event = json.loads(message) if isinstance(message, str) else message
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
except Exception as e:
logger.error(f"DashScope callback error: {e}")
def on_close(self, code, msg):
logger.info(f"DashScope realtime closed: code={code}, msg={msg}")
# ─── WebSocket Proxy Handler ────────────────────────────────────────────────
async def ws_proxy_dashscope(
client_ws: WebSocket,
loop: asyncio.AbstractEventLoop,
language: str = "yue", # "yue" for Cantonese, "zh" for Mandarin, "en" for English
):
"""Proxy browser audio to DashScope real-time ASR.
Flow:
1. Browser sends Float32 PCM bytes via WebSocket
2. Backend converts to S16_LE → base64
3. Append audio to DashScope conversation
4. DashScope SDK sends events → callback → asyncio.Queue
5. Read events from queue → format as JSON → send to browser
Protocol (backend → browser):
Partial: {"delta": "", "full_text": "...", "language": "yue", "is_final": false}
Final: {"delta": "", "full_text": "...", "language": "yue", "is_final": true}
"""
event_queue: asyncio.Queue = asyncio.Queue()
callback = DashScopeCallback(event_queue, loop)
# Initialize DashScope real-time conversation
conversation = OmniRealtimeConversation(
model="qwen3-asr-flash-realtime",
url="wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime",
callback=callback,
)
# connect() is synchronous — run in executor to avoid blocking
await loop.run_in_executor(None, conversation.connect)
logger.info("dashscope-session-connected")
# Configure session: text output + audio transcription
transcription_kwargs: dict = {
"sample_rate": 16000,
"input_audio_format": "pcm",
}
if language != "auto":
transcription_kwargs["language"] = language
transcription_params = TranscriptionParams(**transcription_kwargs)
conversation.update_session(
output_modalities=[MultiModality.TEXT],
enable_input_audio_transcription=True,
transcription_params=transcription_params,
)
logger.info("dashscope-session-updated lang=%s", language)
accumulated_text = ""
current_lang = language
async def read_events():
"""Async task: read DashScope events from queue → send to browser."""
nonlocal accumulated_text, current_lang
while True:
event = await event_queue.get()
event_type = event.get("type", "")
if event_type == "conversation.item.input_audio_transcription.text":
# Partial result (in-progress utterance)
stash = event.get("stash", "")
display = (
build_display_text(accumulated_text, stash)
if stash else accumulated_text
)
await client_ws.send_json({
"delta": "",
"full_text": _to_traditional(display),
"language": event.get("language", current_lang),
"is_final": False,
})
elif event_type == "conversation.item.input_audio_transcription.completed":
# Final utterance completed
transcript = event.get("transcript", "")
if transcript and transcript.strip():
accumulated_text = build_display_text(accumulated_text, transcript)
current_lang = event.get("language", current_lang)
logger.info(
"dashscope-utterance lang=%s text_len=%d accumulated_len=%d",
current_lang, len(transcript), len(accumulated_text),
)
await client_ws.send_json({
"delta": "",
"full_text": _to_traditional(accumulated_text),
"language": current_lang,
"is_final": True,
})
read_task = asyncio.create_task(read_events())
try:
# Main loop: receive Float32 PCM from browser → convert → send to DashScope
while True:
float32_bytes = await client_ws.receive_bytes()
s16_bytes = float32_to_s16le(float32_bytes)
audio_b64 = base64.b64encode(s16_bytes).decode("ascii")
conversation.append_audio(audio_b64)
except WebSocketDisconnect:
pass
finally:
read_task.cancel()
conversation.close()
logger.info(
"dashscope-session-closed text_len=%d",
len(accumulated_text),
)
# ─── FastAPI App Setup ──────────────────────────────────────────────────────
app = FastAPI()
@app.websocket("/ws/asr/{video_id}")
async def ws_asr_endpoint(websocket: WebSocket, video_id: str, language: str = "yue"):
"""WebSocket endpoint for real-time ASR.
Query params:
language: "yue" (Cantonese), "zh" (Mandarin), "en" (English), "auto"
"""
await websocket.accept()
loop = asyncio.get_event_loop()
logger.info("ws-connect video_id=%s lang=%s", video_id, language)
await ws_proxy_dashscope(websocket, loop, language)
logger.info("ws-disconnect video_id=%s", video_id)
# ─── Non-Streaming Fallback (POST) ──────────────────────────────────────────
from fastapi import UploadFile, File
from openai import OpenAI
# Sync client for non-streaming fallback
sync_client = OpenAI(
api_key=DASHSCOPE_API_KEY,
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
)
@app.post("/api/v1/asr/transcribe")
async def transcribe_file(file: UploadFile = File(...), language: str = "yue"):
"""Non-streaming fallback: transcribe an uploaded audio file."""
audio = await file.read()
# Build data URL: data:;base64,<base64_audio>
audio_b64 = base64.b64encode(audio).decode()
data_url = f"data:;base64,{audio_b64}"
resp = sync_client.chat.completions.create(
model="qwen3-asr-flash",
messages=[{
"role": "user",
"content": [{
"type": "input_audio",
"input_audio": {"data": data_url},
}],
}],
extra_body={
"asr_options": {
"language": language if language != "auto" else None,
}
},
)
result = resp.choices[0].message.content
return {
"text": _to_traditional(result),
"language": language,
}