add augment
This commit is contained in:
@@ -50,7 +50,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
sample_rate: int = 16000,
|
||||
n_mels: int = 80 * 4,
|
||||
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
||||
augment: bool = True, # 是否启用数据增强
|
||||
augment: bool = True,
|
||||
augment_prob: float = 0.5, # 数据增强的概率
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -124,9 +124,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
if random.random() < 0.5:
|
||||
waveform = self._voice_stretch_or_compress(waveform=waveform)
|
||||
|
||||
if random.random() < 0.3:
|
||||
waveform = self._drop_frames(waveform)
|
||||
|
||||
if random.random() < 0.4:
|
||||
waveform = self._add_noise(waveform)
|
||||
|
||||
@@ -160,18 +157,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
|
||||
return waveform_stretched
|
||||
|
||||
def _drop_frames(self, waveform: Tensor) -> Tensor:
|
||||
audio_len = waveform.shape[1]
|
||||
drop_ratio = random.uniform(0.05, 0.15)
|
||||
drop_len = int(audio_len * drop_ratio)
|
||||
|
||||
if audio_len > drop_len:
|
||||
start_pos = random.randint(0, audio_len - drop_len)
|
||||
# clean
|
||||
waveform = torch.cat([waveform[:, :start_pos], waveform[:, start_pos + drop_len:]], dim=1)
|
||||
|
||||
return waveform
|
||||
|
||||
def _add_noise(self, waveform: Tensor, snr_db: float = None) -> Tensor:
|
||||
if snr_db is None:
|
||||
snr_db = random.uniform(15, 25)
|
||||
@@ -192,6 +177,13 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
|
||||
return noisy_waveform
|
||||
|
||||
def _augment_spec(self, mel_spec: Tensor) -> Tensor:
|
||||
if not self.augment or random.random() > self.augment_prob:
|
||||
return mel_spec
|
||||
mel_spec = self.time_masking(mel_spec)
|
||||
mel_spec = self.freq_masking(mel_spec)
|
||||
return mel_spec
|
||||
|
||||
def __getitem__(self, index) -> BatchItem:
|
||||
row: TsvFormat = self.data.iloc[index]
|
||||
audio_path: Path = self.audio_dir / row['path']
|
||||
@@ -199,10 +191,9 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
||||
|
||||
waveform = self._load_audio(audio_path=audio_path)
|
||||
|
||||
# waveform = self._augment_waveform(waveform)
|
||||
|
||||
waveform = self._augment_waveform(waveform)
|
||||
mel_spec = self._extract_features(waveform=waveform)
|
||||
mel_spec = self._augment_spec(mel_spec=mel_spec)
|
||||
|
||||
return BatchItem(
|
||||
mel_spec=mel_spec,
|
||||
|
||||
@@ -213,16 +213,17 @@ def main():
|
||||
|
||||
# ============ 创建数据加载器 ============
|
||||
train_loader = create_dataloader(
|
||||
tsv_path=workspace_dir / '.data/train.tsv',
|
||||
audio_dir=workspace_dir / '.data/clips',
|
||||
tsv_path=workspace_dir / '.data/ug/train.tsv',
|
||||
audio_dir=workspace_dir / '.data/ug/clips',
|
||||
tokenizer=tokenizer,
|
||||
batch_size=CONFIG['batch_size'],
|
||||
shuffle=True,
|
||||
augment=True
|
||||
)
|
||||
|
||||
val_loader = create_dataloader(
|
||||
tsv_path=workspace_dir / '.data/dev.tsv',
|
||||
audio_dir=workspace_dir / '.data/clips',
|
||||
tsv_path=workspace_dir / '.data/ug/dev.tsv',
|
||||
audio_dir=workspace_dir / '.data/ug/clips',
|
||||
tokenizer=tokenizer,
|
||||
batch_size=CONFIG['batch_size'],
|
||||
shuffle=False,
|
||||
|
||||
Reference in New Issue
Block a user