feat: change waveform

This commit is contained in:
2026-05-07 11:29:21 +06:00
commit d31233a79a
21 changed files with 5330 additions and 0 deletions

17
.gitignore vendored Normal file
View File

@@ -0,0 +1,17 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# data
.checkpoints
.data
data
runs
config

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

16
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "src/train.py",
"console": "integratedTerminal"
}
]
}

0
README.md Normal file
View File

33
pyproject.toml Normal file
View File

@@ -0,0 +1,33 @@
[project]
name = "study-asr"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"ipython>=9.10.1",
"librosa>=0.11.0",
"matplotlib>=3.10.8",
"mutagen>=1.47.0",
"numpy>=2.4.4",
"pandas>=3.0.2",
"pillow>=12.2.0",
"pydub>=0.25.1",
"pyrubberband>=0.4.0",
"setuptools<82",
"silero-vad>=6.2.1",
"tensorboard>=2.20.0",
"tensorboardx>=2.6.5",
"torch==2.8.0",
"torch-audiomentations>=0.12.0",
"torchaudio==2.8.0",
"torchcodec==0.7.0",
"tqdm>=4.67.3",
"webrtcvad>=2.0.10",
]
[[index]]
url = "https://mirrors.aliyun.com/pypi/simple/"
default = true
# url = "https://pypi.tuna.tsinghua.edu.cn/simple/"

431
src/dataset.py Normal file
View File

