add augment noise and num layer=6
This commit is contained in:
143
src/dataset.py
143
src/dataset.py
@@ -4,7 +4,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import torchaudio
|
||||
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter
|
||||
from torch_audiomentations import ApplyImpulseResponse, Gain, PitchShift, LowPassFilter, HighPassFilter, PolarityInversion
|
||||
from torchaudio.transforms import Resample, TimeStretch
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
@@ -94,6 +94,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
tsv_path: Path,
|
||||
audio_dir: Path,
|
||||
noise_dir: Path,
|
||||
corridor_noise_dir: Path,
|
||||
tokenizer: ASRTokenizer,
|
||||
sample_rate: int = 16000,
|
||||
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
||||
@@ -123,9 +124,10 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
self.gain_down = Gain(min_gain_in_db=-15, max_gain_in_db=-8, p=1.0, output_type='tensor')
|
||||
self.pitch_up = PitchShift(min_transpose_semitones=1, max_transpose_semitones=4, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
||||
self.pitch_down = PitchShift(min_transpose_semitones=-4, max_transpose_semitones=-1, p=1.0, sample_rate=self.sample_rate, output_type='tensor')
|
||||
self.lowpass = LowPassFilter(min_cutoff_freq=600, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||
self.highpass = HighPassFilter(min_cutoff_freq=800, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||
self.apply_ir = ApplyImpulseResponse(ir_paths=noise_dir,convolve_mode='same', p=1, output_type="tensor")
|
||||
self.lowpass = LowPassFilter(min_cutoff_freq=100, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||
self.highpass = HighPassFilter(min_cutoff_freq=1000, max_cutoff_freq=2000, p=1.0, output_type='tensor')
|
||||
self.apply_ir = ApplyImpulseResponse(ir_paths=corridor_noise_dir, convolve_mode='same', p=1, output_type="tensor")
|
||||
self.polarity_inversion = PolarityInversion(p=1.0, output_type="tensor")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@@ -157,7 +159,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
if random.random() < 0.6:
|
||||
waveform = self._stretch_or_compress(waveform=waveform)
|
||||
|
||||
if random.random() < 0.4:
|
||||
if random.random() < 0.7:
|
||||
waveform = self.noise_augmentor.apply_real_noise(waveform)
|
||||
|
||||
if random.random() < 0.3:
|
||||
@@ -166,29 +168,30 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
# torch_audiomentations: [1, time] -> [1, 1, time]
|
||||
if waveform.dim() == 2:
|
||||
waveform_3d = waveform.unsqueeze(0)
|
||||
# 随机选择一种频谱增强
|
||||
# 随机选择一种物理特性增强 (互斥区)
|
||||
choice = random.random()
|
||||
if choice < 0.15:
|
||||
# 增益变化(上或下)
|
||||
if choice < 0.25: # [0.00 - 0.25] 25% 概率:增益
|
||||
if random.random() < 0.5:
|
||||
waveform_3d = self.gain_up(waveform_3d, sample_rate=self.sample_rate)
|
||||
else:
|
||||
waveform_3d = self.gain_down(waveform_3d, sample_rate=self.sample_rate)
|
||||
elif choice < 0.25:
|
||||
# 音高变化(上或下)
|
||||
elif choice < 0.50: # [0.25 - 0.50] 25% 概率:音高
|
||||
if random.random() < 0.5:
|
||||
waveform_3d = self.pitch_up(waveform_3d, sample_rate=self.sample_rate)
|
||||
else:
|
||||
waveform_3d = self.pitch_down(waveform_3d, sample_rate=self.sample_rate)
|
||||
elif choice < 0.30:
|
||||
# 低通滤波(声音发闷)
|
||||
elif choice < 0.70: # [0.50 - 0.70] 20% 概率:低通
|
||||
waveform_3d = self.lowpass(waveform_3d, sample_rate=self.sample_rate)
|
||||
# elif choice < 0.32:
|
||||
# # 低通滤波(声音发闷)
|
||||
# waveform_3d = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
|
||||
elif choice < 0.35:
|
||||
# 高通滤波(电话效果)
|
||||
elif choice < 0.85: # [0.70 - 0.85] 15% 概率:高通
|
||||
waveform_3d = self.highpass(waveform_3d, sample_rate=self.sample_rate)
|
||||
elif choice < 0.95: # [0.85 - 0.95] 10% 概率:走廊混响 (IR)
|
||||
# 使用你测试过最好的 0.8/0.2 比例
|
||||
dry = waveform_3d.clone()
|
||||
wet = self.apply_ir(waveform_3d, sample_rate=self.sample_rate)
|
||||
waveform_3d = 0.8 * dry + 0.2 * wet
|
||||
else: # [0.95 - 1.00] 5% 概率:极性翻转
|
||||
waveform_3d = self.polarity_inversion(waveform_3d, sample_rate=self.sample_rate)
|
||||
|
||||
|
||||
# [1, 1, time] -> [1, time]
|
||||
waveform = waveform_3d.squeeze(0)
|
||||
@@ -201,7 +204,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
return waveform
|
||||
|
||||
def _stretch_or_compress(self, waveform: Tensor) -> Tensor:
|
||||
speed_factor = random.uniform(0.85, 1.4) # (Speed Change: 0.85x - 1.4x)
|
||||
speed_factor = random.uniform(0.80, 1.6) # (Speed Change: 0.85x - 1.4x)
|
||||
spec = torch.stft(
|
||||
waveform.squeeze(0),
|
||||
n_fft=400,
|
||||
@@ -246,7 +249,6 @@ class CommonVoiceDataset(Dataset[BatchItem]):
|
||||
def __getitem__(self, index) -> BatchItem:
|
||||
row: TsvFormat = self.data.iloc[index]
|
||||
audio_path: Path = self.audio_dir / row['path']
|
||||
# text: str = unicodedata.normalize('NFC', row['sentence'].strip())
|
||||
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
||||
|
||||
waveform = self._load_audio(audio_path=audio_path)
|
||||
@@ -297,8 +299,8 @@ def collate_fn(items: List[BatchItem]) -> Batch:
|
||||
)
|
||||
|
||||
|
||||
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
|
||||
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
|
||||
def create_dataloader(tsv_path: Path, audio_dir: Path, noise_dir: Path, corridor_noise_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True, augment_prob: int = 0.5) -> DataLoader:
|
||||
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, noise_dir=noise_dir, corridor_noise_dir=corridor_noise_dir, tokenizer=tokenizer, augment=augment, augment_prob=augment_prob)
|
||||
|
||||
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
|
||||
|
||||
@@ -331,101 +333,4 @@ if __name__ == "__main__":
|
||||
print(f"Target lengths: {batch['target_lengths']}")
|
||||
print(f"Target texts: {batch['target_texts']}")
|
||||
print(f"Audio paths: {batch['audio_paths']}")
|
||||
break
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
||||
|
||||
# --- 1. 加载 VAD 模型 ---
|
||||
# 推荐使用 GPU (cuda) 如果可用,否则使用 CPU (cpu)
|
||||
vad_model = load_silero_vad(source="local", force_onnx=False)
|
||||
|
||||
def process_audio_with_vad_and_asr(audio_path, inference_module):
|
||||
"""
|
||||
使用 Silero VAD 分割音频,然后对每个语音片段进行 ASR 转录。
|
||||
|
||||
Args:
|
||||
audio_path (str): 输入音频文件的路径。
|
||||
inference_module: 您封装了 transcribe 方法的模块或对象。
|
||||
需要确保其 _load_audio 和 transcribe 方法可用。
|
||||
"""
|
||||
print(f"正在加载音频: {audio_path}")
|
||||
|
||||
# --- 2. 加载音频用于 VAD 检测 ---
|
||||
# Silero VAD 推荐使用 16kHz 采样率
|
||||
waveform_vad, sample_rate_vad = torchaudio.load(audio_path)
|
||||
if sample_rate_vad != 16000:
|
||||
# 如果采样率不是16kHz,需要重采样以供VAD使用
|
||||
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_vad, new_freq=16000)
|
||||
waveform_vad = resampler(waveform_vad)
|
||||
sample_rate_vad = 16000
|
||||
|
||||
# --- 3. 获取语音时间段 ---
|
||||
# get_speech_timestamps 返回一个列表,每个元素是 {'start': start_sample, 'end': end_sample} 的字典
|
||||
# 采样率是16000,所以时间戳单位是 1/16000 秒
|
||||
speech_timestamps = get_speech_timestamps(
|
||||
waveform_vad,
|
||||
vad_model,
|
||||
sampling_rate=sample_rate_vad,
|
||||
threshold=0.5, # 可以根据需要调整阈值
|
||||
min_speech_duration_ms=250, # 最小语音持续时间,防止短噪音被误判
|
||||
max_speech_duration_s=float('inf'), # 最大语音持续时间,float('inf') 表示不限制
|
||||
min_silence_duration_ms=100, # 最小静音间隔,用于分割语音块
|
||||
window_size_samples=1536, # VAD窗口大小
|
||||
speech_pad_ms=30 # 在语音块前后添加的填充时间
|
||||
)
|
||||
|
||||
print(f"检测到 {len(speech_timestamps)} 个语音片段。")
|
||||
|
||||
full_text = ""
|
||||
for i, ts in enumerate(speech_timestamps):
|
||||
start_sample = int(ts['start'])
|
||||
end_sample = int(ts['end'])
|
||||
|
||||
# 计算时间戳(秒)
|
||||
start_time = start_sample / sample_rate_vad
|
||||
end_time = end_sample / sample_rate_vad
|
||||
|
||||
print(f"\n处理第 {i+1} 个片段: 时间范围 [{start_time:.2f}s - {end_time:.2f}s]")
|
||||
|
||||
# --- 4. 从原始音频中提取此片段 ---
|
||||
# 注意:这里假设您的 transcribe 函数可以接受 waveform 张量。
|
||||
# 我们需要从原始可能不同采样率的音频中提取片段,或者用VAD处理过的waveform。
|
||||
# 为了匹配您原来的 _load_audio 方式,我们用 torchaudio 再次精确加载片段。
|
||||
|
||||
# 计算原始音频中的样本索引(如果原始音频采样率与VAD不同)
|
||||
original_waveform, original_sr = torchaudio.load(audio_path)
|
||||
if sample_rate_vad != original_sr:
|
||||
# 如果VAD和原始音频采样率不同,需要重新映射索引
|
||||
start_idx_orig = int(start_sample * (original_sr / sample_rate_vad))
|
||||
end_idx_orig = int(end_sample * (original_sr / sample_rate_vad))
|
||||
else:
|
||||
start_idx_orig = start_sample
|
||||
end_idx_orig = end_sample
|
||||
|
||||
segment_waveform = original_waveform[:, start_idx_orig:end_idx_orig]
|
||||
|
||||
# --- 5. 对该片段进行 ASR 转录 ---
|
||||
# 这里调用您原有的 transcribe 方法
|
||||
try:
|
||||
# 假设 transcribe 方法接受一个 waveform tensor
|
||||
segment_text = inference_module.transcribe(waveform=segment_waveform)
|
||||
|
||||
print(f" -> 转录结果: {segment_text}")
|
||||
full_text += f"[{start_time:.2f}-{end_time:.2f}s] {segment_text}\n"
|
||||
except Exception as e:
|
||||
print(f" -> 转录第 {i+1} 个片段时出错: {e}")
|
||||
|
||||
print("\n--- 完整转录结果 ---")
|
||||
print(full_text)
|
||||
|
||||
|
||||
# --- 使用示例 ---
|
||||
# 假设您有一个名为 'inference' 的模块对象,它有 _load_audio 和 transcribe 方法
|
||||
# audio_path = "your_audio_file.wav"
|
||||
# process_audio_with_vad_and_asr(audio_path, inference)
|
||||
break
|
||||
File diff suppressed because one or more lines are too long
@@ -35,7 +35,7 @@ def analyze_mp3_files(directory: Path, ):
|
||||
"""分析目录中的所有MP3文件"""
|
||||
results = []
|
||||
total_duration = 0
|
||||
df = pd.read_csv(workspace_dir / '.data/ug/train_new.tsv', sep='\t')
|
||||
df = pd.read_csv(workspace_dir / '.data/ug/val_new.tsv', sep='\t')
|
||||
|
||||
for _, row in tqdm(df.iterrows(), total=len(df), desc="anlayze audio"):
|
||||
mp3_file: Path = workspace_dir / '.data/ug/clips' / row['path']
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -3,7 +3,6 @@ from pathlib import Path
|
||||
import re
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from handle.text_handle import process_text
|
||||
from handle.text_normalizer import UYGHUR_LETTERS
|
||||
|
||||
|
||||
72
src/train.py
72
src/train.py
@@ -31,11 +31,11 @@ CONFIG = {
|
||||
'input_dim': 256,
|
||||
'num_heads': 8,
|
||||
'ffn_dim': 2048,
|
||||
'num_layers': 8,
|
||||
'dropout': 0.1,
|
||||
'num_layers': 6,
|
||||
'dropout': 0.15,
|
||||
|
||||
# 保存和评估
|
||||
'early_stopping_patience': 12,
|
||||
'early_stopping_patience': 10,
|
||||
}
|
||||
|
||||
|
||||
@@ -168,17 +168,13 @@ def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device
|
||||
avg_cer = total_cer / num_samples
|
||||
avg_wer = total_wer / num_samples
|
||||
|
||||
writer.add_scalar('Val/Loss', avg_loss, global_step)
|
||||
writer.add_scalar('Val/CER', avg_cer, global_step)
|
||||
writer.add_scalar('Val/WER', avg_wer, global_step)
|
||||
|
||||
for index, (true_text, pred_text, cer, wer) in enumerate(examples):
|
||||
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
|
||||
|
||||
model.train()
|
||||
return avg_loss, avg_cer, avg_wer
|
||||
|
||||
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path):
|
||||
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, current_prob: float, save_path: Path):
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
@@ -189,6 +185,7 @@ def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR
|
||||
'epoch': epoch,
|
||||
'cer': cer,
|
||||
'wer': wer,
|
||||
'current_prob': current_prob,
|
||||
}
|
||||
torch.save(checkpoint, save_path)
|
||||
|
||||
@@ -206,14 +203,15 @@ def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, schedule
|
||||
global_step = checkpoint['global_step']
|
||||
best_cer = checkpoint['cer']
|
||||
best_wer = checkpoint['wer']
|
||||
return epoch, global_step, best_cer, best_wer
|
||||
current_prob = checkpoint['current_prob']
|
||||
return epoch, global_step, best_cer, best_wer, current_prob
|
||||
|
||||
def main():
|
||||
workspace_dir = Path(__file__).parent.parent
|
||||
device = torch.device('cuda:0')
|
||||
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
|
||||
final_prob = 0.5
|
||||
warmup_epochs = 8
|
||||
final_prob = 0.8
|
||||
warmup_epochs = 12
|
||||
current_prob = 0.0
|
||||
|
||||
# ============ 创建数据加载器 ============
|
||||
@@ -221,6 +219,7 @@ def main():
|
||||
tsv_path=workspace_dir / '.data/ug/train_new.tsv',
|
||||
audio_dir=workspace_dir / '.data/ug/clips',
|
||||
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||
corridor_noise_dir= workspace_dir / 'data/corridor',
|
||||
tokenizer=tokenizer,
|
||||
batch_size=CONFIG['batch_size'],
|
||||
shuffle=True,
|
||||
@@ -228,15 +227,28 @@ def main():
|
||||
augment_prob=current_prob
|
||||
)
|
||||
|
||||
val_loader = create_dataloader(
|
||||
val_loader_clean = create_dataloader(
|
||||
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
|
||||
audio_dir=workspace_dir / '.data/ug/clips',
|
||||
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||
corridor_noise_dir= workspace_dir / 'data/corridor',
|
||||
tokenizer=tokenizer,
|
||||
batch_size=CONFIG['batch_size'],
|
||||
shuffle=False,
|
||||
augment=False
|
||||
)
|
||||
|
||||
val_loader_noisy = create_dataloader(
|
||||
tsv_path=workspace_dir / '.data/ug/val_new.tsv',
|
||||
audio_dir=workspace_dir / '.data/ug/clips',
|
||||
noise_dir='/mnt/dataset/dataset/audio/noise',
|
||||
corridor_noise_dir= workspace_dir / 'data/corridor',
|
||||
tokenizer=tokenizer,
|
||||
batch_size=CONFIG['batch_size'],
|
||||
shuffle=False,
|
||||
augment=True,
|
||||
augment_prob=0.5
|
||||
)
|
||||
|
||||
# ============ 初始化模型 ============
|
||||
model = ASRModel(
|
||||
@@ -259,9 +271,9 @@ def main():
|
||||
max_lr=CONFIG['learning_rate'],
|
||||
epochs=CONFIG['num_epochs'],
|
||||
steps_per_epoch=len(train_loader),
|
||||
pct_start=0.15,
|
||||
pct_start=0.2,
|
||||
anneal_strategy='cos',
|
||||
div_factor=10.0, # 初始 lr = max_lr / 10
|
||||
div_factor=25.0, # 初始 lr = max_lr / 10
|
||||
final_div_factor=1e4, # 最终 lr = max_lr / 10000
|
||||
)
|
||||
scaler = GradScaler()
|
||||
@@ -286,7 +298,7 @@ def main():
|
||||
|
||||
latest = find_latest_checkpoint(checkpoint_dir)
|
||||
if latest:
|
||||
start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler)
|
||||
start_epoch, global_step, best_cer, best_wer, current_prob = load_checkpoint(latest, model, optimizer, scheduler)
|
||||
start_epoch += 1 # 从下一个 epoch 开始
|
||||
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
|
||||
else:
|
||||
@@ -308,25 +320,26 @@ def main():
|
||||
scaler=scaler,
|
||||
)
|
||||
|
||||
val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||
print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}")
|
||||
val_loss_c, val_cer_c, val_wer_c = validate(model=model, dataloader=val_loader_clean, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||
val_loss_n, val_cer_n, val_wer_n = validate(model=model, dataloader=val_loader_noisy, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||
print(f"\n📊 Step {global_step} | Clean Val Loss: {val_loss_c:.4f} | Clean Val CER: {val_cer_c:.4f} | Clean Val WER: {val_wer_c:.4f}")
|
||||
|
||||
# 保存常规 checkpoint
|
||||
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path)
|
||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=checkpoint_path)
|
||||
|
||||
# 分别保存最佳 CER 和 WER 模型
|
||||
improved = False
|
||||
if val_cer < best_cer:
|
||||
best_cer = val_cer
|
||||
if val_cer_c < best_cer:
|
||||
best_cer = val_cer_c
|
||||
best_cer_path = checkpoint_dir / 'best_cer_model.pt'
|
||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path)
|
||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_cer_path)
|
||||
improved = True
|
||||
|
||||
if val_wer < best_wer:
|
||||
best_wer = val_wer
|
||||
if val_wer_c < best_wer:
|
||||
best_wer = val_wer_c
|
||||
best_wer_path = checkpoint_dir / 'best_wer_model.pt'
|
||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path)
|
||||
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss_c, cer=val_cer_c, wer=val_wer_c, current_prob=current_prob, save_path=best_wer_path)
|
||||
improved = True
|
||||
|
||||
# Early Stopping 逻辑
|
||||
@@ -343,9 +356,16 @@ def main():
|
||||
|
||||
cuda.empty_cache()
|
||||
|
||||
writer.add_scalar('Val/EpochLoss', val_loss, epoch)
|
||||
writer.add_scalar('Train/EpochLoss', train_loss, epoch)
|
||||
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
|
||||
writer.add_scalar('Val/EpochLoss', val_loss_c, epoch)
|
||||
writer.add_scalar('Val/Loss', val_loss_c, global_step)
|
||||
writer.add_scalar('Val/CER', val_cer_c, global_step)
|
||||
writer.add_scalar('Val/WER', val_wer_c, global_step)
|
||||
writer.add_scalar('Val_Noisy/EpochLoss', val_loss_n, epoch)
|
||||
writer.add_scalar('Val_Noisy/Loss', val_loss_n, global_step)
|
||||
writer.add_scalar('Val_Noisy/CER', val_cer_n, global_step)
|
||||
writer.add_scalar('Val_Noisy/WER', val_wer_n, global_step)
|
||||
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss_c:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
|
||||
|
||||
# Early Stopping 检查
|
||||
if patience_counter >= CONFIG['early_stopping_patience']:
|
||||
|
||||
Reference in New Issue
Block a user