106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import json
|
||
from pathlib import Path
|
||
import re
|
||
import torch
|
||
from torch import Tensor
|
||
from handle.text_handle import process_text
|
||
from handle.text_normalizer import UYGHUR_LETTERS
|
||
|
||
class ASRTokenizer:
|
||
def __init__(self, vocab_path: Path) -> None:
|
||
self.vocabs: list[str] = []
|
||
self.vocabs.extend([
|
||
"<BLANK>",
|
||
"<UNK>",
|
||
"<SPACE>"
|
||
])
|
||
for i in range(len(self.vocabs), 64):
|
||
self.vocabs.append(f"REVERSED_{i}")
|
||
|
||
with open(vocab_path, 'r', encoding='utf-8') as file:
|
||
data: list[str] = json.load(file)
|
||
|
||
for char in data:
|
||
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
|
||
self.vocabs.append(char)
|
||
|
||
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
|
||
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
|
||
|
||
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
|
||
|
||
def vocab_size(self):
|
||
return len(self.vocabs)
|
||
|
||
def _split_long_syllable(self, syll: str) -> list[str]:
|
||
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
|
||
pieces = []
|
||
start = 0
|
||
n = len(syll)
|
||
while start < n:
|
||
# 从最大可能长度开始尝试匹配
|
||
for length in range(min(self._max_token_len, n - start), 0, -1):
|
||
sub = syll[start:start + length]
|
||
if sub in self.vocab_to_id:
|
||
pieces.append(sub)
|
||
start += length
|
||
break
|
||
else:
|
||
# 理论上不会发生,因为单字母一定在词表
|
||
pieces.append("<UNK>")
|
||
start += 1
|
||
return pieces
|
||
|
||
def encode(self, text: str) -> list[int]:
|
||
tokens = process_text(text=text)
|
||
result: list[int] = []
|
||
|
||
for token in tokens:
|
||
if token == ' ':
|
||
result.append(self.vocab_to_id['<SPACE>'])
|
||
else:
|
||
token_id = self.vocab_to_id.get(token)
|
||
if token_id is not None:
|
||
result.append(token_id)
|
||
elif token in UYGHUR_LETTERS:
|
||
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
|
||
else:
|
||
sub_pieces = self._split_long_syllable(token)
|
||
for piece in sub_pieces:
|
||
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
|
||
|
||
return result
|
||
|
||
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
|
||
if isinstance(ids, Tensor):
|
||
ids = ids.tolist()
|
||
|
||
result = []
|
||
prev_id = None
|
||
|
||
for token_id in ids:
|
||
# 跳过blank token
|
||
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
|
||
prev_id = None
|
||
continue
|
||
|
||
# CTC解码:跳过连续重复的字符
|
||
if remove_repeat and token_id == prev_id:
|
||
continue
|
||
|
||
char = self.id_to_vocab.get(token_id, "<UNK>")
|
||
|
||
if char == '<SPACE>':
|
||
char = ' '
|
||
|
||
result.append(char)
|
||
prev_id = token_id
|
||
|
||
return ''.join(result)
|
||
|
||
def get_special_token_id(self, token: str) -> int:
|
||
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
|
||
|
||
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
|
||
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
|
||
return self.decode(torch.argmax(log_probs, dim=-1)) |