Files
audio_model/src/tokenizer.py

106 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
from pathlib import Path
import re
import torch
from torch import Tensor
from handle.text_handle import process_text
from handle.text_normalizer import UYGHUR_LETTERS
class ASRTokenizer:
def __init__(self, vocab_path: Path) -> None:
self.vocabs: list[str] = []
self.vocabs.extend([
"<BLANK>",
"<UNK>",
"<SPACE>"
])
for i in range(len(self.vocabs), 64):
self.vocabs.append(f"REVERSED_{i}")
with open(vocab_path, 'r', encoding='utf-8') as file:
data: list[str] = json.load(file)
for char in data:
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
self.vocabs.append(char)
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
def vocab_size(self):
return len(self.vocabs)
def _split_long_syllable(self, syll: str) -> list[str]:
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
pieces = []
start = 0
n = len(syll)
while start < n:
# 从最大可能长度开始尝试匹配
for length in range(min(self._max_token_len, n - start), 0, -1):
sub = syll[start:start + length]
if sub in self.vocab_to_id:
pieces.append(sub)
start += length
break
else:
# 理论上不会发生,因为单字母一定在词表
pieces.append("<UNK>")
start += 1
return pieces
def encode(self, text: str) -> list[int]:
tokens = process_text(text=text)
result: list[int] = []
for token in tokens:
if token == ' ':
result.append(self.vocab_to_id['<SPACE>'])
else:
token_id = self.vocab_to_id.get(token)
if token_id is not None:
result.append(token_id)
elif token in UYGHUR_LETTERS:
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
else:
sub_pieces = self._split_long_syllable(token)
for piece in sub_pieces:
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
return result
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
if isinstance(ids, Tensor):
ids = ids.tolist()
result = []
prev_id = None
for token_id in ids:
# 跳过blank token
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
prev_id = None
continue
# CTC解码跳过连续重复的字符
if remove_repeat and token_id == prev_id:
continue
char = self.id_to_vocab.get(token_id, "<UNK>")
if char == '<SPACE>':
char = ' '
result.append(char)
prev_id = token_id
return ''.join(result)
def get_special_token_id(self, token: str) -> int:
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
return self.decode(torch.argmax(log_probs, dim=-1))