import json from pathlib import Path import re import torch from torch import Tensor from handle.text_handle import process_text from handle.text_normalizer import UYGHUR_LETTERS class ASRTokenizer: def __init__(self, vocab_path: Path) -> None: self.vocabs: list[str] = [] self.vocabs.extend([ "", "", "" ]) for i in range(len(self.vocabs), 64): self.vocabs.append(f"REVERSED_{i}") with open(vocab_path, 'r', encoding='utf-8') as file: data: list[str] = json.load(file) for char in data: if char not in [' ', '\t', '\n'] and char not in self.vocabs: self.vocabs.append(char) self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)} self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)} self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0) def vocab_size(self): return len(self.vocabs) def _split_long_syllable(self, syll: str) -> list[str]: """将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)""" pieces = [] start = 0 n = len(syll) while start < n: # 从最大可能长度开始尝试匹配 for length in range(min(self._max_token_len, n - start), 0, -1): sub = syll[start:start + length] if sub in self.vocab_to_id: pieces.append(sub) start += length break else: # 理论上不会发生,因为单字母一定在词表 pieces.append("") start += 1 return pieces def encode(self, text: str) -> list[int]: tokens = process_text(text=text) result: list[int] = [] for token in tokens: if token == ' ': result.append(self.vocab_to_id['']) else: token_id = self.vocab_to_id.get(token) if token_id is not None: result.append(token_id) elif token in UYGHUR_LETTERS: result.append(self.vocab_to_id.get(token, self.vocab_to_id[''])) else: sub_pieces = self._split_long_syllable(token) for piece in sub_pieces: result.append(self.vocab_to_id.get(piece, self.vocab_to_id[''])) return result def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str: if isinstance(ids, Tensor): ids = ids.tolist() result = [] prev_id = None for token_id in ids: # 跳过blank token if remove_blank and token_id == self.vocab_to_id['']: prev_id = None continue # CTC解码:跳过连续重复的字符 if remove_repeat and token_id == prev_id: continue char = self.id_to_vocab.get(token_id, "") if char == '': char = ' ' result.append(char) prev_id = token_id return ''.join(result) def get_special_token_id(self, token: str) -> int: return self.vocab_to_id.get(token, self.vocab_to_id['']) def ctc_greedy_decode(self, log_probs: Tensor) -> str: """ CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """ return self.decode(torch.argmax(log_probs, dim=-1))