diff --git a/src/dataset.py b/src/dataset.py index b359a93..6b05fc0 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -50,7 +50,7 @@ class CommonVoiceDataset(Dataset[BatchItem]): sample_rate: int = 16000, n_mels: int = 80 * 4, max_audio_len: int = 480000, # 30秒 @ 16kHz - augment: bool = True, # 是否启用数据增强 + augment: bool = True, augment_prob: float = 0.5, # 数据增强的概率 ) -> None: super().__init__() @@ -124,9 +124,6 @@ class CommonVoiceDataset(Dataset[BatchItem]): if random.random() < 0.5: waveform = self._voice_stretch_or_compress(waveform=waveform) - if random.random() < 0.3: - waveform = self._drop_frames(waveform) - if random.random() < 0.4: waveform = self._add_noise(waveform) @@ -160,18 +157,6 @@ class CommonVoiceDataset(Dataset[BatchItem]): return waveform_stretched - def _drop_frames(self, waveform: Tensor) -> Tensor: - audio_len = waveform.shape[1] - drop_ratio = random.uniform(0.05, 0.15) - drop_len = int(audio_len * drop_ratio) - - if audio_len > drop_len: - start_pos = random.randint(0, audio_len - drop_len) - # clean - waveform = torch.cat([waveform[:, :start_pos], waveform[:, start_pos + drop_len:]], dim=1) - - return waveform - def _add_noise(self, waveform: Tensor, snr_db: float = None) -> Tensor: if snr_db is None: snr_db = random.uniform(15, 25) @@ -192,6 +177,13 @@ class CommonVoiceDataset(Dataset[BatchItem]): return noisy_waveform + def _augment_spec(self, mel_spec: Tensor) -> Tensor: + if not self.augment or random.random() > self.augment_prob: + return mel_spec + mel_spec = self.time_masking(mel_spec) + mel_spec = self.freq_masking(mel_spec) + return mel_spec + def __getitem__(self, index) -> BatchItem: row: TsvFormat = self.data.iloc[index] audio_path: Path = self.audio_dir / row['path'] @@ -199,10 +191,9 @@ class CommonVoiceDataset(Dataset[BatchItem]): text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip())) waveform = self._load_audio(audio_path=audio_path) - - # waveform = self._augment_waveform(waveform) - + waveform = self._augment_waveform(waveform) mel_spec = self._extract_features(waveform=waveform) + mel_spec = self._augment_spec(mel_spec=mel_spec) return BatchItem( mel_spec=mel_spec, diff --git a/src/train.py b/src/train.py index 687859e..6235fd1 100644 --- a/src/train.py +++ b/src/train.py @@ -213,16 +213,17 @@ def main(): # ============ 创建数据加载器 ============ train_loader = create_dataloader( - tsv_path=workspace_dir / '.data/train.tsv', - audio_dir=workspace_dir / '.data/clips', + tsv_path=workspace_dir / '.data/ug/train.tsv', + audio_dir=workspace_dir / '.data/ug/clips', tokenizer=tokenizer, batch_size=CONFIG['batch_size'], shuffle=True, + augment=True ) val_loader = create_dataloader( - tsv_path=workspace_dir / '.data/dev.tsv', - audio_dir=workspace_dir / '.data/clips', + tsv_path=workspace_dir / '.data/ug/dev.tsv', + audio_dir=workspace_dir / '.data/ug/clips', tokenizer=tokenizer, batch_size=CONFIG['batch_size'], shuffle=False,