add augment
This commit is contained in:
@@ -50,7 +50,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
n_mels: int = 80 * 4,
|
n_mels: int = 80 * 4,
|
||||||
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
||||||
augment: bool = True, # 是否启用数据增强
|
augment: bool = True,
|
||||||
augment_prob: float = 0.5, # 数据增强的概率
|
augment_prob: float = 0.5, # 数据增强的概率
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -124,9 +124,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
waveform = self._voice_stretch_or_compress(waveform=waveform)
|
waveform = self._voice_stretch_or_compress(waveform=waveform)
|
||||||
|
|
||||||
if random.random() < 0.3:
|
|
||||||
waveform = self._drop_frames(waveform)
|
|
||||||
|
|
||||||
if random.random() < 0.4:
|
if random.random() < 0.4:
|
||||||
waveform = self._add_noise(waveform)
|
waveform = self._add_noise(waveform)
|
||||||
|
|
||||||
@@ -160,18 +157,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
|
|
||||||
return waveform_stretched
|
return waveform_stretched
|
||||||
|
|
||||||
def _drop_frames(self, waveform: Tensor) -> Tensor:
|
|
||||||
audio_len = waveform.shape[1]
|
|
||||||
drop_ratio = random.uniform(0.05, 0.15)
|
|
||||||
drop_len = int(audio_len * drop_ratio)
|
|
||||||
|
|
||||||
if audio_len > drop_len:
|
|
||||||
start_pos = random.randint(0, audio_len - drop_len)
|
|
||||||
# clean
|
|
||||||
waveform = torch.cat([waveform[:, :start_pos], waveform[:, start_pos + drop_len:]], dim=1)
|
|
||||||
|
|
||||||
return waveform
|
|
||||||
|
|
||||||
def _add_noise(self, waveform: Tensor, snr_db: float = None) -> Tensor:
|
def _add_noise(self, waveform: Tensor, snr_db: float = None) -> Tensor:
|
||||||
if snr_db is None:
|
if snr_db is None:
|
||||||
snr_db = random.uniform(15, 25)
|
snr_db = random.uniform(15, 25)
|
||||||
@@ -192,6 +177,13 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
|
|
||||||
return noisy_waveform
|
return noisy_waveform
|
||||||
|
|
||||||
|
def _augment_spec(self, mel_spec: Tensor) -> Tensor:
|
||||||
|
if not self.augment or random.random() > self.augment_prob:
|
||||||
|
return mel_spec
|
||||||
|
mel_spec = self.time_masking(mel_spec)
|
||||||
|
mel_spec = self.freq_masking(mel_spec)
|
||||||
|
return mel_spec
|
||||||
|
|
||||||
def __getitem__(self, index) -> BatchItem:
|
def __getitem__(self, index) -> BatchItem:
|
||||||
row: TsvFormat = self.data.iloc[index]
|
row: TsvFormat = self.data.iloc[index]
|
||||||
audio_path: Path = self.audio_dir / row['path']
|
audio_path: Path = self.audio_dir / row['path']
|
||||||
@@ -199,10 +191,9 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
||||||
|
|
||||||
waveform = self._load_audio(audio_path=audio_path)
|
waveform = self._load_audio(audio_path=audio_path)
|
||||||
|
waveform = self._augment_waveform(waveform)
|
||||||
# waveform = self._augment_waveform(waveform)
|
|
||||||
|
|
||||||
mel_spec = self._extract_features(waveform=waveform)
|
mel_spec = self._extract_features(waveform=waveform)
|
||||||
|
mel_spec = self._augment_spec(mel_spec=mel_spec)
|
||||||
|
|
||||||
return BatchItem(
|
return BatchItem(
|
||||||
mel_spec=mel_spec,
|
mel_spec=mel_spec,
|
||||||
|
|||||||
@@ -213,16 +213,17 @@ def main():
|
|||||||
|
|
||||||
# ============ 创建数据加载器 ============
|
# ============ 创建数据加载器 ============
|
||||||
train_loader = create_dataloader(
|
train_loader = create_dataloader(
|
||||||
tsv_path=workspace_dir / '.data/train.tsv',
|
tsv_path=workspace_dir / '.data/ug/train.tsv',
|
||||||
audio_dir=workspace_dir / '.data/clips',
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=CONFIG['batch_size'],
|
batch_size=CONFIG['batch_size'],
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
augment=True
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = create_dataloader(
|
val_loader = create_dataloader(
|
||||||
tsv_path=workspace_dir / '.data/dev.tsv',
|
tsv_path=workspace_dir / '.data/ug/dev.tsv',
|
||||||
audio_dir=workspace_dir / '.data/clips',
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=CONFIG['batch_size'],
|
batch_size=CONFIG['batch_size'],
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user