191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
import asyncio
|
|
import base64
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
|
|
import httpx
|
|
import zhconv
|
|
from openai import OpenAI
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_random_exponential,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _to_traditional(text: str) -> str:
|
|
if not text:
|
|
return text
|
|
return zhconv.convert(text, "zh-hant")
|
|
|
|
|
|
class ASRError(Exception):
|
|
pass
|
|
|
|
|
|
class ASRProvider(ABC):
|
|
@abstractmethod
|
|
async def transcribe(self, audio_bytes: bytes, language: str) -> str:
|
|
...
|
|
|
|
|
|
class DashScopeASRProvider(ASRProvider):
|
|
def __init__(self, api_key: str, model: str):
|
|
self._api_key = api_key
|
|
self._model = model
|
|
|
|
async def transcribe(self, audio_bytes: bytes, language: str) -> str:
|
|
loop = asyncio.get_running_loop()
|
|
logger.info(
|
|
"asr-transcribe-start provider=dashscope model=%s audio_bytes=%d language=%s",
|
|
self._model, len(audio_bytes), language,
|
|
)
|
|
return await loop.run_in_executor(
|
|
None, self._transcribe_sync, audio_bytes, language
|
|
)
|
|
|
|
def _transcribe_sync(self, audio_bytes: bytes, language: str) -> str:
|
|
audio_b64 = base64.b64encode(audio_bytes).decode()
|
|
data_url = f"data:audio/wav;base64,{audio_b64}"
|
|
|
|
client = OpenAI(
|
|
api_key=self._api_key,
|
|
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
)
|
|
|
|
asr_options: dict = {}
|
|
if language != "auto":
|
|
asr_options["language"] = language
|
|
|
|
resp = client.chat.completions.create(
|
|
model=self._model,
|
|
messages=[{ # type: ignore[list-item]
|
|
"role": "user",
|
|
"content": [{
|
|
"type": "input_audio",
|
|
"input_audio": {"data": data_url},
|
|
}],
|
|
}],
|
|
extra_body={"asr_options": asr_options} if asr_options else None,
|
|
)
|
|
|
|
result = resp.choices[0].message.content or ""
|
|
return _to_traditional(result)
|
|
|
|
|
|
class OpenRouterASRProvider(ASRProvider):
|
|
def __init__(self, api_key: str, base_url: str, model: str):
|
|
self._api_key = api_key
|
|
self._stt_url = f"{base_url.rstrip('/')}/audio/transcriptions"
|
|
self._model = model
|
|
self._client: httpx.AsyncClient | None = None
|
|
|
|
async def _get_client(self) -> httpx.AsyncClient:
|
|
if self._client is None:
|
|
self._client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(120.0),
|
|
headers={
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"Content-Type": "application/json",
|
|
},
|
|
)
|
|
return self._client
|
|
|
|
async def transcribe(self, audio_bytes: bytes, language: str) -> str:
|
|
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
|
|
logger.info(
|
|
"asr-transcribe-start provider=openrouter model=%s url=%s audio_bytes=%d language=%s",
|
|
self._model, self._stt_url, len(audio_bytes), language,
|
|
)
|
|
|
|
payload: dict = {
|
|
"model": self._model,
|
|
"input_audio": {
|
|
"data": audio_b64,
|
|
"format": "wav",
|
|
},
|
|
}
|
|
# OpenRouter STT expects ISO-639-1 (2-letter) codes.
|
|
# DashScope languages like "yue" (Cantonese, ISO-639-3) are not valid here.
|
|
# Omit to let auto-detection handle it.
|
|
if language and language not in ("auto", "yue"):
|
|
payload["language"] = language
|
|
|
|
try:
|
|
result = await self._call_stt_api(payload)
|
|
except (httpx.TransportError, httpx.HTTPStatusError) as e:
|
|
raise ASRError(f"OpenRouter STT request failed: {e}") from e
|
|
|
|
text = result.get("text", "")
|
|
if not text:
|
|
raise ASRError("OpenRouter STT returned empty transcription")
|
|
|
|
logger.info(
|
|
"asr-transcribe-complete provider=openrouter text_len=%d",
|
|
len(text),
|
|
)
|
|
return _to_traditional(text)
|
|
|
|
@retry(
|
|
reraise=True,
|
|
stop=stop_after_attempt(4),
|
|
wait=wait_random_exponential(multiplier=0.2, max=3.0),
|
|
retry=retry_if_exception_type((httpx.TransportError, httpx.HTTPStatusError)),
|
|
)
|
|
async def _call_stt_api(self, payload: dict) -> dict:
|
|
client = await self._get_client()
|
|
response = await client.post(self._stt_url, json=payload)
|
|
if response.status_code >= 400:
|
|
logger.error(
|
|
"openrouter-stt-error status=%d body=%s",
|
|
response.status_code,
|
|
response.text[:500],
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def close(self) -> None:
|
|
if self._client is not None:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
|
|
def create_asr_provider(settings) -> ASRProvider:
|
|
provider_name = settings.asr_provider
|
|
logger.info(
|
|
"asr-provider-selected provider=%s dashscope_key=%s openrouter_key=%s llm_base_url=%s",
|
|
provider_name,
|
|
"set" if settings.dashscope_api_key else "empty",
|
|
"set" if settings.openrouter_api_key else "empty",
|
|
settings.llm_base_url,
|
|
)
|
|
|
|
if provider_name == "dashscope":
|
|
logger.info("asr-provider-init provider=dashscope model=%s", settings.asr_model_name)
|
|
return DashScopeASRProvider(
|
|
api_key=settings.dashscope_api_key,
|
|
model=settings.asr_model_name,
|
|
)
|
|
|
|
if provider_name == "openrouter":
|
|
if not settings.openrouter_api_key:
|
|
raise ASRError(
|
|
"OPENROUTER_API_KEY is not configured. "
|
|
"Set it in .env to use OpenRouter ASR."
|
|
)
|
|
logger.info(
|
|
"asr-provider-init provider=openrouter model=%s url=%s",
|
|
settings.asr_openrouter_model,
|
|
f"{settings.llm_base_url.rstrip('/')}/audio/transcriptions",
|
|
)
|
|
return OpenRouterASRProvider(
|
|
api_key=settings.openrouter_api_key,
|
|
base_url=settings.llm_base_url,
|
|
model=settings.asr_openrouter_model,
|
|
)
|
|
|
|
raise ValueError(f"Unknown ASR provider: {provider_name}")
|