legco_ai_assistant/backend/app/services/cer_wer.py

157 lines
4.1 KiB
Python

def _levenshtein_distance(s1: str, s2: str) -> tuple:
"""Compute Levenshtein distance and return edit operation counts.
Returns (substitutions, deletions, insertions, hits).
"""
if not s1 and not s2:
return 0, 0, 0, 0
if not s1:
return 0, len(s2), 0, 0
if not s2:
return 0, 0, len(s1), 0
m, n = len(s1), len(s2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
for i in range(1, m + 1):
for j in range(1, n + 1):
if s1[i - 1] == s2[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = 1 + min(
dp[i - 1][j],
dp[i][j - 1],
dp[i - 1][j - 1],
)
i, j = m, n
substitutions = 0
deletions = 0
insertions = 0
hits = 0
while i > 0 or j > 0:
if i > 0 and j > 0 and s1[i - 1] == s2[j - 1]:
hits += 1
i -= 1
j -= 1
elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
substitutions += 1
i -= 1
j -= 1
elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
deletions += 1
i -= 1
elif j > 0:
insertions += 1
j -= 1
return substitutions, deletions, insertions, hits
def _tokenize_words(text: str) -> list:
"""Simple word tokenizer for mixed Chinese/English text.
Splits on whitespace. For character-level CER, use the raw string.
For word-level WER, this gives reasonable results for space-separated text.
"""
return text.split()
def calculate_cer(reference: str, hypothesis: str) -> dict:
"""Calculate Character Error Rate (CER) between reference and hypothesis.
Returns dict with keys: cer, reference_length, transcribed_length,
substitutions, deletions, insertions, hits.
"""
ref_len = len(reference)
hyp_len = len(hypothesis)
if ref_len == 0:
return {
"cer": 0.0,
"reference_length": 0,
"transcribed_length": hyp_len,
"substitutions": 0,
"deletions": 0,
"insertions": 0,
"hits": 0,
}
if hyp_len == 0:
return {
"cer": 1.0,
"reference_length": ref_len,
"transcribed_length": 0,
"substitutions": 0,
"deletions": ref_len,
"insertions": 0,
"hits": 0,
}
subs, dels, inss, hits = _levenshtein_distance(reference, hypothesis)
cer = (subs + dels + inss) / max(1, ref_len)
return {
"cer": round(cer, 6),
"reference_length": ref_len,
"transcribed_length": hyp_len,
"substitutions": subs,
"deletions": dels,
"insertions": inss,
"hits": hits,
}
def calculate_wer(reference: str, hypothesis: str) -> dict:
"""Calculate Word Error Rate (WER) between reference and hypothesis.
Returns dict with keys: wer, reference_length, transcribed_length,
substitutions, deletions, insertions, hits.
"""
ref_words = _tokenize_words(reference)
hyp_words = _tokenize_words(hypothesis)
ref_len = len(ref_words)
hyp_len = len(hyp_words)
if ref_len == 0:
return {
"wer": 0.0,
"reference_length": 0,
"transcribed_length": hyp_len,
"substitutions": 0,
"deletions": 0,
"insertions": 0,
"hits": 0,
}
if hyp_len == 0:
return {
"wer": 1.0,
"reference_length": ref_len,
"transcribed_length": 0,
"substitutions": 0,
"deletions": ref_len,
"insertions": 0,
"hits": 0,
}
subs, dels, inss, hits = _levenshtein_distance(ref_words, hyp_words)
wer = (subs + dels + inss) / max(1, ref_len)
return {
"wer": round(wer, 6),
"reference_length": ref_len,
"transcribed_length": hyp_len,
"substitutions": subs,
"deletions": dels,
"insertions": inss,
"hits": hits,
}