add augment noise and num layer=6

This commit is contained in:
2026-05-09 14:11:58 +06:00
parent d31233a79a
commit 96cd0a20cb
6 changed files with 155 additions and 211 deletions

View File

@@ -4,7 +4,7 @@ import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
import torchaudio import torchaudio
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter, PolarityInversion
from torchaudio.transforms import Resample, TimeStretch from torchaudio.transforms import Resample, TimeStretch
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
@@ -94,6 +94,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
tsv_path: Path, tsv_path: Path,
audio_dir: Path, audio_dir: Path,
noise_dir: Path, noise_dir: Path,
corridor_noise_dir: Path,
tokenizer: ASRTokenizer, tokenizer: ASRTokenizer,
sample_rate: int = 16000, sample_rate: int = 16000,
max_audio_len: int = 480000, # 30秒 @ 16kHz max_audio_len: int = 480000, # 30秒 @ 16kHz
@@ -123,9 +124,10 @@ class CommonVoiceDataset(Dataset[BatchItem]):
self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor') self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')
self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor') self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor') self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor') self.lowpass = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=2000, p=1.0, output_type='tensor')
self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor') self.highpass = HighPassFilter(min_cutoff_freq=1000, max_cutoff_freq=2000, p=1.0, output_type='tensor')
self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type="tensor") self.apply_ir = ApplyImpulseResponse(ir_paths=corridor_noise_dir, convolve_mode='same', p=1, output_type="tensor")
self.polarity_inversion = PolarityInversion(p=1.0, output_type="tensor")
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@@ -157,7 +159,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
if random.random() < 0.6: if random.random() < 0.6:
waveform = self._stretch_or_compress(waveform=waveform) waveform = self._stretch_or_compress(waveform=waveform)
if random.random() < 0.4: if random.random() < 0.7:
waveform = self.noise_augmentor.apply_real_noise(waveform) waveform = self.noise_augmentor.apply_real_noise(waveform)
if random.random() < 0.3: if random.random() < 0.3:
@@ -166,29 +168,30 @@ class CommonVoiceDataset(Dataset[BatchItem]):
# torch_audiomentations: [1, time] -> [1, 1, time] # torch_audiomentations: [1, time] -> [1, 1, time]
if waveform.dim() == 2: if waveform.dim() == 2:
waveform_3d = waveform.unsqueeze(0) waveform_3d = waveform.unsqueeze(0)
# 随机选择一种频谱增强 # 随机选择一种物理特性增强 (互斥区)
choice = random.random() choice = random.random()
if choice < 0.15: if choice < 0.25: # [0.00 - 0.25] 25% 概率:增益
# 增益变化(上或下)
if random.random() < 0.5: if random.random() < 0.5:
waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate) waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)
else: else:
waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate) waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.25: elif choice < 0.50: # [0.25 - 0.50] 25% 概率:音高
# 音高变化(上或下)
if random.random() < 0.5: if random.random() < 0.5:
waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate) waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)
else: else:
waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate) waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.30: elif choice < 0.70: # [0.50 - 0.70] 20% 概率:低通
# 低通滤波(声音发闷)
waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate) waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)
# elif choice < 0.32: elif choice < 0.85: # [0.70 - 0.85] 15% 概率:高通
# # 低通滤波(声音发闷)
# waveform_3d = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.35:
# 高通滤波(电话效果)
waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate) waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)
elif choice < 0.95: # [0.85 - 0.95] 10% 概率:走廊混响 (IR)
# 使用你测试过最好的 0.8/0.2 比例
dry = waveform_3d.clone()
wet = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
waveform_3d = 0.8 * dry + 0.2 * wet
else: # [0.95 - 1.00] 5% 概率:极性翻转
waveform_3d = self.polarity_inversion(waveform_3d, sample_rate=self.sample_rate)
# [1, 1, time] -> [1, time] # [1, 1, time] -> [1, time]
waveform = waveform_3d.squeeze(0) waveform = waveform_3d.squeeze(0)
@@ -201,7 +204,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
return waveform return waveform
def _stretch_or_compress(self, waveform: Tensor) -> Tensor: def _stretch_or_compress(self, waveform: Tensor) -> Tensor:
speed_factor = random.uniform(0.85, 1.4) # (Speed Change: 0.85x - 1.4x) speed_factor = random.uniform(0.80, 1.6) # (Speed Change: 0.85x - 1.4x)
spec = torch.stft( spec = torch.stft(
waveform.squeeze(0), waveform.squeeze(0),
n_fft=400, n_fft=400,
@@ -246,7 +249,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
def __getitem__(self, index) -> BatchItem: def __getitem__(self, index) -> BatchItem:
row: TsvFormat = self.data.iloc[index] row: TsvFormat = self.data.iloc[index]
audio_path: Path = self.audio_dir / row['path'] audio_path: Path = self.audio_dir / row['path']
# text: str = unicodedata.normalize('NFC', row['sentence'].strip())
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip())) text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
waveform = self._load_audio(audio_path=audio_path) waveform = self._load_audio(audio_path=audio_path)
@@ -297,8 +299,8 @@ def collate_fn(items: List[BatchItem]) -> Batch:
) )
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader: def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, corridor_noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob) dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, corridor_noise_dir=corridor_noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True) return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
@@ -332,100 +334,3 @@ if __name__ == "__main__":
print(f"Target texts: {batch['target_texts']}") print(f"Target texts: {batch['target_texts']}")
print(f"Audio paths: {batch['audio_paths']}") print(f"Audio paths: {batch['audio_paths']}")
break break
import torch
import torchaudio
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
# --- 1. 加载 VAD 模型 ---
# 推荐使用 GPU (cuda) 如果可用,否则使用 CPU (cpu)
vad_model = load_silero_vad(source="local", force_onnx=False)
def process_audio_with_vad_and_asr(audio_path, inference_module):
"""
使用 Silero VAD 分割音频,然后对每个语音片段进行 ASR 转录。
Args:
audio_path (str): 输入音频文件的路径。
inference_module: 您封装了 transcribe 方法的模块或对象。
需要确保其 _load_audio 和 transcribe 方法可用。
"""
print(f"正在加载音频: {audio_path}")
# --- 2. 加载音频用于 VAD 检测 ---
# Silero VAD 推荐使用 16kHz 采样率
waveform_vad, sample_rate_vad = torchaudio.load(audio_path)
if sample_rate_vad != 16000:
# 如果采样率不是16kHz需要重采样以供VAD使用
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_vad, new_freq=16000)
waveform_vad = resampler(waveform_vad)
sample_rate_vad = 16000
# --- 3. 获取语音时间段 ---
# get_speech_timestamps 返回一个列表,每个元素是 {'start': start_sample, 'end': end_sample} 的字典
# 采样率是16000所以时间戳单位是 1/16000 秒
speech_timestamps = get_speech_timestamps(
waveform_vad,
vad_model,
sampling_rate=sample_rate_vad,
threshold=0.5, # 可以根据需要调整阈值
min_speech_duration_ms=250, # 最小语音持续时间,防止短噪音被误判
max_speech_duration_s=float('inf'), # 最大语音持续时间float('inf') 表示不限制
min_silence_duration_ms=100, # 最小静音间隔,用于分割语音块
window_size_samples=1536, # VAD窗口大小
speech_pad_ms=30 # 在语音块前后添加的填充时间
)
print(f"检测到 {len(speech_timestamps)} 个语音片段。")
full_text = ""
for i, ts in enumerate(speech_timestamps):
start_sample = int(ts['start'])
end_sample = int(ts['end'])
# 计算时间戳(秒)
start_time = start_sample / sample_rate_vad
end_time = end_sample / sample_rate_vad
print(f"\n处理第 {i+1} 个片段: 时间范围 [{start_time:.2f}s - {end_time:.2f}s]")
# --- 4. 从原始音频中提取此片段 ---
# 注意:这里假设您的 transcribe 函数可以接受 waveform 张量。
# 我们需要从原始可能不同采样率的音频中提取片段或者用VAD处理过的waveform。
# 为了匹配您原来的 _load_audio 方式,我们用 torchaudio 再次精确加载片段。
# 计算原始音频中的样本索引如果原始音频采样率与VAD不同
original_waveform, original_sr = torchaudio.load(audio_path)
if sample_rate_vad != original_sr:
# 如果VAD和原始音频采样率不同需要重新映射索引
start_idx_orig = int(start_sample * (original_sr / sample_rate_vad))
end_idx_orig = int(end_sample * (original_sr / sample_rate_vad))
else:
start_idx_orig = start_sample
end_idx_orig = end_sample
segment_waveform = original_waveform[:, start_idx_orig:end_idx_orig]
# --- 5. 对该片段进行 ASR 转录 ---
# 这里调用您原有的 transcribe 方法
try:
# 假设 transcribe 方法接受一个 waveform tensor
segment_text = inference_module.transcribe(waveform=segment_waveform)
print(f" -> 转录结果: {segment_text}")
full_text += f"[{start_time:.2f}-{end_time:.2f}s] {segment_text}\n"
except Exception as e:
print(f" -> 转录第 {i+1} 个片段时出错: {e}")
print("\n--- 完整转录结果 ---")
print(full_text)
# --- 使用示例 ---
# 假设您有一个名为 'inference' 的模块对象,它有 _load_audio 和 transcribe 方法
# audio_path = "your_audio_file.wav"
# process_audio_with_vad_and_asr(audio_path, inference)

File diff suppressed because one or more lines are too long

View File

@@ -35,7 +35,7 @@ def analyze_mp3_files(directory: Path, ):
"""分析目录中的所有MP3文件""" """分析目录中的所有MP3文件"""
results = [] results = []
total_duration = 0 total_duration = 0
df = pd.read_csv(workspace_dir / '.data/ug/train_new.tsv', sep='\t') df = pd.read_csv(workspace_dir / '.data/ug/val_new.tsv', sep='\t')
for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"): for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"):
mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path'] mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path']

File diff suppressed because one or more lines are too long

View File

@@ -3,7 +3,6 @@ from pathlib import Path
import re import re
import torch import torch
from torch import Tensor from torch import Tensor
from handle.text_handle import process_text from handle.text_handle import process_text
from handle.text_normalizer import UYGHUR_LETTERS from handle.text_normalizer import UYGHUR_LETTERS

View File

@@ -31,11 +31,11 @@ CONFIG = {
'input_dim': 256, 'input_dim': 256,
'num_heads': 8, 'num_heads': 8,
'ffn_dim': 2048, 'ffn_dim': 2048,
'num_layers': 8, 'num_layers': 6,
'dropout': 0.1, 'dropout': 0.15,
# 保存和评估 # 保存和评估
'early_stopping_patience': 12, 'early_stopping_patience': 10,
} }
@@ -168,17 +168,13 @@ def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device
avg_cer = total_cer / num_samples avg_cer = total_cer / num_samples
avg_wer = total_wer / num_samples avg_wer = total_wer / num_samples
writer.add_scalar('Val/Loss', avg_loss, global_step)
writer.add_scalar('Val/CER', avg_cer, global_step)
writer.add_scalar('Val/WER', avg_wer, global_step)
for index, (true_text, pred_text, cer, wer) in enumerate(examples): for index, (true_text, pred_text, cer, wer) in enumerate(examples):
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step) writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
model.train() model.train()
return avg_loss, avg_cer, avg_wer return avg_loss, avg_cer, avg_wer
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path): def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, current_prob: float, save_path: Path):
checkpoint = { checkpoint = {
'model_state_dict': model.state_dict(), 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
@@ -189,6 +185,7 @@ def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR
'epoch': epoch, 'epoch': epoch,
'cer': cer, 'cer': cer,
'wer': wer, 'wer': wer,
'current_prob': current_prob,
} }
torch.save(checkpoint, save_path) torch.save(checkpoint, save_path)
@@ -206,14 +203,15 @@ def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, schedule
global_step = checkpoint['global_step'] global_step = checkpoint['global_step']
best_cer = checkpoint['cer'] best_cer = checkpoint['cer']
best_wer = checkpoint['wer'] best_wer = checkpoint['wer']
return epoch, global_step, best_cer, best_wer current_prob = checkpoint['current_prob']
return epoch, global_step, best_cer, best_wer, current_prob
def main(): def main():
workspace_dir = Path(__file__).parent.parent workspace_dir = Path(__file__).parent.parent
device = torch.device('cuda:0') device = torch.device('cuda:0')
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json') tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
final_prob = 0.5 final_prob = 0.8
warmup_epochs = 8 warmup_epochs = 12
current_prob = 0.0 current_prob = 0.0
# ============ 创建数据加载器 ============ # ============ 创建数据加载器 ============
@@ -221,6 +219,7 @@ def main():
tsv_path=workspace_dir / '.data/ug/train_new.tsv', tsv_path=workspace_dir / '.data/ug/train_new.tsv',
audio_dir=workspace_dir / '.data/ug/clips', audio_dir=workspace_dir / '.data/ug/clips',
noise_dir='/mnt/dataset/dataset/audio/noise', noise_dir='/mnt/dataset/dataset/audio/noise',
corridor_noise_dir= workspace_dir / 'data/corridor',
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=CONFIG['batch_size'], batch_size=CONFIG['batch_size'],
shuffle=True, shuffle=True,
@@ -228,16 +227,29 @@ def main():
augment_prob=current_prob augment_prob=current_prob
) )
val_loader = create_dataloader( val_loader_clean = create_dataloader(
tsv_path=workspace_dir / '.data/ug/val_new.tsv', tsv_path=workspace_dir / '.data/ug/val_new.tsv',
audio_dir=workspace_dir / '.data/ug/clips', audio_dir=workspace_dir / '.data/ug/clips',
noise_dir='/mnt/dataset/dataset/audio/noise', noise_dir='/mnt/dataset/dataset/audio/noise',
corridor_noise_dir= workspace_dir / 'data/corridor',
tokenizer=tokenizer, tokenizer=tokenizer,
batch_size=CONFIG['batch_size'], batch_size=CONFIG['batch_size'],
shuffle=False, shuffle=False,
augment=False augment=False
) )
val_loader_noisy = create_dataloader(
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
audio_dir=workspace_dir / '.data/ug/clips',
noise_dir='/mnt/dataset/dataset/audio/noise',
corridor_noise_dir= workspace_dir / 'data/corridor',
tokenizer=tokenizer,
batch_size=CONFIG['batch_size'],
shuffle=False,
augment=True,
augment_prob=0.5
)
# ============ 初始化模型 ============ # ============ 初始化模型 ============
model = ASRModel( model = ASRModel(
vocab_size=tokenizer.vocab_size(), vocab_size=tokenizer.vocab_size(),
@@ -259,9 +271,9 @@ def main():
max_lr=CONFIG['learning_rate'], max_lr=CONFIG['learning_rate'],
epochs=CONFIG['num_epochs'], epochs=CONFIG['num_epochs'],
steps_per_epoch=len(train_loader), steps_per_epoch=len(train_loader),
pct_start=0.15, pct_start=0.2,
anneal_strategy='cos', anneal_strategy='cos',
div_factor=10.0, # 初始 lr = max_lr / 10 div_factor=25.0, # 初始 lr = max_lr / 10
final_div_factor=1e4, # 最终 lr = max_lr / 10000 final_div_factor=1e4, # 最终 lr = max_lr / 10000
) )
scaler = GradScaler() scaler = GradScaler()
@@ -286,7 +298,7 @@ def main():
latest = find_latest_checkpoint(checkpoint_dir) latest = find_latest_checkpoint(checkpoint_dir)
if latest: if latest:
start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler) start_epoch, global_step, best_cer, best_wer, current_prob = load_checkpoint(latest, model, optimizer, scheduler)
start_epoch += 1 # 从下一个 epoch 开始 start_epoch += 1 # 从下一个 epoch 开始
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}") print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
else: else:
@@ -308,25 +320,26 @@ def main():
scaler=scaler, scaler=scaler,
) )
val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step) val_loss_c, val_cer_c, val_wer_c = validate(model=model, dataloader=val_loader_clean, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}") val_loss_n, val_cer_n, val_wer_n = validate(model=model, dataloader=val_loader_noisy, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
print(f"\n📊 Step {global_step} | Clean Val Loss: {val_loss_c:.4f} | Clean Val CER: {val_cer_c:.4f} | Clean Val WER: {val_wer_c:.4f}")
# 保存常规 checkpoint # 保存常规 checkpoint
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt' checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path) save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=checkpoint_path)
# 分别保存最佳 CER 和 WER 模型 # 分别保存最佳 CER 和 WER 模型
improved = False improved = False
if val_cer < best_cer: if val_cer_c < best_cer:
best_cer = val_cer best_cer = val_cer_c
best_cer_path = checkpoint_dir / 'best_cer_model.pt' best_cer_path = checkpoint_dir / 'best_cer_model.pt'
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path) save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_cer_path)
improved = True improved = True
if val_wer < best_wer: if val_wer_c < best_wer:
best_wer = val_wer best_wer = val_wer_c
best_wer_path = checkpoint_dir / 'best_wer_model.pt' best_wer_path = checkpoint_dir / 'best_wer_model.pt'
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path) save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_wer_path)
improved = True improved = True
# Early Stopping 逻辑 # Early Stopping 逻辑
@@ -343,9 +356,16 @@ def main():
cuda.empty_cache() cuda.empty_cache()
writer.add_scalar('Val/EpochLoss', val_loss, epoch)
writer.add_scalar('Train/EpochLoss', train_loss, epoch) writer.add_scalar('Train/EpochLoss', train_loss, epoch)
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n") writer.add_scalar('Val/EpochLoss', val_loss_c, epoch)
writer.add_scalar('Val/Loss', val_loss_c, global_step)
writer.add_scalar('Val/CER', val_cer_c, global_step)
writer.add_scalar('Val/WER', val_wer_c, global_step)
writer.add_scalar('Val_Noisy/EpochLoss', val_loss_n, epoch)
writer.add_scalar('Val_Noisy/Loss', val_loss_n, global_step)
writer.add_scalar('Val_Noisy/CER', val_cer_n, global_step)
writer.add_scalar('Val_Noisy/WER', val_wer_n, global_step)
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss_c:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
# Early Stopping 检查 # Early Stopping 检查
if patience_counter >= CONFIG['early_stopping_patience']: if patience_counter >= CONFIG['early_stopping_patience']: