feat: add ASR provider config, abstraction layer, and OpenRouter provider
Add ASR_PROVIDER env var (dashscope|openrouter), OPENROUTER_API_KEY, and ASR_OPENROUTER_MODEL to Settings. Create ASRProvider ABC with DashScopeASRProvider (wraps existing OpenAI-based DashScope calls via run_in_executor) and OpenRouterASRProvider (httpx + tenacity retry for batch STT). Add tenacity>=8.0.0 dependency. Realtime WebSocket stays DashScope-only. Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
parent
67d2bddeb6
commit
39525a2344
|
|
@ -27,12 +27,27 @@ HISTORY_DB_PATH=./data/history.db
|
|||
|
||||
CORS_ORIGINS=["http://localhost:5173","http://localhost:3000"]
|
||||
|
||||
# Alibaba Cloud DashScope ASR (Phase 2)
|
||||
# -------- ASR Configuration (Phase 2 + Phase 5) --------
|
||||
|
||||
# ASR provider: "dashscope" or "openrouter"
|
||||
# dashscope: Alibaba Cloud DashScope – batch + realtime (WebSocket) Cantonese ASR
|
||||
# openrouter: OpenRouter STT – batch-only Cantonese ASR via REST API
|
||||
# NOTE: "openrouter" only affects batch (Full Transcript) transcription.
|
||||
# Realtime streaming always uses DashScope (OpenRouter has no WebSocket STT).
|
||||
ASR_PROVIDER=dashscope
|
||||
|
||||
# --- DashScope ASR (used when ASR_PROVIDER=dashscope, or for realtime) ---
|
||||
# Get your key from: https://modelstudio.console.alibabacloud.com
|
||||
DASHSCOPE_API_KEY=sk-your-dashscope-key-here
|
||||
ASR_MODEL_NAME=qwen3-asr-flash
|
||||
ASR_REALTIME_MODEL_NAME=qwen3-asr-flash-realtime
|
||||
|
||||
# --- OpenRouter STT (used when ASR_PROVIDER=openrouter) ---
|
||||
# Get your key from: https://openrouter.ai/keys
|
||||
# Separate key for independent accounting/billing
|
||||
OPENROUTER_API_KEY=
|
||||
ASR_OPENROUTER_MODEL=google/gemini-3.1-flash-lite
|
||||
|
||||
# Video upload (Phase 2)
|
||||
VIDEO_UPLOAD_DIR=./uploads
|
||||
MAX_VIDEO_SIZE_MB=300
|
||||
|
|
|
|||
|
|
@ -52,10 +52,16 @@ class Settings(BaseSettings):
|
|||
qa_include_internal_refs: bool = True
|
||||
qa_cache_vision_results: bool = True
|
||||
|
||||
# Alibaba Cloud DashScope ASR (Phase 2)
|
||||
# ASR Configuration (Phase 2 + Phase 5)
|
||||
# Provider: "dashscope" (batch + realtime) or "openrouter" (batch-only)
|
||||
asr_provider: str = "dashscope"
|
||||
# DashScope ASR (used when asr_provider=dashscope, or for realtime WebSocket)
|
||||
dashscope_api_key: str = ""
|
||||
asr_model_name: str = "qwen3-asr-flash"
|
||||
asr_realtime_model_name: str = "qwen3-asr-flash-realtime"
|
||||
# OpenRouter STT (used when asr_provider=openrouter)
|
||||
openrouter_api_key: str = ""
|
||||
asr_openrouter_model: str = "google/gemini-3.1-flash-lite"
|
||||
|
||||
# Video upload (Phase 2)
|
||||
video_upload_dir: str = "./uploads"
|
||||
|
|
@ -70,8 +76,16 @@ class Settings(BaseSettings):
|
|||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
|
||||
VALID_ASR_PROVIDERS = frozenset({"dashscope", "openrouter"})
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
s = Settings()
|
||||
logger.info("Settings loaded: llm_model=%s embedding_model=%s", s.llm_model_name, s.embedding_model)
|
||||
if s.asr_provider not in VALID_ASR_PROVIDERS:
|
||||
raise ValueError(
|
||||
f"Invalid ASR_PROVIDER '{s.asr_provider}'. "
|
||||
f"Must be one of: {', '.join(sorted(VALID_ASR_PROVIDERS))}"
|
||||
)
|
||||
return s
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import httpx
|
||||
import zhconv
|
||||
from openai import OpenAI
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_traditional(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
return zhconv.convert(text, "zh-hant")
|
||||
|
||||
|
||||
class ASRError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ASRProvider(ABC):
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_bytes: bytes, language: str) -> str:
|
||||
...
|
||||
|
||||
|
||||
class DashScopeASRProvider(ASRProvider):
|
||||
def __init__(self, api_key: str, model: str):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
|
||||
async def transcribe(self, audio_bytes: bytes, language: str) -> str:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
None, self._transcribe_sync, audio_bytes, language
|
||||
)
|
||||
|
||||
def _transcribe_sync(self, audio_bytes: bytes, language: str) -> str:
|
||||
audio_b64 = base64.b64encode(audio_bytes).decode()
|
||||
data_url = f"data:audio/wav;base64,{audio_b64}"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=self._api_key,
|
||||
base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
|
||||
asr_options: dict = {}
|
||||
if language != "auto":
|
||||
asr_options["language"] = language
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{ # type: ignore[list-item]
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "input_audio",
|
||||
"input_audio": {"data": data_url},
|
||||
}],
|
||||
}],
|
||||
extra_body={"asr_options": asr_options} if asr_options else None,
|
||||
)
|
||||
|
||||
result = resp.choices[0].message.content or ""
|
||||
return _to_traditional(result)
|
||||
|
||||
|
||||
class OpenRouterASRProvider(ASRProvider):
|
||||
def __init__(self, api_key: str, base_url: str, model: str):
|
||||
self._api_key = api_key
|
||||
self._stt_url = f"{base_url.rstrip('/')}/audio/transcriptions"
|
||||
self._model = model
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(120.0),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_bytes: bytes, language: str) -> str:
|
||||
audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
|
||||
|
||||
payload: dict = {
|
||||
"model": self._model,
|
||||
"input_audio": {
|
||||
"data": audio_b64,
|
||||
"format": "wav",
|
||||
},
|
||||
}
|
||||
if language and language != "auto":
|
||||
payload["language"] = language
|
||||
|
||||
try:
|
||||
result = await self._call_stt_api(payload)
|
||||
except (httpx.TransportError, httpx.HTTPStatusError) as e:
|
||||
raise ASRError(f"OpenRouter STT request failed: {e}") from e
|
||||
|
||||
text = result.get("text", "")
|
||||
if not text:
|
||||
raise ASRError("OpenRouter STT returned empty transcription")
|
||||
|
||||
return _to_traditional(text)
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(4),
|
||||
wait=wait_random_exponential(multiplier=0.2, max=3.0),
|
||||
retry=retry_if_exception_type((httpx.TransportError, httpx.HTTPStatusError)),
|
||||
)
|
||||
async def _call_stt_api(self, payload: dict) -> dict:
|
||||
client = await self._get_client()
|
||||
response = await client.post(self._stt_url, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
def create_asr_provider(settings) -> ASRProvider:
|
||||
provider_name = settings.asr_provider
|
||||
|
||||
if provider_name == "dashscope":
|
||||
return DashScopeASRProvider(
|
||||
api_key=settings.dashscope_api_key,
|
||||
model=settings.asr_model_name,
|
||||
)
|
||||
|
||||
if provider_name == "openrouter":
|
||||
if not settings.openrouter_api_key:
|
||||
raise ASRError(
|
||||
"OPENROUTER_API_KEY is not configured. "
|
||||
"Set it in .env to use OpenRouter ASR."
|
||||
)
|
||||
return OpenRouterASRProvider(
|
||||
api_key=settings.openrouter_api_key,
|
||||
base_url=settings.llm_base_url,
|
||||
model=settings.asr_openrouter_model,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown ASR provider: {provider_name}")
|
||||
|
|
@ -8,6 +8,7 @@ python-docx>=1.1.0
|
|||
pypdf>=4.0.2
|
||||
python-dotenv>=1.0.0
|
||||
httpx>=0.26.0
|
||||
tenacity>=8.0.0
|
||||
openai>=2.26.0,<3.0.0
|
||||
pytest==7.4.4
|
||||
pytest-asyncio==0.23.4
|
||||
|
|
|
|||
Loading…
Reference in New Issue