diff --git a/src/dataset.py b/src/dataset.py index 68c9f81..c88c3dc 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -4,7 +4,7 @@ import torch from torch import Tensor from torch.utils.data import Dataset, DataLoader import torchaudio -from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter +from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter, PolarityInversion from torchaudio.transforms import Resample, TimeStretch from pathlib import Path import pandas as pd @@ -94,6 +94,7 @@ class CommonVoiceDataset(Dataset[BatchItem]): tsv_path: Path, audio_dir: Path, noise_dir: Path, + corridor_noise_dir: Path, tokenizer: ASRTokenizer, sample_rate: int = 16000, max_audio_len: int = 480000, # 30秒 @ 16kHz @@ -123,9 +124,10 @@ class CommonVoiceDataset(Dataset[BatchItem]): self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor') self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor') self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor') - self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor') - self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor') - self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type="tensor") + self.lowpass = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=2000, p=1.0, output_type='tensor') + self.highpass = HighPassFilter(min_cutoff_freq=1000, max_cutoff_freq=2000, p=1.0, output_type='tensor') + self.apply_ir = ApplyImpulseResponse(ir_paths=corridor_noise_dir, convolve_mode='same', p=1, output_type="tensor") + self.polarity_inversion = PolarityInversion(p=1.0, output_type="tensor") def __len__(self): return len(self.data) @@ -157,7 +159,7 @@ class CommonVoiceDataset(Dataset[BatchItem]): if random.random() < 0.6: waveform = self._stretch_or_compress(waveform=waveform) - if random.random() < 0.4: + if random.random() < 0.7: waveform = self.noise_augmentor.apply_real_noise(waveform) if random.random() < 0.3: @@ -166,29 +168,30 @@ class CommonVoiceDataset(Dataset[BatchItem]): # torch_audiomentations: [1, time] -> [1, 1, time] if waveform.dim() == 2: waveform_3d = waveform.unsqueeze(0) - # 随机选择一种频谱增强 + # 随机选择一种物理特性增强 (互斥区) choice = random.random() - if choice < 0.15: - # 增益变化(上或下) + if choice < 0.25: # [0.00 - 0.25] 25% 概率:增益 if random.random() < 0.5: waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate) else: waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate) - elif choice < 0.25: - # 音高变化(上或下) + elif choice < 0.50: # [0.25 - 0.50] 25% 概率:音高 if random.random() < 0.5: waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate) else: waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate) - elif choice < 0.30: - # 低通滤波(声音发闷) + elif choice < 0.70: # [0.50 - 0.70] 20% 概率:低通 waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate) - # elif choice < 0.32: - # # 低通滤波(声音发闷) - # waveform_3d = self.apply_ir(waveform_3d, sample_rate=self.sample_rate) - elif choice < 0.35: - # 高通滤波(电话效果) + elif choice < 0.85: # [0.70 - 0.85] 15% 概率:高通 waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate) + elif choice < 0.95: # [0.85 - 0.95] 10% 概率:走廊混响 (IR) + # 使用你测试过最好的 0.8/0.2 比例 + dry = waveform_3d.clone() + wet = self.apply_ir(waveform_3d, sample_rate=self.sample_rate) + waveform_3d = 0.8 * dry + 0.2 * wet + else: # [0.95 - 1.00] 5% 概率:极性翻转 + waveform_3d = self.polarity_inversion(waveform_3d, sample_rate=self.sample_rate) + # [1, 1, time] -> [1, time] waveform = waveform_3d.squeeze(0) @@ -201,7 +204,7 @@ class CommonVoiceDataset(Dataset[BatchItem]): return waveform def _stretch_or_compress(self, waveform: Tensor) -> Tensor: - speed_factor = random.uniform(0.85, 1.4) # (Speed Change: 0.85x - 1.4x) + speed_factor = random.uniform(0.80, 1.6) # (Speed Change: 0.85x - 1.4x) spec = torch.stft( waveform.squeeze(0), n_fft=400, @@ -246,7 +249,6 @@ class CommonVoiceDataset(Dataset[BatchItem]): def __getitem__(self, index) -> BatchItem: row: TsvFormat = self.data.iloc[index] audio_path: Path = self.audio_dir / row['path'] - # text: str = unicodedata.normalize('NFC', row['sentence'].strip()) text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip())) waveform = self._load_audio(audio_path=audio_path) @@ -297,8 +299,8 @@ def collate_fn(items: List[BatchItem]) -> Batch: ) -def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader: - dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob) +def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, corridor_noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader: + dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, corridor_noise_dir=corridor_noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob) return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True) @@ -331,101 +333,4 @@ if __name__ == "__main__": print(f"Target lengths: {batch['target_lengths']}") print(f"Target texts: {batch['target_texts']}") print(f"Audio paths: {batch['audio_paths']}") - break - - - - - -import torch -import torchaudio -from silero_vad import load_silero_vad, read_audio, get_speech_timestamps - -# --- 1. 加载 VAD 模型 --- -# 推荐使用 GPU (cuda) 如果可用,否则使用 CPU (cpu) -vad_model = load_silero_vad(source="local", force_onnx=False) - -def process_audio_with_vad_and_asr(audio_path, inference_module): - """ - 使用 Silero VAD 分割音频,然后对每个语音片段进行 ASR 转录。 - - Args: - audio_path (str): 输入音频文件的路径。 - inference_module: 您封装了 transcribe 方法的模块或对象。 - 需要确保其 _load_audio 和 transcribe 方法可用。 - """ - print(f"正在加载音频: {audio_path}") - - # --- 2. 加载音频用于 VAD 检测 --- - # Silero VAD 推荐使用 16kHz 采样率 - waveform_vad, sample_rate_vad = torchaudio.load(audio_path) - if sample_rate_vad != 16000: - # 如果采样率不是16kHz,需要重采样以供VAD使用 - resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_vad, new_freq=16000) - waveform_vad = resampler(waveform_vad) - sample_rate_vad = 16000 - - # --- 3. 获取语音时间段 --- - # get_speech_timestamps 返回一个列表,每个元素是 {'start': start_sample, 'end': end_sample} 的字典 - # 采样率是16000,所以时间戳单位是 1/16000 秒 - speech_timestamps = get_speech_timestamps( - waveform_vad, - vad_model, - sampling_rate=sample_rate_vad, - threshold=0.5, # 可以根据需要调整阈值 - min_speech_duration_ms=250, # 最小语音持续时间,防止短噪音被误判 - max_speech_duration_s=float('inf'), # 最大语音持续时间,float('inf') 表示不限制 - min_silence_duration_ms=100, # 最小静音间隔,用于分割语音块 - window_size_samples=1536, # VAD窗口大小 - speech_pad_ms=30 # 在语音块前后添加的填充时间 - ) - - print(f"检测到 {len(speech_timestamps)} 个语音片段。") - - full_text = "" - for i, ts in enumerate(speech_timestamps): - start_sample = int(ts['start']) - end_sample = int(ts['end']) - - # 计算时间戳(秒) - start_time = start_sample / sample_rate_vad - end_time = end_sample / sample_rate_vad - - print(f"\n处理第 {i+1} 个片段: 时间范围 [{start_time:.2f}s - {end_time:.2f}s]") - - # --- 4. 从原始音频中提取此片段 --- - # 注意:这里假设您的 transcribe 函数可以接受 waveform 张量。 - # 我们需要从原始可能不同采样率的音频中提取片段,或者用VAD处理过的waveform。 - # 为了匹配您原来的 _load_audio 方式,我们用 torchaudio 再次精确加载片段。 - - # 计算原始音频中的样本索引(如果原始音频采样率与VAD不同) - original_waveform, original_sr = torchaudio.load(audio_path) - if sample_rate_vad != original_sr: - # 如果VAD和原始音频采样率不同,需要重新映射索引 - start_idx_orig = int(start_sample * (original_sr / sample_rate_vad)) - end_idx_orig = int(end_sample * (original_sr / sample_rate_vad)) - else: - start_idx_orig = start_sample - end_idx_orig = end_sample - - segment_waveform = original_waveform[:, start_idx_orig:end_idx_orig] - - # --- 5. 对该片段进行 ASR 转录 --- - # 这里调用您原有的 transcribe 方法 - try: - # 假设 transcribe 方法接受一个 waveform tensor - segment_text = inference_module.transcribe(waveform=segment_waveform) - - print(f" -> 转录结果: {segment_text}") - full_text += f"[{start_time:.2f}-{end_time:.2f}s] {segment_text}\n" - except Exception as e: - print(f" -> 转录第 {i+1} 个片段时出错: {e}") - - print("\n--- 完整转录结果 ---") - print(full_text) - - -# --- 使用示例 --- -# 假设您有一个名为 'inference' 的模块对象,它有 _load_audio 和 transcribe 方法 -# audio_path = "your_audio_file.wav" -# process_audio_with_vad_and_asr(audio_path, inference) \ No newline at end of file + break \ No newline at end of file diff --git a/src/handle/audio_agement.ipynb b/src/handle/audio_agement.ipynb index 51aa62a..ac95b1a 100644 --- a/src/handle/audio_agement.ipynb +++ b/src/handle/audio_agement.ipynb @@ -2,24 +2,25 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "f6132018", "metadata": {}, "outputs": [], "source": [ "import torch\n", + "import random\n", "import torchaudio\n", "from torch import Tensor\n", "from pathlib import Path\n", "from IPython.display import Audio\n", - "from torch_audiomentations import Gain, PitchShift, LowPassFilter, HighPassFilter, ApplyImpulseResponse\n", + "from torch_audiomentations import Gain, PitchShift, LowPassFilter, HighPassFilter, ApplyImpulseResponse, Compose, BandPassFilter, PolarityInversion\n", "workspace_dir = Path.cwd().parent.parent\n", "sample_rate = 16000" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 13, "id": "a238e4ab", "metadata": {}, "outputs": [ @@ -54,7 +55,7 @@ "text/html": [ "\n", " \n", " " @@ -70,7 +71,7 @@ "source": [ "transforms = {\n", " \"gain_up\": Gain(min_gain_in_db=4, max_gain_in_db=8, p=1.0, output_type='tensor'),\n", - " \"gain_down\": Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0),\n", + " \"gain_down\": Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor'),\n", " \"pitch_up\": PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=sample_rate, output_type='tensor'),\n", " \"pitch_down\": PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=sample_rate, output_type='tensor'),\n", " \"lowpass\": LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor'),\n", @@ -88,10 +89,13 @@ "print('orginal')\n", "display(Audio(waveform.squeeze(0), rate=sample_rate))\n", "\n", - "test_audio = workspace_dir / 'data/test/637ae13bc4a7e3.wav'\n", - "# test_transform = ApplyImpulseResponse(ir_paths=test_audio,convolve_mode='same', p=1, output_type=\"tensor\")\n", - "test_transform = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=sample_rate, output_type='tensor')\n", - "out: Tensor = test_transform(samples=waveform, sample_rate=sample_rate)\n", + "noise_audio = workspace_dir / 'data/test/output.wav'\n", + "noise_audio = workspace_dir / 'data/corridor'\n", + "test_transform = ApplyImpulseResponse(ir_paths=noise_audio,convolve_mode='same', p=1, output_type=\"tensor\")\n", + "dry_waveform = waveform.clone()\n", + "wet_waveform = test_transform(waveform, sample_rate=sample_rate)\n", + "out: Tensor = 0.8 * dry_waveform + 0.2 * wet_waveform\n", + "\n", "max_amp = out.abs().max()\n", "if max_amp > 1.0:\n", " print(max_amp)\n", @@ -110,28 +114,25 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "6f47127f", + "execution_count": 3, + "id": "aadb05aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0.019030116872178315\n" + "0.19999999999999996\n" ] } ], "source": [ "import math\n", - "epoch = 1\n", - "final_prob = 0.5\n", - "warmup_epochs = 8\n", - "# current_prob = min(final_prob, final_prob * epoch / warmup_epochs)\n", - "\n", + "final_prob = 0.8\n", + "warmup_epochs = 12\n", + "current_prob = 0.0\n", + "epoch = 4\n", "current_prob = final_prob * (1 - math.cos(math.pi * epoch / warmup_epochs)) / 2\n", - "if epoch >= warmup_epochs:\n", - " current_prob = final_prob\n", "print(current_prob)" ] } diff --git a/src/handle/audio_analyze.py b/src/handle/audio_analyze.py index a4c1117..a9d11d2 100644 --- a/src/handle/audio_analyze.py +++ b/src/handle/audio_analyze.py @@ -35,7 +35,7 @@ def analyze_mp3_files(directory: Path, ): """分析目录中的所有MP3文件""" results = [] total_duration = 0 - df = pd.read_csv(workspace_dir / '.data/ug/train_new.tsv', sep='\t') + df = pd.read_csv(workspace_dir / '.data/ug/val_new.tsv', sep='\t') for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"): mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path'] diff --git a/src/inference.ipynb b/src/inference.ipynb index b72ef1c..d18fa70 100644 --- a/src/inference.ipynb +++ b/src/inference.ipynb @@ -11,9 +11,9 @@ "import torchaudio\n", "import librosa\n", "import pyrubberband as pyrb\n", - "from torch import Tensor, no_grad, device\n", + "from torch import Tensor, no_grad, device, cuda\n", "from torchaudio.transforms import FrequencyMasking, MelSpectrogram, AmplitudeToDB, Resample, TimeMasking, TimeStretch\n", - "from torch_audiomentations import Gain, PitchShift, LowPassFilter, HighPassFilter\n", + "from torch_audiomentations import Gain, PitchShift, LowPassFilter, HighPassFilter, ApplyImpulseResponse, PolarityInversion\n", "from pathlib import Path\n", "from librosa import effects\n", "import torchaudio.functional as F\n", @@ -172,7 +172,7 @@ " \n", " noisy = signal + scale * noise\n", " # 归一化,防止溢出\n", - " return noisy / (noisy.abs().max() + 1e-8)\n" + " return noisy / (noisy.abs().max() + 1e-8)" ] }, { @@ -194,22 +194,18 @@ " self.model.eval()\n", "\n", " print(f\"params params: {self.model.get_num_params():,}\",)\n", + " print(f\"Vocabulary Size: {self.tokenizer.vocab_size():,}\",)\n", "\n", " self.sample_rate = 16000\n", "\n", - " # self.gain_up = Gain(min_gain_in_db=5, max_gain_in_db=10, p=1.0, output_type='tensor')\n", - " # self.gain_down = Gain(min_gain_in_db=-20, max_gain_in_db=-10, p=1.0, output_type='tensor')\n", - " # self.pitch_up = PitchShift(min_transpose_semitones=3, max_transpose_semitones=5, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n", - " # self.pitch_down = PitchShift(min_transpose_semitones=-5, max_transpose_semitones=-3, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n", - " # self.lowpass = LowPassFilter(min_cutoff_freq=400, max_cutoff_freq=2400, p=1.0, output_type='tensor')\n", - " # self.highpass = HighPassFilter(min_cutoff_freq=1400, max_cutoff_freq=3400, p=1.0, output_type='tensor')\n", - "\n", " self.gain_up = Gain(min_gain_in_db=4, max_gain_in_db=8, p=1.0, output_type='tensor')\n", " self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')\n", " self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n", " self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')\n", - " self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')\n", - " self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')\n", + " self.lowpass = LowPassFilter(min_cutoff_freq=1000, max_cutoff_freq=3000, p=1.0, output_type='tensor')\n", + " self.highpass = HighPassFilter(min_cutoff_freq=1000, max_cutoff_freq=3000, p=1.0, output_type='tensor')\n", + " self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type=\"tensor\")\n", + " self.polarity_inversion = PolarityInversion(p=1.0, output_type=\"tensor\")\n", " \n", " def _load_audio(self, audio_path: Path) -> Tensor:\n", " waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)\n", @@ -224,63 +220,71 @@ " return waveform\n", " \n", " def augment_waveform(self, waveform: Tensor) -> Tensor:\n", - " # if not self.augment or random.random() > self.augment_prob:\n", - " # return waveform\n", + " if not self.augment or random.random() > self.augment_prob:\n", + " return waveform\n", " \n", " # 1. voice Stretch/Compress \n", - " if random.random() < 0.5:\n", - " waveform = self._voice_stretch_or_compress(waveform=waveform)\n", - "\n", - " if random.random() < 0.4:\n", + " if random.random() < 0.6:\n", + " waveform = self._stretch_or_compress(waveform=waveform)\n", + " \n", + " if random.random() < 0.7:\n", " waveform = self.noise_augmentor.apply_real_noise(waveform)\n", - "\n", + " \n", " if random.random() < 0.3:\n", " waveform = self._time_mask_waveform(waveform=waveform)\n", - " \n", + "\n", " # torch_audiomentations: [1, time] -> [1, 1, time]\n", " if waveform.dim() == 2:\n", " waveform_3d = waveform.unsqueeze(0)\n", " # 随机选择一种频谱增强\n", " choice = random.random()\n", - " if choice < 0.15:\n", + " if choice < 0.30:\n", " # 增益变化(上或下)\n", " if random.random() < 0.5:\n", + " print('gain_up')\n", " waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)\n", " else:\n", " waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)\n", - " elif choice < 0.25:\n", + " print('gain_down')\n", + " elif choice < 0.60:\n", " # 音高变化(上或下)\n", " if random.random() < 0.5:\n", + " print('pitch_up')\n", " waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)\n", " else:\n", + " print('pitch_down')\n", " waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)\n", - " elif choice < 0.30:\n", + " elif choice < 0.80:\n", " # 低通滤波(声音发闷)\n", + " print('lowpass')\n", " waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)\n", - " elif choice < 0.35:\n", + " elif choice < 0.95:\n", + " print('highpass')\n", " # 高通滤波(电话效果)\n", " waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)\n", + " else:\n", + " print('polarity_inversion')\n", + " self.polarity_inversion(waveform_3d, sample_rate=self.sample_rate)\n", " \n", " # [1, 1, time] -> [1, time]\n", " waveform = waveform_3d.squeeze(0)\n", - " \n", + "\n", " # 防止多次 augment 后振幅溢出,最后归一化\n", " max_amp = waveform.abs().max()\n", " if max_amp > 1.0:\n", - " print(max_amp)\n", " waveform = waveform / max_amp\n", " \n", " return waveform\n", " \n", - " def _voice_stretch_or_compress1(self, waveform: Tensor) -> Tensor:\n", + " def _stretch_or_compress1(self, waveform: Tensor) -> Tensor:\n", " speed_factor = random.uniform(0.8, 1.4) # (Speed Change: 0.6x - 1.4x)\n", " waveform_np = waveform.squeeze(0).cpu().numpy()\n", " y_stretch = librosa.effects.time_stretch(waveform_np, rate=speed_factor)\n", " waveform = torch.from_numpy(y_stretch).float().to(waveform.device).unsqueeze(0)\n", " return waveform\n", " \n", - " def _voice_stretch_or_compress(self, waveform: Tensor) -> Tensor:\n", - " speed_factor = random.uniform(0.8, 1.25) # (Speed Change: 0.6x - 1.2x)\n", + " def _stretch_or_compress(self, waveform: Tensor) -> Tensor:\n", + " speed_factor = random.uniform(0.8, 1.6) # (Speed Change: 0.8x - 1.6x)\n", " spec = torch.stft(\n", " waveform.squeeze(0),\n", " n_fft=400,\n", @@ -290,7 +294,7 @@ " )\n", " \n", " # 时间拉伸(不改变音高)\n", - " stretch = TimeStretch(hop_length=160, n_freq=spec.shape[-2], fixed_rate=1.4)\n", + " stretch = TimeStretch(hop_length=160, n_freq=spec.shape[-2], fixed_rate=speed_factor)\n", " stretched_spec = stretch(spec)\n", " \n", " # 转回波形\n", @@ -303,7 +307,7 @@ " \n", " return waveform_stretched\n", "\n", - " def _voice_stretch_or_compress2(self, waveform: Tensor) -> Tensor:\n", + " def _stretch_or_compress2(self, waveform: Tensor) -> Tensor:\n", " speed_factor = random.uniform(0.6, 1.4) # (Speed Change: 0.6x - 1.4x)\n", " waveform_np = waveform.squeeze(0).cpu().numpy()\n", " y_stretch = pyrb.time_stretch(waveform_np, self.sample_rate, speed_factor)\n", @@ -355,11 +359,11 @@ "device = torch.device('cuda:1')\n", "\n", "# checkpoint = sorted(workspace_dir.glob('.checkpoints/checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))[-1]\n", - "checkpoint = workspace_dir / \".checkpoints/best_wer_model.pt\"\n", + "# checkpoint = workspace_dir / \".checkpoints/best_wer_model.pt\"\n", "# checkpoint = workspace_dir / \".checkpoints/best_cer_model.pt\"\n", - "checkpoint = workspace_dir / \".checkpoints/prodect/checkpoint_epoch_24 copy.pt\"\n", + "checkpoint = workspace_dir / \".checkpoints/prodect/best_wer_model.pt\"\n", "print(f\"Load Checkpoint: {checkpoint}\")\n", - "inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/asr_vocab.json', noise_dir='/mnt/dataset/dataset/audio/noise', device=device)" + "inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/asr_vocab.json', noise_dir='/mnt/dataset/dataset/audio/noise', device=device, augment_prob=0.8)" ] }, { @@ -376,14 +380,17 @@ "# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_26245615.mp3'\n", "# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_40794549.mp3'\n", "# audio_path = workspace_dir / 'data/ug/clips/common_voice_ug_26245614.mp3'\n", - "audio_path = '/mnt/train/audio/Dilnaz/only_vocals/3272800876_100006098_(vocals)_melband_roformer_big_beta5e.m4a'\n", + "# audio_path = '/mnt/train/audio/Dilnaz/only_vocals/3272800876_100006098_(vocals)_melband_roformer_big_beta5e.m4a'\n", "# audio_path = workspace_dir / 'data/test/how_are_you.mp3'\n", - "# audio_path = workspace_dir / 'data/test/split_1.m4a'\n", + "# audio_path = workspace_dir / 'data/test/split_2.m4a'\n", "# audio_path = workspace_dir / 'data/test/let_me_see_the_computer.mp3'\n", "# audio_path = workspace_dir / 'data/test/introduce_myself.mp3'\n", "# audio_path = workspace_dir / 'data/test/F001_001.wav'\n", + "# audio_path = workspace_dir / 'data/test/test_voise.m4a'\n", + "# audio_path = workspace_dir / 'data/test/radio_low.m4a'\n", + "# audio_path = workspace_dir / 'data/test/radio_high.m4a'\n", + "audio_path = workspace_dir / 'data/test/3272800876_100006098_(vocals)_melband_roformer_big_beta5e.m4a'\n", "# audio_path = workspace_dir / 'data/test/download.wav'\n", - "# audio_path = workspace_dir / 'data/test/test_voise_1.m4a'\n", "orginal_waveform = inference._load_audio(audio_path=audio_path)\n", "print('load audio:', orginal_waveform.shape)\n", "# orginal_waveform = orginal_waveform[:, :]\n", @@ -403,8 +410,6 @@ "display(Audio(orginal_waveform, rate=inference.sample_rate))\n", "\n", "# augment_waveform = inference.augment_waveform(waveform=orginal_waveform)\n", - "# augment_waveform = simple_corridor_echo(waveform=orginal_waveform)\n", - "# augment_waveform = inference._voice_stretch_or_compress(waveform=orginal_waveform)\n", "# noise_augmentor_waveform = inference.noise_augmentor.apply_real_noise(waveform=orginal_waveform)\n", "# display(Audio(augment_waveform, rate=inference.sample_rate))\n", "# text = inference.transcribe(waveform=orginal_waveform)\n", @@ -413,10 +418,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "16195280", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 13815687])\n", + "speech_timestamps: [{'start': 122240, 'end': 221312}, {'start': 223104, 'end': 288000}, {'start': 288000, 'end': 346368}, {'start': 346368, 'end': 416256}, {'start': 416256, 'end': 476416}, {'start': 476416, 'end': 502912}, {'start': 606080, 'end': 632960}, {'start': 645504, 'end': 710784}, {'start': 714112, 'end': 777856}, {'start': 781696, 'end': 819840}, {'start': 823168, 'end': 937088}, {'start': 941440, 'end': 1111680}, {'start': 1122176, 'end': 1211520}, {'start': 1214848, 'end': 1334400}, {'start': 1345408, 'end': 1455744}, {'start': 1479552, 'end': 1648768}, {'start': 1654656, 'end': 1779840}, {'start': 1783168, 'end': 1990784}, {'start': 2034560, 'end': 2108032}, {'start': 2108288, 'end': 2198144}, {'start': 2201984, 'end': 2236416}, {'start': 2236416, 'end': 2267264}, {'start': 2277248, 'end': 2329216}, {'start': 2331520, 'end': 2377344}, {'start': 2383232, 'end': 2403840}, {'start': 2403840, 'end': 2496768}, {'start': 2496768, 'end': 2594432}, {'start': 2595712, 'end': 2626688}, {'start': 2636160, 'end': 2659328}, {'start': 2659328, 'end': 2711168}, {'start': 2719104, 'end': 2837120}, {'start': 2890112, 'end': 3007744}, {'start': 3007744, 'end': 3055232}, {'start': 3059072, 'end': 3223168}, {'start': 3230080, 'end': 3289216}, {'start': 3296640, 'end': 3616384}, {'start': 3617152, 'end': 3677312}, {'start': 3683712, 'end': 3720832}, {'start': 3728256, 'end': 3771520}, {'start': 3795840, 'end': 3837056}, {'start': 3839872, 'end': 3906176}, {'start': 3924864, 'end': 3961984}, {'start': 3964288, 'end': 3986944}, {'start': 3986944, 'end': 4022912}, {'start': 4037504, 'end': 4070016}, {'start': 4080512, 'end': 4135040}, {'start': 4148608, 'end': 4189952}, {'start': 4189952, 'end': 4213888}, {'start': 4225920, 'end': 4264064}, {'start': 4274560, 'end': 4308096}, {'start': 4323200, 'end': 4425344}, {'start': 4427136, 'end': 4553344}, {'start': 4558720, 'end': 4670080}, {'start': 4672384, 'end': 4744832}, {'start': 4785024, 'end': 4913792}, {'start': 4915584, 'end': 5035648}, {'start': 5037952, 'end': 5106944}, {'start': 5106944, 'end': 5140480}, {'start': 5140480, 'end': 5216384}, {'start': 5223296, 'end': 5340800}, {'start': 5344128, 'end': 5365888}, {'start': 5366144, 'end': 5384320}, {'start': 5385088, 'end': 5458560}, {'start': 5515136, 'end': 5683840}, {'start': 5690240, 'end': 5913216}, {'start': 5919616, 'end': 5988480}, {'start': 5993344, 'end': 6111872}, {'start': 6118784, 'end': 6172800}, {'start': 6179712, 'end': 6260352}, {'start': 6264704, 'end': 6431872}, {'start': 6582144, 'end': 6690816}, {'start': 6690816, 'end': 6728320}, {'start': 6735744, 'end': 6766208}, {'start': 6769536, 'end': 6826624}, {'start': 6827904, 'end': 6892160}, {'start': 6900608, 'end': 6927360}, {'start': 6927360, 'end': 6948992}, {'start': 6950272, 'end': 6973568}, {'start': 6978944, 'end': 7029888}, {'start': 7035264, 'end': 7069312}, {'start': 7076736, 'end': 7127168}, {'start': 7135616, 'end': 7221376}, {'start': 7223680, 'end': 7336064}, {'start': 7345536, 'end': 7395456}, {'start': 7395712, 'end': 7474304}, {'start': 7478656, 'end': 7541376}, {'start': 7544704, 'end': 7626368}, {'start': 7628672, 'end': 7639168}, {'start': 7639424, 'end': 7666816}, {'start': 7667072, 'end': 7749760}, {'start': 7751040, 'end': 7841792}, {'start': 7841792, 'end': 7859840}, {'start': 7862144, 'end': 7929344}, {'start': 7929344, 'end': 8049152}, {'start': 8049152, 'end': 8105984}, {'start': 8105984, 'end': 8155776}, {'start': 8161664, 'end': 8202880}, {'start': 8207744, 'end': 8290432}, {'start': 8297856, 'end': 8312448}, {'start': 8313216, 'end': 8377088}, {'start': 8377088, 'end': 8436864}, {'start': 8438144, 'end': 8516736}, {'start': 8524672, 'end': 8583296}, {'start': 8595840, 'end': 8629248}, {'start': 8629248, 'end': 8696448}, {'start': 8706944, 'end': 8724096}, {'start': 8725888, 'end': 8810112}, {'start': 8817536, 'end': 8841856}, {'start': 8842112, 'end': 8900224}, {'start': 8904576, 'end': 9025664}, {'start': 9035648, 'end': 9063168}, {'start': 9063168, 'end': 9096832}, {'start': 9100672, 'end': 9170176}, {'start': 9170176, 'end': 9180160}, {'start': 9180160, 'end': 9364608}, {'start': 9379712, 'end': 9417856}, {'start': 9419648, 'end': 9505408}, {'start': 9510272, 'end': 9552896}, {'start': 9552896, 'end': 9580160}, {'start': 9584512, 'end': 9611392}, {'start': 9619840, 'end': 9638400}, {'start': 9638400, 'end': 9733760}, {'start': 9748864, 'end': 9810048}, {'start': 9814400, 'end': 9889920}, {'start': 9901952, 'end': 9919488}, {'start': 9919488, 'end': 9964672}, {'start': 9966464, 'end': 10050176}, {'start': 10057088, 'end': 10199680}, {'start': 10204032, 'end': 10253824}, {'start': 10253824, 'end': 10315904}, {'start': 10319744, 'end': 10379904}, {'start': 10388864, 'end': 10407040}, {'start': 10411392, 'end': 10464896}, {'start': 10484096, 'end': 10563712}, {'start': 10567552, 'end': 10594048}, {'start': 10594048, 'end': 10691200}, {'start': 10697088, 'end': 10708096}, {'start': 10719616, 'end': 10790528}, {'start': 10803072, 'end': 10835072}, {'start': 10837888, 'end': 10887808}, {'start': 10893184, 'end': 10924800}, {'start': 10924800, 'end': 11035264}, {'start': 11042688, 'end': 11074688}, {'start': 11076992, 'end': 11086976}, {'start': 11100032, 'end': 11177600}, {'start': 11182464, 'end': 11251840}, {'start': 11255680, 'end': 11424384}, {'start': 11612032, 'end': 11735680}, {'start': 11743104, 'end': 11810944}, {'start': 11811712, 'end': 11872896}, {'start': 11873664, 'end': 11917952}, {'start': 11920768, 'end': 12016768}, {'start': 12025216, 'end': 12056704}, {'start': 12073856, 'end': 12133504}, {'start': 12138880, 'end': 12241024}, {'start': 12248448, 'end': 12281984}, {'start': 12283264, 'end': 12316288}, {'start': 12318080, 'end': 12345472}, {'start': 12350848, 'end': 12390400}, {'start': 12390400, 'end': 12428928}, {'start': 12431744, 'end': 12473472}, {'start': 12477824, 'end': 12519552}, {'start': 12519808, 'end': 12553344}, {'start': 12559744, 'end': 12625664}, {'start': 12625664, 'end': 12749312}, {'start': 12749312, 'end': 12800128}, {'start': 12807040, 'end': 12837504}, {'start': 12838784, 'end': 12859008}, {'start': 12863360, 'end': 12910720}, {'start': 12912000, 'end': 12976768}, {'start': 12986752, 'end': 13062784}, {'start': 13068160, 'end': 13200000}, {'start': 13204864, 'end': 13276800}, {'start': 13401472, 'end': 13432960}, {'start': 13444992, 'end': 13470848}, {'start': 13478784, 'end': 13506176}, {'start': 13515648, 'end': 13542016}, {'start': 13555584, 'end': 13579392}, {'start': 13589376, 'end': 13616256}, {'start': 13630336, 'end': 13652608}, {'start': 13666688, 'end': 13696128}, {'start': 13706624, 'end': 13730432}]\n", + "speech_timestamps: 182\n", + "VAD原始片段: 182 -> 合并后片段: 150\n", + "ھۇزۇرۇڭلاردا بولۇۋاتقىنى «ئايشەم مايىل تەبەسسۇم قەلىمىدە پۈتكەن «ئايرىلىش ھەققى» ناملىق نادىر ئەسەر| بۇ ئەسەر شىنجاڭ خەزىنە مەدەنىيەت تارقىتىش شىركىتىدە ئاۋازغا ئېلىنىپ| ئېسىم فىنىم سالونى ۋە دىلناز ئەپچىسىدىن تارقىتىلدى.| باشقا ھەرقانداق سالامن ۋە ئەپچىلەرنىڭ كۆچۈرۈپ تارقىتىشى بىردەك مەى قىلىنىدۇ.| خىلاپلىق قىلغۇچىلار ئەسەرگە كەتكەن بارلىق چىقىمنى ئۆز ئۈستىگە ئالىدۇ| -گىكىمىزگە ھۆرمەت قىلىڭ.تۇز ئۈچىنچى قىسىم| قىزىنىڭ ئۇچۇرىنى ئالالماي ئەندىشىگە چۈشكەن ئايالنىڭ كۆڭلى پەرىشان ئىدى.| قايتا قايتىلاپ بورى ۋەگەن تېلې فوننىڭ ھەممىسى ئېتىش بولۇپ چىقتى.| بۇ قىز تېلېفوننى ھېچ ئاچچاي دېمەك.| ئاخى بولماي خەمەتنىڭ تېلېفوننومۇرىنى تېپىپ، ئۇنىڭغا تېلېفون قىلغان بولسىمۇ، خەمىمۇ ئوخشاشلا ئۇنىڭ ئۇچۇرىنى بىلمەيدىغان بو چىقتى.| بىر كېچە- كۈندۈزنى پۇتى كۆيگەن توخۇدەك جىددىيلىشىپ، پىكپىرلاپ مىڭ تەستە ئۆتكۈزگەن ئايال، ئاخىرى قىززىنىڭ يوقاپ كەتكەنلىكىنى ئېرىگە ئېيتىپ، ئۇنىڭدىن بىر ئېلا قىلىپ قىزىنى تېپىشىنى ئۆتۈندى.| رۇستەمباي ھېلىغىچە قىدىن رەنجى بارغان پىكىردە بولغاچقا، ئايالىنىڭ گەپلىرىگە كۆڭۈل بۆلمەك.| لېكىن بىچارە ئاپىسى بالا دەرجىدە پۇچۇلىنىپ، قاق بىر كۈن ئېرىنىڭ قۇلاق-مۇنى يەپ زادىلا ئاراملىق تەمى.| شۇنىڭدىمۇ ئۇ ئادەمنى ئېرىتىشكە كۆزى يەتمەي، ئاخىرى بولدىلا» دېگىنىچە ئۆز يالغۇز ساقچىخانىغا يۈگۈردى.| ئايال: «ھاسىراپ-ھۆمى دەپ يېتىپ كەلگەندە: خەمىت مۇشۇ يەردە بىرەيلەننى يوقاپ كەتتى» دەپ دېلو مەلۇم قىلىۋاتقان بولۇپ، ھەر ئىككىلىسى يوقلۇقىنى مەلۇم قىلغان ئادەم دەل تېلا قىزبولۇپ چىقتى.| خەمىت ئۇ يەردە بىر ياقتىن ساقچىغا ئەھۋالنى مەلۇم قىلىپ، يەنە بىر ياقتىن پۈتۈن ئىشخانىنى بېشىغا كىيىپ يىغلاۋاتقان ئايالغا تەسەللى بېرىپ ئاۋارە ئى.| ساقچىلار قىزنى پۈتۈن كۈچى بىلەن ئىزدەيدىغانلىقىنى، ئۇنىڭغىچە بۇلارنىڭمۇ ئۇرۇق -تۇغقانلىرىنىڭ ئۆيلىرىدىن ۋە قىزنىڭ بېرىش مۇمكىنچىلىكى بولغان ھەممەيەرلەردىن ئېسىگە كەلگەنلىكى يەرلەرگىچە سۈرۈش قىپ بېقىشنى تاپىلاپ، ئۇلارنى يولغا سالدى.| خەمەك تىللا قىزنىڭ ئاپىسىنى ئۆييىگە ئاپىرىپ قويۇ ئۈچۈن ماشىنىغا چىقىرىپ ئېيت ماڭدى.| يولدا ئايال قارر يامغۇر يېغىلىغىنىچە بار ئەندىشىلىرىنى خەمىپ كۈتۆكىۋاتاتتى.| بەك ئەنسىرەپ قالسام جېنىم ئۇكام.| بىرەر ئىش بولمىغاندۇ قىزىمغا؟ بۇ بالا ئەزەلدىن بۇنداق جىممىدە يوقاپ كېتىپ باققان ئەمەس.| دادىسىدىنبەكلا كۆڭلى ئاغرىدى، بىچارە قىزىمنىڭ.| شۇ ئادەممۇ ئۆمرۈمنىڭ تەڭدىن تولىسىنى ياشاپ بوپتىمەن، ئەمدى قىزىم پوسىنى بەختىڭ تاپسۇن» دېگەن بولسا| ئاچچىقى يامان قىلىپ، تاشدىشىنىپ تۇۋاغىچە شۇ بانە كۆڭلى چۈشكەن ئادەم بىلەن ئۆتكىلى قويغان بولسا| ئۇنداق قىسمۇ بولماستى يە، ئەسىررى مەڭ ھەدە.| لا قىزنى تاپىمىز، مەن ئۇنى چوقۇم تاپىمەن.| -دېدى خەت يۈرىكىنىڭ ئېچىشىپ تۇرغانلىقىنى ئايالغا بىلدۈرمەسلىككە تىرىشىپ كۆزلىرىدە لىغىرلاپ قاغان ياشنى ئۇنىڭدىن يوشۇر| پەمەتت ئايالنى ئۆيىگە ئاپىرىپ قويغاندىن كېيىن، تىلا قىز بىلەن بىللە بېرىپ باققانلىكى جايلارغا بېرىپ، ھەممەيەردىن ئۇنى ئالنى -قويماي ئىززەتدى.| مەن تونۇش بىلىشلەردىن سۈرۈش تۈقىپ چىقتى.| ھەتتا قىز بېرىشنى ياخشى كۆرىدىغان قەھۋەخانا، ساتىراشخانا، ھۆسن تۈزەش سالونلىرى قاتارلىق جايلارغىچە بېرىپ يادىغا كەلگەنلىكى يەرلەرنىڭ ھەممىسىدىن يىپئۇچى ئىكەنلىك.| ئەپسۇسكى، ھېچبىر يەردىن قىزنىڭ دېرىگە بولمايۋاتاتتى.| ئاتىسى بىر ياقتا خەمەت بىرياقتا، ئەنە شۇنداق جان پىدالىق بىلەن ئۇنى ئىزدەپ يۈرگەندە، ساقچىلاردىن يېڭى ئۇچۇر كېلىپ، قىزنىڭ يوقاپ كېتىشىنىڭ ئالدىنقى ئاخشىمى تورسۇپۇسىدا بىر مېھمانخانىدىن ياتاقسەپەز قىلغانلىقى، ئەڭ ئاخىرقى قېتىم شۇ يەردە ياتاقتىزىملاش ئۇچۇرىنى قالدۇرغاندىن كېيىن، ھېچقانداق بىر تىجارەت سورۇنى ياكى يوللاردا پەيدا بولمىغانلىقى دەلىللەندى.| بۇ خەۋەرنى ئاڭلىغان خىزمىنىڭ كۆڭلىدىن كەچمىگەن خىياللار قىممىدى.| ئەجەباب قىز مېھمانقانىڭياتاققېچەر.| ئاشۇ يەردە جىممىداق ھاججەتتىن ئاخىرلاشتۇرغانمىدۇ؟| يا ياكى ئۆزىدىنبەك ئۈمىدسىزلىنىپ كېتىپ| ئۇھەببەت ئىزھار قىلىپ يۈرگەن بىرەر ئوقۇمى بىلەن جىمۋىدىلا ياتاققا كىرىۋالغانمىدۇ؟| ئاتا-ئانىسىنىڭ ئۆيىگە قايتقىسى بولمىسا| كۆڭۈلدىكىنى ئېيتىپ، ئۆزىنىڭ يېنىدا ھاسىمۇ تامامەن بولاتتى.| ئۇ زاتە نېمىشقا ياتاق ئېچىپ قالغان| ۋەيەنە نېمىشقا ھەممە ئادەمدىن ئۆزىنى قاچۇرىدىغاندۇ؟| شۇ تاپتا ئۇ ياتاقتا يالغۇزمىدۇ ياكى| ئىمنە رەسەپ بالىدۇ؟ ياپوننىسى ئاللىقاچان نەپەستىن توختىتى.| جەسىتە چىرىپ كېتىشكە باشلىغانمۇ؟| شۇ خىياللار بىلەن كۆڭلىنى بىردەملا ئارام تاپقۇزانلمىغان خەمىت ساقچىلاردىن دەرھال شۇ ياتاققا بىللە بېرىشنى ئۆتۈندى| ئاڭغىچە قىزنى ئۇچۇ بولغانلىقىدىن خەۋەر تاپقان ئايالمۇ ھاسىراپ-ئۆمۈدەپ پالاقشىپ يۈگۈرگەن پېتى، ساقچىخانىغا يېتىپ كەلگەنى.| قايىل بويىچە ئائىلە تەۋەلىرى ساقچىخانىدا قېلىپ، خەۋەر كۈتۈپ، ساقچىلار ئۆزىنى نەق مەيدانغا بېرىپ ئۇچۇر ئىگىلىسى توغرا بولاتتى.| ئەمما ئايال يىغلاپ يالۋۇرۇپ تۇرۇمىغاچقا، ئۇنىڭ بىللە ئېبېرىشىنى لايىق تاپتى.| مېھمانخانىنا مۇلازىمەت ئولىيايىدىكى خادىملار ئەھۋالنى ئۇققاندىكې، ساقچىلارنىڭ قولىدىكى مۇناسىۋەتلىك ئۇقتۇرۇشىقا ئاساسەن ئىشىكنى ئېچىپ بېرىشكە قوشۇلدى.| قادە بويىچە ئالدى بىلەن ئىشىكنى چېكىپ، ئىچىگە كىرىش كىرمەسلىك ھەققىدە ياتاق ئىچىدىكى مۇئامىلىدارلارنىڭ رۇخسىتىنى ئېلىشقا توغرا كېلەتتى.| شۇڭا مۇلازىمەت دىرېكتورى: ئىشىكنى نەچچە قېتىم ئۇرۇپ ئىچىدىن سادا كۈتتى.| ئۈنلۈك ئاۋازى چاقىرىپمۇ باقتى.| ئەمما ياتاقئىچىدىن چىۋىننىڭ ئۇچقۇنىغا چاغلىق بىرەر ئۈنتىگۈچىلەر كېلەر ئەمەس ئىدى.| بۇ قەدەر جىمجىتلىقنىڭ سەۋەبى، ھەممەيلەننى قىزىقتۇرۇپ، يۈرەكلەرنى سۇ قىلىۋېتىپ بارغاندا، بۇنىڭدىن ئارتۇق كۈتۈشكە ئىمكان بولمىدى.| ماسۇل ساقچىنىڭ چۈشىدە جىق| دېگەن بۇيرۇقىدىن كېيىن، مۇلازىمەت زېرىكتورى ياتاق ئىشىكىگە ئاچقۇ سېلىپ ئاچتى.| ئۇلار ياتاققا كىرگەندە، ياتاققىچى يىللىق قىدە پىۋا، قىزىللار بوتۇلكىلىرى بىلەن تولغان بولۇپ، ھەمىيە سېسىقچىلىق ۋەيرانە ھالدا يېمەكلىك قالدۇقلىرى ۋە ئەخلەتلەر بىلەن توشقانىدى.| دېرىزە پەردىلىرى تولۇق چۈشۈرۈلۈپ، ياتاغقىچى قاپ-قاراڭغۇ قىلىۋېتىلگەن بولۇپ، قاپقارا ئۇزۇن چاچلىرى چۇۋۇلۇپ، يۈزلىرىنى توسۇۋالغان ئورۇقلاپ يېدەكچىلىك قالغان بىر قىز يېرىم بەدىنى كارىۋاتتىن يەرگە ساڭگىلىغان پېتىدا قەداق قىلمايتو ياتاتتى.| چاتاقنىڭ چىرىغى ياندۇرۇلۇشى بىلەنلا بۇ ھالدىن ھەممەيلەننىڭ يۈرىكى ئاتى..| ئەڭ بېشىدا خەمىت يېنىدىكىلەرنى ئىتتىرىپ سۈرۈۋېتىپ، ئۇچقاندەك يۈگۈرگىنىچە قىزنىڭ يېنىغا بېرىپ، ئۇنىڭ بېشىنى كۆتۈرۈپ يۈزىگە قارىتى.| تەھەقىقەت ئالدىدىكى بۇ قىز تىللا قىز شۇ ئىدى.| بىزشۇنى بىلمەيدەك ھالدا مەست بولۇپ، زادى قانچىلىك ئىچكەنلىكىنى بىلگىلى بولمايتتى.| ئۇ تۇيۇقسىز يېنىدا بىر توپ كىشىلەرنىڭ پەيدا بولۇپب قاغانلىقىدىن ھەم ئۆيدىكىلەرنىڭ ئۆزىنى ئىزدەپ تەكپارا بولۇپ كەتكەنلىكىكى بىخەۋەر ھالدا شۇ پېتى كارىۋاتتىن ساڭگىلاپ سوزۇلۇپ ياتاتتى| خەلمىتتىن كېيىنلا قىنىڭ ئاپىسى ئىرغاڭلاپ يۈگۈرۈپ ئۇنىڭ يېنىغا پادى-دە، قىزىنى يۆلەپ، ئۇنىڭ ئىسمىنى چاقىرىپ| يۈزلىرىنى سىلاپ- يىغلاشقا باشلىدى| دادادامغا مەن كېرەك ئەمەس، دەپ گېپىنى باشلىدى تالا قىز ھوشىغا كەلگەندىن كېيىن| نېمە ئۈچۈن بۇ يەرگە بېكىننىۋالغانلىقى ھەققىدە سورالغان سوئالغا جاۋاب بېرىپ| ئۇنىڭغا پۇللىرى، يۈز ئابرۇيىيى.| جەمئىيەتتىكى ئورنى ھەم تالادا تۇتۇپ قويغان سانسىز ئاشنىلىرىلا كېرەك.| ئۇ ئەزەلدىن مېنى ئويلاپ باققان ئەسىز؟| بەلكى، ئىنە ساڭىللىرىمنىمۇ ئويلاپ باققان ئەمەس.| ئۇنىڭغا پەقەتتىن ئۆزى بولسىلا، كۆڭۈللۇق خۇشلۇقى، ئۆزى خالىغان ئىشلىرى بولسىلا بولدى.| ھېچقاچان باللىلىرىمنى نېمىنى خالايدۇ، قانداق بولسا خۇشال بولىدۇ، مەن قانداق دادا بولۇشۇم كېرەك دېگەنلەرنى ئويلاپمۇ باقمايدۇ.| دادام ئېھتىياجلىق چېغىدا، مەن ئۇنىڭغا يېنىدا ئىدىم.| ئۇنىڭ بەختلىك بولۇشى ئۈچۈن، ئائىلىسىنىڭ پۈتۈن بولۇشى ئۈچۈن بارلىق كۈچۈمنى چىقارغانىدىم.| لېكىنمەن ئېھتىياجلىق بولغاندا، دادام ئەزەلدىن يېنىمدا مۇ باقمىدى.| كۆڭلۈم يېرىم بولغاندا قولۇمنى تۇتۇپ، يىغلىسام يېشىمنى سىرتىپ باققان ئادەم ئەمەس سۇ.| ئىسىللەپ كۆيگىدىن تارتىپ| دادامنىڭ يۈكىنى يەڭگىنىپتىمەن، دەپ تىرىشىپ: جان-جەھلىم بىلەن ئۆگىنىپ ئىشلەپ كەلدى| ئالىي مەكتەپتە ئو قويۇمەنمىسەن، ئېرىكىنى تاپشۇرۇپ بار دېسەم، ئوخقۇشۇڭلار ھاجىتى يوقتىپ ئوقۇتمىدى.| قالدىدىڭ، ئەمما مەن چوڭ دۇرغەمدا قىز با نىڭ ئىگىلىككىنى ياراشمايدۇ| مېنىڭ ئو غۇلبالا بولغاندىكىن دىلىكنىڭ ئۇنىڭغا تاپشۇرىمەن. كەييىن چۇساق يەر-يۆلەك بولىدۇ، ھازىر ساغا ئايرىم دۇكان ئېچىپ بېرەيدۇ.| دوپپارچىلىك بىر دۇكاننى تۇتقۇزۇپ قويۇپ بوللۇق| شۇنىڭ بىلەن ئوقۇششاش بانىمۇ يوق بولدى.| مەن بۇلار ئۈچۈن دادامدىن ئاغرىنىپ ئۇخمىدىم| شۇ دۇكاننىڭ پۇلىنى يورغىلىتىپ تېرىپ، تارقالمايدىنىڭ ھەممىسىنى دادامنىڭ چىندىكى سالدى.| باشتادام كارغا كەلگەنلىكىمنى كەتسە، پۇل تاپالايدىغانلىقىمنى دېۋىسە| بابارا بارا دۇكانئىقساددىنى ئۆزۈمدە ئوچىلاق بېرىدىغۇ دەپ ئويلىغانىدىم.| ئەمما ئۇ كېيىنكى ۋاقىتلاردىمۇ ئىزچىل دۇكاننىڭ پايدىسىنى ئۆزىگە تاپشۇرۇشۇمنى ئېيتىپ كەلدى.| ھېسابتا، مەن ئۇنىڭغا بىكارلىق ئىشلەپ رىدىغان مەبىكار بولدۇم.| ئۇ ھەتتا مەن خەجلىگەن| ئىنسانلىرىمغا بەرگەن ئازغىنە پۇلنىڭ ھېسابىنىمۇ مەندىن تولۇقى بىلەن ئالدى.| مەن تېخى ھاياتلىق مۇشۇنداق بولىدىغان ئوخشايدۇ، ئىقتىسادنى مۇشۇنداق تۇتىدىغان ئوخشايدۇ دەپ ئويلاپ كەتتىم.| ئەمما كىيىم بىللىسەم. ئۇ ئاشنىلىرىغا پۇل خەجلىگەندە ئەزەلدىن پۇلنىڭ كۆزىگە قارىماي دىكەن.| ھەتتاييېنەدا بومىسا باشقىلاردىن ئېلىپ، مەندىن ئېلىپ ئاپامغا ساقلاشقا خەجلەشكى بەرگەنلىنىمۇ تاتىلىپ تۇرۇپ ئاشنىلىرىغا خەجلەيدىكەن.| ئاپامنىڭ تۇغۇلغان كۈنىنى ئەزەلدىن ئىسىدا تۇتۇپ باقمىغان.| بالىلىرىنىڭ ھېچقانداق تۇغۇلغان كۈننى، خاتۈر كۈنلىرى يادىدا تۇرمايدىغان دادا| ئاشنىلىرىنىڭ ھېيت-بايراملىرىدىن تەتىپ تۇغۇلغان كۈنلىرىگىچە ھەتتا ئۇلارنىڭ تۇغقانلىرىنىڭ تۇغۇلغان كىلىرىگىچە مەھكەم ئېسىدە ساقلاپ، ھەممىسىنىڭ ئۇرۇققۇنى چۈشۈرۈپ مەي تولۇق ئۆتكۈزۈپ بېرىپ، سوۋغا -سولۇلارنى بېرىپ ماڭىدىكەن.| لېكىن بىز پادىسەك يوقدەي.| ئۆينىڭ چىقىملىرىغىمۇ مەڭلىك تەستە ئىنتايىن قاقشاپ تۇرۇپ كەتبۇ بېبېرىدىكەن| مۇ شۇنچىۋالا پەرقلەرنى بىلگەندىن كېيىپ| دادامدىن، رەنجىپ، خېلىلى قاتتىق رەنجى| شۇنداقتۇر. ئۇ دادام بولۋامدى ئۈچۈن رەنجىگەرلىكىمنى يېلىگە دېيەلمىدىم ياكى ئادانى ساقلاپ قارىماي قويالمايتىم| كېيىن، بۇ تەرەپلىمۇ كۆڭلۈمدىن چىقىرىۋېتىپ رەنجى مەسلا بولدۇم.| چۈنكى ھەرقانچە رەنجىسەممۇ دادامنىڭ بۇ رەنجىشلىرى بىلەن پەرۋايى پەلەك ئىدى.| كېيىن دادام، تاجىمىلىگەن ئايالنى تېپىۋېلىپ ئاپامنى تاشلىۋەتتى.| ئاجرىشىمەن دەپ غەلبە قىلىپ ھەقىقەتەن ئاپامغا تۆتتەڭگە پۇل تۇتقۇزۇپ قويۇپلاجرىشىپ كەتتى.| شۇ چاغدا مەن دادام تەرەپتە تۇرۇپ ئاپامنى بەگىزلىدىم. دادامدىن كەلمىگەن مۇھەببەتنى ئاپامغان مەن يەتكۈزۈپ، مانا مەن شۇ ئائىلىق قايتا بىر پۈتۈنلۈككى ئىگە قىلدىم.| ئەمما دادامنىڭ بۇلار بىلەن ئىككى كارى بولمى..| خۇددى بۇ ئىشلارنى قىلىشىم يوللۇقتەك، بۇ مېنىڭ مەسئۇلى تەندەك| دادام ھەممە ئىشنى ماڭا تاشلاپ قويۇپ، ئۆزى بىر چەتكە چىقىپ تۇرۇۋال.| بوپتۇ دېدىم، ئائىلە بەخلىك بول، مەنمۇ خاتىرجەم بولاتتىم| ئىنسىڭىلىرىمنىڭ كىرىزىدىكى ياخشى قۇرىتى. مەن شۇنىڭ ئۈچۈن تەبىلنى كېشى دېدىم| ئىشلاتىگەن. نۆۋەت مېنىڭ بىربىرنى تېپىشمە سىلىسىگە كەلگەندە، دادام يەنىڭلا ئۆرىزىنىڭ مەنپەلىرىنى لىدى.| ئىسسىننىڭ پرايىنى قوللاش ئۈچۈن خەمىنىڭ مەقتەمنى قۇرنىڭ كۆڭلىغا سېپىلدى.| ئېھ، ئېيقانلارچان| ئۇ نېمىشقا مېنى ياخشى كەلگەن ئادىمىدىن ئايرىيېتەردۇ| جاھاندىكى ھەممەتە ئۆزگەرتىش ئېگى ئادەم بىلەن تويقىسى بولىدىكەن، نېمىشقا بىر مىلەرناممۇنداق ئالامبۇنىڭ يېتى نېمىش| بىر نېمىش؟| تالاقىز ئەنە شۇدانە دەپ ئاپىسىغا ئۆزىنى تاشلاپ بۇكۇداپ يىغلاپ كەتتى.| ئۇنىڭ كۆز ياشلىرى تە يىغىسى ھەممەيلەننىڭ يۈرىكىنى ئېزىۋەتتۇ.| بولۇپمۇ خەمىت كۆزلىرىدىن تائنايلا دەپ قالغان ياشلىرىنى تۇتۇۋېلىش ئۈچۈن شۇنچىلىك زور كۈچ سەرپ قىلىپ ئولتۇراتتىكى، ئەمما يۈرىكىدىكى زەرداپنى بىر ئۆزى بىلەتتى.| جېنىم بالى بىچارە قىزسەن مېنىڭ؟ شۇنداق بولسىمۇ بىزلەرنى تاشلاپ يەپ باق قالماي ئۆيدىن چىقىپ كەتكەن بارمۇ| نېمە گەپ بوسۇن دېيىشىپ، مەسلىھەتلىشىپ ھەل قىلىمىز ئەمەسمۇمۇ| داداڭدىن رەنجىسەڭ مېنىنى تاشلىۋەتسەڭ قانداق بولىدۇ| جېنىم بالام، مەن ئىزچىل ئەتتەرەپتە| ئۆز بەرگەن مۇشۇ ئىشلارنىڭ ئالدى-كەينىدىنمۇ، مەن داداڭنى قايىل قىلىشى ئۈچۈن شۇنچىلىك تىرئىچتىم.| ھېلىھەممە تىرىشىمەن بالام، جېنىم تېنىمىزلا بولىدىكەن، سېنى ھەرگىز يىغلاتمايمەن| داداڭ دېگەن ئۇ شورى قۇر-غۇر، ئۆزىنىڭ ئۆتكۈزگەن خاتالىقلىرى ئېسىدە يوق، ئەمدى سېنىڭ ئىشىغا كىرىشىۋالدى.| لېكىن مەن بۇ قېتىپ يول قويمايمەن.| تويانقى قولاسىمۇ قورشۇلىدۇ.| قوشۇلمىسىمۇ قوشۇلىدۇ، ئەگەر سائەتلا قوشۇلمايدىغان بولسا| سېنى ئېلىپ ئۇنداق ئۆيدىن چىقىپ كېتىمەن| ئۆمرۈمنىڭ تەڭدىن تولىسىنى ياشاپ بوپتىمەن.| بۇنىڭدىن بۇرۇن ئۇنىڭ گېپىنى ئاڭلاپلا كەلدىم| ئەمدى بۇنىڭدىن كېيىن ئۇنداق تالمايمەن.| بىۋاشلارچە گېپىنى ئاڭلاپ ئولتۇرسام ماڭا كۆرسەتكەن كۈنلىرى شۇ بولدى| سې ئۇنىڭ كېپىنى ئاڭلاپ دېگىنى بوچە قىپ بەرسەم، كىم بولىدۇ، تۆت كۈننە توختاپ، يەنە قايسى بىر جىمپۇچلار بىلەن تېپىشىپ قېلىپ| ئانا-بالا تۆتىمىزنىڭمۇ يۈرىكىنى تاتىلاپ قان قىلىۋېتىدۇ.| كۆڭلۈڭ توقتۇق جېنىم ما، يېنىڭدا مەن بار| بۇ ئىشتا داداڭنىڭ دېگەنلىرىگە ھەرگىز كۆنمەيمىز.| كۆڭلۈم كىمنى خالىغان بولسا، شۇنىڭ بىلەن تۇرمۇش قۇرۇپ بەخقىتىڭنى تاتىسەن.| ئايال ئەنە شۇلارنى دەپ نۆل-مىڭ يىغلىغىنىچە قىزىغا تەسەللى بەر.| ئازابتىن يۈرىكى پۇلىنىپ، نەپەس ئاغىدەك ھالى قالمىغان قىز ئاپىسىنىڭ تەسەللىسىدىن تېخىمۇ ئېزىلىپ، تېخىمۇ ئۆچسۈپ يىغلاپ كەتتى.| ئۇلارنىڭ بۇ ھالىغا قاراپ خەمىمۇ ئۆزىنى تۇتۇۋالالمىغىلى تاس قالدى.| ئۈن خىزمىتىگە قاتناشقۇچىلار، گۈلزار مۇدىيىم بەردى| گۈلنىگار ئوبۇل قاسىم، مۇستەسا ئابدۇرېھىم| ئەنۋەر ئىبراھىم، ئامىنە ئابدۇكېرەم| پەرھا تاش پولات، ئانىنە بەكلى قارلىغاچ| تەڭمۈر تۇرسۇن| " + ] + } + ], "source": [ "from silero_vad import load_silero_vad, read_audio, get_speech_timestamps\n", "class TimestampsType(TypedDict):\n", @@ -447,14 +464,15 @@ "\n", "vad_model = load_silero_vad(onnx=False)\n", "waveform_vad = inference._load_audio(audio_path=audio_path)\n", + "print(waveform_vad.shape)\n", "speech_timestamps: list[TimestampsType] = get_speech_timestamps(\n", " waveform_vad, \n", " vad_model, \n", " sampling_rate=inference.sample_rate,\n", - " threshold=0.4, # 可以根据需要调整阈值\n", + " threshold=0.5, # 可以根据需要调整阈值\n", " min_speech_duration_ms=100, # 最小语音持续时间,防止短噪音被误判\n", " min_silence_duration_ms=200, # 最小静音间隔,用于分割语音块\n", - " speech_pad_ms=100, # 在语音块前后添加的填充时间\n", + " speech_pad_ms=200, # 在语音块前后添加的填充时间\n", ")\n", "print('speech_timestamps:', speech_timestamps)\n", "print('speech_timestamps:', len(speech_timestamps))\n", @@ -466,10 +484,11 @@ " end_sample = int(ts['end'])\n", "\n", " segment_waveform = waveform_vad[:, start_sample:end_sample]\n", - " display(Audio(segment_waveform, rate=inference.sample_rate))\n", + " # display(Audio(segment_waveform, rate=inference.sample_rate))\n", "\n", " segment_text = inference.transcribe(waveform=segment_waveform)\n", - " print(segment_text, end=\" \", flush=True)" + " print(segment_text, end=\"| \", flush=True)\n", + "cuda.empty_cache()" ] }, { diff --git a/src/tokenizer.py b/src/tokenizer.py index 5159bc2..2df1fc4 100644 --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -3,7 +3,6 @@ from pathlib import Path import re import torch from torch import Tensor - from handle.text_handle import process_text from handle.text_normalizer import UYGHUR_LETTERS diff --git a/src/train.py b/src/train.py index 68ffa01..f1fa557 100644 --- a/src/train.py +++ b/src/train.py @@ -31,11 +31,11 @@ CONFIG = { 'input_dim': 256, 'num_heads': 8, 'ffn_dim': 2048, - 'num_layers': 8, - 'dropout': 0.1, + 'num_layers': 6, + 'dropout': 0.15, # 保存和评估 - 'early_stopping_patience': 12, + 'early_stopping_patience': 10, } @@ -168,17 +168,13 @@ def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device avg_cer = total_cer / num_samples avg_wer = total_wer / num_samples - writer.add_scalar('Val/Loss', avg_loss, global_step) - writer.add_scalar('Val/CER', avg_cer, global_step) - writer.add_scalar('Val/WER', avg_wer, global_step) - for index, (true_text, pred_text, cer, wer) in enumerate(examples): writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step) model.train() return avg_loss, avg_cer, avg_wer -def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path): +def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, current_prob: float, save_path: Path): checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), @@ -189,6 +185,7 @@ def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR 'epoch': epoch, 'cer': cer, 'wer': wer, + 'current_prob': current_prob, } torch.save(checkpoint, save_path) @@ -206,14 +203,15 @@ def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, schedule global_step = checkpoint['global_step'] best_cer = checkpoint['cer'] best_wer = checkpoint['wer'] - return epoch, global_step, best_cer, best_wer + current_prob = checkpoint['current_prob'] + return epoch, global_step, best_cer, best_wer, current_prob def main(): workspace_dir = Path(__file__).parent.parent device = torch.device('cuda:0') tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json') - final_prob = 0.5 - warmup_epochs = 8 + final_prob = 0.8 + warmup_epochs = 12 current_prob = 0.0 # ============ 创建数据加载器 ============ @@ -221,6 +219,7 @@ def main(): tsv_path=workspace_dir / '.data/ug/train_new.tsv', audio_dir=workspace_dir / '.data/ug/clips', noise_dir='/mnt/dataset/dataset/audio/noise', + corridor_noise_dir= workspace_dir / 'data/corridor', tokenizer=tokenizer, batch_size=CONFIG['batch_size'], shuffle=True, @@ -228,15 +227,28 @@ def main(): augment_prob=current_prob ) - val_loader = create_dataloader( + val_loader_clean = create_dataloader( tsv_path=workspace_dir / '.data/ug/val_new.tsv', audio_dir=workspace_dir / '.data/ug/clips', noise_dir='/mnt/dataset/dataset/audio/noise', + corridor_noise_dir= workspace_dir / 'data/corridor', tokenizer=tokenizer, batch_size=CONFIG['batch_size'], shuffle=False, augment=False ) + + val_loader_noisy = create_dataloader( + tsv_path=workspace_dir / '.data/ug/val_new.tsv', + audio_dir=workspace_dir / '.data/ug/clips', + noise_dir='/mnt/dataset/dataset/audio/noise', + corridor_noise_dir= workspace_dir / 'data/corridor', + tokenizer=tokenizer, + batch_size=CONFIG['batch_size'], + shuffle=False, + augment=True, + augment_prob=0.5 + ) # ============ 初始化模型 ============ model = ASRModel( @@ -259,9 +271,9 @@ def main(): max_lr=CONFIG['learning_rate'], epochs=CONFIG['num_epochs'], steps_per_epoch=len(train_loader), - pct_start=0.15, + pct_start=0.2, anneal_strategy='cos', - div_factor=10.0, # 初始 lr = max_lr / 10 + div_factor=25.0, # 初始 lr = max_lr / 10 final_div_factor=1e4, # 最终 lr = max_lr / 10000 ) scaler = GradScaler() @@ -286,7 +298,7 @@ def main(): latest = find_latest_checkpoint(checkpoint_dir) if latest: - start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler) + start_epoch, global_step, best_cer, best_wer, current_prob = load_checkpoint(latest, model, optimizer, scheduler) start_epoch += 1 # 从下一个 epoch 开始 print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}") else: @@ -308,25 +320,26 @@ def main(): scaler=scaler, ) - val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step) - print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}") + val_loss_c, val_cer_c, val_wer_c = validate(model=model, dataloader=val_loader_clean, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step) + val_loss_n, val_cer_n, val_wer_n = validate(model=model, dataloader=val_loader_noisy, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step) + print(f"\n📊 Step {global_step} | Clean Val Loss: {val_loss_c:.4f} | Clean Val CER: {val_cer_c:.4f} | Clean Val WER: {val_wer_c:.4f}") # 保存常规 checkpoint checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt' - save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path) + save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=checkpoint_path) # 分别保存最佳 CER 和 WER 模型 improved = False - if val_cer < best_cer: - best_cer = val_cer + if val_cer_c < best_cer: + best_cer = val_cer_c best_cer_path = checkpoint_dir / 'best_cer_model.pt' - save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path) + save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_cer_path) improved = True - if val_wer < best_wer: - best_wer = val_wer + if val_wer_c < best_wer: + best_wer = val_wer_c best_wer_path = checkpoint_dir / 'best_wer_model.pt' - save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path) + save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_wer_path) improved = True # Early Stopping 逻辑 @@ -343,9 +356,16 @@ def main(): cuda.empty_cache() - writer.add_scalar('Val/EpochLoss', val_loss, epoch) writer.add_scalar('Train/EpochLoss', train_loss, epoch) - print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n") + writer.add_scalar('Val/EpochLoss', val_loss_c, epoch) + writer.add_scalar('Val/Loss', val_loss_c, global_step) + writer.add_scalar('Val/CER', val_cer_c, global_step) + writer.add_scalar('Val/WER', val_wer_c, global_step) + writer.add_scalar('Val_Noisy/EpochLoss', val_loss_n, epoch) + writer.add_scalar('Val_Noisy/Loss', val_loss_n, global_step) + writer.add_scalar('Val_Noisy/CER', val_cer_n, global_step) + writer.add_scalar('Val_Noisy/WER', val_wer_n, global_step) + print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss_c:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n") # Early Stopping 检查 if patience_counter >= CONFIG['early_stopping_patience']: