feat: change waveform
This commit is contained in:
107
src/tokenizer.py
Normal file
107
src/tokenizer.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from handle.text_handle import process_text
|
||||
from handle.text_normalizer import UYGHUR_LETTERS
|
||||
|
||||
class ASRTokenizer:
|
||||
def __init__(self, vocab_path: Path) -> None:
|
||||
self.vocabs: list[str] = []
|
||||
self.vocabs.extend([
|
||||
"<BLANK>",
|
||||
"<UNK>",
|
||||
"<SPACE>"
|
||||
])
|
||||
for i in range(len(self.vocabs), 64):
|
||||
self.vocabs.append(f"REVERSED_{i}")
|
||||
|
||||
with open(vocab_path, 'r', encoding='utf-8') as file:
|
||||
data: list[str] = json.load(file)
|
||||
|
||||
for char in data:
|
||||
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
|
||||
self.vocabs.append(char)
|
||||
|
||||
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
|
||||
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
|
||||
|
||||
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
|
||||
|
||||
def vocab_size(self):
|
||||
return len(self.vocabs)
|
||||
|
||||
def _split_long_syllable(self, syll: str) -> list[str]:
|
||||
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
|
||||
pieces = []
|
||||
start = 0
|
||||
n = len(syll)
|
||||
while start < n:
|
||||
# 从最大可能长度开始尝试匹配
|
||||
for length in range(min(self._max_token_len, n - start), 0, -1):
|
||||
sub = syll[start:start + length]
|
||||
if sub in self.vocab_to_id:
|
||||
pieces.append(sub)
|
||||
start += length
|
||||
break
|
||||
else:
|
||||
# 理论上不会发生,因为单字母一定在词表
|
||||
pieces.append("<UNK>")
|
||||
start += 1
|
||||
return pieces
|
||||
|
||||
def encode(self, text: str) -> list[int]:
|
||||
tokens = process_text(text=text)
|
||||
result: list[int] = []
|
||||
|
||||
for token in tokens:
|
||||
if token == ' ':
|
||||
result.append(self.vocab_to_id['<SPACE>'])
|
||||
else:
|
||||
token_id = self.vocab_to_id.get(token)
|
||||
if token_id is not None:
|
||||
result.append(token_id)
|
||||
elif token in UYGHUR_LETTERS:
|
||||
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
|
||||
else:
|
||||
sub_pieces = self._split_long_syllable(token)
|
||||
for piece in sub_pieces:
|
||||
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
|
||||
|
||||
return result
|
||||
|
||||
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
|
||||
if isinstance(ids, Tensor):
|
||||
ids = ids.tolist()
|
||||
|
||||
result = []
|
||||
prev_id = None
|
||||
|
||||
for token_id in ids:
|
||||
# 跳过blank token
|
||||
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
|
||||
prev_id = None
|
||||
continue
|
||||
|
||||
# CTC解码:跳过连续重复的字符
|
||||
if remove_repeat and token_id == prev_id:
|
||||
continue
|
||||
|
||||
char = self.id_to_vocab.get(token_id, "<UNK>")
|
||||
|
||||
if char == '<SPACE>':
|
||||
char = ' '
|
||||
|
||||
result.append(char)
|
||||
prev_id = token_id
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
def get_special_token_id(self, token: str) -> int:
|
||||
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
|
||||
|
||||
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
|
||||
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
|
||||
return self.decode(torch.argmax(log_probs, dim=-1))
|
||||
Reference in New Issue
Block a user