@@ -0,0 +1,431 @@
import random
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter
from torchaudio.transforms import Resample, TimeStretch
from pathlib import Path
import pandas as pd
from typing import List, TypedDict
from handle.text_normalizer import collapse_spaces, normalize_extended_uyghur_characters
from tokenizer import ASRTokenizer
# 单个样本的数据结构Dataset.__getitem__ 返回)
class BatchItem(TypedDict):
waveform: Tensor # [time]
target_ids: Tensor # [seq_len] 目标文本的token IDs
target_text: str # 原始文本
audio_path: str # 音频文件路径
# 批量数据的数据结构collate_fn 返回DataLoader 输出)
class Batch(TypedDict):
waveforms: Tensor # [batch, time]
targets: Tensor # [batch, max_len] padding后的目标IDs
waveform_lengths: Tensor # [batch] 每个样本的实际Waveform长度
target_lengths: Tensor # [batch] 每个样本的实际目标长度
target_texts: List[str] # [batch] 原始文本列表
audio_paths: List[str] # [batch] 音频路径列表
class TsvFormat(TypedDict):
client_id: str
path: str
sentence: str
up_votes: int
down_votes: int
age: str
gender: str
locale: str
class NoiseAugmentor:
def __init__(self, noise_root: Path, sample_rate: int=16000):
self.sample_rate = sample_rate
self.noise_files = list(Path(noise_root).rglob("*.wav"))
def apply_real_noise(self, waveform: Tensor):
# 1. 随机选一个噪音文件
noise_path = random.choice(self.noise_files)
noise_waveform, sr = torchaudio.load_with_torchcodec(noise_path)
# Resample to target sample rate.
if sr != self.sample_rate:
noise_waveform = Resample(sr, self.sample_rate)(noise_waveform)
# Convert to mono if it is setro.
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# 3. 截取或填充,使其长度与语音一致
sig_len = waveform.shape[1]
noise_len = noise_waveform.shape[1]
if noise_len >= sig_len:
# 随机截取一段
start = random.randint(0, noise_len - sig_len)
noise_waveform = noise_waveform[:, start:start + sig_len]
else:
full_noise = torch.zeros_like(waveform)
start = random.randint(0, sig_len - noise_len)
full_noise[:, start : start + noise_len] = noise_waveform
noise_waveform = full_noise
# 4. 设定随机信噪比 SNR (5dB 到 20dB)
snr_db = random.uniform(5, 20)
# 5. 混合
return self._mix_at_snr(waveform, noise_waveform, snr_db)
def _mix_at_snr(self, signal: Tensor, noise: Tensor, snr_db: float):
s_p = signal.pow(2).mean()
n_p = noise.pow(2).mean()
snr_linear = 10**(snr_db / 10)
scale = torch.sqrt(s_p / (n_p * snr_linear + 1e-8))
noisy = signal + scale * noise
# 归一化,防止溢出
return noisy / (noisy.abs().max() + 1e-8)
class CommonVoiceDataset(Dataset[BatchItem]):
def __init__(
self,
tsv_path: Path,
audio_dir: Path,
noise_dir: Path,
tokenizer: ASRTokenizer,
sample_rate: int = 16000,
max_audio_len: int = 480000, # 30秒 @ 16kHz
augment: bool = True,
augment_prob: float = 0.5,
) -> None:
super().__init__()
self.noise_augmentor = NoiseAugmentor(noise_root=noise_dir, sample_rate=sample_rate)
self.audio_dir = Path(audio_dir)
self.tokenizer = tokenizer
self.sample_rate = sample_rate
self.max_audio_len = max_audio_len
self.augment = augment
self.augment_prob = augment_prob
self.data: pd.DataFrame = pd.read_csv(tsv_path, sep='\t')
valid_indices = []
for index, row in self.data.iterrows():
audio_path: Path = self.audio_dir / row['path']
if audio_path.exists():
valid_indices.append(index)
self.data = self.data.loc[valid_indices].reset_index(drop=True)
self.gain_up = Gain(min_gain_in_db=4, max_gain_in_db=8, p=1.0, output_type='tensor')
self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')
self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')
self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')
self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type="tensor")
def __len__(self):
return len(self.data)
def _load_audio(self, audio_path: Path) -> Tensor:
waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)
# Resample to target sample rate.
if sample_rate != self.sample_rate:
waveform = Resample(sample_rate, self.sample_rate)(waveform)
# Convert to mono if it is setro.
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Normalization
waveform = waveform / waveform.abs().max()
# Clip waveform exceeds from max length.
if waveform.shape[1] > self.max_audio_len:
waveform = waveform[:, :self.max_audio_len]
return waveform
def _augment_waveform(self, waveform: Tensor) -> Tensor:
if not self.augment or random.random() > self.augment_prob:
return waveform
# 1. voice Stretch/Compress
if random.random() < 0.6:
waveform = self._stretch_or_compress(waveform=waveform)
if random.random() < 0.4:
waveform = self.noise_augmentor.apply_real_noise(waveform)
if random.random() < 0.3:
waveform = self._time_mask_waveform(waveform=waveform)
# torch_audiomentations: [1, time] -> [1, 1, time]
if waveform.dim() == 2:
waveform_3d = waveform.unsqueeze(0)
# 随机选择一种频谱增强
choice = random.random()
if choice < 0.15:
# 增益变化(上或下)
if random.random() < 0.5:
waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)
else:
waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.25:
# 音高变化(上或下)
if random.random() < 0.5:
waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)
else:
waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.30:
# 低通滤波(声音发闷)
waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)
# elif choice < 0.32:
# # 低通滤波(声音发闷)
# waveform_3d = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.35:
# 高通滤波(电话效果)
waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)
# [1, 1, time] -> [1, time]
waveform = waveform_3d.squeeze(0)
# 防止多次 augment 后振幅溢出,最后归一化
max_amp = waveform.abs().max()
if max_amp > 1.0:
waveform = waveform / max_amp
return waveform
def _stretch_or_compress(self, waveform: Tensor) -> Tensor:
speed_factor = random.uniform(0.85, 1.4) # (Speed Change: 0.85x - 1.4x)
spec = torch.stft(
waveform.squeeze(0),
n_fft=400,
hop_length=160,
window=torch.hann_window(400).to(waveform.device),
return_complex=True
)
# 时间拉伸(不改变音高)
stretched_spec = TimeStretch(hop_length=160, n_freq=spec.shape[-2], fixed_rate=speed_factor)(spec)
# 转回波形
waveform_stretched = torch.istft(stretched_spec, n_fft=400, hop_length=160, window=torch.hann_window(400).to(waveform.device)).unsqueeze(0)
return waveform_stretched
def _time_mask_waveform(self, waveform: Tensor) -> Tensor:
audio_len = waveform.shape[1]
sr = self.sample_rate # 16000
# 设置参数:单次遮盖最长 0.4 秒 (6400个点)
max_mask_time = 0.4
max_mask_samples = int(sr * max_mask_time)
# 根据音频长度决定遮盖次数:
# 比如每 3 秒钟允许遮盖 1 次
num_masks = max(1, audio_len // (sr * 3))
for _ in range(num_masks):
# 每次随机遮盖 0.1s 到 0.4s
current_mask_len = random.randint(int(sr * 0.1), max_mask_samples)
if audio_len > current_mask_len:
start_pos = random.randint(0, audio_len - current_mask_len)
# 填充微小噪音(模拟环境底噪)
noise = torch.randn(1, current_mask_len).to(waveform.device) * 0.002
waveform[:, start_pos : start_pos + current_mask_len] = noise
return waveform
def __getitem__(self, index) -> BatchItem:
row: TsvFormat = self.data.iloc[index]
audio_path: Path = self.audio_dir / row['path']
# text: str = unicodedata.normalize('NFC', row['sentence'].strip())
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
waveform = self._load_audio(audio_path=audio_path)
waveform = self._augment_waveform(waveform)
waveform = waveform.squeeze(0)
return BatchItem(
waveform=waveform,
target_ids=torch.tensor(self.tokenizer.encode(text), dtype=torch.long),
target_text=text,
audio_path=str(audio_path)
)
def collate_fn(items: List[BatchItem]) -> Batch:
max_waveform_len = max(item['waveform'].shape[0] for item in items)
max_target_len = max(len(item['target_ids']) for item in items)
batch_size = len(items)
waveforms = torch.zeros(batch_size, max_waveform_len)
targets = torch.zeros(batch_size, max_target_len, dtype=torch.long)
waveform_lengths = torch.zeros(batch_size, dtype=torch.long)
target_lengths = torch.zeros(batch_size, dtype=torch.long)
target_texts = []
audio_paths = []
for i, item in enumerate(items):
waveform_len = item['waveform'].shape[0]
target_len = len(item['target_ids'])
waveforms[i, :waveform_len] = item['waveform']
targets[i, :target_len] = item['target_ids']
waveform_lengths[i] = waveform_len
target_lengths[i] = target_len
target_texts.append(item['target_text'])
audio_paths.append(item['audio_path'])
return Batch(
waveforms=waveforms,
targets=targets,
waveform_lengths=waveform_lengths,
target_lengths=target_lengths,
target_texts=target_texts,
audio_paths=audio_paths
)
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
# ============ 测试代码 ============
if __name__ == "__main__":
from pathlib import Path
workspace_dir = Path(__file__).parent.parent
# 初始化tokenizer
tokenizer = ASRTokenizer(workspace_dir / 'config' / 'asr_vocab.json')
# 创建数据加载器
dataloader = create_dataloader(
tsv_path=workspace_dir / 'data' / 'ug' / 'train.tsv',
audio_dir=workspace_dir / 'data' / 'ug' / 'clips',
tokenizer=tokenizer,
batch_size=2,
shuffle=True,
)
# 测试加载一个batch
print("测试数据加载:")
for batch in dataloader:
batch: Batch
print(f"Mel specs shape: {batch['mel_specs'].shape}")
print(f"Targets shape: {batch['targets'].shape}")
print(f"Mel lengths: {batch['mel_lengths']}")
print(f"Target lengths: {batch['target_lengths']}")
print(f"Target texts: {batch['target_texts']}")
print(f"Audio paths: {batch['audio_paths']}")
break
import torch
import torchaudio
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
# --- 1. 加载 VAD 模型 ---
# 推荐使用 GPU (cuda) 如果可用,否则使用 CPU (cpu)
vad_model = load_silero_vad(source="local", force_onnx=False)
def process_audio_with_vad_and_asr(audio_path, inference_module):
"""
使用 Silero VAD 分割音频,然后对每个语音片段进行 ASR 转录。
Args:
audio_path (str): 输入音频文件的路径。
inference_module: 您封装了 transcribe 方法的模块或对象。
需要确保其 _load_audio 和 transcribe 方法可用。
"""
print(f"正在加载音频: {audio_path}")
# --- 2. 加载音频用于 VAD 检测 ---
# Silero VAD 推荐使用 16kHz 采样率
waveform_vad, sample_rate_vad = torchaudio.load(audio_path)
if sample_rate_vad != 16000:
# 如果采样率不是16kHz需要重采样以供VAD使用
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_vad, new_freq=16000)
waveform_vad = resampler(waveform_vad)
sample_rate_vad = 16000
# --- 3. 获取语音时间段 ---
# get_speech_timestamps 返回一个列表,每个元素是 {'start': start_sample, 'end': end_sample} 的字典
# 采样率是16000所以时间戳单位是 1/16000 秒
speech_timestamps = get_speech_timestamps(
waveform_vad,
vad_model,
sampling_rate=sample_rate_vad,
threshold=0.5, # 可以根据需要调整阈值
min_speech_duration_ms=250, # 最小语音持续时间,防止短噪音被误判
max_speech_duration_s=float('inf'), # 最大语音持续时间float('inf') 表示不限制
min_silence_duration_ms=100, # 最小静音间隔,用于分割语音块
window_size_samples=1536, # VAD窗口大小
speech_pad_ms=30 # 在语音块前后添加的填充时间
)
print(f"检测到 {len(speech_timestamps)} 个语音片段。")
full_text = ""
for i, ts in enumerate(speech_timestamps):
start_sample = int(ts['start'])
end_sample = int(ts['end'])
# 计算时间戳(秒)
start_time = start_sample / sample_rate_vad
end_time = end_sample / sample_rate_vad
print(f"\n处理第 {i+1} 个片段: 时间范围 [{start_time:.2f}s - {end_time:.2f}s]")
# --- 4. 从原始音频中提取此片段 ---
# 注意:这里假设您的 transcribe 函数可以接受 waveform 张量。
# 我们需要从原始可能不同采样率的音频中提取片段或者用VAD处理过的waveform。
# 为了匹配您原来的 _load_audio 方式,我们用 torchaudio 再次精确加载片段。
# 计算原始音频中的样本索引如果原始音频采样率与VAD不同
original_waveform, original_sr = torchaudio.load(audio_path)
if sample_rate_vad != original_sr:
# 如果VAD和原始音频采样率不同需要重新映射索引
start_idx_orig = int(start_sample * (original_sr / sample_rate_vad))
end_idx_orig = int(end_sample * (original_sr / sample_rate_vad))
else:
start_idx_orig = start_sample
end_idx_orig = end_sample
segment_waveform = original_waveform[:, start_idx_orig:end_idx_orig]
# --- 5. 对该片段进行 ASR 转录 ---
# 这里调用您原有的 transcribe 方法
try:
# 假设 transcribe 方法接受一个 waveform tensor
segment_text = inference_module.transcribe(waveform=segment_waveform)
print(f" -> 转录结果: {segment_text}")
full_text += f"[{start_time:.2f}-{end_time:.2f}s] {segment_text}\n"
except Exception as e:
print(f" -> 转录第 {i+1} 个片段时出错: {e}")
print("\n--- 完整转录结果 ---")
print(full_text)
# --- 使用示例 ---
# 假设您有一个名为 'inference' 的模块对象,它有 _load_audio 和 transcribe 方法
# audio_path = "your_audio_file.wav"
# process_audio_with_vad_and_asr(audio_path, inference)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,70 @@
from pathlib import Path
from tqdm import tqdm
from mutagen.mp3 import MP3
from mutagen import MutagenError
import pandas as pd
workspace_dir = Path(__file__).parent.parent.parent
def format_duration(seconds):
"""将秒数格式化为 时:分:秒 的格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
return f"{hours}小时{minutes}{secs}"
elif minutes > 0:
return f"{minutes}{secs}"
else:
return f"{secs}"
def get_mp3_duration(file_path: Path):
"""获取单个MP3文件的时长"""
try:
audio = MP3(file_path)
return audio.info.length
except MutagenError as e:
print(f"错误:无法读取文件 {file_path} - {e}")
return None
except Exception as e:
print(f"错误:处理文件 {file_path} 时出错 - {e}")
return None
def analyze_mp3_files(directory: Path, ):
"""分析目录中的所有MP3文件"""
results = []
total_duration = 0
df = pd.read_csv(workspace_dir / '.data/ug/train_new.tsv', sep='\t')
for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"):
mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path']
duration = get_mp3_duration(mp3_file)
if duration is not None:
total_duration += duration
duration_str = format_duration(duration)
file_size = mp3_file.stat().st_size / (1024 * 1024) # 转换为MB
result_line = f"{mp3_file.name:<50} {duration_str:>20} ({file_size:.2f} MB)"
results.append(result_line)
else:
error_line = f"{mp3_file.name:<50} {'读取失败':>20}"
print(error_line)
results.append(error_line)
mp3_files: list[Path] = []
for ext in ['*.mp3', '*.MP3']:
mp3_files.extend(directory.rglob(ext))
print(f"tsv 找到 {len(results)} 个MP3文件\n")
print(f"找到 {len(mp3_files)} 个MP3文件\n")
print(f"\n总时长: {format_duration(total_duration)}")
print(f"总时长(秒): {total_duration:.2f}")
print(f"总时长(分钟): {total_duration/60:.2f} 分钟")
print(f"总时长(小时): {total_duration/3600:.2f} 小时")
print(f"文件总数: {len(mp3_files)}")
audio_directory = Path(workspace_dir / '.data/ug/clips')
analyze_mp3_files(directory=audio_directory)

View File

@@ -0,0 +1,13 @@
from pathlib import Path
import torch
device = "cuda:0"
workspace_dir = Path(__file__).parent.parent.parent
input_checkpoint = workspace_dir.joinpath('.checkpoints/best_wer_model.pt')
output_checkpoint = workspace_dir.joinpath('.checkpoints/prodect_best_wer_model.pt')
checkpoint = torch.load(input_checkpoint, map_location=device)
torch.save(checkpoint['model_state_dict'], output_checkpoint)

View File

@@ -0,0 +1,37 @@
from pathlib import Path
import pandas as pd
workspace_dir = Path(__file__).parent.parent.parent
# 读取你合并后的总表 (253,430条)
df = pd.read_csv(workspace_dir / ".data/ug/validated.tsv", sep='\t')
# 1. 获取所有唯一的说话人
all_speakers = df['client_id'].unique()
# 2. 随机打乱说话人顺序
import random
random.seed(42) # 固定随机种子,保证实验可重复
random.shuffle(all_speakers)
# 3. 挑选验证集说话人,直到录音总数达到 ~8000 条
val_indices = []
val_count = 0
target_val_size = 8000
for speaker in all_speakers:
speaker_data = df[df['client_id'] == speaker]
val_indices.extend(speaker_data.index.tolist())
val_count += len(speaker_data)
if val_count >= target_val_size:
break
# 4. 划分文件
df_val = df.loc[val_indices]
df_train = df.drop(val_indices)
print(f"训练集条数: {len(df_train)}")
print(f"验证集条数: {len(df_val)}")
df_train.to_csv(workspace_dir / ".data/ug/train_new.tsv", sep='\t', index=False)
df_val.to_csv(workspace_dir / ".data/ug/val_new.tsv", sep='\t', index=False)

View File

@@ -0,0 +1,92 @@
import json
from pathlib import Path
from typing import Counter
import pandas as pd
from tqdm import tqdm
from text_handle import process_text
workspace_dir = Path(__file__).parent.parent.parent
data_dir = Path("/home/blacksheep/projekts/study")
input_files = [
data_dir / "data/multilingual/hanziDB_translated-simplified.jsonl",
data_dir / "data/multilingual/hanziDB_translated_validationset-simplified.jsonl",
data_dir / "data/multilingual/tatoeba-tr-en-ug-uz-kz-zh-simplified.jsonl",
data_dir / "data/multilingual/export_csv_database.jsonl",
data_dir / "data/multilingual/tatoeba_sentences.jsonl",
]
output_file = workspace_dir / "config/asr_vocab_1.json"
syllabizes = []
# langs = ["uig_Arab"]
# for input_file in input_files:
# with open(input_file, 'r', encoding='utf-8') as f:
# total_lines = sum(1 for _ in f)
# f.seek(0)
# for line in tqdm(f, total=total_lines, desc=f"Processing lines {input_file}"):
# data: dict[str, str] = json.loads(line)
# for lang in langs:
# if lang in data:
# syllabizes.extend(export_syllabize(data[lang]))
# print(f'data_dir syllabize len: {len(syllabizes):,}, set: {len(set(syllabizes)):,}')
tsv_files = [
# workspace_dir / "data/ug/test.tsv",
# workspace_dir / "data/ug/invalidated.tsv",
# workspace_dir / "data/ug/train.tsv",
# workspace_dir / "data/ug/validated.tsv",
# workspace_dir / "data/ug/reported.tsv",
# workspace_dir / "data/ug/dev.tsv",
# workspace_dir / "data/ug/other.tsv",
# workspace_dir / ".data/ug/invalidated.tsv",
workspace_dir / ".data/ug/train.tsv",
# workspace_dir / ".data/ug/clip_durations.tsv", # not sentence
# workspace_dir / ".data/ug/test.tsv",
# workspace_dir / ".data/ug/validated_sentences.tsv",
# workspace_dir / ".data/ug/other.tsv",
# workspace_dir / ".data/ug/validated.tsv",
# workspace_dir / ".data/ug/dev.tsv",
# workspace_dir / ".data/ug/unvalidated_sentences.tsv",
# workspace_dir / ".data/ug/reported.tsv" # Lacking sentence
]
for tsv_file in tsv_files:
data = pd.read_csv(tsv_file, sep='\t')
# 带进度条处理每行数据
for index, row in tqdm(data.iterrows(), total=len(data), desc=f"Processing {tsv_file}"):
syllabizes.extend(process_text(row['sentence'].strip()))
# 统计所有音节出现次数
syllable_counter = Counter(syllabizes)
# 过滤出出现100次以上的音节
freq_100_plus = {k: v for k, v in syllable_counter.items() if v >= 150}
freq_100_minus = {k: v for k, v in syllable_counter.items() if v <= 100}
# 保存100次以上的音节列表只有音节排序
vocab = sorted(list(freq_100_plus.keys()), key=len)
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(vocab, f, ensure_ascii=False, indent=2)
# # # 保存100次以上的音节及次数
sorted_freq_100_plus = dict(sorted(freq_100_plus.items(), key=lambda x: x[1], reverse=True))
with open(workspace_dir / 'config/syllables_freq_100_plus.json', 'w', encoding='utf-8') as f:
json.dump(sorted_freq_100_plus, f, ensure_ascii=False, indent=2)
# # 统计信息
print(f"总音节数: {len(syllabizes):,}")
print(f"唯一音节数: {len(syllable_counter):,}")
print(f"出现100次以上的音节数: {len(freq_100_plus):,}")
print(f"出现100次以下的音节数: {len(freq_100_minus):,}")
# 区间统计(不累积)
print("\n=== 音节使用次数统计(区间) ===")
for low in range(0, 100, 10):
high = low + 9
count = sum(1 for freq in syllable_counter.values() if low <= freq <= high)
print(f"出现 {low}-{high} 次: {count:,} 个音节")
# 100次以上
count_100plus = sum(1 for freq in syllable_counter.values() if freq >= 100)
print(f"出现 100+ 次: {count_100plus} 个音节")

490
src/handle/test.ipynb Normal file

File diff suppressed because one or more lines are too long

220
src/handle/test.py Normal file
View File

@@ -0,0 +1,220 @@
import os
import random
import sys
import pandas as pd
import torchaudio
from pathlib import Path
from tqdm import tqdm
workspace_dir = Path(__file__).parent.parent.parent
sys.path.append(str(workspace_dir.joinpath('src')))
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from handle.text_handle import process_text
from tokenizer import ASRTokenizer
# 读取 TSV
# df = pd.read_csv(workspace_dir / '.data/ug/train.tsv', sep='\t')
# # 计算每个音频的时长(秒)
# durations = []
# for audio_path in tqdm(workspace_dir / '.data/ug/clips' / df['path'], desc="计算音频时长"):
# try:
# # 获取音频信息(不加载整个文件,速度快)
# info = torchaudio.info(audio_path)
# duration = info.num_frames / info.sample_rate
# durations.append(duration)
# except Exception as e:
# print(f"读取失败: {audio_path}, 错误: {e}")
# durations.append(None)
# # 添加到 DataFrame
# df['duration'] = durations
# # 统计
# print(df['duration'].describe())
# print(f"超过20秒的样本: {(df['duration'] > 20).sum()}")
# 保存结果(可选)
# df.to_csv('data/audio/train_with_duration.tsv', sep='\t', index=False)
# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_40252722.mp3'
# info = torchaudio.info(audio_path)
# print("info:", info)
import pandas as pd
import torch
import torchaudio
import numpy as np
from pathlib import Path
from tqdm import tqdm
def analyze_audio_quality(audio_path, sample_rate_target=16000):
"""
评估音频质量,返回多个指标
"""
try:
# 加载音频
waveform, sr = torchaudio.load(str(audio_path))
# 重采样 + 单声道
if sr != sample_rate_target:
resampler = torchaudio.transforms.Resample(sr, sample_rate_target)
waveform = resampler(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# 归一化
max_val = waveform.abs().max()
if max_val > 0:
waveform = waveform / max_val
# ===== 指标 1: 静音比例 =====
silence_threshold = 0.01
is_silence = waveform.abs() < silence_threshold
silence_ratio = is_silence.float().mean().item()
# ===== 指标 2: 动态范围 =====
frame_len = int(0.025 * sample_rate_target) # 25ms
hop_len = int(0.010 * sample_rate_target) # 10ms
num_frames = (waveform.shape[1] - frame_len) // hop_len + 1
if num_frames <= 0:
return None
energies = []
for i in range(num_frames):
start = i * hop_len
frame = waveform[:, start:start + frame_len]
rms = frame.pow(2).mean().sqrt().item()
energies.append(rms)
energies = np.array(energies)
energies = np.clip(energies, 1e-10, None)
log_energies = 20 * np.log10(energies)
dynamic_range = log_energies.max() - log_energies.min()
# ===== 指标 3: 语音活跃度 =====
energy_threshold = np.percentile(energies, 30)
voice_ratio = (energies > energy_threshold).mean()
# ===== 指标 4: 频谱中心频率 =====
spec_transform = torchaudio.transforms.Spectrogram(n_fft=512)
spec = spec_transform(waveform) # [1, freq, time]
# 用 torch 计算,避免 numpy axis 问题
freqs = torch.fft.rfftfreq(512, d=1.0/sample_rate_target)
spec_mean = spec.mean(dim=-1).squeeze() # [freq_bins] torch 张量
# 确保类型一致,用 torch 计算
spectral_centroid = (freqs * spec_mean).sum() / (spec_mean.sum() + 1e-10)
spectral_centroid = spectral_centroid.item()
# ===== 指标 5: 频谱通量 =====
spec_np = spec.squeeze().numpy() # [freq, time]
flux = np.abs(np.diff(spec_np, axis=1)).mean()
return {
'duration': waveform.shape[1] / sample_rate_target,
'silence_ratio': silence_ratio,
'dynamic_range_db': float(dynamic_range),
'voice_ratio': float(voice_ratio),
'spectral_centroid': spectral_centroid,
'spectral_flux': float(flux),
'energy_std': float(energies.std()),
}
except Exception as e:
print(f"失败: {audio_path}, {e}")
return None
# ===== 批量处理 =====
# df = pd.read_csv(workspace_dir / '.data/ug/train.tsv', sep='\t')
# results = []
# for _, row in tqdm(df.iterrows(), total=len(df), desc="评估音频质量"):
# audio_path = workspace_dir / '.data/ug/clips' / row['path']
# metrics = analyze_audio_quality(audio_path)
# if metrics:
# metrics['path'] = row['path']
# results.append(metrics)
# # 汇总
# quality_df = pd.DataFrame(results)
# print(quality_df.describe())
# # ===== 质量评分 =====
# def score_quality(row):
# score = 100
# if row['silence_ratio'] > 0.5:
# score -= 30
# elif row['silence_ratio'] > 0.3:
# score -= 15
# if row['dynamic_range_db'] < 10:
# score -= 25
# elif row['dynamic_range_db'] < 20:
# score -= 10
# if row['voice_ratio'] < 0.3:
# score -= 20
# if row['spectral_centroid'] > 4000:
# score -= 15
# return max(0, score)
# quality_df['quality_score'] = quality_df.apply(score_quality, axis=1)
# quality_df['grade'] = pd.cut(
# quality_df['quality_score'],
# bins=[0, 40, 60, 80, 100],
# labels=['差(弃用)', '较差', '一般', '良好']
# )
# print("\n质量分级统计:")
# print(quality_df['grade'].value_counts())
# # 保存结果
# quality_df.to_csv('audio_quality_analysis.csv', index=False)
# print("\n结果已保存到 audio_quality_analysis.csv")
# Test
# text = "مەن مەكتەپكە باردىم"
# text = "ئاۋۋال كومپىيوتىرنى بىر كۆرسەم شۇنىڭغا قاراپ ماسلاشتۇرسام بوللاتتى"
text = "قاپاقنى پۇلغا ئالماي، باراڭنى پۇلغا ئاپتۇ"
# text = "ئامېىقى جاھان گۈرلىكىە قارشتۇرۇپ چوكشەڭگە ياردەم بېرىش ئۇرۇشىنىڭ ئلۇغ غەلبىسى جۇڭگو خەلقى ئورندىن دەستۇرغاندىكىن، دۇنيانىڭ شەرقىدە قەت كۆتۈرگەنلكىنى خىتاب لامىسى."
# text = "ئاۋسترالىيە"
text = " ھەر قانداق ئىشتا كىشىلەر بىلەن كېڭىشىش كېرەك. لېكىن كۆڭۈل تارتقان ئىشنى قىلىۋېرىش كېرەك."
text = "\" 15 يىل بۇرۇن مەن بىلەن سەي چۇڭشىن (ئالى بابانىڭ قۇرغۇچىسىدىن بىرى) ئامېرىكىغا بېرىپ 30 نەچچە مەبلەغ سالغۇچى شىركەتلەر بىلەن كۆرۈشكىنىمىزدە ئۇلار بىزنى رەت قىلىپ ئىشىك سىرتىدا قالدۇرغاندى. ھېچكىم بىزگە ۋە كەلگۈسىمىزگە ئىشەنمىگەن ئىدى، ھېچكىممۇ بىزنىڭ ھازىر 300 مىليارد سودىنى تاماملىيالايدىغانلىقىمىزنى ئويلاپ خىيالىغىمۇ كەلتۈرمىگەن ئىدى.\" ئالى بابانىنىڭ ئورگىنى تەرىپىدىن ئىشلەنگەن ئىگىلىك تىكلەش ھۆججەتلىك فىلىمى \" بۇ چۈش ئەمەس\" نىڭ باش قىسمىدا مۇنداق بىر داڭلىق سۆز چىقىدۇ:\" مەن تاغدىن ئۆتكەن ۋاقتىمدا تاغ ماڭا گەپ قىلمىدى، مەن دېڭىزنى كېچىپ ئۆتكىنىمدە دېڭىز ماڭا گەپ قىلمىدى.\" بۇ فىلىمنىڭ قوشۇمچە ئىسمى بولسا \" مايۈن ۋە ئۇنىڭ مەڭگۈلۈك ' ياش ئالى'سى\" بولۇپ، ھەرخىل خەتەر ۋە قىيىنلىقنى بىرمۇ بىر يەڭگەن مايۈن ۋە ئۇنىڭ ئالى بابا قوشۇنىدىكى ھەمكارلاشقۇچىلىرىنىڭ قەيسەر كەچمىشى جانلىق بايان قىلىنغان."
# text = "يۆ جيەنتاۋ يېقىنقى بىر مەزگىلدە، خىزمەتداشلىرى بىلەن نۇرغۇن قىيىنچىلىققا ئۇچرىغان شوپۇرغا ياردەم قىلغانلىقىنى، شۇنداقلا يېمەكلىك ۋە سۇ يەتكۈزۈپ بەرگەنلىكىنى، ئەمما ئاياغ سوۋغا قىلىشى تۇنجى قېتىم ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى قېتىم شوپۇرغا ئاياغ سوۋغا قىلىشى ئىكەنلىكىنى ئېيتتى."
text = "يۆ جۇڭمىڭ مۇنداق دېدى: 2019- يىلى 12-ئايدا، مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 44- قېتىملىق كومىتېت باشلىقلىرى يىغىنى مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 2020-يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى پىرىنسىپ جەھەتتىن ماقۇللىدى، خىزمەت تەرتىپى بويىچە، 13-نۆۋەتلىك مەملىكەتلىك خەلق قۇرۇلتىيى 3-يىغىنىنىڭ روھى ۋە ۋەكىللەرنىڭ تەكلىپ-تەۋسىيەلىرىگە ئاساسەن پىلاننى تەڭشەش كېرەك. بۇ يىل 6-ئاينىڭ 1-كۈنى، 58-قېتىملىق كومىتېت باشلىقلىرى يىغىنى تەڭشەلگەندىن كېيىنكى يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى قاراپ چىقىپ ماقۇللىدى."
text = " ئاشقازان-ئۈچەينىڭ لۆمۈلدىشىنى ئىلگىرى سۈرىدىغان دورىنى تاماقتىن بۇرۇن ئىستېمال قىلىش كېرەك."
text = "مەن مەكتەپكە باردىم"
# text = "غەرىپئەللىرى"
# text = "ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى"
# text = " يېزىدىكى كەڭ، ئازادە ئۆي، باغلىرىنى تاشلاپ، خەقنىڭ ھويلىسىدا قورۇنۇپ-ئەيمىنىپ يەر دەسسەپ يۈردى. "
# text = "بۇ يىللىق ئەمگەكچىلەر بايرىمىدا، 1-مايدىن 5-مايغىچە جەمئىي بەش كۈن دەم ئېلىشقا قويۇپ بېرىلىدۇ. 9-ماي (شەنبە) نورمال خىزمەت قىلىنىدۇ. شۇنىڭ بىلەن بىر ۋاقىتتا، ئۈرۈمچى شەھىرى تىيانشان رايونى، سايباغ رايونى، داۋانچىڭ رايونى، ئۈرۈمچى ناھىيەسىدىكى ئوتتۇرا، باشلانغۇچ مەكتەپلەر 29-ئاپرېلدىن 30-ئاپرېلغىچە ئەتىيازلىق دەم ئېلىشقا قويۇپ بېرىدۇ؛ سانجى ئوبلاستىدىكى ئوتتۇرا، باشلانغۇچ مەكتەپلەرنىڭ ئەتىيازلىق دەم ئېلىشى ئۈچ كۈن بولۇپ، بۇنىڭ ئىچىدە ئىككى كۈن 29-، 30-ئاپرېلغا ئورۇنلاشتۇرۇلىدۇ، قالغان بىر كۈن 1-ئىيۇنغا ئورۇنلاشتۇرۇلىدۇ، قۇربان ھېيتلىق دەم ئېلىش بىلەن بىرلەشتۈرۈلۈپ، ئۇدا ئالتە كۈن دەم ئېلىنىدۇ."
# text = "ئەڭ يېڭى دەم ئېلىشقا قويۇپ بېرىش ئۇقتۇرۇشى! تەقەززالىق بىلەن كۈتكەن يەنە بىر دەم ئېلىش كېلەي دەپ قالدى! گوۋۇيۈەن بەنگۇڭتىڭىنىڭ «2025-يىللىق قىسمەن بايرام، دەم ئېلىش كۈنلىرىنى ئورۇنلاشتۇرۇش توغرىسىدىكى ئۇقتۇرۇشى»غا ئاساسەن، دۈەنۋۇ بايرىمىلىق دەم ئېلىشقا قويۇپ بېرىش ئورۇنلاشتۇرۇلۇشى تۆۋەندىكىچە: 5-ئاينىڭ 31-كۈنى (شەنبە)دىن 6-ئاينىڭ 2-كۈنىگىچە جەمئىي ئۈچ كۈن دەم ئېلىشقا قويۇپ بېرىلىدۇ! بۇ قېتىمقى دۈەنۋۇ بايرىمىلىق دەم ئېلىش ۋاقتى تەڭشەلمەيدۇ!"
# text = "مەن شۇ چاغدا چۈشۈمدىمۇ كۆرۈپ باقمىغان مۇنچىۋالا جىق قېرىنداشلىرىم، جىگەرلىرىمنىڭ بارلىقىدىن سۆيۈندۈم."
# text = "ئامېرىكا جاھان گىرلىكىگە قارشى تۇرۇپ، چاۋشەنگە ياردەم بىرىش ئۇرۇشىنىڭ ئۇلۇغ غەلبىسى، جۇڭگو خەلقى ئورۇندىن دەس تۇرغاندىنكىن دۇنيانىڭ شەرقىدە قەد كۆتۈرگەنلىكىنىڭ خىتاپنامىس"
# result = export_syllabize(text)
result = process_text(text)
print(f"Original: {text}")
print(f"Syllables: {result}")
# tokenizer = ASRTokenizer(vocab_path=workspace_dir / 'config/asr_vocab.json')
# print(text)
# ids = tokenizer.encode(text=text)
# # print(ids)
# print('|'.join(tokenizer.decode([id]) for id in ids))

195
src/handle/text_handle.py Normal file
View File

@@ -0,0 +1,195 @@
from .text_normalizer import UYGHUR_LETTERS, clean_input_text
# def uighur_syllabize(text: str):
# # Uyghur Vowels
# vowels = "ئاەئەوۆئۇئۈئىئېاەوۆۇۈىې"
# def split_word(word):
# syllables = []
# current_syllable = ""
# i = 0
# # Helper to check if a character is a vowel
# is_vowel = lambda char: char in vowels
# while i < len(word):
# current_syllable += word[i]
# # If we find a vowel, look ahead to decide where to split
# if is_vowel(word[i]):
# # Check next characters
# remaining = word[i+1:]
# # Rule 1: V-V (e.g., 'ائ') -> Split after first vowel
# if len(remaining) >= 1 and is_vowel(remaining[0]):
# syllables.append(current_syllable)
# current_syllable = ""
# # Rule 2: V-C-V (e.g., 'ba-ra') -> Split after first vowel
# elif len(remaining) >= 2 and not is_vowel(remaining[0]) and is_vowel(remaining[1]):
# syllables.append(current_syllable)
# current_syllable = ""
# # Rule 3: V-C-C-V (e.g., 'mek-tep') -> Split after first consonant
# elif len(remaining) >= 3 and not is_vowel(remaining[0]) and not is_vowel(remaining[1]) and is_vowel(remaining[2]):
# current_syllable += remaining[0]
# syllables.append(current_syllable)
# current_syllable = ""
# i += 1 # Skip the consonant we just added
# # Rule 4: End of word or C-C-C clusters (rare in native words)
# # We keep going until we find a clear boundary or end of word
# i += 1
# if current_syllable:
# syllables.append(current_syllable)
# return syllables
# # Clean text and process word by word
# words = text.split()
# all_syllables = []
# for w in words:
# all_syllables.extend(split_word(w))
# #merge ئ to next syllable. single ئ no any semantic meaning and no it's dedicated sound.
# original_ayllables = [*all_syllables]
# all_syllables.clear()
# handled = True
# for s in original_ayllables:
# if s == "ئ":
# handled = False
# continue
# if not handled:
# s = "ئ" + s
# handled = True
# all_syllables.append(s)
# return all_syllables
# test_words = ['تەكلىپتەۋسىيەلىرىگە', 'تەكلىپ-تەۋسىيەلىرىگە']
# for tw in test_words:
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
# تەكلىپتەۋسىيەلىرىگە -> ['تەك', 'لىپ', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
# تەكلىپ-تەۋسىيەلىرىگە -> ['تەك', 'لىپ-تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
def uighur_syllabize(text: str):
# Uyghur Vowels (ASU)
vowels = "ئاەئەوۆئۇئۈئىئېاەوۆۇۈىې"
def is_vowel(char):
return char in vowels
def split_word(word):
syllables = []
i = 0
last_split = 0
while i < len(word):
# Look for vowel patterns to determine boundaries
if is_vowel(word[i]):
rem = word[i+1:]
# Rule: V-V -> Split after first vowel
if len(rem) >= 1 and is_vowel(rem[0]):
syllables.append(word[last_split:i+1])
last_split = i + 1
# Rule: V-C-V -> Split after vowel (e.g., ba-ra)
elif len(rem) >= 2 and not is_vowel(rem[0]) and is_vowel(rem[1]):
syllables.append(word[last_split:i+1])
last_split = i + 1
# Rule: V-C-C-V -> Split between consonants (e.g., mek-tep)
elif len(rem) >= 3 and not is_vowel(rem[0]) and not is_vowel(rem[1]) and is_vowel(rem[2]):
syllables.append(word[last_split:i+2])
last_split = i + 2
i += 1 # Skip first C
# Rule: V-C-C-C-V (VCCCV) -> Split after second consonant (e.g., gert-mek)
# This fixes "gertmek", "eytqan", "partlap", "dostlar"
elif len(rem) >= 4 and not is_vowel(rem[0]) and not is_vowel(rem[1]) and not is_vowel(rem[2]) and is_vowel(rem[3]):
syllables.append(word[last_split:i+3])
last_split = i + 3
i += 2 # Skip first two Cs
i += 1
# Add the remaining part of the word
if last_split < len(word):
syllables.append(word[last_split:])
return syllables
# Process word by word
words = text.split()
final_output = []
for w in words:
raw_syllables = split_word(w)
# Handle Hemze (ئ) re-merging
processed = []
skip_next = False
for j in range(len(raw_syllables)):
if skip_next:
skip_next = False
continue
s = raw_syllables[j]
# If the current syllable is JUST "ئ" or ends in "ئ", merge it with next
if s == "ئ" and j + 1 < len(raw_syllables):
processed.append("ئ" + raw_syllables[j+1])
skip_next = True
else:
processed.append(s)
final_output.extend(processed)
return final_output
# Verification
# test_words = ["گەرتمەك", "ئېيتقان", "دوستلار", "مەكتەپكە", ' كومپېيۇتېر تورى زادى قانداق نەرسىدۇ؟ ']
# for tw in test_words:
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
# test_words = ['تەكلىپتەۋسىيەلىرىگە', 'تەكلىپ-تەۋسىيەلىرىگە']
# for tw in test_words:
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
# تەكلىپتەۋسىيەلىرىگە -> ['تەك', 'لىپ', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
# تەكلىپ-تەۋسىيەلىرىگە -> ['تەك', 'لىپ-', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
text = " ھەر قانداق ئىشتا كىشىلەر بىلەن كېڭىشىش كېرەك. لېكىن كۆڭۈل تارتقان ئىشنى قىلىۋېرىش كېرەك."
text = "\" 15 يىل بۇرۇن مەن بىلەن سەي چۇڭشىن (ئالى بابانىڭ قۇرغۇچىسىدىن بىرى) ئامېرىكىغا بېرىپ 30 نەچچە مەبلەغ سالغۇچى شىركەتلەر بىلەن كۆرۈشكىنىمىزدە ئۇلار بىزنى رەت قىلىپ ئىشىك سىرتىدا قالدۇرغاندى. ھېچكىم بىزگە ۋە كەلگۈسىمىزگە ئىشەنمىگەن ئىدى، ھېچكىممۇ بىزنىڭ ھازىر 300 مىليارد سودىنى تاماملىيالايدىغانلىقىمىزنى ئويلاپ خىيالىغىمۇ كەلتۈرمىگەن ئىدى.\" ئالى بابانىنىڭ ئورگىنى تەرىپىدىن ئىشلەنگەن ئىگىلىك تىكلەش ھۆججەتلىك فىلىمى \" بۇ چۈش ئەمەس\" نىڭ باش قىسمىدا مۇنداق بىر داڭلىق سۆز چىقىدۇ:\" مەن تاغدىن ئۆتكەن ۋاقتىمدا تاغ ماڭا گەپ قىلمىدى، مەن دېڭىزنى كېچىپ ئۆتكىنىمدە دېڭىز ماڭا گەپ قىلمىدى.\" بۇ فىلىمنىڭ قوشۇمچە ئىسمى بولسا \" مايۈن ۋە ئۇنىڭ مەڭگۈلۈك ' ياش ئالى'سى\" بولۇپ، ھەرخىل خەتەر ۋە قىيىنلىقنى بىرمۇ بىر يەڭگەن مايۈن ۋە ئۇنىڭ ئالى بابا قوشۇنىدىكى ھەمكارلاشقۇچىلىرىنىڭ قەيسەر كەچمىشى جانلىق بايان قىلىنغان."
text = "يۆ جيەنتاۋ يېقىنقى بىر مەزگىلدە، خىزمەتداشلىرى بىلەن نۇرغۇن قىيىنچىلىققا ئۇچرىغان شوپۇرغا ياردەم قىلغانلىقىنى، شۇنداقلا يېمەكلىك ۋە سۇ يەتكۈزۈپ بەرگەنلىكىنى، ئەمما ئاياغ سوۋغا قىلىشى تۇنجى قېتىم ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى قېتىم شوپۇرغا ئاياغ سوۋغا قىلىشى ئىكەنلىكىنى ئېيتتى."
text = "يۆ جۇڭمىڭ مۇنداق دېدى: 2019- يىلى 12-ئايدا، مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 44- قېتىملىق كومىتېت باشلىقلىرى يىغىنى مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 2020-يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى پىرىنسىپ جەھەتتىن ماقۇللىدى، خىزمەت تەرتىپى بويىچە، 13-نۆۋەتلىك مەملىكەتلىك خەلق قۇرۇلتىيى 3-يىغىنىنىڭ روھى ۋە ۋەكىللەرنىڭ تەكلىپ-تەۋسىيەلىرىگە ئاساسەن پىلاننى تەڭشەش كېرەك. بۇ يىل 6-ئاينىڭ 1-كۈنى، 58-قېتىملىق كومىتېت باشلىقلىرى يىغىنى تەڭشەلگەندىن كېيىنكى يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى قاراپ چىقىپ ماقۇللىدى."
text = "مەن مەكتەپكە باردىم"
# text = normalize_extended_uyghur_characters(text=text)
# print(clean_uyghur(text=text))
def process_text(text: str) -> list[str]:
text = clean_input_text(text=text)
result = []
current_word_chars = []
for char in text:
if char in UYGHUR_LETTERS:
current_word_chars.append(char)
else:
# 遇到非字母字符,先处理缓存的单词
if current_word_chars:
word = ''.join(current_word_chars)
result.extend(uighur_syllabize(word))
current_word_chars.clear()
# 非字母字符直接加入
result.append(char)
# 处理末尾的单词
if current_word_chars:
word = ''.join(current_word_chars)
result.extend(uighur_syllabize(word))
return result

View File

@@ -0,0 +1,126 @@
UYGHUR_LETTERS = {
'ا', 'ە', 'ب', 'پ', 'ت', 'ج', 'چ', 'خ', 'د', 'ر', 'ز', 'ژ', 'س', 'ش', 'غ', 'ف',
'ق', 'ك', 'گ', 'ڭ', 'ل', 'م', 'ن', 'ھ', 'و', 'ۇ', 'ۆ', 'ۈ', 'ۋ', 'ې', 'ى', 'ي',
"ئ"
}
uyghur_symbols = {
"ھەرىپلەر": [
'ا', 'ە', 'ب', 'پ', 'ت', 'ج', 'چ', 'خ', 'د', 'ر', 'ز', 'ژ', 'س', 'ش', 'غ', 'ف',
'ق', 'ك', 'گ', 'ڭ', 'ل', 'م', 'ن', 'ھ', 'و', 'ۇ', 'ۆ', 'ۈ', 'ۋ', 'ې', 'ى', 'ي',
"ئ"
],
"تىنىش_بەلگىلىرى": [
"۔", "،", "؟", "!", "-", "«", "»", "؛", ":", "'", "\"", "]", "[", " ", "", ""
],
"سانلار": [
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
],
}
english_symbols = {
"characters": [
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'
],
"symbols": [
".", "!", "?", "-", "<", ">", "=", "+", "*", "/", "|", "\\", "~", "_", "^", "@", "{", "}", "[", "]", "(", ")", "#", "%", "$", "", "£", "¥", " "
]
}
# Export all symbols:
symbols_list = uyghur_symbols["ھەرىپلەر"] + uyghur_symbols["تىنىش_بەلگىلىرى"] + uyghur_symbols["سانلار"] + english_symbols["characters"] + english_symbols["symbols"]
symbols = {i: i for i in symbols_list}
def normalize_extended_uyghur_characters(text: str) -> str:
characters = [
#["standard", "individual", "beginning", "middle", "end"]
["ئ","","","",""],
["ا","","","",""],
["ە","","","",""],
["ب","","","",""],
["پ","","","",""],
["ت","","","",""],
["ج","","","",""],
["چ","","","",""],
["خ","","","",""],
["د","","","",""],
["ر","","","",""],
["ز","","","",""],
["ژ","","","",""],
["س","","","",""],
["ش","","","",""],
["غ","","","",""],
["ف","","","",""],
["ق","","","",""],
["ك","","","",""],
["گ","","","",""],
["ڭ","","","",""],
["ل","","","",""],
["م","","","",""],
["ن","","","",""],
["ھ","","","",""],
["و","","","",""],
["ۇ","","","",""],
["ۆ","","","",""],
["ۈ","","","",""],
["ۋ","","","",""],
["ې","","","",""],
["ى","","","",""],
["ي","","","",""],
["ۅ","","","",""],
["ۉ","","","",""],
["ح","","","",""],
["ع","","","",""]
]
replacement_table: dict[str, str] = {}
#Create a replacement table.
for char_map in characters:
for char in char_map[1:]:
replacement_table[char] = char_map[0]
replacement_table[''] = "لا" #add some additional exceptional symbols
text = text.replace('ئ', 'ئ')
clean = ""
#Replace the extended characters with their standard unicode ones.
for char in text:
if char in replacement_table:
char = replacement_table[char]
clean += char
return clean
def is_uyghur_char(char: str) -> bool:
"""判断是否是维吾尔语字母"""
return all(c in UYGHUR_LETTERS for c in char)
def is_uyghur_text(text: str) -> bool:
"""清洗后检查是否全为维吾尔语"""
return all(c in symbols for c in text)
def clean_unknown_symbols(text: str) -> str:
return ''.join(c for c in text if c in symbols)
import re
def collapse_spaces(text: str) -> str:
return re.sub(r'\s+', ' ', text)
def clean_english_text(text: str) -> str:
return re.sub(r'[a-zA-Z]', ' ', text)
def clean_chinese_text(text: str) -> str:
return re.sub(r'[\u4e00-\u9fff]', ' ', text)
def clean_http_links(text: str) -> str:
return re.sub(r'https?://[^\s]+', ' ', text)
def clean_input_text(text: str) -> str:
text = collapse_spaces(text)
text = clean_http_links(text)
text = normalize_extended_uyghur_characters(text)
text = clean_unknown_symbols(text)
return text

580
src/inference.ipynb Normal file
View File

@@ -0,0 +1,580 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f72992c5",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torchaudio\n",
"import librosa\n",
"import pyrubberband as pyrb\n",
"from torch import Tensor, no_grad, device\n",
"from torchaudio.transforms import FrequencyMasking, MelSpectrogram, AmplitudeToDB, Resample, TimeMasking, TimeStretch\n",
"from torch_audiomentations import Gain, PitchShift, LowPassFilter, HighPassFilter\n",
"from pathlib import Path\n",
"from librosa import effects\n",
"import torchaudio.functional as F\n",
"from IPython.display import Audio\n",
"import matplotlib.pyplot as plt\n",
"import random\n",
"from typing import TypedDict\n",
"\n",
"from tokenizer import ASRTokenizer\n",
"from model import ASRModel\n",
"\n",
"CONFIG = {\n",
" # 模型配置\n",
" 'input_dim': 256,\n",
" 'num_heads': 8,\n",
" 'ffn_dim': 2048,\n",
" 'num_layers': 8,\n",
" 'dropout': 0.1,\n",
"}\n",
"workspace_dir = Path.cwd().parent"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15a9a926",
"metadata": {},
"outputs": [],
"source": [
"class AdvancedAugment:\n",
" def __init__(self, noise_folder=None):\n",
" # 如果你有噪音文件夹,可以预加载噪音路径列表\n",
" # 场景建议1.汽车 2.餐厅 3.街道 4.办公室 5.下雨 6.风声 7.键盘声 8.婴儿哭声 9.音乐 10.走廊回声 11.白噪音 12.粉红噪音\n",
" self.noise_files = [] # 这里存放你的 .wav 噪音文件路径\n",
"\n",
" def _add_noise(self, waveform: Tensor) -> Tensor:\n",
" \"\"\"\n",
" 更丰富的加噪:随机选择 白噪音、粉红噪音 或 真实环境音\n",
" \"\"\"\n",
" # 1. 随机决定噪音类型\n",
" noise_type = random.choice(['white', 'pink', 'environmental'])\n",
" \n",
" # 2. 随机设定信噪比 (SNR) - 模拟各种清晰度\n",
" snr_db = random.uniform(5, 25) \n",
" \n",
" if noise_type == 'white':\n",
" noise = torch.randn_like(waveform)\n",
" elif noise_type == 'pink':\n",
" \n",
" noise = self._generate_pink_noise(waveform)\n",
" else:\n",
" # 模拟环境音 (如果你有噪音库)\n",
" # 这里演示如果没有噪音库,就用多种不同频率的随机噪音组合\n",
" noise = self._generate_simulated_env_noise(waveform)\n",
"\n",
" # 3. 计算能量并叠加\n",
" return self._mix_signal_noise(waveform, noise, snr_db)\n",
"\n",
" def _generate_pink_noise(self, waveform):\n",
" \"\"\"粉红噪音(比白噪音更像自然界的声音,低频能量更高)\"\"\"\n",
" # 简易实现:对白噪音做低通滤波\n",
" white = torch.randn_like(waveform)\n",
" return torch.cumsum(white, dim=-1) / 10.0 # 简单的积分近似\n",
"\n",
" def _generate_simulated_env_noise(self, waveform):\n",
" \"\"\"模拟环境噪音(通过叠加不同频率的波形)\"\"\"\n",
" # 模拟 12 种以上变化:随机叠加正弦波或窄带噪音\n",
" noise = torch.zeros_like(waveform)\n",
" for _ in range(random.randint(3, 8)):\n",
" freq = random.uniform(50, 4000)\n",
" t = torch.arange(waveform.shape[-1]).to(waveform.device)\n",
" noise += torch.sin(2 * 3.14159 * freq * t / 16000)\n",
" return noise + torch.randn_like(waveform) * 0.5\n",
"\n",
" def _mix_signal_noise(self, signal, noise, snr_db):\n",
" \"\"\"核心:根据 SNR 混合信号和噪音\"\"\"\n",
" s_p = signal.pow(2).mean()\n",
" n_p = noise.pow(2).mean()\n",
" \n",
" # 防止除以 0\n",
" if n_p == 0: return signal\n",
" \n",
" snr_linear = 10**(snr_db/10)\n",
" scale = torch.sqrt(s_p / (n_p * snr_linear))\n",
" \n",
" noisy = signal + scale * noise\n",
" return noisy / (noisy.abs().max() + 1e-7) # 归一化防止爆音\n",
"\n",
" def _simulate_reverb(self, waveform):\n",
" \"\"\"\n",
" 模拟走廊/房间回声 (Simple Reverb)\n",
" \"\"\"\n",
" # if random.random() < 0.3: # 30% 概率添加回声\n",
" # 模拟简单的延迟反馈(模拟大走廊)\n",
" delay_samples = random.randint(500, 2000) # 30ms - 125ms 延迟\n",
" decay = random.uniform(0.3, 0.6)\n",
" \n",
" reverb_signal = torch.zeros_like(waveform)\n",
" reverb_signal[:, delay_samples:] = waveform[:, :-delay_samples] * decay\n",
" return (waveform + reverb_signal) / (1 + decay)\n",
" return waveform\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0635d68",
"metadata": {},
"outputs": [],
"source": [
"class NoiseAugmentor:\n",
" def __init__(self, noise_root: Path, sample_rate=16000):\n",
" self.sample_rate = sample_rate\n",
" # 递归找到目录下所有的 wav 文件\n",
" self.noise_files = list(Path(noise_root).rglob(\"*.wav\"))\n",
" print(f\"成功加载 {len(self.noise_files)} 个噪音文件\")\n",
"\n",
" def apply_real_noise(self, waveform: Tensor):\n",
" # 1. 随机选一个噪音文件\n",
" noise_path = random.choice(self.noise_files)\n",
" noise_waveform, sr = torchaudio.load_with_torchcodec(noise_path)\n",
" \n",
" # 2. 统一采样率\n",
" if sr != self.sample_rate:\n",
" resampler = Resample(sr, self.sample_rate)\n",
" noise_waveform = resampler(noise_waveform)\n",
" \n",
" # 3. 截取或填充,使其长度与语音一致\n",
" sig_len = waveform.shape[1]\n",
" noise_len = noise_waveform.shape[1]\n",
" \n",
" if noise_len >= sig_len:\n",
" # 随机截取一段\n",
" start = random.randint(0, noise_len - sig_len)\n",
" noise_waveform = noise_waveform[:, start:start + sig_len]\n",
" else:\n",
" full_noise = torch.zeros_like(waveform)\n",
" start = random.randint(0, sig_len - noise_len)\n",
" full_noise[:, start : start + noise_len] = noise_waveform\n",
" noise_waveform = full_noise\n",
" # 如果噪音太短,循环填充\n",
" # repeats = (sig_len // noise_len) + 1\n",
" # noise_waveform = noise_waveform.repeat(1, repeats)[:, :sig_len]\n",
" \n",
" # 4. 设定随机信噪比 SNR (5dB 到 20dB)\n",
" snr_db = random.uniform(5, 20)\n",
" \n",
" # 5. 混合\n",
" return self._mix_at_snr(waveform, noise_waveform, snr_db)\n",
"\n",
" def _mix_at_snr(self, signal: Tensor, noise: Tensor, snr_db: float):\n",
" s_p = signal.pow(2).mean()\n",
" n_p = noise.pow(2).mean()\n",
" snr_linear = 10**(snr_db / 10)\n",
" scale = torch.sqrt(s_p / (n_p * snr_linear + 1e-8))\n",
" \n",
" noisy = signal + scale * noise\n",
" # 归一化,防止溢出\n",
" return noisy / (noisy.abs().max() + 1e-8)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e77899f1",
"metadata": {},
"outputs": [],
"source": [
"class ASRInference:\n",
" def __init__(self, model_path: Path, vocab_path: Path, noise_dir: Path, device: device, augment: bool = True, augment_prob: float = 0.5) -> None:\n",
" self.device = device\n",
" self.augment: bool = augment\n",
" self.augment_prob: float = augment_prob\n",
" self.noise_augmentor = NoiseAugmentor(noise_root=noise_dir)\n",
" self.tokenizer = ASRTokenizer(vocab_path=vocab_path)\n",
" self.model = ASRModel(vocab_size=self.tokenizer.vocab_size(), **CONFIG).to(device=device)\n",
" self.model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])\n",
" self.model.eval()\n",
"\n",
" print(f\"params params: {self.model.get_num_params():,}\",)\n",
"\n",
" self.sample_rate = 16000\n",
"\n",
" # self.gain_up = Gain(min_gain_in_db=5, max_gain_in_db=10, p=1.0, output_type='tensor')\n",
" # self.gain_down = Gain(min_gain_in_db=-20, max_gain_in_db=-10, p=1.0, output_type='tensor')\n",
" # self.pitch_up = PitchShift(min_transpose_semitones=3, max_transpose_semitones=5, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
" # self.pitch_down = PitchShift(min_transpose_semitones=-5, max_transpose_semitones=-3, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
" # self.lowpass = LowPassFilter(min_cutoff_freq=400, max_cutoff_freq=2400, p=1.0, output_type='tensor')\n",
" # self.highpass = HighPassFilter(min_cutoff_freq=1400, max_cutoff_freq=3400, p=1.0, output_type='tensor')\n",
"\n",
" self.gain_up = Gain(min_gain_in_db=4, max_gain_in_db=8, p=1.0, output_type='tensor')\n",
" self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')\n",
" self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
" self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n",
" self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')\n",
" self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')\n",
" \n",
" def _load_audio(self, audio_path: Path) -> Tensor:\n",
" waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)\n",
"\n",
" if sample_rate != self.sample_rate:\n",
" waveform = Resample(sample_rate, self.sample_rate)(waveform)\n",
"\n",
" if waveform.shape[0] > 1:\n",
" waveform = waveform.mean(dim=0, keepdim=True)\n",
"\n",
" waveform = waveform / (waveform.abs().max() + 1e-8)\n",
" return waveform\n",
" \n",
" def augment_waveform(self, waveform: Tensor) -> Tensor:\n",
" # if not self.augment or random.random() > self.augment_prob:\n",
" # return waveform\n",
" \n",
" # 1. voice Stretch/Compress \n",
" if random.random() < 0.5:\n",
" waveform = self._voice_stretch_or_compress(waveform=waveform)\n",
"\n",
" if random.random() < 0.4:\n",
" waveform = self.noise_augmentor.apply_real_noise(waveform)\n",
"\n",
" if random.random() < 0.3:\n",
" waveform = self._time_mask_waveform(waveform=waveform)\n",
" \n",
" # torch_audiomentations: [1, time] -> [1, 1, time]\n",
" if waveform.dim() == 2:\n",
" waveform_3d = waveform.unsqueeze(0)\n",
" # 随机选择一种频谱增强\n",
" choice = random.random()\n",
" if choice < 0.15:\n",
" # 增益变化(上或下)\n",
" if random.random() < 0.5:\n",
" waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)\n",
" else:\n",
" waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)\n",
" elif choice < 0.25:\n",
" # 音高变化(上或下)\n",
" if random.random() < 0.5:\n",
" waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)\n",
" else:\n",
" waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)\n",
" elif choice < 0.30:\n",
" # 低通滤波(声音发闷)\n",
" waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)\n",
" elif choice < 0.35:\n",
" # 高通滤波(电话效果)\n",
" waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)\n",
" \n",
" # [1, 1, time] -> [1, time]\n",
" waveform = waveform_3d.squeeze(0)\n",
" \n",
" # 防止多次 augment 后振幅溢出,最后归一化\n",
" max_amp = waveform.abs().max()\n",
" if max_amp > 1.0:\n",
" print(max_amp)\n",
" waveform = waveform / max_amp\n",
" \n",
" return waveform\n",
" \n",
" def _voice_stretch_or_compress1(self, waveform: Tensor) -> Tensor:\n",
" speed_factor = random.uniform(0.8, 1.4) # (Speed Change: 0.6x - 1.4x)\n",
" waveform_np = waveform.squeeze(0).cpu().numpy()\n",
" y_stretch = librosa.effects.time_stretch(waveform_np, rate=speed_factor)\n",
" waveform = torch.from_numpy(y_stretch).float().to(waveform.device).unsqueeze(0)\n",
" return waveform\n",
" \n",
" def _voice_stretch_or_compress(self, waveform: Tensor) -> Tensor:\n",
" speed_factor = random.uniform(0.8, 1.25) # (Speed Change: 0.6x - 1.2x)\n",
" spec = torch.stft(\n",
" waveform.squeeze(0),\n",
" n_fft=400,\n",
" hop_length=160,\n",
" window=torch.hann_window(400).to(waveform.device),\n",
" return_complex=True\n",
" )\n",
" \n",
" # 时间拉伸(不改变音高)\n",
" stretch = TimeStretch(hop_length=160, n_freq=spec.shape[-2], fixed_rate=1.4)\n",
" stretched_spec = stretch(spec)\n",
" \n",
" # 转回波形\n",
" waveform_stretched = torch.istft(\n",
" stretched_spec,\n",
" n_fft=400,\n",
" hop_length=160,\n",
" window=torch.hann_window(400).to(waveform.device)\n",
" ).unsqueeze(0)\n",
" \n",
" return waveform_stretched\n",
"\n",
" def _voice_stretch_or_compress2(self, waveform: Tensor) -> Tensor:\n",
" speed_factor = random.uniform(0.6, 1.4) # (Speed Change: 0.6x - 1.4x)\n",
" waveform_np = waveform.squeeze(0).cpu().numpy()\n",
" y_stretch = pyrb.time_stretch(waveform_np, self.sample_rate, speed_factor)\n",
" return torch.from_numpy(y_stretch).float().to(waveform.device).unsqueeze(0)\n",
" \n",
" def _time_mask_waveform(self, waveform: Tensor) -> Tensor:\n",
" audio_len = waveform.shape[1]\n",
" sr = self.sample_rate # 16000\n",
" \n",
" # 设置参数:单次遮盖最长 0.4 秒 (6400个点)\n",
" max_mask_time = 0.4\n",
" max_mask_samples = int(sr * max_mask_time)\n",
" \n",
" # 根据音频长度决定遮盖次数:\n",
" # 比如每 3 秒钟允许遮盖 1 次\n",
" num_masks = max(1, audio_len // (sr * 3)) \n",
" \n",
" for _ in range(num_masks):\n",
" # 每次随机遮盖 0.1s 到 0.4s\n",
" current_mask_len = random.randint(int(sr * 0.1), max_mask_samples)\n",
" \n",
" if audio_len > current_mask_len:\n",
" start_pos = random.randint(0, audio_len - current_mask_len)\n",
" \n",
" # 填充微小噪音(模拟环境底噪)\n",
" noise = torch.randn(1, current_mask_len).to(waveform.device) * 0.002\n",
" waveform[:, start_pos : start_pos + current_mask_len] = noise\n",
" \n",
" return waveform\n",
"\n",
" def transcribe(self, waveform: Tensor) -> str:\n",
" waveform = waveform.to(device=self.device)\n",
" waveform_length = torch.tensor([waveform.shape[1]], dtype=torch.long, device=self.device)\n",
"\n",
" with no_grad():\n",
" log_probs, _ = self.model(waveforms=waveform, waveform_lengths=waveform_length)\n",
" \n",
" text = self.tokenizer.ctc_greedy_decode(log_probs=log_probs[0])\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5af7c836",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda:1')\n",
"\n",
"# checkpoint = sorted(workspace_dir.glob('.checkpoints/checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))[-1]\n",
"checkpoint = workspace_dir / \".checkpoints/best_wer_model.pt\"\n",
"# checkpoint = workspace_dir / \".checkpoints/best_cer_model.pt\"\n",
"checkpoint = workspace_dir / \".checkpoints/prodect/checkpoint_epoch_24 copy.pt\"\n",
"print(f\"Load Checkpoint: {checkpoint}\")\n",
"inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/asr_vocab.json', noise_dir='/mnt/dataset/dataset/audio/noise', device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "967e23cb",
"metadata": {},
"outputs": [],
"source": [
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_42681424.mp3'\n",
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_26188445.mp3'\n",
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_43349581.mp3'\n",
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_42640878.mp3'\n",
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_26245615.mp3'\n",
"# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_40794549.mp3'\n",
"# audio_path = workspace_dir / 'data/ug/clips/common_voice_ug_26245614.mp3'\n",
"audio_path = '/mnt/train/audio/Dilnaz/only_vocals/3272800876_100006098_(vocals)_melband_roformer_big_beta5e.m4a'\n",
"# audio_path = workspace_dir / 'data/test/how_are_you.mp3'\n",
"# audio_path = workspace_dir / 'data/test/split_1.m4a'\n",
"# audio_path = workspace_dir / 'data/test/let_me_see_the_computer.mp3'\n",
"# audio_path = workspace_dir / 'data/test/introduce_myself.mp3'\n",
"# audio_path = workspace_dir / 'data/test/F001_001.wav'\n",
"# audio_path = workspace_dir / 'data/test/download.wav'\n",
"# audio_path = workspace_dir / 'data/test/test_voise_1.m4a'\n",
"orginal_waveform = inference._load_audio(audio_path=audio_path)\n",
"print('load audio:', orginal_waveform.shape)\n",
"# orginal_waveform = orginal_waveform[:, :]\n",
"\n",
"def simple_corridor_echo(waveform, sample_rate=16000, delay_ms=80, decay=0.3):\n",
" \"\"\"\n",
" 模拟一次明显的走廊回声:原始声音 + 一个衰减延迟的拷贝\n",
" delay_ms: 延迟毫秒数80ms 模拟中等走廊\n",
" decay: 回声的衰减系数0.2~0.4 比较自然\n",
" \"\"\"\n",
" delay_samples = int(sample_rate * delay_ms / 1000)\n",
" echo = torch.zeros_like(waveform)\n",
" echo[..., delay_samples:] = waveform[..., :-delay_samples] * decay\n",
" return waveform + echo\n",
"\n",
"print(\"Orginal audio\")\n",
"display(Audio(orginal_waveform, rate=inference.sample_rate))\n",
"\n",
"# augment_waveform = inference.augment_waveform(waveform=orginal_waveform)\n",
"# augment_waveform = simple_corridor_echo(waveform=orginal_waveform)\n",
"# augment_waveform = inference._voice_stretch_or_compress(waveform=orginal_waveform)\n",
"# noise_augmentor_waveform = inference.noise_augmentor.apply_real_noise(waveform=orginal_waveform)\n",
"# display(Audio(augment_waveform, rate=inference.sample_rate))\n",
"# text = inference.transcribe(waveform=orginal_waveform)\n",
"# print(f\"\\n{text}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16195280",
"metadata": {},
"outputs": [],
"source": [
"from silero_vad import load_silero_vad, read_audio, get_speech_timestamps\n",
"class TimestampsType(TypedDict):\n",
" start: int\n",
" end: int\n",
"\n",
"THRESHOLD_MS = 2000\n",
"SAMPLE_RATE = inference.sample_rate\n",
"\n",
"def merge_timestamps(timestamps: int, threshold_ms: int, sample_rate: int):\n",
" if not timestamps:\n",
" return []\n",
" threshold_samples = (threshold_ms / 1000) * sample_rate\n",
" merged = []\n",
" curr_start = timestamps[0]['start']\n",
" curr_end = timestamps[0]['end']\n",
" \n",
" for i in range(1, len(timestamps)):\n",
" # 如果当前积攒的长度不到 2 秒,就一直合并到当前的 end 上\n",
" if (curr_end - curr_start) < threshold_samples:\n",
" curr_end = timestamps[i]['end']\n",
" else:\n",
" merged.append({'start': curr_start, 'end': curr_end})\n",
" curr_start = timestamps[i]['start']\n",
" curr_end = timestamps[i]['end']\n",
" merged.append({'start': curr_start, 'end': curr_end})\n",
" return merged\n",
"\n",
"vad_model = load_silero_vad(onnx=False)\n",
"waveform_vad = inference._load_audio(audio_path=audio_path)\n",
"speech_timestamps: list[TimestampsType] = get_speech_timestamps(\n",
" waveform_vad, \n",
" vad_model, \n",
" sampling_rate=inference.sample_rate,\n",
" threshold=0.4, # 可以根据需要调整阈值\n",
" min_speech_duration_ms=100, # 最小语音持续时间,防止短噪音被误判\n",
" min_silence_duration_ms=200, # 最小静音间隔,用于分割语音块\n",
" speech_pad_ms=100, # 在语音块前后添加的填充时间\n",
")\n",
"print('speech_timestamps:', speech_timestamps)\n",
"print('speech_timestamps:', len(speech_timestamps))\n",
"final_timestamps = merge_timestamps(speech_timestamps, THRESHOLD_MS, SAMPLE_RATE)\n",
"print(f'VAD原始片段: {len(speech_timestamps)} -> 合并后片段: {len(final_timestamps)}')\n",
"for ts in final_timestamps:\n",
" ts: TimestampsType\n",
" start_sample = int(ts['start'])\n",
" end_sample = int(ts['end'])\n",
"\n",
" segment_waveform = waveform_vad[:, start_sample:end_sample]\n",
" display(Audio(segment_waveform, rate=inference.sample_rate))\n",
"\n",
" segment_text = inference.transcribe(waveform=segment_waveform)\n",
" print(segment_text, end=\" \", flush=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3671d334",
"metadata": {},
"outputs": [],
"source": [
"# mel_spec = inference._extract_features(waveform=orginal_waveform)\n",
"# ################################\n",
"# from PIL import Image, ImageFilter\n",
"# import numpy as np\n",
"# import math\n",
"\n",
"# img = (mel_spec.numpy())\n",
"# print(img.max(), img.min())\n",
"# img = img + np.abs(img.min())\n",
"# img = img / img.max()\n",
"# print(img.max(), img.min())\n",
"# img = Image.fromarray((img * 255).astype(np.uint8))\n",
"# shapened = img.filter(ImageFilter.UnsharpMask(radius=1, percent=150, threshold=3))\n",
"# img.save(\"original_hd.jpg\")\n",
"# # img.save(\"sharpen_hd.jpg\")\n",
"\n",
"# fig, axes = plt.subplots(2, 1, figsize=(30, 12))\n",
"\n",
"# axes[0].imshow(\n",
"# img,\n",
"# cmap=\"gray\",\n",
"# origin=\"lower\",\n",
"# vmin=0,\n",
"# vmax=255,\n",
"# aspect=\"auto\"\n",
"# )\n",
"# axes[0].set_title(\"Original\")\n",
"# axes[1].imshow(\n",
"# shapened,\n",
"# cmap=\"gray\",\n",
"# origin=\"lower\",\n",
"# vmin=0,\n",
"# vmax=255,\n",
"# aspect=\"auto\"\n",
"# )\n",
"# axes[1].set_title(\"Sharpened\")\n",
"# plt.tight_layout()\n",
"# plt.show()\n",
"\n",
"# ################################\n",
"# mel_spec = inference._augment_spec(mel_spec=mel_spec)\n",
"# print(mel_spec.shape)\n",
"# mel_spec = mel_spec.unsqueeze(0)\n",
"# print(mel_spec.shape)\n",
"# # 创建可视化\n",
"# fig, axes = plt.subplots(2, 1, figsize=(18, 12))\n",
"\n",
"# # 原始波形\n",
"# axes[0].plot(orginal_waveform[0].numpy())\n",
"# axes[0].set_title(\"Original waveform\")\n",
"# axes[0].set_xlabel(\"sample\")\n",
"# axes[0].set_ylabel(\"Amplitude\")\n",
"# axes[0].grid(True)\n",
"\n",
"# # Mel 频谱图\n",
"# im = axes[1].imshow(\n",
"# mel_spec[0].numpy(),\n",
"# cmap='gray',\n",
"# aspect='auto',\n",
"# origin='lower',\n",
"# # vmin=-150,\n",
"# # vmax=1,\n",
"# extent=[0, orginal_waveform.shape[-1] / inference.sample_rate, 0, 80]\n",
"# )\n",
"# axes[1].set_title(\"Log Mel Spectrogram\")\n",
"# axes[1].set_xlabel(\"Time (second)\")\n",
"# axes[1].set_ylabel(\"Mels\")\n",
"# # plt.colorbar(im, ax=axes[1], format='%+2.0f dB')\n",
"\n",
"# plt.tight_layout()\n",
"# plt.show()\n",
"\n",
"\n",
"# print(f\"Mel 频谱图形状: {mel_spec.shape}\") # (channel, n_mels, time_frames)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "study-asr (3.11.12)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

123
src/inference.py Normal file
View File

@@ -0,0 +1,123 @@
import torch
import torchaudio
from torch import Tensor, no_grad, device
from torchaudio.transforms import Resample
from pathlib import Path
from typing import TypedDict
from silero_vad import load_silero_vad, get_speech_timestamps
from tokenizer import ASRTokenizer
from model import ASRModel
CONFIG = {
# 模型配置
'input_dim': 256,
'num_heads': 8,
'ffn_dim': 2048,
'num_layers': 8,
'dropout': 0.1,
}
class TimestampsType(TypedDict):
start: int
end: int
class ASRInference:
def __init__(self, model_path: Path, vocab_path: Path, device: device, sample_rate: int = 16000) -> None:
self.device = device
self.sample_rate = sample_rate
self.tokenizer = ASRTokenizer(vocab_path=vocab_path)
self.model = ASRModel(vocab_size=self.tokenizer.vocab_size(), **CONFIG).to(device=device)
self.model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
self.model.eval()
print(f"params params: {self.model.get_num_params():,}",)
def _load_audio(self, audio_path: Path) -> Tensor:
waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)
if sample_rate != self.sample_rate:
waveform = Resample(sample_rate, self.sample_rate)(waveform)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform / (waveform.abs().max() + 1e-8)
return waveform
def transcribe(self, waveform: Tensor) -> str:
waveform = waveform.to(device=self.device) # [1, time]
waveform_length = torch.tensor([waveform.shape[1]], dtype=torch.long, device=self.device)
with no_grad():
log_probs, _ = self.model(waveforms=waveform, waveform_lengths=waveform_length) # [1, T, vocab]
text = self.tokenizer.ctc_greedy_decode(log_probs=log_probs[0])
return text
def transcribe_batch(self, audio_paths: list[Path]) -> list[str]:
results = []
for audio_path in audio_paths:
text = self.transcribe(audio_path=audio_path)
results.append(text)
return results
def create(self, audio_path: Path, threshold_ms: int = 20000):
def merge_timestamps(timestamps: int, threshold_ms: int, sample_rate: int):
if not timestamps:
return []
threshold_samples = (threshold_ms / 1000) * sample_rate
merged = []
curr_start = timestamps[0]['start']
curr_end = timestamps[0]['end']
for i in range(1, len(timestamps)):
# 如果当前积攒的长度不到 2 秒,就一直合并到当前的 end 上
if (curr_end - curr_start) < threshold_samples:
curr_end = timestamps[i]['end']
else:
merged.append({'start': curr_start, 'end': curr_end})
curr_start = timestamps[i]['start']
curr_end = timestamps[i]['end']
merged.append({'start': curr_start, 'end': curr_end})
return merged
vad_model = load_silero_vad(onnx=False)
waveform = self._load_audio(audio_path=audio_path)
speech_timestamps: list[TimestampsType] = get_speech_timestamps(
waveform,
vad_model,
sampling_rate=self.sample_rate,
threshold=0.5, # 可以根据需要调整阈值
min_speech_duration_ms=100, # 最小语音持续时间,防止短噪音被误判
min_silence_duration_ms=200, # 最小静音间隔,用于分割语音块
speech_pad_ms=100, # 在语音块前后添加的填充时间
)
final_timestamps = merge_timestamps(timestamps=speech_timestamps, threshold_ms=threshold_ms, sample_rate=self.sample_rate)
for ts in final_timestamps:
ts: TimestampsType
start_sample = int(ts['start'])
end_sample = int(ts['end'])
segment_waveform = waveform[:, start_sample:end_sample]
segment_text = self.transcribe(waveform=segment_waveform)
print(segment_text, end=" ", flush=True)
print('\n')
def main():
workspace_dir = Path(__file__).parent.parent
device = torch.device('cuda:1')
checkpoint = workspace_dir / ".checkpoints/best_wer_model.pt"
inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/asr_vocab.json' , device=device)
audio_path = workspace_dir / 'data/test/F001_001.wav'
waveform = inference._load_audio(audio_path=audio_path)
text = inference.transcribe(waveform=waveform)
print("transcribe:", text)
inference.create(audio_path=audio_path)
if __name__ == "__main__":
main()

84
src/model.py Normal file
View File

@@ -0,0 +1,84 @@
from torch.nn import Module, Linear, functional, GELU, Conv1d, GroupNorm, ModuleList
from torch import Tensor
from torchaudio.models import Conformer
class WaveformFilter(Module):
def __init__(self):
super().__init__()
self.filters = ModuleList([
Conv1d(in_channels=1, out_channels=512, kernel_size=10, stride=5),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
Conv1d(in_channels=512, out_channels=512, kernel_size=3, stride=2),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
Conv1d(in_channels=512, out_channels=512, kernel_size=2, stride=2),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
Conv1d(in_channels=512, out_channels=512, kernel_size=2, stride=2),
GroupNorm(num_channels=512, num_groups=32),
GELU(),
])
def compute_lengths(self, waveform_lengths: Tensor) -> Tensor:
"""Accurately compute output lengths after all convolutions."""
lengths = waveform_lengths
for module in self.filters:
if isinstance(module, Conv1d):
# Conv1d with padding=0, dilation=1:
# output = floor((input - kernel_size) / stride) + 1
lengths = (lengths - module.kernel_size[0]) // module.stride[0] + 1
return lengths
def forward(self, x: Tensor) -> Tensor:
for filter in self.filters:
x = filter(x)
return x
class ASRModel(Module):
def __init__(self, vocab_size: int, input_dim: int = 256, num_heads: int = 8, ffn_dim: int = 2048, num_layers: int = 6, dropout: float = 0.1) -> None:
super().__init__()
self.feature_extractor = WaveformFilter()
self.proj = Linear(in_features=512, out_features=input_dim, bias=False)
self.encoder = Conformer(input_dim=input_dim, num_heads=num_heads, ffn_dim=ffn_dim, num_layers=num_layers, depthwise_conv_kernel_size=31, dropout=dropout)
self.ctc_head = Linear(in_features=input_dim, out_features=vocab_size, bias=False)
def forward(self, waveforms: Tensor, waveform_lengths: Tensor) -> tuple[Tensor, Tensor]:
assert len(waveform_lengths.shape) == 1, "The waveform_lengths tensor must be shape of [B] tensor."
assert len(waveforms.shape) == 2, "The waveform tensor must be [B, time] tensor"
# waveforms: [B, time]
x: Tensor = waveforms.unsqueeze(1) # [B, 1, time]
#x: [B, 1, time]
x: Tensor = self.feature_extractor(x)
#x: [B, 512, time]
x = x.permute(0, 2, 1)
#x: [B, time, 512]
x = self.proj(x)
#x: [B, time, 256]
# lengths = torch.tensor([time] * batch, dtype=torch.long, device=x.device)
lengths = self.feature_extractor.compute_lengths(waveform_lengths)
x, lengths = self.encoder(x, lengths)
logits = self.ctc_head(x) # [B, time, vocab_size]
log_probs = functional.log_softmax(logits, dim=-1)
return log_probs, lengths
def get_num_params(self) -> int:
return sum(p.numel() for p in self.parameters())

107
src/tokenizer.py Normal file
View File

@@ -0,0 +1,107 @@
import json
from pathlib import Path
import re
import torch
from torch import Tensor
from handle.text_handle import process_text
from handle.text_normalizer import UYGHUR_LETTERS
class ASRTokenizer:
def __init__(self, vocab_path: Path) -> None:
self.vocabs: list[str] = []
self.vocabs.extend([
"<BLANK>",
"<UNK>",
"<SPACE>"
])
for i in range(len(self.vocabs), 64):
self.vocabs.append(f"REVERSED_{i}")
with open(vocab_path, 'r', encoding='utf-8') as file:
data: list[str] = json.load(file)
for char in data:
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
self.vocabs.append(char)
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
def vocab_size(self):
return len(self.vocabs)
def _split_long_syllable(self, syll: str) -> list[str]:
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
pieces = []
start = 0
n = len(syll)
while start < n:
# 从最大可能长度开始尝试匹配
for length in range(min(self._max_token_len, n - start), 0, -1):
sub = syll[start:start + length]
if sub in self.vocab_to_id:
pieces.append(sub)
start += length
break
else:
# 理论上不会发生,因为单字母一定在词表
pieces.append("<UNK>")
start += 1
return pieces
def encode(self, text: str) -> list[int]:
tokens = process_text(text=text)
result: list[int] = []
for token in tokens:
if token == ' ':
result.append(self.vocab_to_id['<SPACE>'])
else:
token_id = self.vocab_to_id.get(token)
if token_id is not None:
result.append(token_id)
elif token in UYGHUR_LETTERS:
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
else:
sub_pieces = self._split_long_syllable(token)
for piece in sub_pieces:
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
return result
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
if isinstance(ids, Tensor):
ids = ids.tolist()
result = []
prev_id = None
for token_id in ids:
# 跳过blank token
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
prev_id = None
continue
# CTC解码跳过连续重复的字符
if remove_repeat and token_id == prev_id:
continue
char = self.id_to_vocab.get(token_id, "<UNK>")
if char == '<SPACE>':
char = ' '
result.append(char)
prev_id = token_id
return ''.join(result)
def get_special_token_id(self, token: str) -> int:
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
return self.decode(torch.argmax(log_probs, dim=-1))

367
src/train.py Normal file
View File

@@ -0,0 +1,367 @@
import math
import torch
from torch import Tensor, device, no_grad, cuda
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import OneCycleLR
from torch.nn import CTCLoss, utils
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
from torch.amp.grad_scaler import GradScaler
from tokenizer import ASRTokenizer
from dataset import Batch, create_dataloader
from model import ASRModel
# ============ 全局配置 ============
CONFIG = {
# 数据配置
'batch_size': 32,
# 训练配置
'num_epochs': 50,
'learning_rate': 2e-4,
'weight_decay': 1e-4,
'grad_clip_norm': 1.0,
# 模型配置
'input_dim': 256,
'num_heads': 8,
'ffn_dim': 2048,
'num_layers': 8,
'dropout': 0.1,
# 保存和评估
'early_stopping_patience': 12,
}
def calculate_cer(pred: str, target: str) -> float:
""" 计算字符错误率 (Character Error Rate) 使用编辑距离算法"""
if len(target) == 0:
return 0.0 if len(pred) == 0 else 1.0
d = [[0] * (len(target) + 1) for _ in range(len(pred) + 1)]
for i in range(len(pred) + 1):
d[i][0] = i
for j in range(len(target) + 1):
d[0][j] = j
for i in range(1, len(pred) + 1):
for j in range(1, len(target) + 1):
if pred[i-1] == target[j-1]:
d[i][j] = d[i-1][j-1]
else:
d[i][j] = min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + 1
return d[len(pred)][len(target)] / len(target)
def calculate_wer(pred: str, target: str) -> float:
""" 计算词错误率 (Word Error Rate) """
pred_words = pred.split()
target_words = target.split()
if len(target_words) == 0:
return 0.0 if len(pred_words) == 0 else 1.0
d = [[0] * (len(target_words) + 1) for _ in range(len(pred_words) + 1)]
for i in range(len(pred_words) + 1):
d[i][0] = i
for j in range(len(target_words) + 1):
d[0][j] = j
for i in range(1, len(pred_words) + 1):
for j in range(1, len(target_words) + 1):
if pred_words[i-1] == target_words[j-1]:
d[i][j] = d[i-1][j-1]
else:
d[i][j] = min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + 1
return d[len(pred_words)][len(target_words)] / len(target_words)
def train_one_epoch(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, optimizer: Optimizer, scheduler: OneCycleLR, device: device, epoch: int, writer: SummaryWriter, global_step: int, scaler: GradScaler) -> tuple[float, int]:
model.train()
num_batches = 0
total_train_loss = 0
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")
for batch_index, batch in enumerate(progress_bar):
batch: Batch
waveforms = batch['waveforms'].to(device)
targets = batch['targets'].to(device)
waveform_lengths = batch['waveform_lengths'].to(device)
target_lengths = batch['target_lengths'].to(device)
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
log_probs, lengths = model(waveforms=waveforms, waveform_lengths=waveform_lengths)
log_probs: Tensor
log_probs_ctc = log_probs.permute(1, 0, 2)
loss: Tensor = criterion(log_probs=log_probs_ctc, targets=targets, input_lengths=lengths, target_lengths=target_lengths)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
utils.clip_grad_norm_(model.parameters(), max_norm=CONFIG['grad_clip_norm'])
scaler.step(optimizer)
scaler.update()
scheduler.step()
total_train_loss += loss.item()
num_batches += 1
global_step += 1
writer.add_scalar('Train/Loss', loss.item(), global_step)
writer.add_scalar('Train/LearningRate', optimizer.param_groups[0]['lr'], global_step)
progress_bar.set_postfix({'epoch': f"{epoch}/{CONFIG['num_epochs']}",'loss': f'{loss.item():.4f}', 'step': global_step})
train_avg_loss = total_train_loss / num_batches
return train_avg_loss, global_step
def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device: device, tokenizer: ASRTokenizer, writer: SummaryWriter, global_step: int) -> tuple[float, float, float]:
model.eval()
total_loss = 0
total_cer = 0
total_wer = 0
num_samples = 0
num_batches = 0
examples = []
with no_grad():
progress_bar = tqdm(dataloader, desc="Validate", leave=False)
for batch in progress_bar:
batch: Batch
waveforms = batch['waveforms'].to(device)
targets = batch['targets'].to(device)
waveform_lengths = batch['waveform_lengths'].to(device)
target_lengths = batch['target_lengths'].to(device)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
log_probs, lengths = model(waveforms=waveforms, waveform_lengths=waveform_lengths)
log_probs: Tensor
log_probs_ctc = log_probs.permute(1, 0, 2)
loss: Tensor = criterion(log_probs=log_probs_ctc, targets=targets, input_lengths=lengths, target_lengths=target_lengths)
total_loss += loss.item()
for i in range(log_probs.shape[0]):
pred_text = tokenizer.ctc_greedy_decode(log_probs=log_probs[i])
true_text = batch['target_texts'][i]
cer = calculate_cer(pred=pred_text, target=true_text)
wer = calculate_wer(pred=pred_text, target=true_text)
total_cer += cer
total_wer += wer
num_samples += 1
if len(examples) < 3:
examples.append((true_text, pred_text, cer, wer))
num_batches += 1
avg_loss = total_loss / num_batches
avg_cer = total_cer / num_samples
avg_wer = total_wer / num_samples
writer.add_scalar('Val/Loss', avg_loss, global_step)
writer.add_scalar('Val/CER', avg_cer, global_step)
writer.add_scalar('Val/WER', avg_wer, global_step)
for index, (true_text, pred_text, cer, wer) in enumerate(examples):
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
model.train()
return avg_loss, avg_cer, avg_wer
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'global_step': global_step,
'train_loss': train_loss,
'val_loss': val_loss,
'epoch': epoch,
'cer': cer,
'wer': wer,
}
torch.save(checkpoint, save_path)
def find_latest_checkpoint(checkpoint_dir: Path) -> Path | None:
checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))
return checkpoints[-1] if checkpoints else None
def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, scheduler: OneCycleLR):
checkpoint = torch.load(file_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint['epoch']
global_step = checkpoint['global_step']
best_cer = checkpoint['cer']
best_wer = checkpoint['wer']
return epoch, global_step, best_cer, best_wer
def main():
workspace_dir = Path(__file__).parent.parent
device = torch.device('cuda:0')
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
final_prob = 0.5
warmup_epochs = 8
current_prob = 0.0
# ============ 创建数据加载器 ============
train_loader = create_dataloader(
tsv_path=workspace_dir / '.data/ug/train_new.tsv',
audio_dir=workspace_dir / '.data/ug/clips',
noise_dir='/mnt/dataset/dataset/audio/noise',
tokenizer=tokenizer,
batch_size=CONFIG['batch_size'],
shuffle=True,
augment=True,
augment_prob=current_prob
)
val_loader = create_dataloader(
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
audio_dir=workspace_dir / '.data/ug/clips',
noise_dir='/mnt/dataset/dataset/audio/noise',
tokenizer=tokenizer,
batch_size=CONFIG['batch_size'],
shuffle=False,
augment=False
)
# ============ 初始化模型 ============
model = ASRModel(
vocab_size=tokenizer.vocab_size(),
input_dim=CONFIG['input_dim'],
num_heads=CONFIG['num_heads'],
ffn_dim=CONFIG['ffn_dim'],
num_layers=CONFIG['num_layers'],
dropout=CONFIG['dropout'],
).to(device)
print(f"🤖 模型参数量: {model.get_num_params() / 1e6:.2f}M")
criterion = CTCLoss(blank=tokenizer.get_special_token_id('<BLANK>'), zero_infinity=True)
optimizer = AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'], foreach=True)
# 使用 OneCycleLR自动处理 warmup 和衰减
scheduler = OneCycleLR(
optimizer,
max_lr=CONFIG['learning_rate'],
epochs=CONFIG['num_epochs'],
steps_per_epoch=len(train_loader),
pct_start=0.15,
anneal_strategy='cos',
div_factor=10.0, # 初始 lr = max_lr / 10
final_div_factor=1e4, # 最终 lr = max_lr / 10000
)
scaler = GradScaler()
# ============ TensorBoard ============
log_dir = workspace_dir / 'runs' / datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_dir)
config_text = '\n'.join([f'{k}: {v}' for k, v in CONFIG.items()])
writer.add_text('Config', config_text, 0)
print(f"📊 TensorBoard 日志: {log_dir}")
checkpoint_dir = workspace_dir / '.checkpoints'
checkpoint_dir.mkdir(exist_ok=True)
# ============ 训练循环 ============
best_cer = float('inf')
best_wer = float('inf')
patience_counter = 0
global_step = 0
start_epoch = 0
latest = find_latest_checkpoint(checkpoint_dir)
if latest:
start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler)
start_epoch += 1 # 从下一个 epoch 开始
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
else:
start_epoch = 0
print("🆕 从头开始训练\n")
for epoch in range(start_epoch, CONFIG['num_epochs']):
current_prob = final_prob * (1 - math.cos(math.pi * epoch / warmup_epochs)) / 2
train_loss, global_step = train_one_epoch(
model=model,
dataloader=train_loader,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
device=device,
epoch=epoch,
writer=writer,
global_step=global_step,
scaler=scaler,
)
val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}")
# 保存常规 checkpoint
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path)
# 分别保存最佳 CER 和 WER 模型
improved = False
if val_cer < best_cer:
best_cer = val_cer
best_cer_path = checkpoint_dir / 'best_cer_model.pt'
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path)
improved = True
if val_wer < best_wer:
best_wer = val_wer
best_wer_path = checkpoint_dir / 'best_wer_model.pt'
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path)
improved = True
# Early Stopping 逻辑
if improved:
patience_counter = 0
else:
patience_counter += 1
print(f"⚠️ 验证指标未改善patience: {patience_counter}/{CONFIG['early_stopping_patience']}")
# 删除旧的 checkpoint保留最近3个
old_checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))
for old in old_checkpoints[:-3]:
old.unlink()
cuda.empty_cache()
writer.add_scalar('Val/EpochLoss', val_loss, epoch)
writer.add_scalar('Train/EpochLoss', train_loss, epoch)
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
# Early Stopping 检查
if patience_counter >= CONFIG['early_stopping_patience']:
print(f"\n🛑 Early Stopping: 验证指标连续 {CONFIG['early_stopping_patience']} 次没有改善")
print(f"🏆 最佳 CER: {best_cer:.4f}")
print(f"🏆 最佳 WER: {best_wer:.4f}")
break
writer.close()
print(f"\n{'='*60}")
print(f"🎉 训练完成!")
print(f"🏆 最佳 CER: {best_cer:.4f}")
print(f"🏆 最佳 WER: {best_wer:.4f}")
print(f"📊 TensorBoard: tensorboard --logdir={workspace_dir / 'runs'}")
print(f"{'='*60}\n")
if __name__ == "__main__":
main()

2168
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff