change model and MelSpectrogram parameter

This commit is contained in:
2026-05-01 10:20:30 +06:00
parent 39e6270d6e
commit 78519a42b8
2 changed files with 7 additions and 3 deletions

View File

@@ -48,7 +48,7 @@ class CommonVoiceDataset(Dataset[BatchItem]):
audio_dir: Path, audio_dir: Path,
tokenizer: ASRTokenizer, tokenizer: ASRTokenizer,
sample_rate: int = 16000, sample_rate: int = 16000,
n_mels: int = 80, n_mels: int = 80 * 4,
max_audio_len: int = 480000, # 30秒 @ 16kHz max_audio_len: int = 480000, # 30秒 @ 16kHz
augment: bool = True, # 是否启用数据增强 augment: bool = True, # 是否启用数据增强
augment_prob: float = 0.5, # 数据增强的概率 augment_prob: float = 0.5, # 数据增强的概率
@@ -77,11 +77,11 @@ class CommonVoiceDataset(Dataset[BatchItem]):
sample_rate=sample_rate, sample_rate=sample_rate,
n_fft=400, n_fft=400,
win_length=400, win_length=400,
hop_length=80, hop_length=160,
n_mels=n_mels, n_mels=n_mels,
f_min=0, f_min=0,
f_max=8000, f_max=8000,
power=2.0 power=3.0
) )
self.amplitude_to_db = AmplitudeToDB() self.amplitude_to_db = AmplitudeToDB()

View File

@@ -9,6 +9,8 @@ class ASRModel(Module):
self.conv1 = Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.conv1 = Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv3 = Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1))
self.proj = Linear(in_features=1280, out_features=640)
self.relu = ReLU() self.relu = ReLU()
self.encoder = Conformer(input_dim=input_dim, num_heads=num_heads, ffn_dim=ffn_dim, num_layers=num_layers, depthwise_conv_kernel_size=31, dropout=dropout) self.encoder = Conformer(input_dim=input_dim, num_heads=num_heads, ffn_dim=ffn_dim, num_layers=num_layers, depthwise_conv_kernel_size=31, dropout=dropout)
@@ -21,10 +23,12 @@ class ASRModel(Module):
x = self.relu(self.conv1(x)) # [batch, 16, n_mels/2, time/2] x = self.relu(self.conv1(x)) # [batch, 16, n_mels/2, time/2]
x = self.relu(self.conv2(x)) # [batch, 32, n_mels/4, time/4] x = self.relu(self.conv2(x)) # [batch, 32, n_mels/4, time/4]
x = self.relu(self.conv3(x)) # [batch, 32, n_mels/8, time/4]
# [B, channels, freq, time] → [B, time, channels*freq] # [B, channels, freq, time] → [B, time, channels*freq]
batch, channels, freq, time = x.shape batch, channels, freq, time = x.shape
x = x.permute(0, 3, 1, 2).reshape(batch, time, channels * freq) x = x.permute(0, 3, 1, 2).reshape(batch, time, channels * freq)
x = self.proj(x)
# lengths = torch.tensor([time] * batch, dtype=torch.long, device=x.device) # lengths = torch.tensor([time] * batch, dtype=torch.long, device=x.device)
lengths = ((mel_lengths + 1) // 2 + 1) // 2 # 两层 stride=2 lengths = ((mel_lengths + 1) // 2 + 1) // 2 # 两层 stride=2