add augment noise and num layer=6
This commit is contained in:
143
src/dataset.py
143
src/dataset.py
@@ -4,7 +4,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter
|
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter, PolarityInversion
|
||||||
from torchaudio.transforms import Resample, TimeStretch
|
from torchaudio.transforms import Resample, TimeStretch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -94,6 +94,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
tsv_path: Path,
|
tsv_path: Path,
|
||||||
audio_dir: Path,
|
audio_dir: Path,
|
||||||
noise_dir: Path,
|
noise_dir: Path,
|
||||||
|
corridor_noise_dir: Path,
|
||||||
tokenizer: ASRTokenizer,
|
tokenizer: ASRTokenizer,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
||||||
@@ -123,9 +124,10 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')
|
self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')
|
||||||
self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
||||||
self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
||||||
self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
self.lowpass = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||||
self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
self.highpass = HighPassFilter(min_cutoff_freq=1000, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||||
self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type="tensor")
|
self.apply_ir = ApplyImpulseResponse(ir_paths=corridor_noise_dir, convolve_mode='same', p=1, output_type="tensor")
|
||||||
|
self.polarity_inversion = PolarityInversion(p=1.0, output_type="tensor")
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@@ -157,7 +159,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
if random.random() < 0.6:
|
if random.random() < 0.6:
|
||||||
waveform = self._stretch_or_compress(waveform=waveform)
|
waveform = self._stretch_or_compress(waveform=waveform)
|
||||||
|
|
||||||
if random.random() < 0.4:
|
if random.random() < 0.7:
|
||||||
waveform = self.noise_augmentor.apply_real_noise(waveform)
|
waveform = self.noise_augmentor.apply_real_noise(waveform)
|
||||||
|
|
||||||
if random.random() < 0.3:
|
if random.random() < 0.3:
|
||||||
@@ -166,29 +168,30 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
# torch_audiomentations: [1, time] -> [1, 1, time]
|
# torch_audiomentations: [1, time] -> [1, 1, time]
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform_3d = waveform.unsqueeze(0)
|
waveform_3d = waveform.unsqueeze(0)
|
||||||
# 随机选择一种频谱增强
|
# 随机选择一种物理特性增强 (互斥区)
|
||||||
choice = random.random()
|
choice = random.random()
|
||||||
if choice < 0.15:
|
if choice < 0.25: # [0.00 - 0.25] 25% 概率:增益
|
||||||
# 增益变化(上或下)
|
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)
|
waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)
|
||||||
else:
|
else:
|
||||||
waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)
|
waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)
|
||||||
elif choice < 0.25:
|
elif choice < 0.50: # [0.25 - 0.50] 25% 概率:音高
|
||||||
# 音高变化(上或下)
|
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)
|
waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)
|
||||||
else:
|
else:
|
||||||
waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)
|
waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)
|
||||||
elif choice < 0.30:
|
elif choice < 0.70: # [0.50 - 0.70] 20% 概率:低通
|
||||||
# 低通滤波(声音发闷)
|
|
||||||
waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)
|
waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)
|
||||||
# elif choice < 0.32:
|
elif choice < 0.85: # [0.70 - 0.85] 15% 概率:高通
|
||||||
# # 低通滤波(声音发闷)
|
|
||||||
# waveform_3d = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
|
|
||||||
elif choice < 0.35:
|
|
||||||
# 高通滤波(电话效果)
|
|
||||||
waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)
|
waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
elif choice < 0.95: # [0.85 - 0.95] 10% 概率:走廊混响 (IR)
|
||||||
|
# 使用你测试过最好的 0.8/0.2 比例
|
||||||
|
dry = waveform_3d.clone()
|
||||||
|
wet = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
waveform_3d = 0.8 * dry + 0.2 * wet
|
||||||
|
else: # [0.95 - 1.00] 5% 概率:极性翻转
|
||||||
|
waveform_3d = self.polarity_inversion(waveform_3d, sample_rate=self.sample_rate)
|
||||||
|
|
||||||
|
|
||||||
# [1, 1, time] -> [1, time]
|
# [1, 1, time] -> [1, time]
|
||||||
waveform = waveform_3d.squeeze(0)
|
waveform = waveform_3d.squeeze(0)
|
||||||
@@ -201,7 +204,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
def _stretch_or_compress(self, waveform: Tensor) -> Tensor:
|
def _stretch_or_compress(self, waveform: Tensor) -> Tensor:
|
||||||
speed_factor = random.uniform(0.85, 1.4) # (Speed Change: 0.85x - 1.4x)
|
speed_factor = random.uniform(0.80, 1.6) # (Speed Change: 0.85x - 1.4x)
|
||||||
spec = torch.stft(
|
spec = torch.stft(
|
||||||
waveform.squeeze(0),
|
waveform.squeeze(0),
|
||||||
n_fft=400,
|
n_fft=400,
|
||||||
@@ -246,7 +249,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
|||||||
def __getitem__(self, index) -> BatchItem:
|
def __getitem__(self, index) -> BatchItem:
|
||||||
row: TsvFormat = self.data.iloc[index]
|
row: TsvFormat = self.data.iloc[index]
|
||||||
audio_path: Path = self.audio_dir / row['path']
|
audio_path: Path = self.audio_dir / row['path']
|
||||||
# text: str = unicodedata.normalize('NFC', row['sentence'].strip())
|
|
||||||
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
||||||
|
|
||||||
waveform = self._load_audio(audio_path=audio_path)
|
waveform = self._load_audio(audio_path=audio_path)
|
||||||
@@ -297,8 +299,8 @@ def collate_fn(items: List[BatchItem]) -> Batch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
|
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, corridor_noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
|
||||||
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
|
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, corridor_noise_dir=corridor_noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
|
||||||
|
|
||||||
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
|
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
|
||||||
|
|
||||||
@@ -331,101 +333,4 @@ if __name__ == "__main__":
|
|||||||
print(f"Target lengths: {batch['target_lengths']}")
|
print(f"Target lengths: {batch['target_lengths']}")
|
||||||
print(f"Target texts: {batch['target_texts']}")
|
print(f"Target texts: {batch['target_texts']}")
|
||||||
print(f"Audio paths: {batch['audio_paths']}")
|
print(f"Audio paths: {batch['audio_paths']}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
|
||||||
|
|
||||||
# --- 1. 加载 VAD 模型 ---
|
|
||||||
# 推荐使用 GPU (cuda) 如果可用,否则使用 CPU (cpu)
|
|
||||||
vad_model = load_silero_vad(source="local", force_onnx=False)
|
|
||||||
|
|
||||||
def process_audio_with_vad_and_asr(audio_path, inference_module):
|
|
||||||
"""
|
|
||||||
使用 Silero VAD 分割音频,然后对每个语音片段进行 ASR 转录。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_path (str): 输入音频文件的路径。
|
|
||||||
inference_module: 您封装了 transcribe 方法的模块或对象。
|
|
||||||
需要确保其 _load_audio 和 transcribe 方法可用。
|
|
||||||
"""
|
|
||||||
print(f"正在加载音频: {audio_path}")
|
|
||||||
|
|
||||||
# --- 2. 加载音频用于 VAD 检测 ---
|
|
||||||
# Silero VAD 推荐使用 16kHz 采样率
|
|
||||||
waveform_vad, sample_rate_vad = torchaudio.load(audio_path)
|
|
||||||
if sample_rate_vad != 16000:
|
|
||||||
# 如果采样率不是16kHz,需要重采样以供VAD使用
|
|
||||||
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_vad, new_freq=16000)
|
|
||||||
waveform_vad = resampler(waveform_vad)
|
|
||||||
sample_rate_vad = 16000
|
|
||||||
|
|
||||||
# --- 3. 获取语音时间段 ---
|
|
||||||
# get_speech_timestamps 返回一个列表,每个元素是 {'start': start_sample, 'end': end_sample} 的字典
|
|
||||||
# 采样率是16000,所以时间戳单位是 1/16000 秒
|
|
||||||
speech_timestamps = get_speech_timestamps(
|
|
||||||
waveform_vad,
|
|
||||||
vad_model,
|
|
||||||
sampling_rate=sample_rate_vad,
|
|
||||||
threshold=0.5, # 可以根据需要调整阈值
|
|
||||||
min_speech_duration_ms=250, # 最小语音持续时间,防止短噪音被误判
|
|
||||||
max_speech_duration_s=float('inf'), # 最大语音持续时间,float('inf') 表示不限制
|
|
||||||
min_silence_duration_ms=100, # 最小静音间隔,用于分割语音块
|
|
||||||
window_size_samples=1536, # VAD窗口大小
|
|
||||||
speech_pad_ms=30 # 在语音块前后添加的填充时间
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"检测到 {len(speech_timestamps)} 个语音片段。")
|
|
||||||
|
|
||||||
full_text = ""
|
|
||||||
for i, ts in enumerate(speech_timestamps):
|
|
||||||
start_sample = int(ts['start'])
|
|
||||||
end_sample = int(ts['end'])
|
|
||||||
|
|
||||||
# 计算时间戳(秒)
|
|
||||||
start_time = start_sample / sample_rate_vad
|
|
||||||
end_time = end_sample / sample_rate_vad
|
|
||||||
|
|
||||||
print(f"\n处理第 {i+1} 个片段: 时间范围 [{start_time:.2f}s - {end_time:.2f}s]")
|
|
||||||
|
|
||||||
# --- 4. 从原始音频中提取此片段 ---
|
|
||||||
# 注意:这里假设您的 transcribe 函数可以接受 waveform 张量。
|
|
||||||
# 我们需要从原始可能不同采样率的音频中提取片段,或者用VAD处理过的waveform。
|
|
||||||
# 为了匹配您原来的 _load_audio 方式,我们用 torchaudio 再次精确加载片段。
|
|
||||||
|
|
||||||
# 计算原始音频中的样本索引(如果原始音频采样率与VAD不同)
|
|
||||||
original_waveform, original_sr = torchaudio.load(audio_path)
|
|
||||||
if sample_rate_vad != original_sr:
|
|
||||||
# 如果VAD和原始音频采样率不同,需要重新映射索引
|
|
||||||
start_idx_orig = int(start_sample * (original_sr / sample_rate_vad))
|
|
||||||
end_idx_orig = int(end_sample * (original_sr / sample_rate_vad))
|
|
||||||
else:
|
|
||||||
start_idx_orig = start_sample
|
|
||||||
end_idx_orig = end_sample
|
|
||||||
|
|
||||||
segment_waveform = original_waveform[:, start_idx_orig:end_idx_orig]
|
|
||||||
|
|
||||||
# --- 5. 对该片段进行 ASR 转录 ---
|
|
||||||
# 这里调用您原有的 transcribe 方法
|
|
||||||
try:
|
|
||||||
# 假设 transcribe 方法接受一个 waveform tensor
|
|
||||||
segment_text = inference_module.transcribe(waveform=segment_waveform)
|
|
||||||
|
|
||||||
print(f" -> 转录结果: {segment_text}")
|
|
||||||
full_text += f"[{start_time:.2f}-{end_time:.2f}s] {segment_text}\n"
|
|
||||||
except Exception as e:
|
|
||||||
print(f" -> 转录第 {i+1} 个片段时出错: {e}")
|
|
||||||
|
|
||||||
print("\n--- 完整转录结果 ---")
|
|
||||||
print(full_text)
|
|
||||||
|
|
||||||
|
|
||||||
# --- 使用示例 ---
|
|
||||||
# 假设您有一个名为 'inference' 的模块对象,它有 _load_audio 和 transcribe 方法
|
|
||||||
# audio_path = "your_audio_file.wav"
|
|
||||||
# process_audio_with_vad_and_asr(audio_path, inference)
|
|
||||||
File diff suppressed because one or more lines are too long
@@ -35,7 +35,7 @@ def analyze_mp3_files(directory: Path, ):
|
|||||||
"""分析目录中的所有MP3文件"""
|
"""分析目录中的所有MP3文件"""
|
||||||
results = []
|
results = []
|
||||||
total_duration = 0
|
total_duration = 0
|
||||||
df = pd.read_csv(workspace_dir / '.data/ug/train_new.tsv', sep='\t')
|
df = pd.read_csv(workspace_dir / '.data/ug/val_new.tsv', sep='\t')
|
||||||
|
|
||||||
for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"):
|
for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"):
|
||||||
mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path']
|
mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path']
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -3,7 +3,6 @@ from pathlib import Path
|
|||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from handle.text_handle import process_text
|
from handle.text_handle import process_text
|
||||||
from handle.text_normalizer import UYGHUR_LETTERS
|
from handle.text_normalizer import UYGHUR_LETTERS
|
||||||
|
|
||||||
|
|||||||
72
src/train.py
72
src/train.py
@@ -31,11 +31,11 @@ CONFIG = {
|
|||||||
'input_dim': 256,
|
'input_dim': 256,
|
||||||
'num_heads': 8,
|
'num_heads': 8,
|
||||||
'ffn_dim': 2048,
|
'ffn_dim': 2048,
|
||||||
'num_layers': 8,
|
'num_layers': 6,
|
||||||
'dropout': 0.1,
|
'dropout': 0.15,
|
||||||
|
|
||||||
# 保存和评估
|
# 保存和评估
|
||||||
'early_stopping_patience': 12,
|
'early_stopping_patience': 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -168,17 +168,13 @@ def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device
|
|||||||
avg_cer = total_cer / num_samples
|
avg_cer = total_cer / num_samples
|
||||||
avg_wer = total_wer / num_samples
|
avg_wer = total_wer / num_samples
|
||||||
|
|
||||||
writer.add_scalar('Val/Loss', avg_loss, global_step)
|
|
||||||
writer.add_scalar('Val/CER', avg_cer, global_step)
|
|
||||||
writer.add_scalar('Val/WER', avg_wer, global_step)
|
|
||||||
|
|
||||||
for index, (true_text, pred_text, cer, wer) in enumerate(examples):
|
for index, (true_text, pred_text, cer, wer) in enumerate(examples):
|
||||||
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
|
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
return avg_loss, avg_cer, avg_wer
|
return avg_loss, avg_cer, avg_wer
|
||||||
|
|
||||||
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path):
|
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, current_prob: float, save_path: Path):
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'model_state_dict': model.state_dict(),
|
'model_state_dict': model.state_dict(),
|
||||||
'optimizer_state_dict': optimizer.state_dict(),
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
@@ -189,6 +185,7 @@ def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR
|
|||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
'cer': cer,
|
'cer': cer,
|
||||||
'wer': wer,
|
'wer': wer,
|
||||||
|
'current_prob': current_prob,
|
||||||
}
|
}
|
||||||
torch.save(checkpoint, save_path)
|
torch.save(checkpoint, save_path)
|
||||||
|
|
||||||
@@ -206,14 +203,15 @@ def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, schedule
|
|||||||
global_step = checkpoint['global_step']
|
global_step = checkpoint['global_step']
|
||||||
best_cer = checkpoint['cer']
|
best_cer = checkpoint['cer']
|
||||||
best_wer = checkpoint['wer']
|
best_wer = checkpoint['wer']
|
||||||
return epoch, global_step, best_cer, best_wer
|
current_prob = checkpoint['current_prob']
|
||||||
|
return epoch, global_step, best_cer, best_wer, current_prob
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
workspace_dir = Path(__file__).parent.parent
|
workspace_dir = Path(__file__).parent.parent
|
||||||
device = torch.device('cuda:0')
|
device = torch.device('cuda:0')
|
||||||
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
|
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
|
||||||
final_prob = 0.5
|
final_prob = 0.8
|
||||||
warmup_epochs = 8
|
warmup_epochs = 12
|
||||||
current_prob = 0.0
|
current_prob = 0.0
|
||||||
|
|
||||||
# ============ 创建数据加载器 ============
|
# ============ 创建数据加载器 ============
|
||||||
@@ -221,6 +219,7 @@ def main():
|
|||||||
tsv_path=workspace_dir / '.data/ug/train_new.tsv',
|
tsv_path=workspace_dir / '.data/ug/train_new.tsv',
|
||||||
audio_dir=workspace_dir / '.data/ug/clips',
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
noise_dir='/mnt/dataset/dataset/audio/noise',
|
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||||
|
corridor_noise_dir= workspace_dir / 'data/corridor',
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=CONFIG['batch_size'],
|
batch_size=CONFIG['batch_size'],
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
@@ -228,15 +227,28 @@ def main():
|
|||||||
augment_prob=current_prob
|
augment_prob=current_prob
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loader = create_dataloader(
|
val_loader_clean = create_dataloader(
|
||||||
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
|
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
|
||||||
audio_dir=workspace_dir / '.data/ug/clips',
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
noise_dir='/mnt/dataset/dataset/audio/noise',
|
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||||
|
corridor_noise_dir= workspace_dir / 'data/corridor',
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
batch_size=CONFIG['batch_size'],
|
batch_size=CONFIG['batch_size'],
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
augment=False
|
augment=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val_loader_noisy = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
|
||||||
|
audio_dir=workspace_dir / '.data/ug/clips',
|
||||||
|
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||||
|
corridor_noise_dir= workspace_dir / 'data/corridor',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=CONFIG['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
augment=True,
|
||||||
|
augment_prob=0.5
|
||||||
|
)
|
||||||
|
|
||||||
# ============ 初始化模型 ============
|
# ============ 初始化模型 ============
|
||||||
model = ASRModel(
|
model = ASRModel(
|
||||||
@@ -259,9 +271,9 @@ def main():
|
|||||||
max_lr=CONFIG['learning_rate'],
|
max_lr=CONFIG['learning_rate'],
|
||||||
epochs=CONFIG['num_epochs'],
|
epochs=CONFIG['num_epochs'],
|
||||||
steps_per_epoch=len(train_loader),
|
steps_per_epoch=len(train_loader),
|
||||||
pct_start=0.15,
|
pct_start=0.2,
|
||||||
anneal_strategy='cos',
|
anneal_strategy='cos',
|
||||||
div_factor=10.0, # 初始 lr = max_lr / 10
|
div_factor=25.0, # 初始 lr = max_lr / 10
|
||||||
final_div_factor=1e4, # 最终 lr = max_lr / 10000
|
final_div_factor=1e4, # 最终 lr = max_lr / 10000
|
||||||
)
|
)
|
||||||
scaler = GradScaler()
|
scaler = GradScaler()
|
||||||
@@ -286,7 +298,7 @@ def main():
|
|||||||
|
|
||||||
latest = find_latest_checkpoint(checkpoint_dir)
|
latest = find_latest_checkpoint(checkpoint_dir)
|
||||||
if latest:
|
if latest:
|
||||||
start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler)
|
start_epoch, global_step, best_cer, best_wer, current_prob = load_checkpoint(latest, model, optimizer, scheduler)
|
||||||
start_epoch += 1 # 从下一个 epoch 开始
|
start_epoch += 1 # 从下一个 epoch 开始
|
||||||
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
|
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
|
||||||
else:
|
else:
|
||||||
@@ -308,25 +320,26 @@ def main():
|
|||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
val_loss_c, val_cer_c, val_wer_c = validate(model=model, dataloader=val_loader_clean, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||||
print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}")
|
val_loss_n, val_cer_n, val_wer_n = validate(model=model, dataloader=val_loader_noisy, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||||
|
print(f"\n📊 Step {global_step} | Clean Val Loss: {val_loss_c:.4f} | Clean Val CER: {val_cer_c:.4f} | Clean Val WER: {val_wer_c:.4f}")
|
||||||
|
|
||||||
# 保存常规 checkpoint
|
# 保存常规 checkpoint
|
||||||
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
||||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path)
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=checkpoint_path)
|
||||||
|
|
||||||
# 分别保存最佳 CER 和 WER 模型
|
# 分别保存最佳 CER 和 WER 模型
|
||||||
improved = False
|
improved = False
|
||||||
if val_cer < best_cer:
|
if val_cer_c < best_cer:
|
||||||
best_cer = val_cer
|
best_cer = val_cer_c
|
||||||
best_cer_path = checkpoint_dir / 'best_cer_model.pt'
|
best_cer_path = checkpoint_dir / 'best_cer_model.pt'
|
||||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path)
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_cer_path)
|
||||||
improved = True
|
improved = True
|
||||||
|
|
||||||
if val_wer < best_wer:
|
if val_wer_c < best_wer:
|
||||||
best_wer = val_wer
|
best_wer = val_wer_c
|
||||||
best_wer_path = checkpoint_dir / 'best_wer_model.pt'
|
best_wer_path = checkpoint_dir / 'best_wer_model.pt'
|
||||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path)
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_wer_path)
|
||||||
improved = True
|
improved = True
|
||||||
|
|
||||||
# Early Stopping 逻辑
|
# Early Stopping 逻辑
|
||||||
@@ -343,9 +356,16 @@ def main():
|
|||||||
|
|
||||||
cuda.empty_cache()
|
cuda.empty_cache()
|
||||||
|
|
||||||
writer.add_scalar('Val/EpochLoss', val_loss, epoch)
|
|
||||||
writer.add_scalar('Train/EpochLoss', train_loss, epoch)
|
writer.add_scalar('Train/EpochLoss', train_loss, epoch)
|
||||||
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
|
writer.add_scalar('Val/EpochLoss', val_loss_c, epoch)
|
||||||
|
writer.add_scalar('Val/Loss', val_loss_c, global_step)
|
||||||
|
writer.add_scalar('Val/CER', val_cer_c, global_step)
|
||||||
|
writer.add_scalar('Val/WER', val_wer_c, global_step)
|
||||||
|
writer.add_scalar('Val_Noisy/EpochLoss', val_loss_n, epoch)
|
||||||
|
writer.add_scalar('Val_Noisy/Loss', val_loss_n, global_step)
|
||||||
|
writer.add_scalar('Val_Noisy/CER', val_cer_n, global_step)
|
||||||
|
writer.add_scalar('Val_Noisy/WER', val_wer_n, global_step)
|
||||||
|
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss_c:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
|
||||||
|
|
||||||
# Early Stopping 检查
|
# Early Stopping 检查
|
||||||
if patience_counter >= CONFIG['early_stopping_patience']:
|
if patience_counter >= CONFIG['early_stopping_patience']:
|
||||||
|
|||||||
Reference in New Issue
Block a user