feat: change waveform
This commit is contained in:
17
.gitignore
vendored
Normal file
17
.gitignore
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# data
|
||||||
|
.checkpoints
|
||||||
|
.data
|
||||||
|
data
|
||||||
|
runs
|
||||||
|
config
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.11
|
||||||
16
.vscode/launch.json
vendored
Normal file
16
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "src/train.py",
|
||||||
|
"console": "integratedTerminal"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
33
pyproject.toml
Normal file
33
pyproject.toml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
[project]
|
||||||
|
name = "study-asr"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"ipython>=9.10.1",
|
||||||
|
"librosa>=0.11.0",
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"mutagen>=1.47.0",
|
||||||
|
"numpy>=2.4.4",
|
||||||
|
"pandas>=3.0.2",
|
||||||
|
"pillow>=12.2.0",
|
||||||
|
"pydub>=0.25.1",
|
||||||
|
"pyrubberband>=0.4.0",
|
||||||
|
"setuptools<82",
|
||||||
|
"silero-vad>=6.2.1",
|
||||||
|
"tensorboard>=2.20.0",
|
||||||
|
"tensorboardx>=2.6.5",
|
||||||
|
"torch==2.8.0",
|
||||||
|
"torch-audiomentations>=0.12.0",
|
||||||
|
"torchaudio==2.8.0",
|
||||||
|
"torchcodec==0.7.0",
|
||||||
|
"tqdm>=4.67.3",
|
||||||
|
"webrtcvad>=2.0.10",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[index]]
|
||||||
|
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||||
|
default = true
|
||||||
|
# url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
|
||||||
|
|
||||||
431
src/dataset.py
Normal file
431
src/dataset.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torchaudio
|
||||||
|
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter
|
||||||
|
from torchaudio.transforms import Resample, TimeStretch
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, TypedDict
|
||||||
|
from handle.text_normalizer import collapse_spaces, normalize_extended_uyghur_characters
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# 单个样本的数据结构(Dataset.__getitem__ 返回)
|
||||||
|
class BatchItem(TypedDict):
|
||||||
|
waveform: Tensor # [time]
|
||||||
|
target_ids: Tensor # [seq_len] 目标文本的token IDs
|
||||||
|
target_text: str # 原始文本
|
||||||
|
audio_path: str # 音频文件路径
|
||||||
|
|
||||||
|
|
||||||
|
# 批量数据的数据结构(collate_fn 返回,DataLoader 输出)
|
||||||
|
class Batch(TypedDict):
|
||||||
|
waveforms: Tensor # [batch, time]
|
||||||
|
targets: Tensor # [batch, max_len] padding后的目标IDs
|
||||||
|
waveform_lengths: Tensor # [batch] 每个样本的实际Waveform长度
|
||||||
|
target_lengths: Tensor # [batch] 每个样本的实际目标长度
|
||||||
|
target_texts: List[str] # [batch] 原始文本列表
|
||||||
|
audio_paths: List[str] # [batch] 音频路径列表
|
||||||
|
|
||||||
|
class TsvFormat(TypedDict):
|
||||||
|
client_id: str
|
||||||
|
path: str
|
||||||
|
sentence: str
|
||||||
|
up_votes: int
|
||||||
|
down_votes: int
|
||||||
|
age: str
|
||||||
|
gender: str
|
||||||
|
locale: str
|
||||||
|
|
||||||
|
class NoiseAugmentor:
|
||||||
|
def __init__(self, noise_root: Path, sample_rate: int=16000):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.noise_files = list(Path(noise_root).rglob("*.wav"))
|
||||||
|
|
||||||
|
def apply_real_noise(self, waveform: Tensor):
|
||||||
|
# 1. 随机选一个噪音文件
|
||||||
|
noise_path = random.choice(self.noise_files)
|
||||||
|
noise_waveform, sr = torchaudio.load_with_torchcodec(noise_path)
|
||||||
|
|
||||||
|
# Resample to target sample rate.
|
||||||
|
if sr != self.sample_rate:
|
||||||
|
noise_waveform = Resample(sr, self.sample_rate)(noise_waveform)
|
||||||
|
|
||||||
|
# Convert to mono if it is setro.
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# 3. 截取或填充,使其长度与语音一致
|
||||||
|
sig_len = waveform.shape[1]
|
||||||
|
noise_len = noise_waveform.shape[1]
|
||||||
|
|
||||||
|
if noise_len >= sig_len:
|
||||||
|
# 随机截取一段
|
||||||
|
start = random.randint(0, noise_len - sig_len)
|
||||||
|
noise_waveform = noise_waveform[:, start:start + sig_len]
|
||||||
|
else:
|
||||||
|
full_noise = torch.zeros_like(waveform)
|
||||||
|
start = random.randint(0, sig_len - noise_len)
|
||||||
|
full_noise[:, start : start + noise_len] = noise_waveform
|
||||||
|
noise_waveform = full_noise
|
||||||
|
|
||||||
|
# 4. 设定随机信噪比 SNR (5dB 到 20dB)
|
||||||
|
snr_db = random.uniform(5, 20)
|
||||||
|
|
||||||
|
# 5. 混合
|
||||||
|
return self._mix_at_snr(waveform, noise_waveform, snr_db)
|
||||||
|
|
||||||
|
def _mix_at_snr(self, signal: Tensor, noise: Tensor, snr_db: float):
|
||||||
|
s_p = signal.pow(2).mean()
|
||||||
|
n_p = noise.pow(2).mean()
|
||||||
|
snr_linear = 10**(snr_db / 10)
|
||||||
|
scale = torch.sqrt(s_p / (n_p * snr_linear + 1e-8))
|
||||||
|
|
||||||
|
noisy = signal + scale * noise
|
||||||
|
# 归一化,防止溢出
|
||||||
|
return noisy / (noisy.abs().max() + 1e-8)
|
||||||
|
|
||||||
|
class CommonVoiceDataset(Dataset[BatchItem]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tsv_path: Path,
|
||||||
|
audio_dir: Path,
|
||||||
|
noise_dir: Path,
|
||||||
|
tokenizer: ASRTokenizer,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
||||||
|
augment: bool = True,
|
||||||
|
augment_prob: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.noise_augmentor = NoiseAugmentor(noise_root=noise_dir, sample_rate=sample_rate)
|
||||||
|
self.audio_dir = Path(audio_dir)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.max_audio_len = max_audio_len
|
||||||
|
self.augment = augment
|
||||||
|
self.augment_prob = augment_prob
|
||||||
|
|
||||||
|
self.data: pd.DataFrame = pd.read_csv(tsv_path, sep='\t')
|
||||||
|
|
||||||
|
valid_indices = []
|
||||||
|
for index, row in self.data.iterrows():
|
||||||
|
audio_path: Path = self.audio_dir / row['path']
|
||||||
|
if audio_path.exists():
|
||||||
|
valid_indices.append(index)
|
||||||
|
|
||||||
|
self.data = self.data.loc[valid_indices].reset_index(drop=True)
|
||||||
|
|
||||||
|
self.gain_up = Gain(min_gain_in_db=4, max_gain_in_db=8, p=1.0, output_type='tensor')
|
||||||
|
self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')
|
||||||
|
self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
||||||
|
self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
||||||
|
self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||||
|
self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||||
|
self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type="tensor")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def _load_audio(self, audio_path: Path) -> Tensor:
|
||||||
|
waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)
|
||||||
|
|
||||||
|
# Resample to target sample rate.
|
||||||
|
if sample_rate != self.sample_rate:
|
||||||
|
waveform = Resample(sample_rate, self.sample_rate)(waveform)
|
||||||
|
|
||||||
|
# Convert to mono if it is setro.
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# Normalization
|
||||||
|
waveform = waveform / waveform.abs().max()
|
||||||
|
|
||||||
|
# Clip waveform exceeds from max length.
|
||||||
|
if waveform.shape[1] > self.max_audio_len:
|
||||||
|
waveform = waveform[:, :self.max_audio_len]
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _augment_waveform(self, waveform: Tensor) -> Tensor:
|
||||||
|
if not self.augment or random.random() > self.augment_prob:
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
# 1. voice Stretch/Compress
|
||||||
|
if random.random() < 0.6:
|
||||||
|
waveform = self._stretch_or_compress(waveform=waveform)
|
||||||
|
|
||||||
|
if random.random() < 0.4:
|
||||||
|
waveform = self.noise_augmentor.apply_real_noise(waveform)
|
||||||
|
|
||||||
|
if random.random() < 0.3:
|
||||||
|
waveform = self._time_mask_waveform(waveform=waveform)
|
||||||
|
|
||||||
|
# torch_audiomentations: [1, time] -> [1, 1, time]
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
waveform_3d = waveform.unsqueeze(0)
|
||||||
|
# 随机选择一种频谱增强
|
||||||
|
choice = random.random()
|
||||||
|
if choice < 0.15:
|
||||||
|
# 增益变化(上或下)
|
||||||
|
if random.random() < 0.5:
|
||||||
|
waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
else:
|
||||||
|
waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
elif choice < 0.25:
|
||||||
|
# 音高变化(上或下)
|
||||||
|
if random.random() < 0.5:
|
||||||
|
waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
else:
|
||||||
|
waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
elif choice < 0.30:
|
||||||
|
# 低通滤波(声音发闷)
|
||||||
|
waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
# elif choice < 0.32:
|
||||||
|
# # 低通滤波(声音发闷)
|
||||||
|
# waveform_3d = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
elif choice < 0.35:
|
||||||
|
# 高通滤波(电话效果)
|
||||||
|
waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
|
||||||
|
# [1, 1, time] -> [1, time]
|
||||||
|
waveform = waveform_3d.squeeze(0)
|
||||||
|
|
||||||
|
# 防止多次 augment 后振幅溢出,最后归一化
|
||||||
|
max_amp = waveform.abs().max()
|
||||||
|
if max_amp > 1.0:
|
||||||
|
waveform = waveform / max_amp
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _stretch_or_compress(self, waveform: Tensor) -> Tensor:
|
||||||
|
speed_factor = random.uniform(0.85, 1.4) # (Speed Change: 0.85x - 1.4x)
|
||||||
|
spec = torch.stft(
|
||||||
|
waveform.squeeze(0),
|
||||||
|
n_fft=400,
|
||||||
|
hop_length=160,
|
||||||
|
window=torch.hann_window(400).to(waveform.device),
|
||||||
|
return_complex=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 时间拉伸(不改变音高)
|
||||||
|
stretched_spec = TimeStretch(hop_length=160, n_freq=spec.shape[-2], fixed_rate=speed_factor)(spec)
|
||||||
|
|
||||||
|
# 转回波形
|
||||||
|
waveform_stretched = torch.istft(stretched_spec, n_fft=400, hop_length=160, window=torch.hann_window(400).to(waveform.device)).unsqueeze(0)
|
||||||
|
|
||||||
|
return waveform_stretched
|
||||||
|
|
||||||
|
def _time_mask_waveform(self, waveform: Tensor) -> Tensor:
|
||||||
|
audio_len = waveform.shape[1]
|
||||||
|
sr = self.sample_rate # 16000
|
||||||
|
|
||||||
|
# 设置参数:单次遮盖最长 0.4 秒 (6400个点)
|
||||||
|
max_mask_time = 0.4
|
||||||
|
max_mask_samples = int(sr * max_mask_time)
|
||||||
|
|
||||||
|
# 根据音频长度决定遮盖次数:
|
||||||
|
# 比如每 3 秒钟允许遮盖 1 次
|
||||||
|
num_masks = max(1, audio_len // (sr * 3))
|
||||||
|
|
||||||
|
for _ in range(num_masks):
|
||||||
|
# 每次随机遮盖 0.1s 到 0.4s
|
||||||
|
current_mask_len = random.randint(int(sr * 0.1), max_mask_samples)
|
||||||
|
|
||||||
|
if audio_len > current_mask_len:
|
||||||
|
start_pos = random.randint(0, audio_len - current_mask_len)
|
||||||
|
|
||||||
|
# 填充微小噪音(模拟环境底噪)
|
||||||
|
noise = torch.randn(1, current_mask_len).to(waveform.device) * 0.002
|
||||||
|
waveform[:, start_pos : start_pos + current_mask_len] = noise
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def __getitem__(self, index) -> BatchItem:
|
||||||
|
row: TsvFormat = self.data.iloc[index]
|
||||||
|
audio_path: Path = self.audio_dir / row['path']
|
||||||
|
# text: str = unicodedata.normalize('NFC', row['sentence'].strip())
|
||||||
|
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
||||||
|
|
||||||
|
waveform = self._load_audio(audio_path=audio_path)
|
||||||
|
waveform = self._augment_waveform(waveform)
|
||||||
|
waveform = waveform.squeeze(0)
|
||||||
|
|
||||||
|
return BatchItem(
|
||||||
|
waveform=waveform,
|
||||||
|
target_ids=torch.tensor(self.tokenizer.encode(text), dtype=torch.long),
|
||||||
|
target_text=text,
|
||||||
|
audio_path=str(audio_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
def collate_fn(items: List[BatchItem]) -> Batch:
|
||||||
|
max_waveform_len = max(item['waveform'].shape[0] for item in items)
|
||||||
|
max_target_len = max(len(item['target_ids']) for item in items)
|
||||||
|
|
||||||
|
batch_size = len(items)
|
||||||
|
|
||||||
|
waveforms = torch.zeros(batch_size, max_waveform_len)
|
||||||
|
targets = torch.zeros(batch_size, max_target_len, dtype=torch.long)
|
||||||
|
waveform_lengths = torch.zeros(batch_size, dtype=torch.long)
|
||||||
|
target_lengths = torch.zeros(batch_size, dtype=torch.long)
|
||||||
|
|
||||||
|
target_texts = []
|
||||||
|
audio_paths = []
|
||||||
|
|
||||||
|
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
waveform_len = item['waveform'].shape[0]
|
||||||
|
target_len = len(item['target_ids'])
|
||||||
|
|
||||||
|
waveforms[i, :waveform_len] = item['waveform']
|
||||||
|
targets[i, :target_len] = item['target_ids']
|
||||||
|
waveform_lengths[i] = waveform_len
|
||||||
|
target_lengths[i] = target_len
|
||||||
|
|
||||||
|
target_texts.append(item['target_text'])
|
||||||
|
audio_paths.append(item['audio_path'])
|
||||||
|
|
||||||
|
return Batch(
|
||||||
|
waveforms=waveforms,
|
||||||
|
targets=targets,
|
||||||
|
waveform_lengths=waveform_lengths,
|
||||||
|
target_lengths=target_lengths,
|
||||||
|
target_texts=target_texts,
|
||||||
|
audio_paths=audio_paths
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
|
||||||
|
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
|
||||||
|
|
||||||
|
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 测试代码 ============
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
# 初始化tokenizer
|
||||||
|
tokenizer = ASRTokenizer(workspace_dir / 'config' / 'asr_vocab.json')
|
||||||
|
|
||||||
|
# 创建数据加载器
|
||||||
|
dataloader = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / 'data' / 'ug' / 'train.tsv',
|
||||||
|
audio_dir=workspace_dir / 'data' / 'ug' / 'clips',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试加载一个batch
|
||||||
|
print("测试数据加载:")
|
||||||
|
for batch in dataloader:
|
||||||
|
batch: Batch
|
||||||
|
print(f"Mel specs shape: {batch['mel_specs'].shape}")
|
||||||
|
print(f"Targets shape: {batch['targets'].shape}")
|
||||||
|
print(f"Mel lengths: {batch['mel_lengths']}")
|
||||||
|
print(f"Target lengths: {batch['target_lengths']}")
|
||||||
|
print(f"Target texts: {batch['target_texts']}")
|
||||||
|
print(f"Audio paths: {batch['audio_paths']}")
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
||||||
|
|
||||||
|
# --- 1. 加载 VAD 模型 ---
|
||||||
|
# 推荐使用 GPU (cuda) 如果可用,否则使用 CPU (cpu)
|
||||||
|
vad_model = load_silero_vad(source="local", force_onnx=False)
|
||||||
|
|
||||||
|
def process_audio_with_vad_and_asr(audio_path, inference_module):
|
||||||
|
"""
|
||||||
|
使用 Silero VAD 分割音频,然后对每个语音片段进行 ASR 转录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path (str): 输入音频文件的路径。
|
||||||
|
inference_module: 您封装了 transcribe 方法的模块或对象。
|
||||||
|
需要确保其 _load_audio 和 transcribe 方法可用。
|
||||||
|
"""
|
||||||
|
print(f"正在加载音频: {audio_path}")
|
||||||
|
|
||||||
|
# --- 2. 加载音频用于 VAD 检测 ---
|
||||||
|
# Silero VAD 推荐使用 16kHz 采样率
|
||||||
|
waveform_vad, sample_rate_vad = torchaudio.load(audio_path)
|
||||||
|
if sample_rate_vad != 16000:
|
||||||
|
# 如果采样率不是16kHz,需要重采样以供VAD使用
|
||||||
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_vad, new_freq=16000)
|
||||||
|
waveform_vad = resampler(waveform_vad)
|
||||||
|
sample_rate_vad = 16000
|
||||||
|
|
||||||
|
# --- 3. 获取语音时间段 ---
|
||||||
|
# get_speech_timestamps 返回一个列表,每个元素是 {'start': start_sample, 'end': end_sample} 的字典
|
||||||
|
# 采样率是16000,所以时间戳单位是 1/16000 秒
|
||||||
|
speech_timestamps = get_speech_timestamps(
|
||||||
|
waveform_vad,
|
||||||
|
vad_model,
|
||||||
|
sampling_rate=sample_rate_vad,
|
||||||
|
threshold=0.5, # 可以根据需要调整阈值
|
||||||
|
min_speech_duration_ms=250, # 最小语音持续时间,防止短噪音被误判
|
||||||
|
max_speech_duration_s=float('inf'), # 最大语音持续时间,float('inf') 表示不限制
|
||||||
|
min_silence_duration_ms=100, # 最小静音间隔,用于分割语音块
|
||||||
|
window_size_samples=1536, # VAD窗口大小
|
||||||
|
speech_pad_ms=30 # 在语音块前后添加的填充时间
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"检测到 {len(speech_timestamps)} 个语音片段。")
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
for i, ts in enumerate(speech_timestamps):
|
||||||
|
start_sample = int(ts['start'])
|
||||||
|
end_sample = int(ts['end'])
|
||||||
|
|
||||||
|
# 计算时间戳(秒)
|
||||||
|
start_time = start_sample / sample_rate_vad
|
||||||
|
end_time = end_sample / sample_rate_vad
|
||||||
|
|
||||||
|
print(f"\n处理第 {i+1} 个片段: 时间范围 [{start_time:.2f}s - {end_time:.2f}s]")
|
||||||
|
|
||||||
|
# --- 4. 从原始音频中提取此片段 ---
|
||||||
|
# 注意:这里假设您的 transcribe 函数可以接受 waveform 张量。
|
||||||
|
# 我们需要从原始可能不同采样率的音频中提取片段,或者用VAD处理过的waveform。
|
||||||
|
# 为了匹配您原来的 _load_audio 方式,我们用 torchaudio 再次精确加载片段。
|
||||||
|
|
||||||
|
# 计算原始音频中的样本索引(如果原始音频采样率与VAD不同)
|
||||||
|
original_waveform, original_sr = torchaudio.load(audio_path)
|
||||||
|
if sample_rate_vad != original_sr:
|
||||||
|
# 如果VAD和原始音频采样率不同,需要重新映射索引
|
||||||
|
start_idx_orig = int(start_sample * (original_sr / sample_rate_vad))
|
||||||
|
end_idx_orig = int(end_sample * (original_sr / sample_rate_vad))
|
||||||
|
else:
|
||||||
|
start_idx_orig = start_sample
|
||||||
|
end_idx_orig = end_sample
|
||||||
|
|
||||||
|
segment_waveform = original_waveform[:, start_idx_orig:end_idx_orig]
|
||||||
|
|
||||||
|
# --- 5. 对该片段进行 ASR 转录 ---
|
||||||
|
# 这里调用您原有的 transcribe 方法
|
||||||
|
try:
|
||||||
|
# 假设 transcribe 方法接受一个 waveform tensor
|
||||||
|
segment_text = inference_module.transcribe(waveform=segment_waveform)
|
||||||
|
|
||||||
|
print(f" -> 转录结果: {segment_text}")
|
||||||
|
full_text += f"[{start_time:.2f}-{end_time:.2f}s] {segment_text}\n"
|
||||||
|
except Exception as e:
|
||||||
|
print(f" -> 转录第 {i+1} 个片段时出错: {e}")
|
||||||
|
|
||||||
|
print("\n--- 完整转录结果 ---")
|
||||||
|
print(full_text)
|
||||||
|
|
||||||
|
|
||||||
|
# --- 使用示例 ---
|
||||||
|
# 假设您有一个名为 'inference' 的模块对象,它有 _load_audio 和 transcribe 方法
|
||||||
|
# audio_path = "your_audio_file.wav"
|
||||||
|
# process_audio_with_vad_and_asr(audio_path, inference)
|
||||||
160
src/handle/audio_agement.ipynb
Normal file
160
src/handle/audio_agement.ipynb
Normal file
File diff suppressed because one or more lines are too long
70
src/handle/audio_analyze.py
Normal file
70
src/handle/audio_analyze.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from mutagen.mp3 import MP3
|
||||||
|
from mutagen import MutagenError
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
def format_duration(seconds):
|
||||||
|
"""将秒数格式化为 时:分:秒 的格式"""
|
||||||
|
hours = int(seconds // 3600)
|
||||||
|
minutes = int((seconds % 3600) // 60)
|
||||||
|
secs = int(seconds % 60)
|
||||||
|
|
||||||
|
if hours > 0:
|
||||||
|
return f"{hours}小时{minutes}分{secs}秒"
|
||||||
|
elif minutes > 0:
|
||||||
|
return f"{minutes}分{secs}秒"
|
||||||
|
else:
|
||||||
|
return f"{secs}秒"
|
||||||
|
|
||||||
|
def get_mp3_duration(file_path: Path):
|
||||||
|
"""获取单个MP3文件的时长(秒)"""
|
||||||
|
try:
|
||||||
|
audio = MP3(file_path)
|
||||||
|
return audio.info.length
|
||||||
|
except MutagenError as e:
|
||||||
|
print(f"错误:无法读取文件 {file_path} - {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误:处理文件 {file_path} 时出错 - {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def analyze_mp3_files(directory: Path, ):
|
||||||
|
"""分析目录中的所有MP3文件"""
|
||||||
|
results = []
|
||||||
|
total_duration = 0
|
||||||
|
df = pd.read_csv(workspace_dir / '.data/ug/train_new.tsv', sep='\t')
|
||||||
|
|
||||||
|
for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"):
|
||||||
|
mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path']
|
||||||
|
duration = get_mp3_duration(mp3_file)
|
||||||
|
|
||||||
|
if duration is not None:
|
||||||
|
total_duration += duration
|
||||||
|
duration_str = format_duration(duration)
|
||||||
|
file_size = mp3_file.stat().st_size / (1024 * 1024) # 转换为MB
|
||||||
|
|
||||||
|
result_line = f"{mp3_file.name:<50} {duration_str:>20} ({file_size:.2f} MB)"
|
||||||
|
results.append(result_line)
|
||||||
|
else:
|
||||||
|
error_line = f"{mp3_file.name:<50} {'读取失败':>20}"
|
||||||
|
print(error_line)
|
||||||
|
results.append(error_line)
|
||||||
|
|
||||||
|
mp3_files: list[Path] = []
|
||||||
|
for ext in ['*.mp3', '*.MP3']:
|
||||||
|
mp3_files.extend(directory.rglob(ext))
|
||||||
|
|
||||||
|
print(f"tsv 找到 {len(results)} 个MP3文件\n")
|
||||||
|
print(f"找到 {len(mp3_files)} 个MP3文件\n")
|
||||||
|
print(f"\n总时长: {format_duration(total_duration)}")
|
||||||
|
print(f"总时长(秒): {total_duration:.2f} 秒")
|
||||||
|
print(f"总时长(分钟): {total_duration/60:.2f} 分钟")
|
||||||
|
print(f"总时长(小时): {total_duration/3600:.2f} 小时")
|
||||||
|
print(f"文件总数: {len(mp3_files)}")
|
||||||
|
|
||||||
|
|
||||||
|
audio_directory = Path(workspace_dir / '.data/ug/clips')
|
||||||
|
analyze_mp3_files(directory=audio_directory)
|
||||||
13
src/handle/export_model_state_dict.py
Normal file
13
src/handle/export_model_state_dict.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device = "cuda:0"
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
input_checkpoint = workspace_dir.joinpath('.checkpoints/best_wer_model.pt')
|
||||||
|
output_checkpoint = workspace_dir.joinpath('.checkpoints/prodect_best_wer_model.pt')
|
||||||
|
|
||||||
|
checkpoint = torch.load(input_checkpoint, map_location=device)
|
||||||
|
torch.save(checkpoint['model_state_dict'], output_checkpoint)
|
||||||
37
src/handle/export_new_tsv.py
Normal file
37
src/handle/export_new_tsv.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
# 读取你合并后的总表 (253,430条)
|
||||||
|
df = pd.read_csv(workspace_dir / ".data/ug/validated.tsv", sep='\t')
|
||||||
|
|
||||||
|
# 1. 获取所有唯一的说话人
|
||||||
|
all_speakers = df['client_id'].unique()
|
||||||
|
|
||||||
|
# 2. 随机打乱说话人顺序
|
||||||
|
import random
|
||||||
|
random.seed(42) # 固定随机种子,保证实验可重复
|
||||||
|
random.shuffle(all_speakers)
|
||||||
|
|
||||||
|
# 3. 挑选验证集说话人,直到录音总数达到 ~8000 条
|
||||||
|
val_indices = []
|
||||||
|
val_count = 0
|
||||||
|
target_val_size = 8000
|
||||||
|
|
||||||
|
for speaker in all_speakers:
|
||||||
|
speaker_data = df[df['client_id'] == speaker]
|
||||||
|
val_indices.extend(speaker_data.index.tolist())
|
||||||
|
val_count += len(speaker_data)
|
||||||
|
if val_count >= target_val_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 4. 划分文件
|
||||||
|
df_val = df.loc[val_indices]
|
||||||
|
df_train = df.drop(val_indices)
|
||||||
|
|
||||||
|
print(f"训练集条数: {len(df_train)}")
|
||||||
|
print(f"验证集条数: {len(df_val)}")
|
||||||
|
|
||||||
|
df_train.to_csv(workspace_dir / ".data/ug/train_new.tsv", sep='\t', index=False)
|
||||||
|
df_val.to_csv(workspace_dir / ".data/ug/val_new.tsv", sep='\t', index=False)
|
||||||
92
src/handle/export_vocab.py
Normal file
92
src/handle/export_vocab.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Counter
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from text_handle import process_text
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
data_dir = Path("/home/blacksheep/projekts/study")
|
||||||
|
input_files = [
|
||||||
|
data_dir / "data/multilingual/hanziDB_translated-simplified.jsonl",
|
||||||
|
data_dir / "data/multilingual/hanziDB_translated_validationset-simplified.jsonl",
|
||||||
|
data_dir / "data/multilingual/tatoeba-tr-en-ug-uz-kz-zh-simplified.jsonl",
|
||||||
|
data_dir / "data/multilingual/export_csv_database.jsonl",
|
||||||
|
data_dir / "data/multilingual/tatoeba_sentences.jsonl",
|
||||||
|
]
|
||||||
|
output_file = workspace_dir / "config/asr_vocab_1.json"
|
||||||
|
syllabizes = []
|
||||||
|
|
||||||
|
# langs = ["uig_Arab"]
|
||||||
|
# for input_file in input_files:
|
||||||
|
# with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
# total_lines = sum(1 for _ in f)
|
||||||
|
# f.seek(0)
|
||||||
|
# for line in tqdm(f, total=total_lines, desc=f"Processing lines {input_file}"):
|
||||||
|
# data: dict[str, str] = json.loads(line)
|
||||||
|
# for lang in langs:
|
||||||
|
# if lang in data:
|
||||||
|
# syllabizes.extend(export_syllabize(data[lang]))
|
||||||
|
|
||||||
|
# print(f'data_dir syllabize len: {len(syllabizes):,}, set: {len(set(syllabizes)):,}')
|
||||||
|
|
||||||
|
|
||||||
|
tsv_files = [
|
||||||
|
# workspace_dir / "data/ug/test.tsv",
|
||||||
|
# workspace_dir / "data/ug/invalidated.tsv",
|
||||||
|
# workspace_dir / "data/ug/train.tsv",
|
||||||
|
# workspace_dir / "data/ug/validated.tsv",
|
||||||
|
# workspace_dir / "data/ug/reported.tsv",
|
||||||
|
# workspace_dir / "data/ug/dev.tsv",
|
||||||
|
# workspace_dir / "data/ug/other.tsv",
|
||||||
|
# workspace_dir / ".data/ug/invalidated.tsv",
|
||||||
|
workspace_dir / ".data/ug/train.tsv",
|
||||||
|
# workspace_dir / ".data/ug/clip_durations.tsv", # not sentence
|
||||||
|
# workspace_dir / ".data/ug/test.tsv",
|
||||||
|
# workspace_dir / ".data/ug/validated_sentences.tsv",
|
||||||
|
# workspace_dir / ".data/ug/other.tsv",
|
||||||
|
# workspace_dir / ".data/ug/validated.tsv",
|
||||||
|
# workspace_dir / ".data/ug/dev.tsv",
|
||||||
|
# workspace_dir / ".data/ug/unvalidated_sentences.tsv",
|
||||||
|
# workspace_dir / ".data/ug/reported.tsv" # Lacking sentence
|
||||||
|
]
|
||||||
|
|
||||||
|
for tsv_file in tsv_files:
|
||||||
|
data = pd.read_csv(tsv_file, sep='\t')
|
||||||
|
# 带进度条处理每行数据
|
||||||
|
for index, row in tqdm(data.iterrows(), total=len(data), desc=f"Processing {tsv_file}"):
|
||||||
|
syllabizes.extend(process_text(row['sentence'].strip()))
|
||||||
|
|
||||||
|
|
||||||
|
# 统计所有音节出现次数
|
||||||
|
syllable_counter = Counter(syllabizes)
|
||||||
|
# 过滤出出现100次以上的音节
|
||||||
|
freq_100_plus = {k: v for k, v in syllable_counter.items() if v >= 150}
|
||||||
|
freq_100_minus = {k: v for k, v in syllable_counter.items() if v <= 100}
|
||||||
|
|
||||||
|
# 保存100次以上的音节列表(只有音节,排序)
|
||||||
|
vocab = sorted(list(freq_100_plus.keys()), key=len)
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(vocab, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# # # 保存100次以上的音节及次数
|
||||||
|
sorted_freq_100_plus = dict(sorted(freq_100_plus.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
with open(workspace_dir / 'config/syllables_freq_100_plus.json', 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(sorted_freq_100_plus, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# # 统计信息
|
||||||
|
print(f"总音节数: {len(syllabizes):,}")
|
||||||
|
print(f"唯一音节数: {len(syllable_counter):,}")
|
||||||
|
print(f"出现100次以上的音节数: {len(freq_100_plus):,}")
|
||||||
|
print(f"出现100次以下的音节数: {len(freq_100_minus):,}")
|
||||||
|
|
||||||
|
# 区间统计(不累积)
|
||||||
|
print("\n=== 音节使用次数统计(区间) ===")
|
||||||
|
for low in range(0, 100, 10):
|
||||||
|
high = low + 9
|
||||||
|
count = sum(1 for freq in syllable_counter.values() if low <= freq <= high)
|
||||||
|
print(f"出现 {low}-{high} 次: {count:,} 个音节")
|
||||||
|
# 100次以上
|
||||||
|
count_100plus = sum(1 for freq in syllable_counter.values() if freq >= 100)
|
||||||
|
print(f"出现 100+ 次: {count_100plus} 个音节")
|
||||||
490
src/handle/test.ipynb
Normal file
490
src/handle/test.ipynb
Normal file
File diff suppressed because one or more lines are too long
220
src/handle/test.py
Normal file
220
src/handle/test.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torchaudio
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(workspace_dir.joinpath('src')))
|
||||||
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from handle.text_handle import process_text
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
|
||||||
|
# 读取 TSV
|
||||||
|
# df = pd.read_csv(workspace_dir / '.data/ug/train.tsv', sep='\t')
|
||||||
|
|
||||||
|
# # 计算每个音频的时长(秒)
|
||||||
|
# durations = []
|
||||||
|
# for audio_path in tqdm(workspace_dir / '.data/ug/clips' / df['path'], desc="计算音频时长"):
|
||||||
|
# try:
|
||||||
|
# # 获取音频信息(不加载整个文件,速度快)
|
||||||
|
# info = torchaudio.info(audio_path)
|
||||||
|
# duration = info.num_frames / info.sample_rate
|
||||||
|
# durations.append(duration)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"读取失败: {audio_path}, 错误: {e}")
|
||||||
|
# durations.append(None)
|
||||||
|
|
||||||
|
# # 添加到 DataFrame
|
||||||
|
# df['duration'] = durations
|
||||||
|
|
||||||
|
# # 统计
|
||||||
|
# print(df['duration'].describe())
|
||||||
|
# print(f"超过20秒的样本: {(df['duration'] > 20).sum()}")
|
||||||
|
|
||||||
|
# 保存结果(可选)
|
||||||
|
# df.to_csv('data/audio/train_with_duration.tsv', sep='\t', index=False)
|
||||||
|
|
||||||
|
# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_40252722.mp3'
|
||||||
|
# info = torchaudio.info(audio_path)
|
||||||
|
# print("info:", info)
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def analyze_audio_quality(audio_path, sample_rate_target=16000):
|
||||||
|
"""
|
||||||
|
评估音频质量,返回多个指标
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载音频
|
||||||
|
waveform, sr = torchaudio.load(str(audio_path))
|
||||||
|
|
||||||
|
# 重采样 + 单声道
|
||||||
|
if sr != sample_rate_target:
|
||||||
|
resampler = torchaudio.transforms.Resample(sr, sample_rate_target)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
max_val = waveform.abs().max()
|
||||||
|
if max_val > 0:
|
||||||
|
waveform = waveform / max_val
|
||||||
|
|
||||||
|
# ===== 指标 1: 静音比例 =====
|
||||||
|
silence_threshold = 0.01
|
||||||
|
is_silence = waveform.abs() < silence_threshold
|
||||||
|
silence_ratio = is_silence.float().mean().item()
|
||||||
|
|
||||||
|
# ===== 指标 2: 动态范围 =====
|
||||||
|
frame_len = int(0.025 * sample_rate_target) # 25ms
|
||||||
|
hop_len = int(0.010 * sample_rate_target) # 10ms
|
||||||
|
|
||||||
|
num_frames = (waveform.shape[1] - frame_len) // hop_len + 1
|
||||||
|
if num_frames <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
energies = []
|
||||||
|
for i in range(num_frames):
|
||||||
|
start = i * hop_len
|
||||||
|
frame = waveform[:, start:start + frame_len]
|
||||||
|
rms = frame.pow(2).mean().sqrt().item()
|
||||||
|
energies.append(rms)
|
||||||
|
|
||||||
|
energies = np.array(energies)
|
||||||
|
energies = np.clip(energies, 1e-10, None)
|
||||||
|
|
||||||
|
log_energies = 20 * np.log10(energies)
|
||||||
|
dynamic_range = log_energies.max() - log_energies.min()
|
||||||
|
|
||||||
|
# ===== 指标 3: 语音活跃度 =====
|
||||||
|
energy_threshold = np.percentile(energies, 30)
|
||||||
|
voice_ratio = (energies > energy_threshold).mean()
|
||||||
|
|
||||||
|
# ===== 指标 4: 频谱中心频率 =====
|
||||||
|
spec_transform = torchaudio.transforms.Spectrogram(n_fft=512)
|
||||||
|
spec = spec_transform(waveform) # [1, freq, time]
|
||||||
|
|
||||||
|
# 用 torch 计算,避免 numpy axis 问题
|
||||||
|
freqs = torch.fft.rfftfreq(512, d=1.0/sample_rate_target)
|
||||||
|
spec_mean = spec.mean(dim=-1).squeeze() # [freq_bins] torch 张量
|
||||||
|
|
||||||
|
# 确保类型一致,用 torch 计算
|
||||||
|
spectral_centroid = (freqs * spec_mean).sum() / (spec_mean.sum() + 1e-10)
|
||||||
|
spectral_centroid = spectral_centroid.item()
|
||||||
|
|
||||||
|
# ===== 指标 5: 频谱通量 =====
|
||||||
|
spec_np = spec.squeeze().numpy() # [freq, time]
|
||||||
|
flux = np.abs(np.diff(spec_np, axis=1)).mean()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'duration': waveform.shape[1] / sample_rate_target,
|
||||||
|
'silence_ratio': silence_ratio,
|
||||||
|
'dynamic_range_db': float(dynamic_range),
|
||||||
|
'voice_ratio': float(voice_ratio),
|
||||||
|
'spectral_centroid': spectral_centroid,
|
||||||
|
'spectral_flux': float(flux),
|
||||||
|
'energy_std': float(energies.std()),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"失败: {audio_path}, {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 批量处理 =====
|
||||||
|
# df = pd.read_csv(workspace_dir / '.data/ug/train.tsv', sep='\t')
|
||||||
|
|
||||||
|
# results = []
|
||||||
|
# for _, row in tqdm(df.iterrows(), total=len(df), desc="评估音频质量"):
|
||||||
|
# audio_path = workspace_dir / '.data/ug/clips' / row['path']
|
||||||
|
# metrics = analyze_audio_quality(audio_path)
|
||||||
|
# if metrics:
|
||||||
|
# metrics['path'] = row['path']
|
||||||
|
# results.append(metrics)
|
||||||
|
|
||||||
|
# # 汇总
|
||||||
|
# quality_df = pd.DataFrame(results)
|
||||||
|
# print(quality_df.describe())
|
||||||
|
|
||||||
|
|
||||||
|
# # ===== 质量评分 =====
|
||||||
|
# def score_quality(row):
|
||||||
|
# score = 100
|
||||||
|
|
||||||
|
# if row['silence_ratio'] > 0.5:
|
||||||
|
# score -= 30
|
||||||
|
# elif row['silence_ratio'] > 0.3:
|
||||||
|
# score -= 15
|
||||||
|
|
||||||
|
# if row['dynamic_range_db'] < 10:
|
||||||
|
# score -= 25
|
||||||
|
# elif row['dynamic_range_db'] < 20:
|
||||||
|
# score -= 10
|
||||||
|
|
||||||
|
# if row['voice_ratio'] < 0.3:
|
||||||
|
# score -= 20
|
||||||
|
|
||||||
|
# if row['spectral_centroid'] > 4000:
|
||||||
|
# score -= 15
|
||||||
|
|
||||||
|
# return max(0, score)
|
||||||
|
|
||||||
|
# quality_df['quality_score'] = quality_df.apply(score_quality, axis=1)
|
||||||
|
|
||||||
|
# quality_df['grade'] = pd.cut(
|
||||||
|
# quality_df['quality_score'],
|
||||||
|
# bins=[0, 40, 60, 80, 100],
|
||||||
|
# labels=['差(弃用)', '较差', '一般', '良好']
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print("\n质量分级统计:")
|
||||||
|
# print(quality_df['grade'].value_counts())
|
||||||
|
|
||||||
|
# # 保存结果
|
||||||
|
# quality_df.to_csv('audio_quality_analysis.csv', index=False)
|
||||||
|
# print("\n结果已保存到 audio_quality_analysis.csv")
|
||||||
|
|
||||||
|
|
||||||
|
# Test
|
||||||
|
# text = "مەن مەكتەپكە باردىم"
|
||||||
|
# text = "ئاۋۋال كومپىيوتىرنى بىر كۆرسەم شۇنىڭغا قاراپ ماسلاشتۇرسام بوللاتتى"
|
||||||
|
text = "قاپاقنى پۇلغا ئالماي، باراڭنى پۇلغا ئاپتۇ"
|
||||||
|
# text = "ئامېىقى جاھان گۈرلىكىە قارشتۇرۇپ چوكشەڭگە ياردەم بېرىش ئۇرۇشىنىڭ ئلۇغ غەلبىسى جۇڭگو خەلقى ئورندىن دەستۇرغاندىكىن، دۇنيانىڭ شەرقىدە قەت كۆتۈرگەنلكىنى خىتاب لامىسى."
|
||||||
|
# text = "ئاۋسترالىيە"
|
||||||
|
|
||||||
|
text = " ھەر قانداق ئىشتا كىشىلەر بىلەن كېڭىشىش كېرەك. لېكىن كۆڭۈل تارتقان ئىشنى قىلىۋېرىش كېرەك."
|
||||||
|
text = "\" 15 يىل بۇرۇن مەن بىلەن سەي چۇڭشىن (ئالى بابانىڭ قۇرغۇچىسىدىن بىرى) ئامېرىكىغا بېرىپ 30 نەچچە مەبلەغ سالغۇچى شىركەتلەر بىلەن كۆرۈشكىنىمىزدە ئۇلار بىزنى رەت قىلىپ ئىشىك سىرتىدا قالدۇرغاندى. ھېچكىم بىزگە ۋە كەلگۈسىمىزگە ئىشەنمىگەن ئىدى، ھېچكىممۇ بىزنىڭ ھازىر 300 مىليارد سودىنى تاماملىيالايدىغانلىقىمىزنى ئويلاپ خىيالىغىمۇ كەلتۈرمىگەن ئىدى.\" ئالى بابانىنىڭ ئورگىنى تەرىپىدىن ئىشلەنگەن ئىگىلىك تىكلەش ھۆججەتلىك فىلىمى \" بۇ چۈش ئەمەس\" نىڭ باش قىسمىدا مۇنداق بىر داڭلىق سۆز چىقىدۇ:\" مەن تاغدىن ئۆتكەن ۋاقتىمدا تاغ ماڭا گەپ قىلمىدى، مەن دېڭىزنى كېچىپ ئۆتكىنىمدە دېڭىز ماڭا گەپ قىلمىدى.\" بۇ فىلىمنىڭ قوشۇمچە ئىسمى بولسا \" مايۈن ۋە ئۇنىڭ مەڭگۈلۈك ' ياش ئالى'سى\" بولۇپ، ھەرخىل خەتەر ۋە قىيىنلىقنى بىرمۇ بىر يەڭگەن مايۈن ۋە ئۇنىڭ ئالى بابا قوشۇنىدىكى ھەمكارلاشقۇچىلىرىنىڭ قەيسەر كەچمىشى جانلىق بايان قىلىنغان."
|
||||||
|
# text = "يۆ جيەنتاۋ يېقىنقى بىر مەزگىلدە، خىزمەتداشلىرى بىلەن نۇرغۇن قىيىنچىلىققا ئۇچرىغان شوپۇرغا ياردەم قىلغانلىقىنى، شۇنداقلا يېمەكلىك ۋە سۇ يەتكۈزۈپ بەرگەنلىكىنى، ئەمما ئاياغ سوۋغا قىلىشى تۇنجى قېتىم ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى قېتىم شوپۇرغا ئاياغ سوۋغا قىلىشى ئىكەنلىكىنى ئېيتتى."
|
||||||
|
text = "يۆ جۇڭمىڭ مۇنداق دېدى: 2019- يىلى 12-ئايدا، مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 44- قېتىملىق كومىتېت باشلىقلىرى يىغىنى مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 2020-يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى پىرىنسىپ جەھەتتىن ماقۇللىدى، خىزمەت تەرتىپى بويىچە، 13-نۆۋەتلىك مەملىكەتلىك خەلق قۇرۇلتىيى 3-يىغىنىنىڭ روھى ۋە ۋەكىللەرنىڭ تەكلىپ-تەۋسىيەلىرىگە ئاساسەن پىلاننى تەڭشەش كېرەك. بۇ يىل 6-ئاينىڭ 1-كۈنى، 58-قېتىملىق كومىتېت باشلىقلىرى يىغىنى تەڭشەلگەندىن كېيىنكى يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى قاراپ چىقىپ ماقۇللىدى."
|
||||||
|
text = " ئاشقازان-ئۈچەينىڭ لۆمۈلدىشىنى ئىلگىرى سۈرىدىغان دورىنى تاماقتىن بۇرۇن ئىستېمال قىلىش كېرەك."
|
||||||
|
text = "مەن مەكتەپكە باردىم"
|
||||||
|
# text = "غەرىپئەللىرى"
|
||||||
|
# text = "ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى"
|
||||||
|
# text = " يېزىدىكى كەڭ، ئازادە ئۆي، باغلىرىنى تاشلاپ، خەقنىڭ ھويلىسىدا قورۇنۇپ-ئەيمىنىپ يەر دەسسەپ يۈردى. "
|
||||||
|
# text = "بۇ يىللىق ئەمگەكچىلەر بايرىمىدا، 1-مايدىن 5-مايغىچە جەمئىي بەش كۈن دەم ئېلىشقا قويۇپ بېرىلىدۇ. 9-ماي (شەنبە) نورمال خىزمەت قىلىنىدۇ. شۇنىڭ بىلەن بىر ۋاقىتتا، ئۈرۈمچى شەھىرى تىيانشان رايونى، سايباغ رايونى، داۋانچىڭ رايونى، ئۈرۈمچى ناھىيەسىدىكى ئوتتۇرا، باشلانغۇچ مەكتەپلەر 29-ئاپرېلدىن 30-ئاپرېلغىچە ئەتىيازلىق دەم ئېلىشقا قويۇپ بېرىدۇ؛ سانجى ئوبلاستىدىكى ئوتتۇرا، باشلانغۇچ مەكتەپلەرنىڭ ئەتىيازلىق دەم ئېلىشى ئۈچ كۈن بولۇپ، بۇنىڭ ئىچىدە ئىككى كۈن 29-، 30-ئاپرېلغا ئورۇنلاشتۇرۇلىدۇ، قالغان بىر كۈن 1-ئىيۇنغا ئورۇنلاشتۇرۇلىدۇ، قۇربان ھېيتلىق دەم ئېلىش بىلەن بىرلەشتۈرۈلۈپ، ئۇدا ئالتە كۈن دەم ئېلىنىدۇ."
|
||||||
|
# text = "ئەڭ يېڭى دەم ئېلىشقا قويۇپ بېرىش ئۇقتۇرۇشى! تەقەززالىق بىلەن كۈتكەن يەنە بىر دەم ئېلىش كېلەي دەپ قالدى! گوۋۇيۈەن بەنگۇڭتىڭىنىڭ «2025-يىللىق قىسمەن بايرام، دەم ئېلىش كۈنلىرىنى ئورۇنلاشتۇرۇش توغرىسىدىكى ئۇقتۇرۇشى»غا ئاساسەن، دۈەنۋۇ بايرىمىلىق دەم ئېلىشقا قويۇپ بېرىش ئورۇنلاشتۇرۇلۇشى تۆۋەندىكىچە: 5-ئاينىڭ 31-كۈنى (شەنبە)دىن 6-ئاينىڭ 2-كۈنىگىچە جەمئىي ئۈچ كۈن دەم ئېلىشقا قويۇپ بېرىلىدۇ! بۇ قېتىمقى دۈەنۋۇ بايرىمىلىق دەم ئېلىش ۋاقتى تەڭشەلمەيدۇ!"
|
||||||
|
# text = "مەن شۇ چاغدا چۈشۈمدىمۇ كۆرۈپ باقمىغان مۇنچىۋالا جىق قېرىنداشلىرىم، جىگەرلىرىمنىڭ بارلىقىدىن سۆيۈندۈم."
|
||||||
|
# text = "ئامېرىكا جاھان گىرلىكىگە قارشى تۇرۇپ، چاۋشەنگە ياردەم بىرىش ئۇرۇشىنىڭ ئۇلۇغ غەلبىسى، جۇڭگو خەلقى ئورۇندىن دەس تۇرغاندىنكىن دۇنيانىڭ شەرقىدە قەد كۆتۈرگەنلىكىنىڭ خىتاپنامىس"
|
||||||
|
|
||||||
|
|
||||||
|
# result = export_syllabize(text)
|
||||||
|
result = process_text(text)
|
||||||
|
print(f"Original: {text}")
|
||||||
|
print(f"Syllables: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
# tokenizer = ASRTokenizer(vocab_path=workspace_dir / 'config/asr_vocab.json')
|
||||||
|
|
||||||
|
# print(text)
|
||||||
|
# ids = tokenizer.encode(text=text)
|
||||||
|
# # print(ids)
|
||||||
|
# print('|'.join(tokenizer.decode([id]) for id in ids))
|
||||||
195
src/handle/text_handle.py
Normal file
195
src/handle/text_handle.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
|
||||||
|
from .text_normalizer import UYGHUR_LETTERS, clean_input_text
|
||||||
|
|
||||||
|
# def uighur_syllabize(text: str):
|
||||||
|
# # Uyghur Vowels
|
||||||
|
# vowels = "ئاەئەوۆئۇئۈئىئېاەوۆۇۈىې"
|
||||||
|
|
||||||
|
# def split_word(word):
|
||||||
|
# syllables = []
|
||||||
|
# current_syllable = ""
|
||||||
|
# i = 0
|
||||||
|
|
||||||
|
# # Helper to check if a character is a vowel
|
||||||
|
# is_vowel = lambda char: char in vowels
|
||||||
|
|
||||||
|
# while i < len(word):
|
||||||
|
# current_syllable += word[i]
|
||||||
|
|
||||||
|
# # If we find a vowel, look ahead to decide where to split
|
||||||
|
# if is_vowel(word[i]):
|
||||||
|
# # Check next characters
|
||||||
|
# remaining = word[i+1:]
|
||||||
|
|
||||||
|
# # Rule 1: V-V (e.g., 'ائ') -> Split after first vowel
|
||||||
|
# if len(remaining) >= 1 and is_vowel(remaining[0]):
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# current_syllable = ""
|
||||||
|
|
||||||
|
# # Rule 2: V-C-V (e.g., 'ba-ra') -> Split after first vowel
|
||||||
|
# elif len(remaining) >= 2 and not is_vowel(remaining[0]) and is_vowel(remaining[1]):
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# current_syllable = ""
|
||||||
|
|
||||||
|
# # Rule 3: V-C-C-V (e.g., 'mek-tep') -> Split after first consonant
|
||||||
|
# elif len(remaining) >= 3 and not is_vowel(remaining[0]) and not is_vowel(remaining[1]) and is_vowel(remaining[2]):
|
||||||
|
# current_syllable += remaining[0]
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# current_syllable = ""
|
||||||
|
# i += 1 # Skip the consonant we just added
|
||||||
|
|
||||||
|
# # Rule 4: End of word or C-C-C clusters (rare in native words)
|
||||||
|
# # We keep going until we find a clear boundary or end of word
|
||||||
|
# i += 1
|
||||||
|
|
||||||
|
# if current_syllable:
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# return syllables
|
||||||
|
|
||||||
|
# # Clean text and process word by word
|
||||||
|
# words = text.split()
|
||||||
|
# all_syllables = []
|
||||||
|
# for w in words:
|
||||||
|
# all_syllables.extend(split_word(w))
|
||||||
|
|
||||||
|
# #merge ئ to next syllable. single ئ no any semantic meaning and no it's dedicated sound.
|
||||||
|
# original_ayllables = [*all_syllables]
|
||||||
|
# all_syllables.clear()
|
||||||
|
# handled = True
|
||||||
|
# for s in original_ayllables:
|
||||||
|
# if s == "ئ":
|
||||||
|
# handled = False
|
||||||
|
# continue
|
||||||
|
# if not handled:
|
||||||
|
# s = "ئ" + s
|
||||||
|
# handled = True
|
||||||
|
# all_syllables.append(s)
|
||||||
|
|
||||||
|
# return all_syllables
|
||||||
|
|
||||||
|
# test_words = ['تەكلىپتەۋسىيەلىرىگە', 'تەكلىپ-تەۋسىيەلىرىگە']
|
||||||
|
# for tw in test_words:
|
||||||
|
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
|
||||||
|
# تەكلىپتەۋسىيەلىرىگە -> ['تەك', 'لىپ', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
# تەكلىپ-تەۋسىيەلىرىگە -> ['تەك', 'لىپ-تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
|
||||||
|
|
||||||
|
def uighur_syllabize(text: str):
|
||||||
|
# Uyghur Vowels (ASU)
|
||||||
|
vowels = "ئاەئەوۆئۇئۈئىئېاەوۆۇۈىې"
|
||||||
|
|
||||||
|
def is_vowel(char):
|
||||||
|
return char in vowels
|
||||||
|
|
||||||
|
def split_word(word):
|
||||||
|
syllables = []
|
||||||
|
i = 0
|
||||||
|
last_split = 0
|
||||||
|
|
||||||
|
while i < len(word):
|
||||||
|
# Look for vowel patterns to determine boundaries
|
||||||
|
if is_vowel(word[i]):
|
||||||
|
rem = word[i+1:]
|
||||||
|
|
||||||
|
# Rule: V-V -> Split after first vowel
|
||||||
|
if len(rem) >= 1 and is_vowel(rem[0]):
|
||||||
|
syllables.append(word[last_split:i+1])
|
||||||
|
last_split = i + 1
|
||||||
|
|
||||||
|
# Rule: V-C-V -> Split after vowel (e.g., ba-ra)
|
||||||
|
elif len(rem) >= 2 and not is_vowel(rem[0]) and is_vowel(rem[1]):
|
||||||
|
syllables.append(word[last_split:i+1])
|
||||||
|
last_split = i + 1
|
||||||
|
|
||||||
|
# Rule: V-C-C-V -> Split between consonants (e.g., mek-tep)
|
||||||
|
elif len(rem) >= 3 and not is_vowel(rem[0]) and not is_vowel(rem[1]) and is_vowel(rem[2]):
|
||||||
|
syllables.append(word[last_split:i+2])
|
||||||
|
last_split = i + 2
|
||||||
|
i += 1 # Skip first C
|
||||||
|
|
||||||
|
# Rule: V-C-C-C-V (VCCCV) -> Split after second consonant (e.g., gert-mek)
|
||||||
|
# This fixes "gertmek", "eytqan", "partlap", "dostlar"
|
||||||
|
elif len(rem) >= 4 and not is_vowel(rem[0]) and not is_vowel(rem[1]) and not is_vowel(rem[2]) and is_vowel(rem[3]):
|
||||||
|
syllables.append(word[last_split:i+3])
|
||||||
|
last_split = i + 3
|
||||||
|
i += 2 # Skip first two Cs
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Add the remaining part of the word
|
||||||
|
if last_split < len(word):
|
||||||
|
syllables.append(word[last_split:])
|
||||||
|
|
||||||
|
return syllables
|
||||||
|
|
||||||
|
# Process word by word
|
||||||
|
words = text.split()
|
||||||
|
final_output = []
|
||||||
|
|
||||||
|
for w in words:
|
||||||
|
raw_syllables = split_word(w)
|
||||||
|
|
||||||
|
# Handle Hemze (ئ) re-merging
|
||||||
|
processed = []
|
||||||
|
skip_next = False
|
||||||
|
for j in range(len(raw_syllables)):
|
||||||
|
if skip_next:
|
||||||
|
skip_next = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
s = raw_syllables[j]
|
||||||
|
# If the current syllable is JUST "ئ" or ends in "ئ", merge it with next
|
||||||
|
if s == "ئ" and j + 1 < len(raw_syllables):
|
||||||
|
processed.append("ئ" + raw_syllables[j+1])
|
||||||
|
skip_next = True
|
||||||
|
else:
|
||||||
|
processed.append(s)
|
||||||
|
final_output.extend(processed)
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
# Verification
|
||||||
|
# test_words = ["گەرتمەك", "ئېيتقان", "دوستلار", "مەكتەپكە", ' كومپېيۇتېر تورى زادى قانداق نەرسىدۇ؟ ']
|
||||||
|
# for tw in test_words:
|
||||||
|
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
|
||||||
|
|
||||||
|
# test_words = ['تەكلىپتەۋسىيەلىرىگە', 'تەكلىپ-تەۋسىيەلىرىگە']
|
||||||
|
# for tw in test_words:
|
||||||
|
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
|
||||||
|
|
||||||
|
# تەكلىپتەۋسىيەلىرىگە -> ['تەك', 'لىپ', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
# تەكلىپ-تەۋسىيەلىرىگە -> ['تەك', 'لىپ-', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
text = " ھەر قانداق ئىشتا كىشىلەر بىلەن كېڭىشىش كېرەك. لېكىن كۆڭۈل تارتقان ئىشنى قىلىۋېرىش كېرەك."
|
||||||
|
text = "\" 15 يىل بۇرۇن مەن بىلەن سەي چۇڭشىن (ئالى بابانىڭ قۇرغۇچىسىدىن بىرى) ئامېرىكىغا بېرىپ 30 نەچچە مەبلەغ سالغۇچى شىركەتلەر بىلەن كۆرۈشكىنىمىزدە ئۇلار بىزنى رەت قىلىپ ئىشىك سىرتىدا قالدۇرغاندى. ھېچكىم بىزگە ۋە كەلگۈسىمىزگە ئىشەنمىگەن ئىدى، ھېچكىممۇ بىزنىڭ ھازىر 300 مىليارد سودىنى تاماملىيالايدىغانلىقىمىزنى ئويلاپ خىيالىغىمۇ كەلتۈرمىگەن ئىدى.\" ئالى بابانىنىڭ ئورگىنى تەرىپىدىن ئىشلەنگەن ئىگىلىك تىكلەش ھۆججەتلىك فىلىمى \" بۇ چۈش ئەمەس\" نىڭ باش قىسمىدا مۇنداق بىر داڭلىق سۆز چىقىدۇ:\" مەن تاغدىن ئۆتكەن ۋاقتىمدا تاغ ماڭا گەپ قىلمىدى، مەن دېڭىزنى كېچىپ ئۆتكىنىمدە دېڭىز ماڭا گەپ قىلمىدى.\" بۇ فىلىمنىڭ قوشۇمچە ئىسمى بولسا \" مايۈن ۋە ئۇنىڭ مەڭگۈلۈك ' ياش ئالى'سى\" بولۇپ، ھەرخىل خەتەر ۋە قىيىنلىقنى بىرمۇ بىر يەڭگەن مايۈن ۋە ئۇنىڭ ئالى بابا قوشۇنىدىكى ھەمكارلاشقۇچىلىرىنىڭ قەيسەر كەچمىشى جانلىق بايان قىلىنغان."
|
||||||
|
text = "يۆ جيەنتاۋ يېقىنقى بىر مەزگىلدە، خىزمەتداشلىرى بىلەن نۇرغۇن قىيىنچىلىققا ئۇچرىغان شوپۇرغا ياردەم قىلغانلىقىنى، شۇنداقلا يېمەكلىك ۋە سۇ يەتكۈزۈپ بەرگەنلىكىنى، ئەمما ئاياغ سوۋغا قىلىشى تۇنجى قېتىم ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى قېتىم شوپۇرغا ئاياغ سوۋغا قىلىشى ئىكەنلىكىنى ئېيتتى."
|
||||||
|
text = "يۆ جۇڭمىڭ مۇنداق دېدى: 2019- يىلى 12-ئايدا، مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 44- قېتىملىق كومىتېت باشلىقلىرى يىغىنى مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 2020-يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى پىرىنسىپ جەھەتتىن ماقۇللىدى، خىزمەت تەرتىپى بويىچە، 13-نۆۋەتلىك مەملىكەتلىك خەلق قۇرۇلتىيى 3-يىغىنىنىڭ روھى ۋە ۋەكىللەرنىڭ تەكلىپ-تەۋسىيەلىرىگە ئاساسەن پىلاننى تەڭشەش كېرەك. بۇ يىل 6-ئاينىڭ 1-كۈنى، 58-قېتىملىق كومىتېت باشلىقلىرى يىغىنى تەڭشەلگەندىن كېيىنكى يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى قاراپ چىقىپ ماقۇللىدى."
|
||||||
|
text = "مەن مەكتەپكە باردىم"
|
||||||
|
|
||||||
|
# text = normalize_extended_uyghur_characters(text=text)
|
||||||
|
# print(clean_uyghur(text=text))
|
||||||
|
|
||||||
|
def process_text(text: str) -> list[str]:
|
||||||
|
text = clean_input_text(text=text)
|
||||||
|
result = []
|
||||||
|
current_word_chars = []
|
||||||
|
|
||||||
|
for char in text:
|
||||||
|
if char in UYGHUR_LETTERS:
|
||||||
|
current_word_chars.append(char)
|
||||||
|
else:
|
||||||
|
# 遇到非字母字符,先处理缓存的单词
|
||||||
|
if current_word_chars:
|
||||||
|
word = ''.join(current_word_chars)
|
||||||
|
result.extend(uighur_syllabize(word))
|
||||||
|
current_word_chars.clear()
|
||||||
|
# 非字母字符直接加入
|
||||||
|
result.append(char)
|
||||||
|
|
||||||
|
# 处理末尾的单词
|
||||||
|
if current_word_chars:
|
||||||
|
word = ''.join(current_word_chars)
|
||||||
|
result.extend(uighur_syllabize(word))
|
||||||
|
|
||||||
|
return result
|
||||||
126
src/handle/text_normalizer.py
Normal file
126
src/handle/text_normalizer.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
UYGHUR_LETTERS = {
|
||||||
|
'ا', 'ە', 'ب', 'پ', 'ت', 'ج', 'چ', 'خ', 'د', 'ر', 'ز', 'ژ', 'س', 'ش', 'غ', 'ف',
|
||||||
|
'ق', 'ك', 'گ', 'ڭ', 'ل', 'م', 'ن', 'ھ', 'و', 'ۇ', 'ۆ', 'ۈ', 'ۋ', 'ې', 'ى', 'ي',
|
||||||
|
"ئ"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
uyghur_symbols = {
|
||||||
|
"ھەرىپلەر": [
|
||||||
|
'ا', 'ە', 'ب', 'پ', 'ت', 'ج', 'چ', 'خ', 'د', 'ر', 'ز', 'ژ', 'س', 'ش', 'غ', 'ف',
|
||||||
|
'ق', 'ك', 'گ', 'ڭ', 'ل', 'م', 'ن', 'ھ', 'و', 'ۇ', 'ۆ', 'ۈ', 'ۋ', 'ې', 'ى', 'ي',
|
||||||
|
"ئ"
|
||||||
|
],
|
||||||
|
"تىنىش_بەلگىلىرى": [
|
||||||
|
"۔", "،", "؟", "!", "-", "«", "»", "؛", ":", "'", "\"", "]", "[", " ", "›", "‹"
|
||||||
|
],
|
||||||
|
"سانلار": [
|
||||||
|
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
english_symbols = {
|
||||||
|
"characters": [
|
||||||
|
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
|
||||||
|
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'
|
||||||
|
],
|
||||||
|
"symbols": [
|
||||||
|
".", "!", "?", "-", "<", ">", "=", "+", "*", "/", "|", "\\", "~", "_", "^", "@", "{", "}", "[", "]", "(", ")", "#", "%", "$", "€", "£", "¥", " "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Export all symbols:
|
||||||
|
symbols_list = uyghur_symbols["ھەرىپلەر"] + uyghur_symbols["تىنىش_بەلگىلىرى"] + uyghur_symbols["سانلار"] + english_symbols["characters"] + english_symbols["symbols"]
|
||||||
|
symbols = {i: i for i in symbols_list}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_extended_uyghur_characters(text: str) -> str:
|
||||||
|
characters = [
|
||||||
|
#["standard", "individual", "beginning", "middle", "end"]
|
||||||
|
["ئ","ﺋ","ﺋ","ﺌ","ﺌ"],
|
||||||
|
["ا","ﺍ","ﺍ","ﺎ","ﺎ"],
|
||||||
|
["ە","ﻩ","ﻩ","ﻪ","ﻪ"],
|
||||||
|
["ب","ﺏ","ﺑ","ﺒ","ﺐ"],
|
||||||
|
["پ","ﭖ","ﭘ","ﭙ","ﭗ"],
|
||||||
|
["ت","ﺕ","ﺗ","ﺘ","ﺖ"],
|
||||||
|
["ج","ﺝ","ﺟ","ﺠ","ﺞ"],
|
||||||
|
["چ","ﭺ","ﭼ","ﭽ","ﭻ"],
|
||||||
|
["خ","ﺥ","ﺧ","ﺨ","ﺦ"],
|
||||||
|
["د","ﺩ","ﺩ","ﺪ","ﺪ"],
|
||||||
|
["ر","ﺭ","ﺭ","ﺮ","ﺮ"],
|
||||||
|
["ز","ﺯ","ﺯ","ﺰ","ﺰ"],
|
||||||
|
["ژ","ﮊ","ﮊ","ﮋ","ﮋ"],
|
||||||
|
["س","ﺱ","ﺳ","ﺴ","ﺲ"],
|
||||||
|
["ش","ﺵ","ﺷ","ﺸ","ﺶ"],
|
||||||
|
["غ","ﻍ","ﻏ","ﻐ","ﻎ"],
|
||||||
|
["ف","ﻑ","ﻓ","ﻔ","ﻒ"],
|
||||||
|
["ق","ﻕ","ﻗ","ﻘ","ﻖ"],
|
||||||
|
["ك","ﻙ","ﻛ","ﻜ","ﻚ"],
|
||||||
|
["گ","ﮒ","ﮔ","ﮕ","ﮓ"],
|
||||||
|
["ڭ","ﯓ","ﯕ","ﯖ","ﯔ"],
|
||||||
|
["ل","ﻝ","ﻟ","ﻠ","ﻞ"],
|
||||||
|
["م","ﻡ","ﻣ","ﻤ","ﻢ"],
|
||||||
|
["ن","ﻥ","ﻧ","ﻨ","ﻦ"],
|
||||||
|
["ھ","ﮪ","ﮬ","ﮭ","ﮫ"],
|
||||||
|
["و","ﻭ","ﻭ","ﻮ","ﻮ"],
|
||||||
|
["ۇ","ﯗ","ﯗ","ﯘ","ﯘ"],
|
||||||
|
["ۆ","ﯙ","ﯙ","ﯚ","ﯚ"],
|
||||||
|
["ۈ","ﯛ","ﯛ","ﯜ","ﯜ"],
|
||||||
|
["ۋ","ﯞ","ﯞ","ﯟ","ﯟ"],
|
||||||
|
["ې","ﯤ","ﯦ","ﯧ","ﯥ"],
|
||||||
|
["ى","ﻯ","ﯨ","ﯩ","ﻰ"],
|
||||||
|
["ي","ﻱ","ﻳ","ﻴ","ﻲ"],
|
||||||
|
["ۅ","ﯠ","ﯠ","ﯡ","ﯡ"],
|
||||||
|
["ۉ","ﯢ","ﯢ","ﯣ","ﯣ"],
|
||||||
|
["ح","ﺡ","ﺣ","ﺤ","ﺢ"],
|
||||||
|
["ع","ﻉ","ﻋ","ﻌ","ﻊ"]
|
||||||
|
]
|
||||||
|
replacement_table: dict[str, str] = {}
|
||||||
|
#Create a replacement table.
|
||||||
|
for char_map in characters:
|
||||||
|
for char in char_map[1:]:
|
||||||
|
replacement_table[char] = char_map[0]
|
||||||
|
|
||||||
|
replacement_table['ﻼ'] = "لا" #add some additional exceptional symbols
|
||||||
|
|
||||||
|
text = text.replace('ئ', 'ئ')
|
||||||
|
clean = ""
|
||||||
|
#Replace the extended characters with their standard unicode ones.
|
||||||
|
for char in text:
|
||||||
|
if char in replacement_table:
|
||||||
|
char = replacement_table[char]
|
||||||
|
clean += char
|
||||||
|
|
||||||
|
return clean
|
||||||
|
|
||||||
|
|
||||||
|
def is_uyghur_char(char: str) -> bool:
|
||||||
|
"""判断是否是维吾尔语字母"""
|
||||||
|
return all(c in UYGHUR_LETTERS for c in char)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uyghur_text(text: str) -> bool:
|
||||||
|
"""清洗后检查是否全为维吾尔语"""
|
||||||
|
return all(c in symbols for c in text)
|
||||||
|
|
||||||
|
def clean_unknown_symbols(text: str) -> str:
|
||||||
|
return ''.join(c for c in text if c in symbols)
|
||||||
|
|
||||||
|
import re
|
||||||
|
def collapse_spaces(text: str) -> str:
|
||||||
|
return re.sub(r'\s+', ' ', text)
|
||||||
|
|
||||||
|
def clean_english_text(text: str) -> str:
|
||||||
|
return re.sub(r'[a-zA-Z]', ' ', text)
|
||||||
|
|
||||||
|
def clean_chinese_text(text: str) -> str:
|
||||||
|
return re.sub(r'[\u4e00-\u9fff]', ' ', text)
|
||||||
|
|
||||||
|
def clean_http_links(text: str) -> str:
|
||||||
|
return re.sub(r'https?://[^\s]+', ' ', text)
|
||||||
|
|
||||||
|
def clean_input_text(text: str) -> str:
|
||||||
|
text = collapse_spaces(text)
|
||||||
|
text = clean_http_links(text)
|
||||||
|
text = normalize_extended_uyghur_characters(text)
|
||||||
|
text = clean_unknown_symbols(text)
|
||||||
|
return text
|
||||||
580
src/inference.ipynb
Normal file
580
src/inference.ipynb
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f72992c5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"import torchaudio\n",
|
||||||
|
"import librosa\n",
|
||||||
|
"import pyrubberband as pyrb\n",
|
||||||
|
"from torch import Tensor, no_grad, device\n",
|
||||||
|
"from torchaudio.transforms import FrequencyMasking, MelSpectrogram, AmplitudeToDB, Resample, TimeMasking, TimeStretch\n",
|
||||||
|
"from torch_audiomentations import Gain, PitchShift, LowPassFilter, HighPassFilter\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from librosa import effects\n",
|
||||||
|
"import torchaudio.functional as F\n",
|
||||||
|
"from IPython.display import Audio\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import random\n",
|
||||||
|
"from typing import TypedDict\n",
|
||||||
|
"\n",
|
||||||
|
"from tokenizer import ASRTokenizer\n",
|
||||||
|
"from model import ASRModel\n",
|
||||||
|
"\n",
|
||||||
|
"CONFIG = {\n",
|
||||||
|
" # 模型配置\n",
|
||||||
|
" 'input_dim': 256,\n",
|
||||||
|
" 'num_heads': 8,\n",
|
||||||
|
" 'ffn_dim': 2048,\n",
|
||||||
|
" 'num_layers': 8,\n",
|
||||||
|
" 'dropout': 0.1,\n",
|
||||||
|
"}\n",
|
||||||
|
"workspace_dir = Path.cwd().parent"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "15a9a926",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class AdvancedAugment:\n",
|
||||||
|
" def __init__(self, noise_folder=None):\n",
|
||||||
|
" # 如果你有噪音文件夹,可以预加载噪音路径列表\n",
|
||||||
|
" # 场景建议:1.汽车 2.餐厅 3.街道 4.办公室 5.下雨 6.风声 7.键盘声 8.婴儿哭声 9.音乐 10.走廊回声 11.白噪音 12.粉红噪音\n",
|
||||||
|
" self.noise_files = [] # 这里存放你的 .wav 噪音文件路径\n",
|
||||||
|
"\n",
|
||||||
|
" def _add_noise(self, waveform: Tensor) -> Tensor:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" 更丰富的加噪:随机选择 白噪音、粉红噪音 或 真实环境音\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # 1. 随机决定噪音类型\n",
|
||||||
|
" noise_type = random.choice(['white', 'pink', 'environmental'])\n",
|
||||||
|
" \n",
|
||||||
|
" # 2. 随机设定信噪比 (SNR) - 模拟各种清晰度\n",
|
||||||
|
" snr_db = random.uniform(5, 25) \n",
|
||||||
|
" \n",
|
||||||
|
" if noise_type == 'white':\n",
|
||||||
|
" noise = torch.randn_like(waveform)\n",
|
||||||
|
" elif noise_type == 'pink':\n",
|
||||||
|
" \n",
|
||||||
|
" noise = self._generate_pink_noise(waveform)\n",
|
||||||
|
" else:\n",
|
||||||
|
" # 模拟环境音 (如果你有噪音库)\n",
|
||||||
|
" # 这里演示如果没有噪音库,就用多种不同频率的随机噪音组合\n",
|
||||||
|
" noise = self._generate_simulated_env_noise(waveform)\n",
|
||||||
|
"\n",
|
||||||
|
" # 3. 计算能量并叠加\n",
|
||||||
|
" return self._mix_signal_noise(waveform, noise, snr_db)\n",
|
||||||
|
"\n",
|
||||||
|
" def _generate_pink_noise(self, waveform):\n",
|
||||||
|
" \"\"\"粉红噪音(比白噪音更像自然界的声音,低频能量更高)\"\"\"\n",
|
||||||
|
" # 简易实现:对白噪音做低通滤波\n",
|
||||||
|
" white = torch.randn_like(waveform)\n",
|
||||||
|
" return torch.cumsum(white, dim=-1) / 10.0 # 简单的积分近似\n",
|
||||||
|
"\n",
|
||||||
|
" def _generate_simulated_env_noise(self, waveform):\n",
|
||||||
|
" \"\"\"模拟环境噪音(通过叠加不同频率的波形)\"\"\"\n",
|
||||||
|
" # 模拟 12 种以上变化:随机叠加正弦波或窄带噪音\n",
|
||||||
|
" noise = torch.zeros_like(waveform)\n",
|
||||||
|
" for _ in range(random.randint(3, 8)):\n",
|
||||||
|
" freq = random.uniform(50, 4000)\n",
|
||||||
|
" t = torch.arange(waveform.shape[-1]).to(waveform.device)\n",
|
||||||
|
" noise += torch.sin(2 * 3.14159 * freq * t / 16000)\n",
|
||||||
|
" return noise + torch.randn_like(waveform) * 0.5\n",
|
||||||
|
"\n",
|
||||||
|
" def _mix_signal_noise(self, signal, noise, snr_db):\n",
|
||||||
|
" \"\"\"核心:根据 SNR 混合信号和噪音\"\"\"\n",
|
||||||
|
" s_p = signal.pow(2).mean()\n",
|
||||||
|
" n_p = noise.pow(2).mean()\n",
|
||||||
|
" \n",
|
||||||
|
" # 防止除以 0\n",
|
||||||
|
" if n_p == 0: return signal\n",
|
||||||
|
" \n",
|
||||||
|
" snr_linear = 10**(snr_db/10)\n",
|
||||||
|
" scale = torch.sqrt(s_p / (n_p * snr_linear))\n",
|
||||||
|
" \n",
|
||||||
|
" noisy = signal + scale * noise\n",
|
||||||
|
" return noisy / (noisy.abs().max() + 1e-7) # 归一化防止爆音\n",
|
||||||
|
"\n",
|
||||||
|
" def _simulate_reverb(self, waveform):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" 模拟走廊/房间回声 (Simple Reverb)\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # if random.random() < 0.3: # 30% 概率添加回声\n",
|
||||||
|
" # 模拟简单的延迟反馈(模拟大走廊)\n",
|
||||||
|
" delay_samples = random.randint(500, 2000) # 30ms - 125ms 延迟\n",
|
||||||
|
" decay = random.uniform(0.3, 0.6)\n",
|
||||||
|
" \n",
|
||||||
|
" reverb_signal = torch.zeros_like(waveform)\n",
|
||||||
|
" reverb_signal[:, delay_samples:] = waveform[:, :-delay_samples] * decay\n",
|
||||||
|
" return (waveform + reverb_signal) / (1 + decay)\n",
|
||||||
|
" return waveform\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c0635d68",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class NoiseAugmentor:\n",
|
||||||
|
" def __init__(self, noise_root: Path, sample_rate=16000):\n",
|
||||||
|
" self.sample_rate = sample_rate\n",
|
||||||
|
" # 递归找到目录下所有的 wav 文件\n",
|
||||||
|
" self.noise_files = list(Path(noise_root).rglob(\"*.wav\"))\n",
|
||||||
|
" print(f\"成功加载 {len(self.noise_files)} 个噪音文件\")\n",
|
||||||
|
"\n",
|
||||||
|
" def apply_real_noise(self, waveform: Tensor):\n",
|
||||||
|
" # 1. 随机选一个噪音文件\n",
|
||||||
|
" noise_path = random.choice(self.noise_files)\n",
|
||||||
|
" noise_waveform, sr = torchaudio.load_with_torchcodec(noise_path)\n",
|
||||||
|
" \n",
|
||||||
|
" # 2. 统一采样率\n",
|
||||||
|
" if sr != self.sample_rate:\n",
|
||||||
|
" resampler = Resample(sr, self.sample_rate)\n",
|
||||||
|
" noise_waveform = resampler(noise_waveform)\n",
|
||||||
|
" \n",
|
||||||
|
" # 3. 截取或填充,使其长度与语音一致\n",
|
||||||
|
" sig_len = waveform.shape[1]\n",
|
||||||
|
" noise_len = noise_waveform.shape[1]\n",
|
||||||
|
" \n",
|
||||||
|
" if noise_len >= sig_len:\n",
|
||||||
|
" # 随机截取一段\n",
|
||||||
|
" start = random.randint(0, noise_len - sig_len)\n",
|
||||||
|
" noise_waveform = noise_waveform[:, start:start + sig_len]\n",
|
||||||
|
" else:\n",
|
||||||
|
" full_noise = torch.zeros_like(waveform)\n",
|
||||||
|
" start = random.randint(0, sig_len - noise_len)\n",
|
||||||
|
" full_noise[:, start : start + noise_len] = noise_waveform\n",
|
||||||
|
" noise_waveform = full_noise\n",
|
||||||
|
" # 如果噪音太短,循环填充\n",
|
||||||
|
" # repeats = (sig_len // noise_len) + 1\n",
|
||||||
|
" # noise_waveform = noise_waveform.repeat(1, repeats)[:, :sig_len]\n",
|
||||||
|
" \n",
|
||||||
|
" # 4. 设定随机信噪比 SNR (5dB 到 20dB)\n",
|
||||||
|
" snr_db = random.uniform(5, 20)\n",
|
||||||
|
" \n",
|
||||||
|
" # 5. 混合\n",
|
||||||
|
" return self._mix_at_snr(waveform, noise_waveform, snr_db)\n",
|
||||||
|
"\n",
|
||||||
|
" def _mix_at_snr(self, signal: Tensor, noise: Tensor, snr_db: float):\n",
|
||||||
|
" s_p = signal.pow(2).mean()\n",
|
||||||
|
" n_p = noise.pow(2).mean()\n",
|
||||||
|
" snr_linear = 10**(snr_db / 10)\n",
|
||||||
|
" scale = torch.sqrt(s_p / (n_p * snr_linear + 1e-8))\n",
|
||||||
|
" \n",
|
||||||
|
" noisy = signal + scale * noise\n",
|
||||||
|
" # 归一化,防止溢出\n",
|
||||||
|
" return noisy / (noisy.abs().max() + 1e-8)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e77899f1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class ASRInference:\n",
|
||||||
|
" def __init__(self, model_path: Path, vocab_path: Path, noise_dir: Path, device: device, augment: bool = True, augment_prob: float = 0.5) -> None:\n",
|
||||||
|
" self.device = device\n",
|
||||||
|
" self.augment: bool = augment\n",
|
||||||
|
" self.augment_prob: float = augment_prob\n",
|
||||||
|
" self.noise_augmentor = NoiseAugmentor(noise_root=noise_dir)\n",
|
||||||
|
" self.tokenizer = ASRTokenizer(vocab_path=vocab_path)\n",
|
||||||
|
" self.model = ASRModel(vocab_size=self.tokenizer.vocab_size(), **CONFIG).to(device=device)\n",
|
||||||
|
" self.model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])\n",
|
||||||
|
" self.model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
" print(f\"params params: {self.model.get_num_params():,}\",)\n",
|
||||||
|
"\n",
|
||||||
|
" self.sample_rate = 16000\n",
|
||||||
|
"\n",
|
||||||
|
" # self.gain_up = Gain(min_gain_in_db=5, max_gain_in_db=10, p=1.0, output_type='tensor')\n",
|
||||||
|
" # self.gain_down = Gain(min_gain_in_db=-20, max_gain_in_db=-10, p=1.0, output_type='tensor')\n",
|
||||||
|
" # self.pitch_up = PitchShift(min_transpose_semitones=3, max_transpose_semitones=5, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
|
||||||
|
" # self.pitch_down = PitchShift(min_transpose_semitones=-5, max_transpose_semitones=-3, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
|
||||||
|
" # self.lowpass = LowPassFilter(min_cutoff_freq=400, max_cutoff_freq=2400, p=1.0, output_type='tensor')\n",
|
||||||
|
" # self.highpass = HighPassFilter(min_cutoff_freq=1400, max_cutoff_freq=3400, p=1.0, output_type='tensor')\n",
|
||||||
|
"\n",
|
||||||
|
" self.gain_up = Gain(min_gain_in_db=4, max_gain_in_db=8, p=1.0, output_type='tensor')\n",
|
||||||
|
" self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')\n",
|
||||||
|
" self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
|
||||||
|
" self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
|
||||||
|
" self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')\n",
|
||||||
|
" self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')\n",
|
||||||
|
" \n",
|
||||||
|
" def _load_audio(self, audio_path: Path) -> Tensor:\n",
|
||||||
|
" waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)\n",
|
||||||
|
"\n",
|
||||||
|
" if sample_rate != self.sample_rate:\n",
|
||||||
|
" waveform = Resample(sample_rate, self.sample_rate)(waveform)\n",
|
||||||
|
"\n",
|
||||||
|
" if waveform.shape[0] > 1:\n",
|
||||||
|
" waveform = waveform.mean(dim=0, keepdim=True)\n",
|
||||||
|
"\n",
|
||||||
|
" waveform = waveform / (waveform.abs().max() + 1e-8)\n",
|
||||||
|
" return waveform\n",
|
||||||
|
" \n",
|
||||||
|
" def augment_waveform(self, waveform: Tensor) -> Tensor:\n",
|
||||||
|
" # if not self.augment or random.random() > self.augment_prob:\n",
|
||||||
|
" # return waveform\n",
|
||||||
|
" \n",
|
||||||
|
" # 1. voice Stretch/Compress \n",
|
||||||
|
" if random.random() < 0.5:\n",
|
||||||
|
" waveform = self._voice_stretch_or_compress(waveform=waveform)\n",
|
||||||
|
"\n",
|
||||||
|
" if random.random() < 0.4:\n",
|
||||||
|
" waveform = self.noise_augmentor.apply_real_noise(waveform)\n",
|
||||||
|
"\n",
|
||||||
|
" if random.random() < 0.3:\n",
|
||||||
|
" waveform = self._time_mask_waveform(waveform=waveform)\n",
|
||||||
|
" \n",
|
||||||
|
" # torch_audiomentations: [1, time] -> [1, 1, time]\n",
|
||||||
|
" if waveform.dim() == 2:\n",
|
||||||
|
" waveform_3d = waveform.unsqueeze(0)\n",
|
||||||
|
" # 随机选择一种频谱增强\n",
|
||||||
|
" choice = random.random()\n",
|
||||||
|
" if choice < 0.15:\n",
|
||||||
|
" # 增益变化(上或下)\n",
|
||||||
|
" if random.random() < 0.5:\n",
|
||||||
|
" waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)\n",
|
||||||
|
" else:\n",
|
||||||
|
" waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)\n",
|
||||||
|
" elif choice < 0.25:\n",
|
||||||
|
" # 音高变化(上或下)\n",
|
||||||
|
" if random.random() < 0.5:\n",
|
||||||
|
" waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)\n",
|
||||||
|
" else:\n",
|
||||||
|
" waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)\n",
|
||||||
|
" elif choice < 0.30:\n",
|
||||||
|
" # 低通滤波(声音发闷)\n",
|
||||||
|
" waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)\n",
|
||||||
|
" elif choice < 0.35:\n",
|
||||||
|
" # 高通滤波(电话效果)\n",
|
||||||
|
" waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)\n",
|
||||||
|
" \n",
|
||||||
|
" # [1, 1, time] -> [1, time]\n",
|
||||||
|
" waveform = waveform_3d.squeeze(0)\n",
|
||||||
|
" \n",
|
||||||
|
" # 防止多次 augment 后振幅溢出,最后归一化\n",
|
||||||
|
" max_amp = waveform.abs().max()\n",
|
||||||
|
" if max_amp > 1.0:\n",
|
||||||
|
" print(max_amp)\n",
|
||||||
|
" waveform = waveform / max_amp\n",
|
||||||
|
" \n",
|
||||||
|
" return waveform\n",
|
||||||
|
" \n",
|
||||||
|
" def _voice_stretch_or_compress1(self, waveform: Tensor) -> Tensor:\n",
|
||||||
|
" speed_factor = random.uniform(0.8, 1.4) # (Speed Change: 0.6x - 1.4x)\n",
|
||||||
|
" waveform_np = waveform.squeeze(0).cpu().numpy()\n",
|
||||||
|
" y_stretch = librosa.effects.time_stretch(waveform_np, rate=speed_factor)\n",
|
||||||
|
" waveform = torch.from_numpy(y_stretch).float().to(waveform.device).unsqueeze(0)\n",
|
||||||
|
" return waveform\n",
|
||||||
|
" \n",
|
||||||
|
" def _voice_stretch_or_compress(self, waveform: Tensor) -> Tensor:\n",
|
||||||
|
" speed_factor = random.uniform(0.8, 1.25) # (Speed Change: 0.6x - 1.2x)\n",
|
||||||
|
" spec = torch.stft(\n",
|
||||||
|
" waveform.squeeze(0),\n",
|
||||||
|
" n_fft=400,\n",
|
||||||
|
" hop_length=160,\n",
|
||||||
|
" window=torch.hann_window(400).to(waveform.device),\n",
|
||||||
|
" return_complex=True\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # 时间拉伸(不改变音高)\n",
|
||||||
|
" stretch = TimeStretch(hop_length=160, n_freq=spec.shape[-2], fixed_rate=1.4)\n",
|
||||||
|
" stretched_spec = stretch(spec)\n",
|
||||||
|
" \n",
|
||||||
|
" # 转回波形\n",
|
||||||
|
" waveform_stretched = torch.istft(\n",
|
||||||
|
" stretched_spec,\n",
|
||||||
|
" n_fft=400,\n",
|
||||||
|
" hop_length=160,\n",
|
||||||
|
" window=torch.hann_window(400).to(waveform.device)\n",
|
||||||
|
" ).unsqueeze(0)\n",
|
||||||
|
" \n",
|
||||||
|
" return waveform_stretched\n",
|
||||||
|
"\n",
|
||||||
|
" def _voice_stretch_or_compress2(self, waveform: Tensor) -> Tensor:\n",
|
||||||
|
" speed_factor = random.uniform(0.6, 1.4) # (Speed Change: 0.6x - 1.4x)\n",
|
||||||
|
" waveform_np = waveform.squeeze(0).cpu().numpy()\n",
|
||||||
|
" y_stretch = pyrb.time_stretch(waveform_np, self.sample_rate, speed_factor)\n",
|
||||||
|
" return torch.from_numpy(y_stretch).float().to(waveform.device).unsqueeze(0)\n",
|
||||||
|
" \n",
|
||||||
|
" def _time_mask_waveform(self, waveform: Tensor) -> Tensor:\n",
|
||||||
|
" audio_len = waveform.shape[1]\n",
|
||||||
|
" sr = self.sample_rate # 16000\n",
|
||||||
|
" \n",
|
||||||
|
" # 设置参数:单次遮盖最长 0.4 秒 (6400个点)\n",
|
||||||
|
" max_mask_time = 0.4\n",
|
||||||
|
" max_mask_samples = int(sr * max_mask_time)\n",
|
||||||
|
" \n",
|
||||||
|
" # 根据音频长度决定遮盖次数:\n",
|
||||||
|
" # 比如每 3 秒钟允许遮盖 1 次\n",
|
||||||
|
" num_masks = max(1, audio_len // (sr * 3)) \n",
|
||||||
|
" \n",
|
||||||
|
" for _ in range(num_masks):\n",
|
||||||
|
" # 每次随机遮盖 0.1s 到 0.4s\n",
|
||||||
|
" current_mask_len = random.randint(int(sr * 0.1), max_mask_samples)\n",
|
||||||
|
" \n",
|
||||||
|
" if audio_len > current_mask_len:\n",
|
||||||
|
" start_pos = random.randint(0, audio_len - current_mask_len)\n",
|
||||||
|
" \n",
|
||||||
|
" # 填充微小噪音(模拟环境底噪)\n",
|
||||||
|
" noise = torch.randn(1, current_mask_len).to(waveform.device) * 0.002\n",
|
||||||
|
" waveform[:, start_pos : start_pos + current_mask_len] = noise\n",
|
||||||
|
" \n",
|
||||||
|
" return waveform\n",
|
||||||
|
"\n",
|
||||||
|
" def transcribe(self, waveform: Tensor) -> str:\n",
|
||||||
|
" waveform = waveform.to(device=self.device)\n",
|
||||||
|
" waveform_length = torch.tensor([waveform.shape[1]], dtype=torch.long, device=self.device)\n",
|
||||||
|
"\n",
|
||||||
|
" with no_grad():\n",
|
||||||
|
" log_probs, _ = self.model(waveforms=waveform, waveform_lengths=waveform_length)\n",
|
||||||
|
" \n",
|
||||||
|
" text = self.tokenizer.ctc_greedy_decode(log_probs=log_probs[0])\n",
|
||||||
|
" return text"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5af7c836",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device('cuda:1')\n",
|
||||||
|
"\n",
|
||||||
|
"# checkpoint = sorted(workspace_dir.glob('.checkpoints/checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))[-1]\n",
|
||||||
|
"checkpoint = workspace_dir / \".checkpoints/best_wer_model.pt\"\n",
|
||||||
|
"# checkpoint = workspace_dir / \".checkpoints/best_cer_model.pt\"\n",
|
||||||
|
"checkpoint = workspace_dir / \".checkpoints/prodect/checkpoint_epoch_24 copy.pt\"\n",
|
||||||
|
"print(f\"Load Checkpoint: {checkpoint}\")\n",
|
||||||
|
"inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/asr_vocab.json', noise_dir='/mnt/dataset/dataset/audio/noise', device=device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "967e23cb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_42681424.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_26188445.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_43349581.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_42640878.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_26245615.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_40794549.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/ug/clips/common_voice_ug_26245614.mp3'\n",
|
||||||
|
"audio_path = '/mnt/train/audio/Dilnaz/only_vocals/3272800876_100006098_(vocals)_melband_roformer_big_beta5e.m4a'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/how_are_you.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/split_1.m4a'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/let_me_see_the_computer.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/introduce_myself.mp3'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/F001_001.wav'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/download.wav'\n",
|
||||||
|
"# audio_path = workspace_dir / 'data/test/test_voise_1.m4a'\n",
|
||||||
|
"orginal_waveform = inference._load_audio(audio_path=audio_path)\n",
|
||||||
|
"print('load audio:', orginal_waveform.shape)\n",
|
||||||
|
"# orginal_waveform = orginal_waveform[:, :]\n",
|
||||||
|
"\n",
|
||||||
|
"def simple_corridor_echo(waveform, sample_rate=16000, delay_ms=80, decay=0.3):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" 模拟一次明显的走廊回声:原始声音 + 一个衰减延迟的拷贝\n",
|
||||||
|
" delay_ms: 延迟毫秒数,80ms 模拟中等走廊\n",
|
||||||
|
" decay: 回声的衰减系数,0.2~0.4 比较自然\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" delay_samples = int(sample_rate * delay_ms / 1000)\n",
|
||||||
|
" echo = torch.zeros_like(waveform)\n",
|
||||||
|
" echo[..., delay_samples:] = waveform[..., :-delay_samples] * decay\n",
|
||||||
|
" return waveform + echo\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Orginal audio\")\n",
|
||||||
|
"display(Audio(orginal_waveform, rate=inference.sample_rate))\n",
|
||||||
|
"\n",
|
||||||
|
"# augment_waveform = inference.augment_waveform(waveform=orginal_waveform)\n",
|
||||||
|
"# augment_waveform = simple_corridor_echo(waveform=orginal_waveform)\n",
|
||||||
|
"# augment_waveform = inference._voice_stretch_or_compress(waveform=orginal_waveform)\n",
|
||||||
|
"# noise_augmentor_waveform = inference.noise_augmentor.apply_real_noise(waveform=orginal_waveform)\n",
|
||||||
|
"# display(Audio(augment_waveform, rate=inference.sample_rate))\n",
|
||||||
|
"# text = inference.transcribe(waveform=orginal_waveform)\n",
|
||||||
|
"# print(f\"\\n{text}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "16195280",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from silero_vad import load_silero_vad, read_audio, get_speech_timestamps\n",
|
||||||
|
"class TimestampsType(TypedDict):\n",
|
||||||
|
" start: int\n",
|
||||||
|
" end: int\n",
|
||||||
|
"\n",
|
||||||
|
"THRESHOLD_MS = 2000\n",
|
||||||
|
"SAMPLE_RATE = inference.sample_rate\n",
|
||||||
|
"\n",
|
||||||
|
"def merge_timestamps(timestamps: int, threshold_ms: int, sample_rate: int):\n",
|
||||||
|
" if not timestamps:\n",
|
||||||
|
" return []\n",
|
||||||
|
" threshold_samples = (threshold_ms / 1000) * sample_rate\n",
|
||||||
|
" merged = []\n",
|
||||||
|
" curr_start = timestamps[0]['start']\n",
|
||||||
|
" curr_end = timestamps[0]['end']\n",
|
||||||
|
" \n",
|
||||||
|
" for i in range(1, len(timestamps)):\n",
|
||||||
|
" # 如果当前积攒的长度不到 2 秒,就一直合并到当前的 end 上\n",
|
||||||
|
" if (curr_end - curr_start) < threshold_samples:\n",
|
||||||
|
" curr_end = timestamps[i]['end']\n",
|
||||||
|
" else:\n",
|
||||||
|
" merged.append({'start': curr_start, 'end': curr_end})\n",
|
||||||
|
" curr_start = timestamps[i]['start']\n",
|
||||||
|
" curr_end = timestamps[i]['end']\n",
|
||||||
|
" merged.append({'start': curr_start, 'end': curr_end})\n",
|
||||||
|
" return merged\n",
|
||||||
|
"\n",
|
||||||
|
"vad_model = load_silero_vad(onnx=False)\n",
|
||||||
|
"waveform_vad = inference._load_audio(audio_path=audio_path)\n",
|
||||||
|
"speech_timestamps: list[TimestampsType] = get_speech_timestamps(\n",
|
||||||
|
" waveform_vad, \n",
|
||||||
|
" vad_model, \n",
|
||||||
|
" sampling_rate=inference.sample_rate,\n",
|
||||||
|
" threshold=0.4, # 可以根据需要调整阈值\n",
|
||||||
|
" min_speech_duration_ms=100, # 最小语音持续时间,防止短噪音被误判\n",
|
||||||
|
" min_silence_duration_ms=200, # 最小静音间隔,用于分割语音块\n",
|
||||||
|
" speech_pad_ms=100, # 在语音块前后添加的填充时间\n",
|
||||||
|
")\n",
|
||||||
|
"print('speech_timestamps:', speech_timestamps)\n",
|
||||||
|
"print('speech_timestamps:', len(speech_timestamps))\n",
|
||||||
|
"final_timestamps = merge_timestamps(speech_timestamps, THRESHOLD_MS, SAMPLE_RATE)\n",
|
||||||
|
"print(f'VAD原始片段: {len(speech_timestamps)} -> 合并后片段: {len(final_timestamps)}')\n",
|
||||||
|
"for ts in final_timestamps:\n",
|
||||||
|
" ts: TimestampsType\n",
|
||||||
|
" start_sample = int(ts['start'])\n",
|
||||||
|
" end_sample = int(ts['end'])\n",
|
||||||
|
"\n",
|
||||||
|
" segment_waveform = waveform_vad[:, start_sample:end_sample]\n",
|
||||||
|
" display(Audio(segment_waveform, rate=inference.sample_rate))\n",
|
||||||
|
"\n",
|
||||||
|
" segment_text = inference.transcribe(waveform=segment_waveform)\n",
|
||||||
|
" print(segment_text, end=\" \", flush=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3671d334",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# mel_spec = inference._extract_features(waveform=orginal_waveform)\n",
|
||||||
|
"# ################################\n",
|
||||||
|
"# from PIL import Image, ImageFilter\n",
|
||||||
|
"# import numpy as np\n",
|
||||||
|
"# import math\n",
|
||||||
|
"\n",
|
||||||
|
"# img = (mel_spec.numpy())\n",
|
||||||
|
"# print(img.max(), img.min())\n",
|
||||||
|
"# img = img + np.abs(img.min())\n",
|
||||||
|
"# img = img / img.max()\n",
|
||||||
|
"# print(img.max(), img.min())\n",
|
||||||
|
"# img = Image.fromarray((img * 255).astype(np.uint8))\n",
|
||||||
|
"# shapened = img.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))\n",
|
||||||
|
"# img.save(\"original_hd.jpg\")\n",
|
||||||
|
"# # img.save(\"sharpen_hd.jpg\")\n",
|
||||||
|
"\n",
|
||||||
|
"# fig, axes = plt.subplots(2, 1, figsize=(30, 12))\n",
|
||||||
|
"\n",
|
||||||
|
"# axes[0].imshow(\n",
|
||||||
|
"# img,\n",
|
||||||
|
"# cmap=\"gray\",\n",
|
||||||
|
"# origin=\"lower\",\n",
|
||||||
|
"# vmin=0,\n",
|
||||||
|
"# vmax=255,\n",
|
||||||
|
"# aspect=\"auto\"\n",
|
||||||
|
"# )\n",
|
||||||
|
"# axes[0].set_title(\"Original\")\n",
|
||||||
|
"# axes[1].imshow(\n",
|
||||||
|
"# shapened,\n",
|
||||||
|
"# cmap=\"gray\",\n",
|
||||||
|
"# origin=\"lower\",\n",
|
||||||
|
"# vmin=0,\n",
|
||||||
|
"# vmax=255,\n",
|
||||||
|
"# aspect=\"auto\"\n",
|
||||||
|
"# )\n",
|
||||||
|
"# axes[1].set_title(\"Sharpened\")\n",
|
||||||
|
"# plt.tight_layout()\n",
|
||||||
|
"# plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"# ################################\n",
|
||||||
|
"# mel_spec = inference._augment_spec(mel_spec=mel_spec)\n",
|
||||||
|
"# print(mel_spec.shape)\n",
|
||||||
|
"# mel_spec = mel_spec.unsqueeze(0)\n",
|
||||||
|
"# print(mel_spec.shape)\n",
|
||||||
|
"# # 创建可视化\n",
|
||||||
|
"# fig, axes = plt.subplots(2, 1, figsize=(18, 12))\n",
|
||||||
|
"\n",
|
||||||
|
"# # 原始波形\n",
|
||||||
|
"# axes[0].plot(orginal_waveform[0].numpy())\n",
|
||||||
|
"# axes[0].set_title(\"Original waveform\")\n",
|
||||||
|
"# axes[0].set_xlabel(\"sample\")\n",
|
||||||
|
"# axes[0].set_ylabel(\"Amplitude\")\n",
|
||||||
|
"# axes[0].grid(True)\n",
|
||||||
|
"\n",
|
||||||
|
"# # Mel 频谱图\n",
|
||||||
|
"# im = axes[1].imshow(\n",
|
||||||
|
"# mel_spec[0].numpy(),\n",
|
||||||
|
"# cmap='gray',\n",
|
||||||
|
"# aspect='auto',\n",
|
||||||
|
"# origin='lower',\n",
|
||||||
|
"# # vmin=-150,\n",
|
||||||
|
"# # vmax=1,\n",
|
||||||
|
"# extent=[0, orginal_waveform.shape[-1] / inference.sample_rate, 0, 80]\n",
|
||||||
|
"# )\n",
|
||||||
|
"# axes[1].set_title(\"Log Mel Spectrogram\")\n",
|
||||||
|
"# axes[1].set_xlabel(\"Time (second)\")\n",
|
||||||
|
"# axes[1].set_ylabel(\"Mels\")\n",
|
||||||
|
"# # plt.colorbar(im, ax=axes[1], format='%+2.0f dB')\n",
|
||||||
|
"\n",
|
||||||
|
"# plt.tight_layout()\n",
|
||||||
|
"# plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# print(f\"Mel 频谱图形状: {mel_spec.shape}\") # (channel, n_mels, time_frames)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "study-asr (3.11.12)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
123
src/inference.py
Normal file
123
src/inference.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch import Tensor, no_grad, device
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
from silero_vad import load_silero_vad, get_speech_timestamps
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
from model import ASRModel
|
||||||
|
|
||||||
|
CONFIG = {
|
||||||
|
# 模型配置
|
||||||
|
'input_dim': 256,
|
||||||
|
'num_heads': 8,
|
||||||
|
'ffn_dim': 2048,
|
||||||
|
'num_layers': 8,
|
||||||
|
'dropout': 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
|
class TimestampsType(TypedDict):
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
class ASRInference:
|
||||||
|
def __init__(self, model_path: Path, vocab_path: Path, device: device, sample_rate: int = 16000) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.tokenizer = ASRTokenizer(vocab_path=vocab_path)
|
||||||
|
self.model = ASRModel(vocab_size=self.tokenizer.vocab_size(), **CONFIG).to(device=device)
|
||||||
|
self.model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
print(f"params params: {self.model.get_num_params():,}",)
|
||||||
|
|
||||||
|
def _load_audio(self, audio_path: Path) -> Tensor:
|
||||||
|
waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)
|
||||||
|
|
||||||
|
if sample_rate != self.sample_rate:
|
||||||
|
waveform = Resample(sample_rate, self.sample_rate)(waveform)
|
||||||
|
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
waveform = waveform / (waveform.abs().max() + 1e-8)
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def transcribe(self, waveform: Tensor) -> str:
|
||||||
|
waveform = waveform.to(device=self.device) # [1, time]
|
||||||
|
waveform_length = torch.tensor([waveform.shape[1]], dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
log_probs, _ = self.model(waveforms=waveform, waveform_lengths=waveform_length) # [1, T, vocab]
|
||||||
|
|
||||||
|
text = self.tokenizer.ctc_greedy_decode(log_probs=log_probs[0])
|
||||||
|
return text
|
||||||
|
|
||||||
|
def transcribe_batch(self, audio_paths: list[Path]) -> list[str]:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for audio_path in audio_paths:
|
||||||
|
text = self.transcribe(audio_path=audio_path)
|
||||||
|
results.append(text)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def create(self, audio_path: Path, threshold_ms: int = 20000):
|
||||||
|
def merge_timestamps(timestamps: int, threshold_ms: int, sample_rate: int):
|
||||||
|
if not timestamps:
|
||||||
|
return []
|
||||||
|
threshold_samples = (threshold_ms / 1000) * sample_rate
|
||||||
|
merged = []
|
||||||
|
curr_start = timestamps[0]['start']
|
||||||
|
curr_end = timestamps[0]['end']
|
||||||
|
|
||||||
|
for i in range(1, len(timestamps)):
|
||||||
|
# 如果当前积攒的长度不到 2 秒,就一直合并到当前的 end 上
|
||||||
|
if (curr_end - curr_start) < threshold_samples:
|
||||||
|
curr_end = timestamps[i]['end']
|
||||||
|
else:
|
||||||
|
merged.append({'start': curr_start, 'end': curr_end})
|
||||||
|
curr_start = timestamps[i]['start']
|
||||||
|
curr_end = timestamps[i]['end']
|
||||||
|
merged.append({'start': curr_start, 'end': curr_end})
|
||||||
|
return merged
|
||||||
|
|
||||||
|
vad_model = load_silero_vad(onnx=False)
|
||||||
|
waveform = self._load_audio(audio_path=audio_path)
|
||||||
|
speech_timestamps: list[TimestampsType] = get_speech_timestamps(
|
||||||
|
waveform,
|
||||||
|
vad_model,
|
||||||
|
sampling_rate=self.sample_rate,
|
||||||
|
threshold=0.5, # 可以根据需要调整阈值
|
||||||
|
min_speech_duration_ms=100, # 最小语音持续时间,防止短噪音被误判
|
||||||
|
min_silence_duration_ms=200, # 最小静音间隔,用于分割语音块
|
||||||
|
speech_pad_ms=100, # 在语音块前后添加的填充时间
|
||||||
|
)
|
||||||
|
final_timestamps = merge_timestamps(timestamps=speech_timestamps, threshold_ms=threshold_ms, sample_rate=self.sample_rate)
|
||||||
|
for ts in final_timestamps:
|
||||||
|
ts: TimestampsType
|
||||||
|
start_sample = int(ts['start'])
|
||||||
|
end_sample = int(ts['end'])
|
||||||
|
|
||||||
|
segment_waveform = waveform[:, start_sample:end_sample]
|
||||||
|
|
||||||
|
segment_text = self.transcribe(waveform=segment_waveform)
|
||||||
|
print(segment_text, end=" ", flush=True)
|
||||||
|
print('\n')
|
||||||
|
|
||||||
|
def main():
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
device = torch.device('cuda:1')
|
||||||
|
|
||||||
|
checkpoint = workspace_dir / ".checkpoints/best_wer_model.pt"
|
||||||
|
inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/asr_vocab.json' , device=device)
|
||||||
|
|
||||||
|
audio_path = workspace_dir / 'data/test/F001_001.wav'
|
||||||
|
waveform = inference._load_audio(audio_path=audio_path)
|
||||||
|
text = inference.transcribe(waveform=waveform)
|
||||||
|
print("transcribe:", text)
|
||||||
|
inference.create(audio_path=audio_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
84
src/model.py
Normal file
84
src/model.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
from torch.nn import Module, Linear, functional, GELU, Conv1d, GroupNorm, ModuleList
|
||||||
|
from torch import Tensor
|
||||||
|
from torchaudio.models import Conformer
|
||||||
|
|
||||||
|
class WaveformFilter(Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.filters = ModuleList([
|
||||||
|
Conv1d(in_channels=1, out_channels=512, kernel_size=10, stride=5),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
Conv1d(in_channels=512, out_channels=512, kernel_size=2, stride=2),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
Conv1d(in_channels=512, out_channels=512, kernel_size=2, stride=2),
|
||||||
|
GroupNorm(num_channels=512, num_groups=32),
|
||||||
|
GELU(),
|
||||||
|
])
|
||||||
|
|
||||||
|
def compute_lengths(self, waveform_lengths: Tensor) -> Tensor:
|
||||||
|
"""Accurately compute output lengths after all convolutions."""
|
||||||
|
lengths = waveform_lengths
|
||||||
|
for module in self.filters:
|
||||||
|
if isinstance(module, Conv1d):
|
||||||
|
# Conv1d with padding=0, dilation=1:
|
||||||
|
# output = floor((input - kernel_size) / stride) + 1
|
||||||
|
lengths = (lengths - module.kernel_size[0]) // module.stride[0] + 1
|
||||||
|
return lengths
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
for filter in self.filters:
|
||||||
|
x = filter(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ASRModel(Module):
|
||||||
|
def __init__(self, vocab_size: int, input_dim: int = 256, num_heads: int = 8, ffn_dim: int = 2048, num_layers: int = 6, dropout: float = 0.1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feature_extractor = WaveformFilter()
|
||||||
|
self.proj = Linear(in_features=512, out_features=input_dim, bias=False)
|
||||||
|
|
||||||
|
self.encoder = Conformer(input_dim=input_dim, num_heads=num_heads, ffn_dim=ffn_dim, num_layers=num_layers, depthwise_conv_kernel_size=31, dropout=dropout)
|
||||||
|
|
||||||
|
self.ctc_head = Linear(in_features=input_dim, out_features=vocab_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, waveforms: Tensor, waveform_lengths: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
assert len(waveform_lengths.shape) == 1, "The waveform_lengths tensor must be shape of [B] tensor."
|
||||||
|
assert len(waveforms.shape) == 2, "The waveform tensor must be [B, time] tensor"
|
||||||
|
# waveforms: [B, time]
|
||||||
|
x: Tensor = waveforms.unsqueeze(1) # [B, 1, time]
|
||||||
|
#x: [B, 1, time]
|
||||||
|
x: Tensor = self.feature_extractor(x)
|
||||||
|
#x: [B, 512, time]
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
#x: [B, time, 512]
|
||||||
|
x = self.proj(x)
|
||||||
|
#x: [B, time, 256]
|
||||||
|
# lengths = torch.tensor([time] * batch, dtype=torch.long, device=x.device)
|
||||||
|
lengths = self.feature_extractor.compute_lengths(waveform_lengths)
|
||||||
|
|
||||||
|
x, lengths = self.encoder(x, lengths)
|
||||||
|
|
||||||
|
logits = self.ctc_head(x) # [B, time, vocab_size]
|
||||||
|
log_probs = functional.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
return log_probs, lengths
|
||||||
|
|
||||||
|
def get_num_params(self) -> int:
|
||||||
|
return sum(p.numel() for p in self.parameters())
|
||||||
107
src/tokenizer.py
Normal file
107
src/tokenizer.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from handle.text_handle import process_text
|
||||||
|
from handle.text_normalizer import UYGHUR_LETTERS
|
||||||
|
|
||||||
|
class ASRTokenizer:
|
||||||
|
def __init__(self, vocab_path: Path) -> None:
|
||||||
|
self.vocabs: list[str] = []
|
||||||
|
self.vocabs.extend([
|
||||||
|
"<BLANK>",
|
||||||
|
"<UNK>",
|
||||||
|
"<SPACE>"
|
||||||
|
])
|
||||||
|
for i in range(len(self.vocabs), 64):
|
||||||
|
self.vocabs.append(f"REVERSED_{i}")
|
||||||
|
|
||||||
|
with open(vocab_path, 'r', encoding='utf-8') as file:
|
||||||
|
data: list[str] = json.load(file)
|
||||||
|
|
||||||
|
for char in data:
|
||||||
|
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
|
||||||
|
self.vocabs.append(char)
|
||||||
|
|
||||||
|
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
|
||||||
|
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
|
||||||
|
|
||||||
|
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
|
||||||
|
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.vocabs)
|
||||||
|
|
||||||
|
def _split_long_syllable(self, syll: str) -> list[str]:
|
||||||
|
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
|
||||||
|
pieces = []
|
||||||
|
start = 0
|
||||||
|
n = len(syll)
|
||||||
|
while start < n:
|
||||||
|
# 从最大可能长度开始尝试匹配
|
||||||
|
for length in range(min(self._max_token_len, n - start), 0, -1):
|
||||||
|
sub = syll[start:start + length]
|
||||||
|
if sub in self.vocab_to_id:
|
||||||
|
pieces.append(sub)
|
||||||
|
start += length
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 理论上不会发生,因为单字母一定在词表
|
||||||
|
pieces.append("<UNK>")
|
||||||
|
start += 1
|
||||||
|
return pieces
|
||||||
|
|
||||||
|
def encode(self, text: str) -> list[int]:
|
||||||
|
tokens = process_text(text=text)
|
||||||
|
result: list[int] = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if token == ' ':
|
||||||
|
result.append(self.vocab_to_id['<SPACE>'])
|
||||||
|
else:
|
||||||
|
token_id = self.vocab_to_id.get(token)
|
||||||
|
if token_id is not None:
|
||||||
|
result.append(token_id)
|
||||||
|
elif token in UYGHUR_LETTERS:
|
||||||
|
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
|
||||||
|
else:
|
||||||
|
sub_pieces = self._split_long_syllable(token)
|
||||||
|
for piece in sub_pieces:
|
||||||
|
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
|
||||||
|
if isinstance(ids, Tensor):
|
||||||
|
ids = ids.tolist()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
prev_id = None
|
||||||
|
|
||||||
|
for token_id in ids:
|
||||||
|
# 跳过blank token
|
||||||
|
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
|
||||||
|
prev_id = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
# CTC解码:跳过连续重复的字符
|
||||||
|
if remove_repeat and token_id == prev_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
char = self.id_to_vocab.get(token_id, "<UNK>")
|
||||||
|
|
||||||
|
if char == '<SPACE>':
|
||||||
|
char = ' '
|
||||||
|
|
||||||
|
result.append(char)
|
||||||
|
prev_id = token_id
|
||||||
|
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
def get_special_token_id(self, token: str) -> int:
|
||||||
|
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
|
||||||
|
|
||||||
|
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
|
||||||
|
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
|
||||||
|
return self.decode(torch.argmax(log_probs, dim=-1))
|
||||||
367
src/train.py
Normal file
367
src/train.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, device, no_grad, cuda
|
||||||
|
from torch.optim import AdamW, Optimizer
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from torch.nn import CTCLoss, utils
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from datetime import datetime
|
||||||
|
from torch.amp.grad_scaler import GradScaler
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
from dataset import Batch, create_dataloader
|
||||||
|
from model import ASRModel
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 全局配置 ============
|
||||||
|
CONFIG = {
|
||||||
|
# 数据配置
|
||||||
|
'batch_size': 32,
|
||||||
|
|
||||||
|
# 训练配置
|
||||||
|
'num_epochs': 50,
|
||||||
|
'learning_rate': 2e-4,
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'grad_clip_norm': 1.0,
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
'input_dim': 256,
|
||||||
|
'num_heads': 8,
|
||||||
|
'ffn_dim': 2048,
|
||||||
|
'num_layers': 8,
|
||||||
|
'dropout': 0.1,
|
||||||
|
|
||||||
|
# 保存和评估
|
||||||
|
'early_stopping_patience': 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cer(pred: str, target: str) -> float:
|
||||||
|
""" 计算字符错误率 (Character Error Rate) 使用编辑距离算法"""
|
||||||
|
if len(target) == 0:
|
||||||
|
return 0.0 if len(pred) == 0 else 1.0
|
||||||
|
|
||||||
|
d = [[0] * (len(target) + 1) for _ in range(len(pred) + 1)]
|
||||||
|
|
||||||
|
for i in range(len(pred) + 1):
|
||||||
|
d[i][0] = i
|
||||||
|
for j in range(len(target) + 1):
|
||||||
|
d[0][j] = j
|
||||||
|
|
||||||
|
for i in range(1, len(pred) + 1):
|
||||||
|
for j in range(1, len(target) + 1):
|
||||||
|
if pred[i-1] == target[j-1]:
|
||||||
|
d[i][j] = d[i-1][j-1]
|
||||||
|
else:
|
||||||
|
d[i][j] = min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + 1
|
||||||
|
|
||||||
|
return d[len(pred)][len(target)] / len(target)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_wer(pred: str, target: str) -> float:
|
||||||
|
""" 计算词错误率 (Word Error Rate) """
|
||||||
|
pred_words = pred.split()
|
||||||
|
target_words = target.split()
|
||||||
|
|
||||||
|
if len(target_words) == 0:
|
||||||
|
return 0.0 if len(pred_words) == 0 else 1.0
|
||||||
|
|
||||||
|
d = [[0] * (len(target_words) + 1) for _ in range(len(pred_words) + 1)]
|
||||||
|
|
||||||
|
for i in range(len(pred_words) + 1):
|
||||||
|
d[i][0] = i
|
||||||
|
for j in range(len(target_words) + 1):
|
||||||
|
d[0][j] = j
|
||||||
|
|
||||||
|
for i in range(1, len(pred_words) + 1):
|
||||||
|
for j in range(1, len(target_words) + 1):
|
||||||
|
if pred_words[i-1] == target_words[j-1]:
|
||||||
|
d[i][j] = d[i-1][j-1]
|
||||||
|
else:
|
||||||
|
d[i][j] = min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + 1
|
||||||
|
|
||||||
|
return d[len(pred_words)][len(target_words)] / len(target_words)
|
||||||
|
|
||||||
|
def train_one_epoch(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, optimizer: Optimizer, scheduler: OneCycleLR, device: device, epoch: int, writer: SummaryWriter, global_step: int, scaler: GradScaler) -> tuple[float, int]:
|
||||||
|
model.train()
|
||||||
|
num_batches = 0
|
||||||
|
total_train_loss = 0
|
||||||
|
|
||||||
|
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")
|
||||||
|
for batch_index, batch in enumerate(progress_bar):
|
||||||
|
batch: Batch
|
||||||
|
waveforms = batch['waveforms'].to(device)
|
||||||
|
targets = batch['targets'].to(device)
|
||||||
|
waveform_lengths = batch['waveform_lengths'].to(device)
|
||||||
|
target_lengths = batch['target_lengths'].to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||||
|
log_probs, lengths = model(waveforms=waveforms, waveform_lengths=waveform_lengths)
|
||||||
|
log_probs: Tensor
|
||||||
|
log_probs_ctc = log_probs.permute(1, 0, 2)
|
||||||
|
loss: Tensor = criterion(log_probs=log_probs_ctc, targets=targets, input_lengths=lengths, target_lengths=target_lengths)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
utils.clip_grad_norm_(model.parameters(), max_norm=CONFIG['grad_clip_norm'])
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
total_train_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
writer.add_scalar('Train/Loss', loss.item(), global_step)
|
||||||
|
writer.add_scalar('Train/LearningRate', optimizer.param_groups[0]['lr'], global_step)
|
||||||
|
progress_bar.set_postfix({'epoch': f"{epoch}/{CONFIG['num_epochs']}",'loss': f'{loss.item():.4f}', 'step': global_step})
|
||||||
|
|
||||||
|
train_avg_loss = total_train_loss / num_batches
|
||||||
|
return train_avg_loss, global_step
|
||||||
|
|
||||||
|
def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device: device, tokenizer: ASRTokenizer, writer: SummaryWriter, global_step: int) -> tuple[float, float, float]:
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
total_cer = 0
|
||||||
|
total_wer = 0
|
||||||
|
num_samples = 0
|
||||||
|
num_batches = 0
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
progress_bar = tqdm(dataloader, desc="Validate", leave=False)
|
||||||
|
for batch in progress_bar:
|
||||||
|
batch: Batch
|
||||||
|
waveforms = batch['waveforms'].to(device)
|
||||||
|
targets = batch['targets'].to(device)
|
||||||
|
waveform_lengths = batch['waveform_lengths'].to(device)
|
||||||
|
target_lengths = batch['target_lengths'].to(device)
|
||||||
|
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||||
|
log_probs, lengths = model(waveforms=waveforms, waveform_lengths=waveform_lengths)
|
||||||
|
log_probs: Tensor
|
||||||
|
log_probs_ctc = log_probs.permute(1, 0, 2)
|
||||||
|
loss: Tensor = criterion(log_probs=log_probs_ctc, targets=targets, input_lengths=lengths, target_lengths=target_lengths)
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
for i in range(log_probs.shape[0]):
|
||||||
|
pred_text = tokenizer.ctc_greedy_decode(log_probs=log_probs[i])
|
||||||
|
true_text = batch['target_texts'][i]
|
||||||
|
cer = calculate_cer(pred=pred_text, target=true_text)
|
||||||
|
wer = calculate_wer(pred=pred_text, target=true_text)
|
||||||
|
total_cer += cer
|
||||||
|
total_wer += wer
|
||||||
|
num_samples += 1
|
||||||
|
|
||||||
|
if len(examples) < 3:
|
||||||
|
examples.append((true_text, pred_text, cer, wer))
|
||||||
|
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches
|
||||||
|
avg_cer = total_cer / num_samples
|
||||||
|
avg_wer = total_wer / num_samples
|
||||||
|
|
||||||
|
writer.add_scalar('Val/Loss', avg_loss, global_step)
|
||||||
|
writer.add_scalar('Val/CER', avg_cer, global_step)
|
||||||
|
writer.add_scalar('Val/WER', avg_wer, global_step)
|
||||||
|
|
||||||
|
for index, (true_text, pred_text, cer, wer) in enumerate(examples):
|
||||||
|
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
return avg_loss, avg_cer, avg_wer
|
||||||
|
|
||||||
|
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path):
|
||||||
|
checkpoint = {
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
|
'global_step': global_step,
|
||||||
|
'train_loss': train_loss,
|
||||||
|
'val_loss': val_loss,
|
||||||
|
'epoch': epoch,
|
||||||
|
'cer': cer,
|
||||||
|
'wer': wer,
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, save_path)
|
||||||
|
|
||||||
|
def find_latest_checkpoint(checkpoint_dir: Path) -> Path | None:
|
||||||
|
checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))
|
||||||
|
return checkpoints[-1] if checkpoints else None
|
||||||
|
|
||||||
|
def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, scheduler: OneCycleLR):
|
||||||
|
checkpoint = torch.load(file_path, weights_only=False)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
epoch = checkpoint['epoch']
|
||||||
|
global_step = checkpoint['global_step']
|
||||||
|
best_cer = checkpoint['cer']
|
||||||
|
best_wer = checkpoint['wer']
|
||||||
|
return epoch, global_step, best_cer, best_wer
|
||||||
|
|
||||||
|
def main():
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
|
||||||
|
final_prob = 0.5
|
||||||
|
warmup_epochs = 8
|
||||||
|
current_prob = 0.0
|
||||||
|
|
||||||
|
# ============ 创建数据加载器 ============
|
||||||
|
train_loader = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / '.data/ug/train_new.tsv',
|
||||||
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
|
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=CONFIG['batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
augment=True,
|
||||||
|
augment_prob=current_prob
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
|
||||||
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
|
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=CONFIG['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
augment=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ 初始化模型 ============
|
||||||
|
model = ASRModel(
|
||||||
|
vocab_size=tokenizer.vocab_size(),
|
||||||
|
input_dim=CONFIG['input_dim'],
|
||||||
|
num_heads=CONFIG['num_heads'],
|
||||||
|
ffn_dim=CONFIG['ffn_dim'],
|
||||||
|
num_layers=CONFIG['num_layers'],
|
||||||
|
dropout=CONFIG['dropout'],
|
||||||
|
).to(device)
|
||||||
|
print(f"🤖 模型参数量: {model.get_num_params() / 1e6:.2f}M")
|
||||||
|
|
||||||
|
|
||||||
|
criterion = CTCLoss(blank=tokenizer.get_special_token_id('<BLANK>'), zero_infinity=True)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'], foreach=True)
|
||||||
|
|
||||||
|
# 使用 OneCycleLR:自动处理 warmup 和衰减
|
||||||
|
scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=CONFIG['learning_rate'],
|
||||||
|
epochs=CONFIG['num_epochs'],
|
||||||
|
steps_per_epoch=len(train_loader),
|
||||||
|
pct_start=0.15,
|
||||||
|
anneal_strategy='cos',
|
||||||
|
div_factor=10.0, # 初始 lr = max_lr / 10
|
||||||
|
final_div_factor=1e4, # 最终 lr = max_lr / 10000
|
||||||
|
)
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
# ============ TensorBoard ============
|
||||||
|
log_dir = workspace_dir / 'runs' / datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
config_text = '\n'.join([f'{k}: {v}' for k, v in CONFIG.items()])
|
||||||
|
writer.add_text('Config', config_text, 0)
|
||||||
|
print(f"📊 TensorBoard 日志: {log_dir}")
|
||||||
|
|
||||||
|
checkpoint_dir = workspace_dir / '.checkpoints'
|
||||||
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# ============ 训练循环 ============
|
||||||
|
best_cer = float('inf')
|
||||||
|
best_wer = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
global_step = 0
|
||||||
|
start_epoch = 0
|
||||||
|
|
||||||
|
latest = find_latest_checkpoint(checkpoint_dir)
|
||||||
|
if latest:
|
||||||
|
start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler)
|
||||||
|
start_epoch += 1 # 从下一个 epoch 开始
|
||||||
|
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
|
||||||
|
else:
|
||||||
|
start_epoch = 0
|
||||||
|
print("🆕 从头开始训练\n")
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, CONFIG['num_epochs']):
|
||||||
|
current_prob = final_prob * (1 - math.cos(math.pi * epoch / warmup_epochs)) / 2
|
||||||
|
train_loss, global_step = train_one_epoch(
|
||||||
|
model=model,
|
||||||
|
dataloader=train_loader,
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch,
|
||||||
|
writer=writer,
|
||||||
|
global_step=global_step,
|
||||||
|
scaler=scaler,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||||
|
print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}")
|
||||||
|
|
||||||
|
# 保存常规 checkpoint
|
||||||
|
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
||||||
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path)
|
||||||
|
|
||||||
|
# 分别保存最佳 CER 和 WER 模型
|
||||||
|
improved = False
|
||||||
|
if val_cer < best_cer:
|
||||||
|
best_cer = val_cer
|
||||||
|
best_cer_path = checkpoint_dir / 'best_cer_model.pt'
|
||||||
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path)
|
||||||
|
improved = True
|
||||||
|
|
||||||
|
if val_wer < best_wer:
|
||||||
|
best_wer = val_wer
|
||||||
|
best_wer_path = checkpoint_dir / 'best_wer_model.pt'
|
||||||
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path)
|
||||||
|
improved = True
|
||||||
|
|
||||||
|
# Early Stopping 逻辑
|
||||||
|
if improved:
|
||||||
|
patience_counter = 0
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
print(f"⚠️ 验证指标未改善,patience: {patience_counter}/{CONFIG['early_stopping_patience']}")
|
||||||
|
|
||||||
|
# 删除旧的 checkpoint(保留最近3个)
|
||||||
|
old_checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))
|
||||||
|
for old in old_checkpoints[:-3]:
|
||||||
|
old.unlink()
|
||||||
|
|
||||||
|
cuda.empty_cache()
|
||||||
|
|
||||||
|
writer.add_scalar('Val/EpochLoss', val_loss, epoch)
|
||||||
|
writer.add_scalar('Train/EpochLoss', train_loss, epoch)
|
||||||
|
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
|
||||||
|
|
||||||
|
# Early Stopping 检查
|
||||||
|
if patience_counter >= CONFIG['early_stopping_patience']:
|
||||||
|
print(f"\n🛑 Early Stopping: 验证指标连续 {CONFIG['early_stopping_patience']} 次没有改善")
|
||||||
|
print(f"🏆 最佳 CER: {best_cer:.4f}")
|
||||||
|
print(f"🏆 最佳 WER: {best_wer:.4f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"🎉 训练完成!")
|
||||||
|
print(f"🏆 最佳 CER: {best_cer:.4f}")
|
||||||
|
print(f"🏆 最佳 WER: {best_wer:.4f}")
|
||||||
|
print(f"📊 TensorBoard: tensorboard --logdir={workspace_dir / 'runs'}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user