feat: change waveform

This commit is contained in:
2026-05-07 11:29:21 +06:00
commit d31233a79a
21 changed files with 5330 additions and 0 deletions

107
src/tokenizer.py Normal file
View File

@@ -0,0 +1,107 @@
import json
from pathlib import Path
import re
import torch
from torch import Tensor
from handle.text_handle import process_text
from handle.text_normalizer import UYGHUR_LETTERS
class ASRTokenizer:
def __init__(self, vocab_path: Path) -> None:
self.vocabs: list[str] = []
self.vocabs.extend([
"<BLANK>",
"<UNK>",
"<SPACE>"
])
for i in range(len(self.vocabs), 64):
self.vocabs.append(f"REVERSED_{i}")
with open(vocab_path, 'r', encoding='utf-8') as file:
data: list[str] = json.load(file)
for char in data:
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
self.vocabs.append(char)
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
def vocab_size(self):
return len(self.vocabs)
def _split_long_syllable(self, syll: str) -> list[str]:
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
pieces = []
start = 0
n = len(syll)
while start < n:
# 从最大可能长度开始尝试匹配
for length in range(min(self._max_token_len, n - start), 0, -1):
sub = syll[start:start + length]
if sub in self.vocab_to_id:
pieces.append(sub)
start += length
break
else:
# 理论上不会发生,因为单字母一定在词表
pieces.append("<UNK>")
start += 1
return pieces
def encode(self, text: str) -> list[int]:
tokens = process_text(text=text)
result: list[int] = []
for token in tokens:
if token == ' ':
result.append(self.vocab_to_id['<SPACE>'])
else:
token_id = self.vocab_to_id.get(token)
if token_id is not None:
result.append(token_id)
elif token in UYGHUR_LETTERS:
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
else:
sub_pieces = self._split_long_syllable(token)
for piece in sub_pieces:
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
return result
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
if isinstance(ids, Tensor):
ids = ids.tolist()
result = []
prev_id = None
for token_id in ids:
# 跳过blank token
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
prev_id = None
continue
# CTC解码跳过连续重复的字符
if remove_repeat and token_id == prev_id:
continue
char = self.id_to_vocab.get(token_id, "<UNK>")
if char == '<SPACE>':
char = ' '
result.append(char)
prev_id = token_id
return ''.join(result)
def get_special_token_id(self, token: str) -> int:
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
return self.decode(torch.argmax(log_probs, dim=-1))