first commit asr model
This commit is contained in:
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# data
|
||||||
|
.data
|
||||||
|
data
|
||||||
|
runs
|
||||||
|
config
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.11
|
||||||
28
pyproject.toml
Normal file
28
pyproject.toml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
[project]
|
||||||
|
name = "study-asr"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"ipython>=9.10.1",
|
||||||
|
"librosa>=0.11.0",
|
||||||
|
"matplotlib>=3.10.8",
|
||||||
|
"numpy>=2.4.4",
|
||||||
|
"pandas>=3.0.2",
|
||||||
|
"pillow>=12.2.0",
|
||||||
|
"pyrubberband>=0.4.0",
|
||||||
|
"setuptools<82",
|
||||||
|
"tensorboard>=2.20.0",
|
||||||
|
"tensorboardx>=2.6.5",
|
||||||
|
"torch==2.8.0",
|
||||||
|
"torchaudio==2.8.0",
|
||||||
|
"torchcodec==0.7.0",
|
||||||
|
"tqdm>=4.67.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[index]]
|
||||||
|
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||||
|
default = true
|
||||||
|
# url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
|
||||||
|
|
||||||
286
src/dataset.py
Normal file
286
src/dataset.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
import random
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import torchaudio
|
||||||
|
from torchaudio.transforms import FrequencyMasking, MelSpectrogram, AmplitudeToDB, Resample, TimeMasking, TimeStretch
|
||||||
|
from pathlib import Path
|
||||||
|
import torchaudio.functional as F
|
||||||
|
import pandas as pd
|
||||||
|
from typing import Dict, List, TypedDict
|
||||||
|
from handle.text_normalizer import collapse_spaces, normalize_extended_uyghur_characters
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
# 单个样本的数据结构(Dataset.__getitem__ 返回)
|
||||||
|
class BatchItem(TypedDict):
|
||||||
|
mel_spec: Tensor # [n_mels, time] Mel频谱
|
||||||
|
target_ids: Tensor # [seq_len] 目标文本的token IDs
|
||||||
|
target_text: str # 原始文本
|
||||||
|
audio_path: str # 音频文件路径
|
||||||
|
|
||||||
|
|
||||||
|
# 批量数据的数据结构(collate_fn 返回,DataLoader 输出)
|
||||||
|
class Batch(TypedDict):
|
||||||
|
mel_specs: Tensor # [batch, n_mels, time] padding后的Mel频谱
|
||||||
|
targets: Tensor # [batch, max_len] padding后的目标IDs
|
||||||
|
mel_lengths: Tensor # [batch] 每个样本的实际Mel长度
|
||||||
|
target_lengths: Tensor # [batch] 每个样本的实际目标长度
|
||||||
|
target_texts: List[str] # [batch] 原始文本列表
|
||||||
|
audio_paths: List[str] # [batch] 音频路径列表
|
||||||
|
|
||||||
|
class TsvFormat(TypedDict):
|
||||||
|
client_id: str
|
||||||
|
path: str
|
||||||
|
sentence: str
|
||||||
|
up_votes: int
|
||||||
|
down_votes: int
|
||||||
|
age: str
|
||||||
|
gender: str
|
||||||
|
locale: str
|
||||||
|
|
||||||
|
class CommonVoiceDataset(Dataset[BatchItem]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tsv_path: Path,
|
||||||
|
audio_dir: Path,
|
||||||
|
tokenizer: ASRTokenizer,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
n_mels: int = 80,
|
||||||
|
max_audio_len: int = 480000, # 30秒 @ 16kHz
|
||||||
|
augment: bool = True, # 是否启用数据增强
|
||||||
|
augment_prob: float = 0.5, # 数据增强的概率
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.audio_dir = Path(audio_dir)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.max_audio_len = max_audio_len
|
||||||
|
self.augment = augment
|
||||||
|
self.augment_prob = augment_prob
|
||||||
|
|
||||||
|
self.data: pd.DataFrame = pd.read_csv(tsv_path, sep='\t')
|
||||||
|
|
||||||
|
valid_indices = []
|
||||||
|
for index, row in self.data.iterrows():
|
||||||
|
audio_path: Path = self.audio_dir / row['path']
|
||||||
|
if audio_path.exists():
|
||||||
|
valid_indices.append(index)
|
||||||
|
|
||||||
|
self.data = self.data.loc[valid_indices].reset_index(drop=True)
|
||||||
|
|
||||||
|
# Mel频谱转换
|
||||||
|
self.mel_transform = MelSpectrogram(
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
n_fft=400,
|
||||||
|
win_length=400,
|
||||||
|
hop_length=80,
|
||||||
|
n_mels=n_mels,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
power=2.0
|
||||||
|
)
|
||||||
|
self.amplitude_to_db = AmplitudeToDB()
|
||||||
|
|
||||||
|
# SpecAugment 转换
|
||||||
|
self.time_masking = TimeMasking(time_mask_param=30) # 遮蔽最多30帧
|
||||||
|
self.freq_masking = FrequencyMasking(freq_mask_param=15) # 遮蔽最多15个频率
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def _load_audio(self, audio_path: Path) -> Tensor:
|
||||||
|
waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)
|
||||||
|
|
||||||
|
if sample_rate != self.sample_rate:
|
||||||
|
waveform = Resample(sample_rate, self.sample_rate)(waveform)
|
||||||
|
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
max_val = waveform.abs().max()
|
||||||
|
if max_val > 0:
|
||||||
|
waveform = waveform / max_val
|
||||||
|
|
||||||
|
if waveform.shape[1] > self.max_audio_len:
|
||||||
|
waveform = waveform[:, :self.max_audio_len]
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _extract_features(self, waveform: Tensor) -> Tensor:
|
||||||
|
""" 提取 Mel 频谱特征 """
|
||||||
|
mel_spec: Tensor = self.mel_transform(waveform)
|
||||||
|
log_mel_spec: Tensor = self.amplitude_to_db(mel_spec)
|
||||||
|
return log_mel_spec.squeeze(0) # [n_mels, time]
|
||||||
|
|
||||||
|
def _augment_waveform(self, waveform: Tensor) -> Tensor:
|
||||||
|
if not self.augment or random.random() > self.augment_prob:
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
# 1. voice Stretch/Compress
|
||||||
|
if random.random() < 0.5:
|
||||||
|
waveform = self._voice_stretch_or_compress(waveform=waveform)
|
||||||
|
|
||||||
|
if random.random() < 0.3:
|
||||||
|
waveform = self._drop_frames(waveform)
|
||||||
|
|
||||||
|
if random.random() < 0.4:
|
||||||
|
waveform = self._add_noise(waveform)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _voice_stretch_or_compress(self, waveform: Tensor) -> Tensor:
|
||||||
|
speed_factor = random.uniform(0.6, 1.4) # (Speed Change: 0.6x - 1.4x)
|
||||||
|
spec = torch.stft(
|
||||||
|
waveform.squeeze(0),
|
||||||
|
n_fft=400,
|
||||||
|
hop_length=160,
|
||||||
|
window=torch.hann_window(400).to(waveform.device),
|
||||||
|
return_complex=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 时间拉伸(不改变音高)
|
||||||
|
stretch = TimeStretch(
|
||||||
|
hop_length=160,
|
||||||
|
n_freq=201,
|
||||||
|
fixed_rate=speed_factor
|
||||||
|
)
|
||||||
|
stretched_spec = stretch(spec)
|
||||||
|
|
||||||
|
# 转回波形
|
||||||
|
waveform_stretched = torch.istft(
|
||||||
|
stretched_spec,
|
||||||
|
n_fft=400,
|
||||||
|
hop_length=160,
|
||||||
|
window=torch.hann_window(400).to(waveform.device)
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
return waveform_stretched
|
||||||
|
|
||||||
|
def _drop_frames(self, waveform: Tensor) -> Tensor:
|
||||||
|
audio_len = waveform.shape[1]
|
||||||
|
drop_ratio = random.uniform(0.05, 0.15)
|
||||||
|
drop_len = int(audio_len * drop_ratio)
|
||||||
|
|
||||||
|
if audio_len > drop_len:
|
||||||
|
start_pos = random.randint(0, audio_len - drop_len)
|
||||||
|
# clean
|
||||||
|
waveform = torch.cat([waveform[:, :start_pos], waveform[:, start_pos + drop_len:]], dim=1)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _add_noise(self, waveform: Tensor, snr_db: float = None) -> Tensor:
|
||||||
|
if snr_db is None:
|
||||||
|
snr_db = random.uniform(15, 25)
|
||||||
|
|
||||||
|
signal_power = waveform.norm(p=2)
|
||||||
|
|
||||||
|
snr_linear = 10 ** (snr_db / 10)
|
||||||
|
noise_power = signal_power / snr_linear
|
||||||
|
|
||||||
|
# Generate Gaussian noise
|
||||||
|
noise = torch.randn_like(waveform) * noise_power / waveform.shape[1] ** 0.5
|
||||||
|
|
||||||
|
noisy_waveform: Tensor = waveform + noise
|
||||||
|
|
||||||
|
max_val = noisy_waveform.abs().max()
|
||||||
|
if max_val > 0:
|
||||||
|
noisy_waveform = noisy_waveform / max_val
|
||||||
|
|
||||||
|
return noisy_waveform
|
||||||
|
|
||||||
|
def __getitem__(self, index) -> BatchItem:
|
||||||
|
row: TsvFormat = self.data.iloc[index]
|
||||||
|
audio_path: Path = self.audio_dir / row['path']
|
||||||
|
# text: str = unicodedata.normalize('NFC', row['sentence'].strip())
|
||||||
|
text: str = normalize_extended_uyghur_characters(collapse_spaces(row['sentence'].strip()))
|
||||||
|
|
||||||
|
waveform = self._load_audio(audio_path=audio_path)
|
||||||
|
|
||||||
|
# waveform = self._augment_waveform(waveform)
|
||||||
|
|
||||||
|
mel_spec = self._extract_features(waveform=waveform)
|
||||||
|
|
||||||
|
return BatchItem(
|
||||||
|
mel_spec=mel_spec,
|
||||||
|
target_ids=torch.tensor(self.tokenizer.encode(text), dtype=torch.long),
|
||||||
|
target_text=text,
|
||||||
|
audio_path=str(audio_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
def collate_fn(items: List[BatchItem]) -> Batch:
|
||||||
|
max_mel_len = max(item['mel_spec'].shape[1] for item in items)
|
||||||
|
max_target_len = max(len(item['target_ids']) for item in items)
|
||||||
|
|
||||||
|
batch_size = len(items)
|
||||||
|
n_mels = items[0]['mel_spec'].shape[0]
|
||||||
|
|
||||||
|
mel_specs = torch.zeros(batch_size, n_mels, max_mel_len)
|
||||||
|
targets = torch.zeros(batch_size, max_target_len, dtype=torch.long)
|
||||||
|
mel_lengths = torch.zeros(batch_size, dtype=torch.long)
|
||||||
|
target_lengths = torch.zeros(batch_size, dtype=torch.long)
|
||||||
|
|
||||||
|
target_texts = []
|
||||||
|
audio_paths = []
|
||||||
|
|
||||||
|
|
||||||
|
for i, item in enumerate(items):
|
||||||
|
mel_len = item['mel_spec'].shape[1]
|
||||||
|
target_len = len(item['target_ids'])
|
||||||
|
|
||||||
|
mel_specs[i, :, :mel_len] = item['mel_spec']
|
||||||
|
targets[i, :target_len] = item['target_ids']
|
||||||
|
mel_lengths[i] = mel_len
|
||||||
|
target_lengths[i] = target_len
|
||||||
|
|
||||||
|
target_texts.append(item['target_text'])
|
||||||
|
audio_paths.append(item['audio_path'])
|
||||||
|
|
||||||
|
return Batch(
|
||||||
|
mel_specs=mel_specs,
|
||||||
|
targets=targets,
|
||||||
|
mel_lengths=mel_lengths,
|
||||||
|
target_lengths=target_lengths,
|
||||||
|
target_texts=target_texts,
|
||||||
|
audio_paths=audio_paths
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloader(tsv_path: Path, audio_dir: Path, tokenizer: ASRTokenizer, batch_size: int = 8, shuffle: bool = True, augment: bool = True) -> DataLoader:
|
||||||
|
dataset = CommonVoiceDataset(tsv_path=tsv_path, audio_dir=audio_dir, tokenizer=tokenizer, augment=augment)
|
||||||
|
|
||||||
|
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, pin_memory=True, num_workers=8, prefetch_factor=8, persistent_workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 测试代码 ============
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
# 初始化tokenizer
|
||||||
|
tokenizer = ASRTokenizer(workspace_dir / 'config' / 'asr_vocab.json')
|
||||||
|
|
||||||
|
# 创建数据加载器
|
||||||
|
dataloader = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / 'data' / 'ug' / 'train.tsv',
|
||||||
|
audio_dir=workspace_dir / 'data' / 'ug' / 'clips',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=2,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试加载一个batch
|
||||||
|
print("测试数据加载:")
|
||||||
|
for batch in dataloader:
|
||||||
|
batch: Batch
|
||||||
|
print(f"Mel specs shape: {batch['mel_specs'].shape}")
|
||||||
|
print(f"Targets shape: {batch['targets'].shape}")
|
||||||
|
print(f"Mel lengths: {batch['mel_lengths']}")
|
||||||
|
print(f"Target lengths: {batch['target_lengths']}")
|
||||||
|
print(f"Target texts: {batch['target_texts']}")
|
||||||
|
print(f"Audio paths: {batch['audio_paths']}")
|
||||||
|
break
|
||||||
14
src/handle/export_model_state_dict.py
Normal file
14
src/handle/export_model_state_dict.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
device = "cuda:0"
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
input_checkpoint = workspace_dir.joinpath('.checkpoints/checkpoint_step_9500.pt')
|
||||||
|
output_checkpoint = workspace_dir.joinpath('.checkpoints/checkpoint_step.pt')
|
||||||
|
|
||||||
|
checkpoint = torch.load(input_checkpoint, map_location=device)
|
||||||
|
torch.save(checkpoint['model_state_dict'], output_checkpoint)
|
||||||
|
|
||||||
92
src/handle/export_vocab.py
Normal file
92
src/handle/export_vocab.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Counter
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
from text_handle import process_text
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
data_dir = Path("/home/blacksheep/projekts/study")
|
||||||
|
input_files = [
|
||||||
|
data_dir / "data/multilingual/hanziDB_translated-simplified.jsonl",
|
||||||
|
data_dir / "data/multilingual/hanziDB_translated_validationset-simplified.jsonl",
|
||||||
|
data_dir / "data/multilingual/tatoeba-tr-en-ug-uz-kz-zh-simplified.jsonl",
|
||||||
|
data_dir / "data/multilingual/export_csv_database.jsonl",
|
||||||
|
data_dir / "data/multilingual/tatoeba_sentences.jsonl",
|
||||||
|
]
|
||||||
|
output_file = workspace_dir / "config/asr_vocab_1.json"
|
||||||
|
syllabizes = []
|
||||||
|
|
||||||
|
# langs = ["uig_Arab"]
|
||||||
|
# for input_file in input_files:
|
||||||
|
# with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
|
# total_lines = sum(1 for _ in f)
|
||||||
|
# f.seek(0)
|
||||||
|
# for line in tqdm(f, total=total_lines, desc=f"Processing lines {input_file}"):
|
||||||
|
# data: dict[str, str] = json.loads(line)
|
||||||
|
# for lang in langs:
|
||||||
|
# if lang in data:
|
||||||
|
# syllabizes.extend(export_syllabize(data[lang]))
|
||||||
|
|
||||||
|
# print(f'data_dir syllabize len: {len(syllabizes):,}, set: {len(set(syllabizes)):,}')
|
||||||
|
|
||||||
|
|
||||||
|
tsv_files = [
|
||||||
|
# workspace_dir / "data/ug/test.tsv",
|
||||||
|
# workspace_dir / "data/ug/invalidated.tsv",
|
||||||
|
# workspace_dir / "data/ug/train.tsv",
|
||||||
|
# workspace_dir / "data/ug/validated.tsv",
|
||||||
|
# workspace_dir / "data/ug/reported.tsv",
|
||||||
|
# workspace_dir / "data/ug/dev.tsv",
|
||||||
|
# workspace_dir / "data/ug/other.tsv",
|
||||||
|
# workspace_dir / ".data/ug/invalidated.tsv",
|
||||||
|
workspace_dir / ".data/ug/train.tsv",
|
||||||
|
# workspace_dir / ".data/ug/clip_durations.tsv", # not sentence
|
||||||
|
# workspace_dir / ".data/ug/test.tsv",
|
||||||
|
# workspace_dir / ".data/ug/validated_sentences.tsv",
|
||||||
|
# workspace_dir / ".data/ug/other.tsv",
|
||||||
|
# workspace_dir / ".data/ug/validated.tsv",
|
||||||
|
# workspace_dir / ".data/ug/dev.tsv",
|
||||||
|
# workspace_dir / ".data/ug/unvalidated_sentences.tsv",
|
||||||
|
# workspace_dir / ".data/ug/reported.tsv" # Lacking sentence
|
||||||
|
]
|
||||||
|
|
||||||
|
for tsv_file in tsv_files:
|
||||||
|
data = pd.read_csv(tsv_file, sep='\t')
|
||||||
|
# 带进度条处理每行数据
|
||||||
|
for index, row in tqdm(data.iterrows(), total=len(data), desc=f"Processing {tsv_file}"):
|
||||||
|
syllabizes.extend(process_text(row['sentence'].strip()))
|
||||||
|
|
||||||
|
|
||||||
|
# 统计所有音节出现次数
|
||||||
|
syllable_counter = Counter(syllabizes)
|
||||||
|
# 过滤出出现100次以上的音节
|
||||||
|
freq_100_plus = {k: v for k, v in syllable_counter.items() if v >= 150}
|
||||||
|
freq_100_minus = {k: v for k, v in syllable_counter.items() if v <= 100}
|
||||||
|
|
||||||
|
# 保存100次以上的音节列表(只有音节,排序)
|
||||||
|
vocab = sorted(list(freq_100_plus.keys()), key=len)
|
||||||
|
with open(output_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(vocab, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# # # 保存100次以上的音节及次数
|
||||||
|
sorted_freq_100_plus = dict(sorted(freq_100_plus.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
with open(workspace_dir / 'config/syllables_freq_100_plus.json', 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(sorted_freq_100_plus, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# # 统计信息
|
||||||
|
print(f"总音节数: {len(syllabizes):,}")
|
||||||
|
print(f"唯一音节数: {len(syllable_counter):,}")
|
||||||
|
print(f"出现100次以上的音节数: {len(freq_100_plus):,}")
|
||||||
|
print(f"出现100次以下的音节数: {len(freq_100_minus):,}")
|
||||||
|
|
||||||
|
# 区间统计(不累积)
|
||||||
|
print("\n=== 音节使用次数统计(区间) ===")
|
||||||
|
for low in range(0, 100, 10):
|
||||||
|
high = low + 9
|
||||||
|
count = sum(1 for freq in syllable_counter.values() if low <= freq <= high)
|
||||||
|
print(f"出现 {low}-{high} 次: {count:,} 个音节")
|
||||||
|
# 100次以上
|
||||||
|
count_100plus = sum(1 for freq in syllable_counter.values() if freq >= 100)
|
||||||
|
print(f"出现 100+ 次: {count_100plus} 个音节")
|
||||||
356
src/handle/test.ipynb
Normal file
356
src/handle/test.ipynb
Normal file
File diff suppressed because one or more lines are too long
220
src/handle/test.py
Normal file
220
src/handle/test.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torchaudio
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
workspace_dir = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(workspace_dir.joinpath('src')))
|
||||||
|
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
from handle.text_handle import process_text
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
|
||||||
|
# 读取 TSV
|
||||||
|
# df = pd.read_csv(workspace_dir / '.data/ug/train.tsv', sep='\t')
|
||||||
|
|
||||||
|
# # 计算每个音频的时长(秒)
|
||||||
|
# durations = []
|
||||||
|
# for audio_path in tqdm(workspace_dir / '.data/ug/clips' / df['path'], desc="计算音频时长"):
|
||||||
|
# try:
|
||||||
|
# # 获取音频信息(不加载整个文件,速度快)
|
||||||
|
# info = torchaudio.info(audio_path)
|
||||||
|
# duration = info.num_frames / info.sample_rate
|
||||||
|
# durations.append(duration)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"读取失败: {audio_path}, 错误: {e}")
|
||||||
|
# durations.append(None)
|
||||||
|
|
||||||
|
# # 添加到 DataFrame
|
||||||
|
# df['duration'] = durations
|
||||||
|
|
||||||
|
# # 统计
|
||||||
|
# print(df['duration'].describe())
|
||||||
|
# print(f"超过20秒的样本: {(df['duration'] > 20).sum()}")
|
||||||
|
|
||||||
|
# 保存结果(可选)
|
||||||
|
# df.to_csv('data/audio/train_with_duration.tsv', sep='\t', index=False)
|
||||||
|
|
||||||
|
# audio_path = workspace_dir / '.data/ug/clips/common_voice_ug_40252722.mp3'
|
||||||
|
# info = torchaudio.info(audio_path)
|
||||||
|
# print("info:", info)
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def analyze_audio_quality(audio_path, sample_rate_target=16000):
|
||||||
|
"""
|
||||||
|
评估音频质量,返回多个指标
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载音频
|
||||||
|
waveform, sr = torchaudio.load(str(audio_path))
|
||||||
|
|
||||||
|
# 重采样 + 单声道
|
||||||
|
if sr != sample_rate_target:
|
||||||
|
resampler = torchaudio.transforms.Resample(sr, sample_rate_target)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
max_val = waveform.abs().max()
|
||||||
|
if max_val > 0:
|
||||||
|
waveform = waveform / max_val
|
||||||
|
|
||||||
|
# ===== 指标 1: 静音比例 =====
|
||||||
|
silence_threshold = 0.01
|
||||||
|
is_silence = waveform.abs() < silence_threshold
|
||||||
|
silence_ratio = is_silence.float().mean().item()
|
||||||
|
|
||||||
|
# ===== 指标 2: 动态范围 =====
|
||||||
|
frame_len = int(0.025 * sample_rate_target) # 25ms
|
||||||
|
hop_len = int(0.010 * sample_rate_target) # 10ms
|
||||||
|
|
||||||
|
num_frames = (waveform.shape[1] - frame_len) // hop_len + 1
|
||||||
|
if num_frames <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
energies = []
|
||||||
|
for i in range(num_frames):
|
||||||
|
start = i * hop_len
|
||||||
|
frame = waveform[:, start:start + frame_len]
|
||||||
|
rms = frame.pow(2).mean().sqrt().item()
|
||||||
|
energies.append(rms)
|
||||||
|
|
||||||
|
energies = np.array(energies)
|
||||||
|
energies = np.clip(energies, 1e-10, None)
|
||||||
|
|
||||||
|
log_energies = 20 * np.log10(energies)
|
||||||
|
dynamic_range = log_energies.max() - log_energies.min()
|
||||||
|
|
||||||
|
# ===== 指标 3: 语音活跃度 =====
|
||||||
|
energy_threshold = np.percentile(energies, 30)
|
||||||
|
voice_ratio = (energies > energy_threshold).mean()
|
||||||
|
|
||||||
|
# ===== 指标 4: 频谱中心频率 =====
|
||||||
|
spec_transform = torchaudio.transforms.Spectrogram(n_fft=512)
|
||||||
|
spec = spec_transform(waveform) # [1, freq, time]
|
||||||
|
|
||||||
|
# 用 torch 计算,避免 numpy axis 问题
|
||||||
|
freqs = torch.fft.rfftfreq(512, d=1.0/sample_rate_target)
|
||||||
|
spec_mean = spec.mean(dim=-1).squeeze() # [freq_bins] torch 张量
|
||||||
|
|
||||||
|
# 确保类型一致,用 torch 计算
|
||||||
|
spectral_centroid = (freqs * spec_mean).sum() / (spec_mean.sum() + 1e-10)
|
||||||
|
spectral_centroid = spectral_centroid.item()
|
||||||
|
|
||||||
|
# ===== 指标 5: 频谱通量 =====
|
||||||
|
spec_np = spec.squeeze().numpy() # [freq, time]
|
||||||
|
flux = np.abs(np.diff(spec_np, axis=1)).mean()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'duration': waveform.shape[1] / sample_rate_target,
|
||||||
|
'silence_ratio': silence_ratio,
|
||||||
|
'dynamic_range_db': float(dynamic_range),
|
||||||
|
'voice_ratio': float(voice_ratio),
|
||||||
|
'spectral_centroid': spectral_centroid,
|
||||||
|
'spectral_flux': float(flux),
|
||||||
|
'energy_std': float(energies.std()),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"失败: {audio_path}, {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 批量处理 =====
|
||||||
|
# df = pd.read_csv(workspace_dir / '.data/ug/train.tsv', sep='\t')
|
||||||
|
|
||||||
|
# results = []
|
||||||
|
# for _, row in tqdm(df.iterrows(), total=len(df), desc="评估音频质量"):
|
||||||
|
# audio_path = workspace_dir / '.data/ug/clips' / row['path']
|
||||||
|
# metrics = analyze_audio_quality(audio_path)
|
||||||
|
# if metrics:
|
||||||
|
# metrics['path'] = row['path']
|
||||||
|
# results.append(metrics)
|
||||||
|
|
||||||
|
# # 汇总
|
||||||
|
# quality_df = pd.DataFrame(results)
|
||||||
|
# print(quality_df.describe())
|
||||||
|
|
||||||
|
|
||||||
|
# # ===== 质量评分 =====
|
||||||
|
# def score_quality(row):
|
||||||
|
# score = 100
|
||||||
|
|
||||||
|
# if row['silence_ratio'] > 0.5:
|
||||||
|
# score -= 30
|
||||||
|
# elif row['silence_ratio'] > 0.3:
|
||||||
|
# score -= 15
|
||||||
|
|
||||||
|
# if row['dynamic_range_db'] < 10:
|
||||||
|
# score -= 25
|
||||||
|
# elif row['dynamic_range_db'] < 20:
|
||||||
|
# score -= 10
|
||||||
|
|
||||||
|
# if row['voice_ratio'] < 0.3:
|
||||||
|
# score -= 20
|
||||||
|
|
||||||
|
# if row['spectral_centroid'] > 4000:
|
||||||
|
# score -= 15
|
||||||
|
|
||||||
|
# return max(0, score)
|
||||||
|
|
||||||
|
# quality_df['quality_score'] = quality_df.apply(score_quality, axis=1)
|
||||||
|
|
||||||
|
# quality_df['grade'] = pd.cut(
|
||||||
|
# quality_df['quality_score'],
|
||||||
|
# bins=[0, 40, 60, 80, 100],
|
||||||
|
# labels=['差(弃用)', '较差', '一般', '良好']
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print("\n质量分级统计:")
|
||||||
|
# print(quality_df['grade'].value_counts())
|
||||||
|
|
||||||
|
# # 保存结果
|
||||||
|
# quality_df.to_csv('audio_quality_analysis.csv', index=False)
|
||||||
|
# print("\n结果已保存到 audio_quality_analysis.csv")
|
||||||
|
|
||||||
|
|
||||||
|
# Test
|
||||||
|
# text = "مەن مەكتەپكە باردىم"
|
||||||
|
# text = "ئاۋۋال كومپىيوتىرنى بىر كۆرسەم شۇنىڭغا قاراپ ماسلاشتۇرسام بوللاتتى"
|
||||||
|
text = "قاپاقنى پۇلغا ئالماي، باراڭنى پۇلغا ئاپتۇ"
|
||||||
|
# text = "ئامېىقى جاھان گۈرلىكىە قارشتۇرۇپ چوكشەڭگە ياردەم بېرىش ئۇرۇشىنىڭ ئلۇغ غەلبىسى جۇڭگو خەلقى ئورندىن دەستۇرغاندىكىن، دۇنيانىڭ شەرقىدە قەت كۆتۈرگەنلكىنى خىتاب لامىسى."
|
||||||
|
# text = "ئاۋسترالىيە"
|
||||||
|
|
||||||
|
text = " ھەر قانداق ئىشتا كىشىلەر بىلەن كېڭىشىش كېرەك. لېكىن كۆڭۈل تارتقان ئىشنى قىلىۋېرىش كېرەك."
|
||||||
|
text = "\" 15 يىل بۇرۇن مەن بىلەن سەي چۇڭشىن (ئالى بابانىڭ قۇرغۇچىسىدىن بىرى) ئامېرىكىغا بېرىپ 30 نەچچە مەبلەغ سالغۇچى شىركەتلەر بىلەن كۆرۈشكىنىمىزدە ئۇلار بىزنى رەت قىلىپ ئىشىك سىرتىدا قالدۇرغاندى. ھېچكىم بىزگە ۋە كەلگۈسىمىزگە ئىشەنمىگەن ئىدى، ھېچكىممۇ بىزنىڭ ھازىر 300 مىليارد سودىنى تاماملىيالايدىغانلىقىمىزنى ئويلاپ خىيالىغىمۇ كەلتۈرمىگەن ئىدى.\" ئالى بابانىنىڭ ئورگىنى تەرىپىدىن ئىشلەنگەن ئىگىلىك تىكلەش ھۆججەتلىك فىلىمى \" بۇ چۈش ئەمەس\" نىڭ باش قىسمىدا مۇنداق بىر داڭلىق سۆز چىقىدۇ:\" مەن تاغدىن ئۆتكەن ۋاقتىمدا تاغ ماڭا گەپ قىلمىدى، مەن دېڭىزنى كېچىپ ئۆتكىنىمدە دېڭىز ماڭا گەپ قىلمىدى.\" بۇ فىلىمنىڭ قوشۇمچە ئىسمى بولسا \" مايۈن ۋە ئۇنىڭ مەڭگۈلۈك ' ياش ئالى'سى\" بولۇپ، ھەرخىل خەتەر ۋە قىيىنلىقنى بىرمۇ بىر يەڭگەن مايۈن ۋە ئۇنىڭ ئالى بابا قوشۇنىدىكى ھەمكارلاشقۇچىلىرىنىڭ قەيسەر كەچمىشى جانلىق بايان قىلىنغان."
|
||||||
|
# text = "يۆ جيەنتاۋ يېقىنقى بىر مەزگىلدە، خىزمەتداشلىرى بىلەن نۇرغۇن قىيىنچىلىققا ئۇچرىغان شوپۇرغا ياردەم قىلغانلىقىنى، شۇنداقلا يېمەكلىك ۋە سۇ يەتكۈزۈپ بەرگەنلىكىنى، ئەمما ئاياغ سوۋغا قىلىشى تۇنجى قېتىم ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى قېتىم شوپۇرغا ئاياغ سوۋغا قىلىشى ئىكەنلىكىنى ئېيتتى."
|
||||||
|
text = "يۆ جۇڭمىڭ مۇنداق دېدى: 2019- يىلى 12-ئايدا، مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 44- قېتىملىق كومىتېت باشلىقلىرى يىغىنى مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 2020-يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى پىرىنسىپ جەھەتتىن ماقۇللىدى، خىزمەت تەرتىپى بويىچە، 13-نۆۋەتلىك مەملىكەتلىك خەلق قۇرۇلتىيى 3-يىغىنىنىڭ روھى ۋە ۋەكىللەرنىڭ تەكلىپ-تەۋسىيەلىرىگە ئاساسەن پىلاننى تەڭشەش كېرەك. بۇ يىل 6-ئاينىڭ 1-كۈنى، 58-قېتىملىق كومىتېت باشلىقلىرى يىغىنى تەڭشەلگەندىن كېيىنكى يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى قاراپ چىقىپ ماقۇللىدى."
|
||||||
|
text = " ئاشقازان-ئۈچەينىڭ لۆمۈلدىشىنى ئىلگىرى سۈرىدىغان دورىنى تاماقتىن بۇرۇن ئىستېمال قىلىش كېرەك."
|
||||||
|
# text = "مەن مەكتەپكە باردىم"
|
||||||
|
# text = "غەرىپئەللىرى"
|
||||||
|
# text = "ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى"
|
||||||
|
# text = " يېزىدىكى كەڭ، ئازادە ئۆي، باغلىرىنى تاشلاپ، خەقنىڭ ھويلىسىدا قورۇنۇپ-ئەيمىنىپ يەر دەسسەپ يۈردى. "
|
||||||
|
# text = "بۇ يىللىق ئەمگەكچىلەر بايرىمىدا، 1-مايدىن 5-مايغىچە جەمئىي بەش كۈن دەم ئېلىشقا قويۇپ بېرىلىدۇ. 9-ماي (شەنبە) نورمال خىزمەت قىلىنىدۇ. شۇنىڭ بىلەن بىر ۋاقىتتا، ئۈرۈمچى شەھىرى تىيانشان رايونى، سايباغ رايونى، داۋانچىڭ رايونى، ئۈرۈمچى ناھىيەسىدىكى ئوتتۇرا، باشلانغۇچ مەكتەپلەر 29-ئاپرېلدىن 30-ئاپرېلغىچە ئەتىيازلىق دەم ئېلىشقا قويۇپ بېرىدۇ؛ سانجى ئوبلاستىدىكى ئوتتۇرا، باشلانغۇچ مەكتەپلەرنىڭ ئەتىيازلىق دەم ئېلىشى ئۈچ كۈن بولۇپ، بۇنىڭ ئىچىدە ئىككى كۈن 29-، 30-ئاپرېلغا ئورۇنلاشتۇرۇلىدۇ، قالغان بىر كۈن 1-ئىيۇنغا ئورۇنلاشتۇرۇلىدۇ، قۇربان ھېيتلىق دەم ئېلىش بىلەن بىرلەشتۈرۈلۈپ، ئۇدا ئالتە كۈن دەم ئېلىنىدۇ."
|
||||||
|
# text = "ئەڭ يېڭى دەم ئېلىشقا قويۇپ بېرىش ئۇقتۇرۇشى! تەقەززالىق بىلەن كۈتكەن يەنە بىر دەم ئېلىش كېلەي دەپ قالدى! گوۋۇيۈەن بەنگۇڭتىڭىنىڭ «2025-يىللىق قىسمەن بايرام، دەم ئېلىش كۈنلىرىنى ئورۇنلاشتۇرۇش توغرىسىدىكى ئۇقتۇرۇشى»غا ئاساسەن، دۈەنۋۇ بايرىمىلىق دەم ئېلىشقا قويۇپ بېرىش ئورۇنلاشتۇرۇلۇشى تۆۋەندىكىچە: 5-ئاينىڭ 31-كۈنى (شەنبە)دىن 6-ئاينىڭ 2-كۈنىگىچە جەمئىي ئۈچ كۈن دەم ئېلىشقا قويۇپ بېرىلىدۇ! بۇ قېتىمقى دۈەنۋۇ بايرىمىلىق دەم ئېلىش ۋاقتى تەڭشەلمەيدۇ!"
|
||||||
|
# text = "مەن شۇ چاغدا چۈشۈمدىمۇ كۆرۈپ باقمىغان مۇنچىۋالا جىق قېرىنداشلىرىم، جىگەرلىرىمنىڭ بارلىقىدىن سۆيۈندۈم."
|
||||||
|
# text = "ئامېرىكا جاھان گىرلىكىگە قارشى تۇرۇپ، چاۋشەنگە ياردەم بىرىش ئۇرۇشىنىڭ ئۇلۇغ غەلبىسى، جۇڭگو خەلقى ئورۇندىن دەس تۇرغاندىنكىن دۇنيانىڭ شەرقىدە قەد كۆتۈرگەنلىكىنىڭ خىتاپنامىس"
|
||||||
|
|
||||||
|
|
||||||
|
# result = export_syllabize(text)
|
||||||
|
result = process_text(text)
|
||||||
|
print(f"Original: {text}")
|
||||||
|
print(f"Syllables: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = ASRTokenizer(vocab_path=workspace_dir / 'config/asr_vocab.json')
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
ids = tokenizer.encode(text=text)
|
||||||
|
# print(ids)
|
||||||
|
print('|'.join(tokenizer.decode([id]) for id in ids))
|
||||||
195
src/handle/text_handle.py
Normal file
195
src/handle/text_handle.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
|
||||||
|
from .text_normalizer import UYGHUR_LETTERS, clean_input_text
|
||||||
|
|
||||||
|
# def uighur_syllabize(text: str):
|
||||||
|
# # Uyghur Vowels
|
||||||
|
# vowels = "ئاەئەوۆئۇئۈئىئېاەوۆۇۈىې"
|
||||||
|
|
||||||
|
# def split_word(word):
|
||||||
|
# syllables = []
|
||||||
|
# current_syllable = ""
|
||||||
|
# i = 0
|
||||||
|
|
||||||
|
# # Helper to check if a character is a vowel
|
||||||
|
# is_vowel = lambda char: char in vowels
|
||||||
|
|
||||||
|
# while i < len(word):
|
||||||
|
# current_syllable += word[i]
|
||||||
|
|
||||||
|
# # If we find a vowel, look ahead to decide where to split
|
||||||
|
# if is_vowel(word[i]):
|
||||||
|
# # Check next characters
|
||||||
|
# remaining = word[i+1:]
|
||||||
|
|
||||||
|
# # Rule 1: V-V (e.g., 'ائ') -> Split after first vowel
|
||||||
|
# if len(remaining) >= 1 and is_vowel(remaining[0]):
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# current_syllable = ""
|
||||||
|
|
||||||
|
# # Rule 2: V-C-V (e.g., 'ba-ra') -> Split after first vowel
|
||||||
|
# elif len(remaining) >= 2 and not is_vowel(remaining[0]) and is_vowel(remaining[1]):
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# current_syllable = ""
|
||||||
|
|
||||||
|
# # Rule 3: V-C-C-V (e.g., 'mek-tep') -> Split after first consonant
|
||||||
|
# elif len(remaining) >= 3 and not is_vowel(remaining[0]) and not is_vowel(remaining[1]) and is_vowel(remaining[2]):
|
||||||
|
# current_syllable += remaining[0]
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# current_syllable = ""
|
||||||
|
# i += 1 # Skip the consonant we just added
|
||||||
|
|
||||||
|
# # Rule 4: End of word or C-C-C clusters (rare in native words)
|
||||||
|
# # We keep going until we find a clear boundary or end of word
|
||||||
|
# i += 1
|
||||||
|
|
||||||
|
# if current_syllable:
|
||||||
|
# syllables.append(current_syllable)
|
||||||
|
# return syllables
|
||||||
|
|
||||||
|
# # Clean text and process word by word
|
||||||
|
# words = text.split()
|
||||||
|
# all_syllables = []
|
||||||
|
# for w in words:
|
||||||
|
# all_syllables.extend(split_word(w))
|
||||||
|
|
||||||
|
# #merge ئ to next syllable. single ئ no any semantic meaning and no it's dedicated sound.
|
||||||
|
# original_ayllables = [*all_syllables]
|
||||||
|
# all_syllables.clear()
|
||||||
|
# handled = True
|
||||||
|
# for s in original_ayllables:
|
||||||
|
# if s == "ئ":
|
||||||
|
# handled = False
|
||||||
|
# continue
|
||||||
|
# if not handled:
|
||||||
|
# s = "ئ" + s
|
||||||
|
# handled = True
|
||||||
|
# all_syllables.append(s)
|
||||||
|
|
||||||
|
# return all_syllables
|
||||||
|
|
||||||
|
# test_words = ['تەكلىپتەۋسىيەلىرىگە', 'تەكلىپ-تەۋسىيەلىرىگە']
|
||||||
|
# for tw in test_words:
|
||||||
|
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
|
||||||
|
# تەكلىپتەۋسىيەلىرىگە -> ['تەك', 'لىپ', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
# تەكلىپ-تەۋسىيەلىرىگە -> ['تەك', 'لىپ-تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
|
||||||
|
|
||||||
|
def uighur_syllabize(text: str):
|
||||||
|
# Uyghur Vowels (ASU)
|
||||||
|
vowels = "ئاەئەوۆئۇئۈئىئېاەوۆۇۈىې"
|
||||||
|
|
||||||
|
def is_vowel(char):
|
||||||
|
return char in vowels
|
||||||
|
|
||||||
|
def split_word(word):
|
||||||
|
syllables = []
|
||||||
|
i = 0
|
||||||
|
last_split = 0
|
||||||
|
|
||||||
|
while i < len(word):
|
||||||
|
# Look for vowel patterns to determine boundaries
|
||||||
|
if is_vowel(word[i]):
|
||||||
|
rem = word[i+1:]
|
||||||
|
|
||||||
|
# Rule: V-V -> Split after first vowel
|
||||||
|
if len(rem) >= 1 and is_vowel(rem[0]):
|
||||||
|
syllables.append(word[last_split:i+1])
|
||||||
|
last_split = i + 1
|
||||||
|
|
||||||
|
# Rule: V-C-V -> Split after vowel (e.g., ba-ra)
|
||||||
|
elif len(rem) >= 2 and not is_vowel(rem[0]) and is_vowel(rem[1]):
|
||||||
|
syllables.append(word[last_split:i+1])
|
||||||
|
last_split = i + 1
|
||||||
|
|
||||||
|
# Rule: V-C-C-V -> Split between consonants (e.g., mek-tep)
|
||||||
|
elif len(rem) >= 3 and not is_vowel(rem[0]) and not is_vowel(rem[1]) and is_vowel(rem[2]):
|
||||||
|
syllables.append(word[last_split:i+2])
|
||||||
|
last_split = i + 2
|
||||||
|
i += 1 # Skip first C
|
||||||
|
|
||||||
|
# Rule: V-C-C-C-V (VCCCV) -> Split after second consonant (e.g., gert-mek)
|
||||||
|
# This fixes "gertmek", "eytqan", "partlap", "dostlar"
|
||||||
|
elif len(rem) >= 4 and not is_vowel(rem[0]) and not is_vowel(rem[1]) and not is_vowel(rem[2]) and is_vowel(rem[3]):
|
||||||
|
syllables.append(word[last_split:i+3])
|
||||||
|
last_split = i + 3
|
||||||
|
i += 2 # Skip first two Cs
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Add the remaining part of the word
|
||||||
|
if last_split < len(word):
|
||||||
|
syllables.append(word[last_split:])
|
||||||
|
|
||||||
|
return syllables
|
||||||
|
|
||||||
|
# Process word by word
|
||||||
|
words = text.split()
|
||||||
|
final_output = []
|
||||||
|
|
||||||
|
for w in words:
|
||||||
|
raw_syllables = split_word(w)
|
||||||
|
|
||||||
|
# Handle Hemze (ئ) re-merging
|
||||||
|
processed = []
|
||||||
|
skip_next = False
|
||||||
|
for j in range(len(raw_syllables)):
|
||||||
|
if skip_next:
|
||||||
|
skip_next = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
s = raw_syllables[j]
|
||||||
|
# If the current syllable is JUST "ئ" or ends in "ئ", merge it with next
|
||||||
|
if s == "ئ" and j + 1 < len(raw_syllables):
|
||||||
|
processed.append("ئ" + raw_syllables[j+1])
|
||||||
|
skip_next = True
|
||||||
|
else:
|
||||||
|
processed.append(s)
|
||||||
|
final_output.extend(processed)
|
||||||
|
|
||||||
|
return final_output
|
||||||
|
|
||||||
|
# Verification
|
||||||
|
# test_words = ["گەرتمەك", "ئېيتقان", "دوستلار", "مەكتەپكە", ' كومپېيۇتېر تورى زادى قانداق نەرسىدۇ؟ ']
|
||||||
|
# for tw in test_words:
|
||||||
|
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
|
||||||
|
|
||||||
|
# test_words = ['تەكلىپتەۋسىيەلىرىگە', 'تەكلىپ-تەۋسىيەلىرىگە']
|
||||||
|
# for tw in test_words:
|
||||||
|
# print(f"{tw} -> {uighur_syllabize(tw.strip())}")
|
||||||
|
|
||||||
|
# تەكلىپتەۋسىيەلىرىگە -> ['تەك', 'لىپ', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
# تەكلىپ-تەۋسىيەلىرىگە -> ['تەك', 'لىپ-', 'تەۋ', 'سى', 'يە', 'لى', 'رى', 'گە']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
text = " ھەر قانداق ئىشتا كىشىلەر بىلەن كېڭىشىش كېرەك. لېكىن كۆڭۈل تارتقان ئىشنى قىلىۋېرىش كېرەك."
|
||||||
|
text = "\" 15 يىل بۇرۇن مەن بىلەن سەي چۇڭشىن (ئالى بابانىڭ قۇرغۇچىسىدىن بىرى) ئامېرىكىغا بېرىپ 30 نەچچە مەبلەغ سالغۇچى شىركەتلەر بىلەن كۆرۈشكىنىمىزدە ئۇلار بىزنى رەت قىلىپ ئىشىك سىرتىدا قالدۇرغاندى. ھېچكىم بىزگە ۋە كەلگۈسىمىزگە ئىشەنمىگەن ئىدى، ھېچكىممۇ بىزنىڭ ھازىر 300 مىليارد سودىنى تاماملىيالايدىغانلىقىمىزنى ئويلاپ خىيالىغىمۇ كەلتۈرمىگەن ئىدى.\" ئالى بابانىنىڭ ئورگىنى تەرىپىدىن ئىشلەنگەن ئىگىلىك تىكلەش ھۆججەتلىك فىلىمى \" بۇ چۈش ئەمەس\" نىڭ باش قىسمىدا مۇنداق بىر داڭلىق سۆز چىقىدۇ:\" مەن تاغدىن ئۆتكەن ۋاقتىمدا تاغ ماڭا گەپ قىلمىدى، مەن دېڭىزنى كېچىپ ئۆتكىنىمدە دېڭىز ماڭا گەپ قىلمىدى.\" بۇ فىلىمنىڭ قوشۇمچە ئىسمى بولسا \" مايۈن ۋە ئۇنىڭ مەڭگۈلۈك ' ياش ئالى'سى\" بولۇپ، ھەرخىل خەتەر ۋە قىيىنلىقنى بىرمۇ بىر يەڭگەن مايۈن ۋە ئۇنىڭ ئالى بابا قوشۇنىدىكى ھەمكارلاشقۇچىلىرىنىڭ قەيسەر كەچمىشى جانلىق بايان قىلىنغان."
|
||||||
|
text = "يۆ جيەنتاۋ يېقىنقى بىر مەزگىلدە، خىزمەتداشلىرى بىلەن نۇرغۇن قىيىنچىلىققا ئۇچرىغان شوپۇرغا ياردەم قىلغانلىقىنى، شۇنداقلا يېمەكلىك ۋە سۇ يەتكۈزۈپ بەرگەنلىكىنى، ئەمما ئاياغ سوۋغا قىلىشى تۇنجى قېتىم ئىكەنلىكىنى، بۈگۈنكىسى 20 يىللىق ساقچىلىق جەريانىدا تۇنجى قېتىم شوپۇرغا ئاياغ سوۋغا قىلىشى ئىكەنلىكىنى ئېيتتى."
|
||||||
|
text = "يۆ جۇڭمىڭ مۇنداق دېدى: 2019- يىلى 12-ئايدا، مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 44- قېتىملىق كومىتېت باشلىقلىرى يىغىنى مەملىكەتلىك خەلق قۇرۇلتىيى دائىمىي كومىتېتىنىڭ 2020-يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى پىرىنسىپ جەھەتتىن ماقۇللىدى، خىزمەت تەرتىپى بويىچە، 13-نۆۋەتلىك مەملىكەتلىك خەلق قۇرۇلتىيى 3-يىغىنىنىڭ روھى ۋە ۋەكىللەرنىڭ تەكلىپ-تەۋسىيەلىرىگە ئاساسەن پىلاننى تەڭشەش كېرەك. بۇ يىل 6-ئاينىڭ 1-كۈنى، 58-قېتىملىق كومىتېت باشلىقلىرى يىغىنى تەڭشەلگەندىن كېيىنكى يىللىق قانۇن چىقىرىش خىزمىتى پىلانىنى قاراپ چىقىپ ماقۇللىدى."
|
||||||
|
text = "مەن مەكتەپكە باردىم"
|
||||||
|
|
||||||
|
# text = normalize_extended_uyghur_characters(text=text)
|
||||||
|
# print(clean_uyghur(text=text))
|
||||||
|
|
||||||
|
def process_text(text: str) -> list[str]:
|
||||||
|
text = clean_input_text(text=text)
|
||||||
|
result = []
|
||||||
|
current_word_chars = []
|
||||||
|
|
||||||
|
for char in text:
|
||||||
|
if char in UYGHUR_LETTERS:
|
||||||
|
current_word_chars.append(char)
|
||||||
|
else:
|
||||||
|
# 遇到非字母字符,先处理缓存的单词
|
||||||
|
if current_word_chars:
|
||||||
|
word = ''.join(current_word_chars)
|
||||||
|
result.extend(uighur_syllabize(word))
|
||||||
|
current_word_chars.clear()
|
||||||
|
# 非字母字符直接加入
|
||||||
|
result.append(char)
|
||||||
|
|
||||||
|
# 处理末尾的单词
|
||||||
|
if current_word_chars:
|
||||||
|
word = ''.join(current_word_chars)
|
||||||
|
result.extend(uighur_syllabize(word))
|
||||||
|
|
||||||
|
return result
|
||||||
126
src/handle/text_normalizer.py
Normal file
126
src/handle/text_normalizer.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
UYGHUR_LETTERS = {
|
||||||
|
'ا', 'ە', 'ب', 'پ', 'ت', 'ج', 'چ', 'خ', 'د', 'ر', 'ز', 'ژ', 'س', 'ش', 'غ', 'ف',
|
||||||
|
'ق', 'ك', 'گ', 'ڭ', 'ل', 'م', 'ن', 'ھ', 'و', 'ۇ', 'ۆ', 'ۈ', 'ۋ', 'ې', 'ى', 'ي',
|
||||||
|
"ئ"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
uyghur_symbols = {
|
||||||
|
"ھەرىپلەر": [
|
||||||
|
'ا', 'ە', 'ب', 'پ', 'ت', 'ج', 'چ', 'خ', 'د', 'ر', 'ز', 'ژ', 'س', 'ش', 'غ', 'ف',
|
||||||
|
'ق', 'ك', 'گ', 'ڭ', 'ل', 'م', 'ن', 'ھ', 'و', 'ۇ', 'ۆ', 'ۈ', 'ۋ', 'ې', 'ى', 'ي',
|
||||||
|
"ئ"
|
||||||
|
],
|
||||||
|
"تىنىش_بەلگىلىرى": [
|
||||||
|
"۔", "،", "؟", "!", "-", "«", "»", "؛", ":", "'", "\"", "]", "[", " ", "›", "‹"
|
||||||
|
],
|
||||||
|
"سانلار": [
|
||||||
|
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
english_symbols = {
|
||||||
|
"characters": [
|
||||||
|
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
|
||||||
|
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'
|
||||||
|
],
|
||||||
|
"symbols": [
|
||||||
|
".", "!", "?", "-", "<", ">", "=", "+", "*", "/", "|", "\\", "~", "_", "^", "@", "{", "}", "[", "]", "(", ")", "#", "%", "$", "€", "£", "¥", " "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Export all symbols:
|
||||||
|
symbols_list = uyghur_symbols["ھەرىپلەر"] + uyghur_symbols["تىنىش_بەلگىلىرى"] + uyghur_symbols["سانلار"] + english_symbols["characters"] + english_symbols["symbols"]
|
||||||
|
symbols = {i: i for i in symbols_list}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_extended_uyghur_characters(text: str) -> str:
|
||||||
|
characters = [
|
||||||
|
#["standard", "individual", "beginning", "middle", "end"]
|
||||||
|
["ئ","ﺋ","ﺋ","ﺌ","ﺌ"],
|
||||||
|
["ا","ﺍ","ﺍ","ﺎ","ﺎ"],
|
||||||
|
["ە","ﻩ","ﻩ","ﻪ","ﻪ"],
|
||||||
|
["ب","ﺏ","ﺑ","ﺒ","ﺐ"],
|
||||||
|
["پ","ﭖ","ﭘ","ﭙ","ﭗ"],
|
||||||
|
["ت","ﺕ","ﺗ","ﺘ","ﺖ"],
|
||||||
|
["ج","ﺝ","ﺟ","ﺠ","ﺞ"],
|
||||||
|
["چ","ﭺ","ﭼ","ﭽ","ﭻ"],
|
||||||
|
["خ","ﺥ","ﺧ","ﺨ","ﺦ"],
|
||||||
|
["د","ﺩ","ﺩ","ﺪ","ﺪ"],
|
||||||
|
["ر","ﺭ","ﺭ","ﺮ","ﺮ"],
|
||||||
|
["ز","ﺯ","ﺯ","ﺰ","ﺰ"],
|
||||||
|
["ژ","ﮊ","ﮊ","ﮋ","ﮋ"],
|
||||||
|
["س","ﺱ","ﺳ","ﺴ","ﺲ"],
|
||||||
|
["ش","ﺵ","ﺷ","ﺸ","ﺶ"],
|
||||||
|
["غ","ﻍ","ﻏ","ﻐ","ﻎ"],
|
||||||
|
["ف","ﻑ","ﻓ","ﻔ","ﻒ"],
|
||||||
|
["ق","ﻕ","ﻗ","ﻘ","ﻖ"],
|
||||||
|
["ك","ﻙ","ﻛ","ﻜ","ﻚ"],
|
||||||
|
["گ","ﮒ","ﮔ","ﮕ","ﮓ"],
|
||||||
|
["ڭ","ﯓ","ﯕ","ﯖ","ﯔ"],
|
||||||
|
["ل","ﻝ","ﻟ","ﻠ","ﻞ"],
|
||||||
|
["م","ﻡ","ﻣ","ﻤ","ﻢ"],
|
||||||
|
["ن","ﻥ","ﻧ","ﻨ","ﻦ"],
|
||||||
|
["ھ","ﮪ","ﮬ","ﮭ","ﮫ"],
|
||||||
|
["و","ﻭ","ﻭ","ﻮ","ﻮ"],
|
||||||
|
["ۇ","ﯗ","ﯗ","ﯘ","ﯘ"],
|
||||||
|
["ۆ","ﯙ","ﯙ","ﯚ","ﯚ"],
|
||||||
|
["ۈ","ﯛ","ﯛ","ﯜ","ﯜ"],
|
||||||
|
["ۋ","ﯞ","ﯞ","ﯟ","ﯟ"],
|
||||||
|
["ې","ﯤ","ﯦ","ﯧ","ﯥ"],
|
||||||
|
["ى","ﻯ","ﯨ","ﯩ","ﻰ"],
|
||||||
|
["ي","ﻱ","ﻳ","ﻴ","ﻲ"],
|
||||||
|
["ۅ","ﯠ","ﯠ","ﯡ","ﯡ"],
|
||||||
|
["ۉ","ﯢ","ﯢ","ﯣ","ﯣ"],
|
||||||
|
["ح","ﺡ","ﺣ","ﺤ","ﺢ"],
|
||||||
|
["ع","ﻉ","ﻋ","ﻌ","ﻊ"]
|
||||||
|
]
|
||||||
|
replacement_table: dict[str, str] = {}
|
||||||
|
#Create a replacement table.
|
||||||
|
for char_map in characters:
|
||||||
|
for char in char_map[1:]:
|
||||||
|
replacement_table[char] = char_map[0]
|
||||||
|
|
||||||
|
replacement_table['ﻼ'] = "لا" #add some additional exceptional symbols
|
||||||
|
|
||||||
|
text = text.replace('ئ', 'ئ')
|
||||||
|
clean = ""
|
||||||
|
#Replace the extended characters with their standard unicode ones.
|
||||||
|
for char in text:
|
||||||
|
if char in replacement_table:
|
||||||
|
char = replacement_table[char]
|
||||||
|
clean += char
|
||||||
|
|
||||||
|
return clean
|
||||||
|
|
||||||
|
|
||||||
|
def is_uyghur_char(char: str) -> bool:
|
||||||
|
"""判断是否是维吾尔语字母"""
|
||||||
|
return all(c in UYGHUR_LETTERS for c in char)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uyghur_text(text: str) -> bool:
|
||||||
|
"""清洗后检查是否全为维吾尔语"""
|
||||||
|
return all(c in symbols for c in text)
|
||||||
|
|
||||||
|
def clean_unknown_symbols(text: str) -> str:
|
||||||
|
return ''.join(c for c in text if c in symbols)
|
||||||
|
|
||||||
|
import re
|
||||||
|
def collapse_spaces(text: str) -> str:
|
||||||
|
return re.sub(r'\s+', ' ', text)
|
||||||
|
|
||||||
|
def clean_english_text(text: str) -> str:
|
||||||
|
return re.sub(r'[a-zA-Z]', ' ', text)
|
||||||
|
|
||||||
|
def clean_chinese_text(text: str) -> str:
|
||||||
|
return re.sub(r'[\u4e00-\u9fff]', ' ', text)
|
||||||
|
|
||||||
|
def clean_http_links(text: str) -> str:
|
||||||
|
return re.sub(r'https?://[^\s]+', ' ', text)
|
||||||
|
|
||||||
|
def clean_input_text(text: str) -> str:
|
||||||
|
text = collapse_spaces(text)
|
||||||
|
text = clean_http_links(text)
|
||||||
|
text = normalize_extended_uyghur_characters(text)
|
||||||
|
text = clean_unknown_symbols(text)
|
||||||
|
return text
|
||||||
433
src/inference.ipynb
Normal file
433
src/inference.ipynb
Normal file
File diff suppressed because one or more lines are too long
183
src/inference.py
Normal file
183
src/inference.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from torch import Tensor, no_grad, device
|
||||||
|
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB, Resample, TimeStretch
|
||||||
|
from pathlib import Path
|
||||||
|
import torchaudio.functional as F
|
||||||
|
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
from model import ASRModel
|
||||||
|
|
||||||
|
CONFIG = {
|
||||||
|
# 模型配置
|
||||||
|
'input_dim': 640,
|
||||||
|
'num_heads': 8,
|
||||||
|
'ffn_dim': 2048,
|
||||||
|
'num_layers': 8,
|
||||||
|
'dropout': 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
|
class ASRInference:
|
||||||
|
def __init__(self, model_path: Path, vocab_path: Path, device: device, augment: bool = True, augment_prob: float = 0.5) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.augment: bool = augment
|
||||||
|
self.augment_prob: float = augment_prob
|
||||||
|
self.tokenizer = ASRTokenizer(vocab_path=vocab_path)
|
||||||
|
self.model = ASRModel(
|
||||||
|
vocab_size=self.tokenizer.vocab_size(),
|
||||||
|
input_dim=CONFIG['input_dim'],
|
||||||
|
num_heads=CONFIG['num_heads'],
|
||||||
|
ffn_dim=CONFIG['ffn_dim'],
|
||||||
|
num_layers=CONFIG['num_layers'],
|
||||||
|
dropout=CONFIG['dropout'],
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
print(f"params params: {self.model.get_num_params():,}",)
|
||||||
|
|
||||||
|
self.sample_rate = 16000
|
||||||
|
self.mel_transform = MelSpectrogram(
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
n_fft=400,
|
||||||
|
win_length=400,
|
||||||
|
hop_length=160,
|
||||||
|
n_mels=80,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
power=2.0,
|
||||||
|
)
|
||||||
|
self.amplitude_to_db = AmplitudeToDB()
|
||||||
|
|
||||||
|
def _load_audio(self, audio_path: Path) -> Tensor:
|
||||||
|
waveform, sample_rate = torchaudio.load_with_torchcodec(audio_path)
|
||||||
|
|
||||||
|
if sample_rate != self.sample_rate:
|
||||||
|
waveform = Resample(sample_rate, self.sample_rate)(waveform)
|
||||||
|
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform.mean(dim=0, keepdim=True)
|
||||||
|
|
||||||
|
waveform = waveform / (waveform.abs().max() + 1e-8)
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _extract_features(self, waveform: Tensor) -> Tensor:
|
||||||
|
mel_spec: Tensor = self.mel_transform(waveform)
|
||||||
|
log_mel_spec: Tensor = self.amplitude_to_db(mel_spec)
|
||||||
|
return log_mel_spec.squeeze(0) # [n_mels, time]
|
||||||
|
|
||||||
|
def _augment_waveform(self, waveform: Tensor) -> Tensor:
|
||||||
|
if not self.augment or random.random() > self.augment_prob:
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
# 1. voice Stretch/Compress
|
||||||
|
if random.random() < 0.5:
|
||||||
|
waveform = self._voice_stretch_or_compress(waveform=waveform)
|
||||||
|
|
||||||
|
if random.random() < 0.3:
|
||||||
|
waveform = self._drop_frames(waveform)
|
||||||
|
|
||||||
|
if random.random() < 0.4:
|
||||||
|
waveform = self._add_noise(waveform)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _voice_stretch_or_compress(self, waveform: Tensor) -> Tensor:
|
||||||
|
speed_factor = random.uniform(0.6, 1.4) # (Speed Change: 0.6x - 1.4x)
|
||||||
|
spec = torch.stft(
|
||||||
|
waveform.squeeze(0),
|
||||||
|
n_fft=400,
|
||||||
|
hop_length=160,
|
||||||
|
window=torch.hann_window(400).to(waveform.device),
|
||||||
|
return_complex=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 时间拉伸(不改变音高)
|
||||||
|
stretch = TimeStretch(
|
||||||
|
hop_length=160,
|
||||||
|
n_freq=201,
|
||||||
|
fixed_rate=speed_factor
|
||||||
|
)
|
||||||
|
stretched_spec = stretch(spec)
|
||||||
|
|
||||||
|
# 转回波形
|
||||||
|
waveform_stretched = torch.istft(
|
||||||
|
stretched_spec,
|
||||||
|
n_fft=400,
|
||||||
|
hop_length=160,
|
||||||
|
window=torch.hann_window(400).to(waveform.device)
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
return waveform_stretched
|
||||||
|
|
||||||
|
def _drop_frames(self, waveform: Tensor) -> Tensor:
|
||||||
|
audio_len = waveform.shape[1]
|
||||||
|
drop_ratio = random.uniform(0.05, 0.15)
|
||||||
|
drop_len = int(audio_len * drop_ratio)
|
||||||
|
|
||||||
|
if audio_len > drop_len:
|
||||||
|
start_pos = random.randint(0, audio_len - drop_len)
|
||||||
|
# clean
|
||||||
|
waveform = torch.cat([waveform[:, :start_pos], waveform[:, start_pos + drop_len:]], dim=1)
|
||||||
|
|
||||||
|
return waveform
|
||||||
|
|
||||||
|
def _add_noise(self, waveform: Tensor) -> Tensor:
|
||||||
|
snr_db = random.uniform(10, 20)
|
||||||
|
signal_power = torch.mean(waveform ** 2)
|
||||||
|
|
||||||
|
snr_linear = 10 ** (snr_db / 10)
|
||||||
|
noise_power = signal_power / snr_linear
|
||||||
|
|
||||||
|
noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
|
||||||
|
return waveform + noise
|
||||||
|
|
||||||
|
def transcribe(self, audio_path: Path) -> str:
|
||||||
|
waveform = self._load_audio(audio_path=audio_path)
|
||||||
|
waveform = self._augment_waveform(waveform=waveform)
|
||||||
|
mel_spec = self._extract_features(waveform=waveform)
|
||||||
|
|
||||||
|
mel_spec = mel_spec.unsqueeze(0).to(self.device) # [1, n_mels, time]
|
||||||
|
mel_length = torch.tensor([mel_spec.shape[2]], dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
log_probs, _ = self.model(mel_specs=mel_spec, mel_lengths=mel_length) # [1, T, vocab]
|
||||||
|
|
||||||
|
text = self.tokenizer.ctc_greedy_decode(log_probs=log_probs[0])
|
||||||
|
return text
|
||||||
|
|
||||||
|
def transcribe_batch(self, audio_paths: list[Path]) -> list[str]:
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for audio_path in audio_paths:
|
||||||
|
text = self.transcribe(audio_path=audio_path)
|
||||||
|
results.append(text)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
|
||||||
|
# checkpoint = sorted(workspace_dir.glob('.checkpoints/checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[1]))[-1]
|
||||||
|
checkpoint = workspace_dir / ".checkpoints/checkpoint_step_9500.pt"
|
||||||
|
print(f"Load Checkpoint: {checkpoint}")
|
||||||
|
|
||||||
|
inference = ASRInference(model_path=checkpoint, vocab_path=workspace_dir / 'config/uig_vocab.json' , device=device)
|
||||||
|
|
||||||
|
audio_path = workspace_dir / 'data/test/F001_001.wav'
|
||||||
|
print(f"\n转录音频: {audio_path}")
|
||||||
|
text = inference.transcribe(audio_path=audio_path)
|
||||||
|
print(f"\n识别结果: {text}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
# workspace_dir = Path(__file__).parent.parent
|
||||||
|
# checkpoint = sorted(workspace_dir.glob('.checkpoints/checkpoint_step_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))[-1]
|
||||||
|
|
||||||
|
# print(checkpoint)
|
||||||
39
src/model.py
Normal file
39
src/model.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from torch.nn import Conv2d, Module, ReLU, Linear, functional
|
||||||
|
from torch import Tensor
|
||||||
|
from torchaudio.models import Conformer
|
||||||
|
|
||||||
|
|
||||||
|
class ASRModel(Module):
|
||||||
|
def __init__(self, vocab_size: int, input_dim: int = 640, num_heads: int = 8, ffn_dim: int = 2048, num_layers: int = 6, dropout: float = 0.1) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||||
|
self.conv2 = Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||||
|
self.relu = ReLU()
|
||||||
|
|
||||||
|
self.encoder = Conformer(input_dim=input_dim, num_heads=num_heads, ffn_dim=ffn_dim, num_layers=num_layers, depthwise_conv_kernel_size=31, dropout=dropout)
|
||||||
|
|
||||||
|
self.ctc_head = Linear(in_features=input_dim, out_features=vocab_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, mel_specs: Tensor, mel_lengths: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
# mel_specs: [B, n_mels, time]
|
||||||
|
x: Tensor = mel_specs.unsqueeze(1) # [B, 1, n_mels, time]
|
||||||
|
|
||||||
|
x = self.relu(self.conv1(x)) # [batch, 16, n_mels/2, time/2]
|
||||||
|
x = self.relu(self.conv2(x)) # [batch, 32, n_mels/4, time/4]
|
||||||
|
|
||||||
|
# [B, channels, freq, time] → [B, time, channels*freq]
|
||||||
|
batch, channels, freq, time = x.shape
|
||||||
|
x = x.permute(0, 3, 1, 2).reshape(batch, time, channels * freq)
|
||||||
|
# lengths = torch.tensor([time] * batch, dtype=torch.long, device=x.device)
|
||||||
|
lengths = ((mel_lengths + 1) // 2 + 1) // 2 # 两层 stride=2
|
||||||
|
|
||||||
|
x, lengths = self.encoder(x, lengths)
|
||||||
|
|
||||||
|
logits = self.ctc_head(x) # [B, time, vocab_size]
|
||||||
|
log_probs = functional.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
return log_probs, lengths
|
||||||
|
|
||||||
|
def get_num_params(self) -> int:
|
||||||
|
return sum(p.numel() for p in self.parameters())
|
||||||
107
src/tokenizer.py
Normal file
107
src/tokenizer.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from handle.text_handle import process_text
|
||||||
|
from handle.text_normalizer import UYGHUR_LETTERS
|
||||||
|
|
||||||
|
class ASRTokenizer:
|
||||||
|
def __init__(self, vocab_path: Path) -> None:
|
||||||
|
self.vocabs: list[str] = []
|
||||||
|
self.vocabs.extend([
|
||||||
|
"<BLANK>",
|
||||||
|
"<UNK>",
|
||||||
|
"<SPACE>"
|
||||||
|
])
|
||||||
|
for i in range(len(self.vocabs), 64):
|
||||||
|
self.vocabs.append(f"REVERSED_{i}")
|
||||||
|
|
||||||
|
with open(vocab_path, 'r', encoding='utf-8') as file:
|
||||||
|
data: list[str] = json.load(file)
|
||||||
|
|
||||||
|
for char in data:
|
||||||
|
if char not in [' ', '\t', '\n'] and char not in self.vocabs:
|
||||||
|
self.vocabs.append(char)
|
||||||
|
|
||||||
|
self.vocab_to_id: dict[str, int] = {vocab:i for i, vocab in enumerate(self.vocabs)}
|
||||||
|
self.id_to_vocab: dict[int, str] = {i:vocab for i, vocab in enumerate(self.vocabs)}
|
||||||
|
|
||||||
|
self._max_token_len = max((len(v) for v in self.vocabs if not v.startswith("<")), default=0)
|
||||||
|
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.vocabs)
|
||||||
|
|
||||||
|
def _split_long_syllable(self, syll: str) -> list[str]:
|
||||||
|
"""将不在词表中的长音节拆分为词表内的子词序列(最大正向匹配)"""
|
||||||
|
pieces = []
|
||||||
|
start = 0
|
||||||
|
n = len(syll)
|
||||||
|
while start < n:
|
||||||
|
# 从最大可能长度开始尝试匹配
|
||||||
|
for length in range(min(self._max_token_len, n - start), 0, -1):
|
||||||
|
sub = syll[start:start + length]
|
||||||
|
if sub in self.vocab_to_id:
|
||||||
|
pieces.append(sub)
|
||||||
|
start += length
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 理论上不会发生,因为单字母一定在词表
|
||||||
|
pieces.append("<UNK>")
|
||||||
|
start += 1
|
||||||
|
return pieces
|
||||||
|
|
||||||
|
def encode(self, text: str) -> list[int]:
|
||||||
|
tokens = process_text(text=text)
|
||||||
|
result: list[int] = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if token == ' ':
|
||||||
|
result.append(self.vocab_to_id['<SPACE>'])
|
||||||
|
else:
|
||||||
|
token_id = self.vocab_to_id.get(token)
|
||||||
|
if token_id is not None:
|
||||||
|
result.append(token_id)
|
||||||
|
elif token in UYGHUR_LETTERS:
|
||||||
|
result.append(self.vocab_to_id.get(token, self.vocab_to_id['<UNK>']))
|
||||||
|
else:
|
||||||
|
sub_pieces = self._split_long_syllable(token)
|
||||||
|
for piece in sub_pieces:
|
||||||
|
result.append(self.vocab_to_id.get(piece, self.vocab_to_id['<UNK>']))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def decode(self, ids: list[int] | Tensor, remove_blank: bool = True, remove_repeat: bool = True) -> str:
|
||||||
|
if isinstance(ids, Tensor):
|
||||||
|
ids = ids.tolist()
|
||||||
|
|
||||||
|
result = []
|
||||||
|
prev_id = None
|
||||||
|
|
||||||
|
for token_id in ids:
|
||||||
|
# 跳过blank token
|
||||||
|
if remove_blank and token_id == self.vocab_to_id['<BLANK>']:
|
||||||
|
prev_id = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
# CTC解码:跳过连续重复的字符
|
||||||
|
if remove_repeat and token_id == prev_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
char = self.id_to_vocab.get(token_id, "<UNK>")
|
||||||
|
|
||||||
|
if char == '<SPACE>':
|
||||||
|
char = ' '
|
||||||
|
|
||||||
|
result.append(char)
|
||||||
|
prev_id = token_id
|
||||||
|
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
def get_special_token_id(self, token: str) -> int:
|
||||||
|
return self.vocab_to_id.get(token, self.vocab_to_id['<UNK>'])
|
||||||
|
|
||||||
|
def ctc_greedy_decode(self, log_probs: Tensor) -> str:
|
||||||
|
""" CTC贪心解码 (log_probs: [seq_len, vocab_size] 的log概率) """
|
||||||
|
return self.decode(torch.argmax(log_probs, dim=-1))
|
||||||
357
src/train.py
Normal file
357
src/train.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
import torch
|
||||||
|
from torch import Tensor, device, no_grad, cuda
|
||||||
|
from torch.optim import AdamW, Optimizer
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from torch.nn import CTCLoss, utils
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from datetime import datetime
|
||||||
|
from torch.amp.grad_scaler import GradScaler
|
||||||
|
from tokenizer import ASRTokenizer
|
||||||
|
from dataset import Batch, create_dataloader
|
||||||
|
from model import ASRModel
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 全局配置 ============
|
||||||
|
CONFIG = {
|
||||||
|
# 数据配置
|
||||||
|
'batch_size': 96,
|
||||||
|
|
||||||
|
# 训练配置
|
||||||
|
'num_epochs': 50,
|
||||||
|
'learning_rate': 1e-4,
|
||||||
|
'weight_decay': 1e-4,
|
||||||
|
'grad_clip_norm': 1.0,
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
'input_dim': 640,
|
||||||
|
'num_heads': 8,
|
||||||
|
'ffn_dim': 2048,
|
||||||
|
'num_layers': 8,
|
||||||
|
'dropout': 0.1,
|
||||||
|
|
||||||
|
# 保存和评估
|
||||||
|
'early_stopping_patience': 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cer(pred: str, target: str) -> float:
|
||||||
|
""" 计算字符错误率 (Character Error Rate) 使用编辑距离算法"""
|
||||||
|
if len(target) == 0:
|
||||||
|
return 0.0 if len(pred) == 0 else 1.0
|
||||||
|
|
||||||
|
d = [[0] * (len(target) + 1) for _ in range(len(pred) + 1)]
|
||||||
|
|
||||||
|
for i in range(len(pred) + 1):
|
||||||
|
d[i][0] = i
|
||||||
|
for j in range(len(target) + 1):
|
||||||
|
d[0][j] = j
|
||||||
|
|
||||||
|
for i in range(1, len(pred) + 1):
|
||||||
|
for j in range(1, len(target) + 1):
|
||||||
|
if pred[i-1] == target[j-1]:
|
||||||
|
d[i][j] = d[i-1][j-1]
|
||||||
|
else:
|
||||||
|
d[i][j] = min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + 1
|
||||||
|
|
||||||
|
return d[len(pred)][len(target)] / len(target)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_wer(pred: str, target: str) -> float:
|
||||||
|
""" 计算词错误率 (Word Error Rate) """
|
||||||
|
pred_words = pred.split()
|
||||||
|
target_words = target.split()
|
||||||
|
|
||||||
|
if len(target_words) == 0:
|
||||||
|
return 0.0 if len(pred_words) == 0 else 1.0
|
||||||
|
|
||||||
|
d = [[0] * (len(target_words) + 1) for _ in range(len(pred_words) + 1)]
|
||||||
|
|
||||||
|
for i in range(len(pred_words) + 1):
|
||||||
|
d[i][0] = i
|
||||||
|
for j in range(len(target_words) + 1):
|
||||||
|
d[0][j] = j
|
||||||
|
|
||||||
|
for i in range(1, len(pred_words) + 1):
|
||||||
|
for j in range(1, len(target_words) + 1):
|
||||||
|
if pred_words[i-1] == target_words[j-1]:
|
||||||
|
d[i][j] = d[i-1][j-1]
|
||||||
|
else:
|
||||||
|
d[i][j] = min(d[i-1][j], d[i][j-1], d[i-1][j-1]) + 1
|
||||||
|
|
||||||
|
return d[len(pred_words)][len(target_words)] / len(target_words)
|
||||||
|
|
||||||
|
def train_one_epoch(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, optimizer: Optimizer, scheduler: OneCycleLR, device: device, epoch: int, writer: SummaryWriter, global_step: int, scaler: GradScaler) -> tuple[float, int]:
|
||||||
|
model.train()
|
||||||
|
num_batches = 0
|
||||||
|
total_train_loss = 0
|
||||||
|
|
||||||
|
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")
|
||||||
|
for batch_index, batch in enumerate(progress_bar):
|
||||||
|
batch: Batch
|
||||||
|
mel_specs = batch['mel_specs'].to(device)
|
||||||
|
targets = batch['targets'].to(device)
|
||||||
|
mel_lengths = batch['mel_lengths'].to(device)
|
||||||
|
target_lengths = batch['target_lengths'].to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||||
|
log_probs, lengths = model(mel_specs=mel_specs, mel_lengths=mel_lengths)
|
||||||
|
log_probs: Tensor
|
||||||
|
log_probs_ctc = log_probs.permute(1, 0, 2)
|
||||||
|
loss: Tensor = criterion(log_probs=log_probs_ctc, targets=targets, input_lengths=lengths, target_lengths=target_lengths)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
utils.clip_grad_norm_(model.parameters(), max_norm=CONFIG['grad_clip_norm'])
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
total_train_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
writer.add_scalar('Train/Loss', loss.item(), global_step)
|
||||||
|
writer.add_scalar('Train/LearningRate', optimizer.param_groups[0]['lr'], global_step)
|
||||||
|
progress_bar.set_postfix({'epoch': f"{epoch}/{CONFIG['num_epochs']}",'loss': f'{loss.item():.4f}', 'step': global_step})
|
||||||
|
|
||||||
|
train_avg_loss = total_train_loss / num_batches
|
||||||
|
return train_avg_loss, global_step
|
||||||
|
|
||||||
|
def validate(model: ASRModel, dataloader: DataLoader, criterion: CTCLoss, device: device, tokenizer: ASRTokenizer, writer: SummaryWriter, global_step: int) -> tuple[float, float, float]:
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
total_cer = 0
|
||||||
|
total_wer = 0
|
||||||
|
num_samples = 0
|
||||||
|
num_batches = 0
|
||||||
|
examples = []
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
progress_bar = tqdm(dataloader, desc="Validate", leave=False)
|
||||||
|
for batch in progress_bar:
|
||||||
|
batch: Batch
|
||||||
|
mel_specs = batch['mel_specs'].to(device)
|
||||||
|
targets = batch['targets'].to(device)
|
||||||
|
mel_lengths = batch['mel_lengths'].to(device)
|
||||||
|
target_lengths = batch['target_lengths'].to(device)
|
||||||
|
|
||||||
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||||
|
log_probs, lengths = model(mel_specs=mel_specs, mel_lengths=mel_lengths)
|
||||||
|
log_probs: Tensor
|
||||||
|
log_probs_ctc = log_probs.permute(1, 0, 2)
|
||||||
|
loss: Tensor = criterion(log_probs=log_probs_ctc, targets=targets, input_lengths=lengths, target_lengths=target_lengths)
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
for i in range(log_probs.shape[0]):
|
||||||
|
pred_text = tokenizer.ctc_greedy_decode(log_probs=log_probs[i])
|
||||||
|
true_text = batch['target_texts'][i]
|
||||||
|
cer = calculate_cer(pred=pred_text, target=true_text)
|
||||||
|
wer = calculate_wer(pred=pred_text, target=true_text)
|
||||||
|
total_cer += cer
|
||||||
|
total_wer += wer
|
||||||
|
num_samples += 1
|
||||||
|
|
||||||
|
if len(examples) < 3:
|
||||||
|
examples.append((true_text, pred_text, cer, wer))
|
||||||
|
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches
|
||||||
|
avg_cer = total_cer / num_samples
|
||||||
|
avg_wer = total_wer / num_samples
|
||||||
|
|
||||||
|
writer.add_scalar('Val/Loss', avg_loss, global_step)
|
||||||
|
writer.add_scalar('Val/CER', avg_cer, global_step)
|
||||||
|
writer.add_scalar('Val/WER', avg_wer, global_step)
|
||||||
|
|
||||||
|
for index, (true_text, pred_text, cer, wer) in enumerate(examples):
|
||||||
|
writer.add_text(f'Val/Example_{index}', f'True: {true_text}\nPred: {pred_text}\nCER: {cer:.4f} | WER: {wer:.4f}', global_step)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
return avg_loss, avg_cer, avg_wer
|
||||||
|
|
||||||
|
def save_checkpoint(model: ASRModel, optimizer: Optimizer, scheduler: OneCycleLR, global_step: int, epoch: int, train_loss: float, val_loss: float, cer: float, wer: float, save_path: Path):
|
||||||
|
checkpoint = {
|
||||||
|
'model_state_dict': model.state_dict(),
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': scheduler.state_dict(),
|
||||||
|
'global_step': global_step,
|
||||||
|
'train_loss': train_loss,
|
||||||
|
'val_loss': val_loss,
|
||||||
|
'epoch': epoch,
|
||||||
|
'cer': cer,
|
||||||
|
'wer': wer,
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, save_path)
|
||||||
|
|
||||||
|
def find_latest_checkpoint(checkpoint_dir: Path) -> Path | None:
|
||||||
|
checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))
|
||||||
|
return checkpoints[-1] if checkpoints else None
|
||||||
|
|
||||||
|
def load_checkpoint(file_path: Path, model: ASRModel, optimizer: AdamW, scheduler: OneCycleLR):
|
||||||
|
checkpoint = torch.load(file_path, weights_only=False)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
epoch = checkpoint['epoch']
|
||||||
|
global_step = checkpoint['global_step']
|
||||||
|
best_cer = checkpoint['cer']
|
||||||
|
best_wer = checkpoint['wer']
|
||||||
|
return epoch, global_step, best_cer, best_wer
|
||||||
|
|
||||||
|
def main():
|
||||||
|
workspace_dir = Path(__file__).parent.parent
|
||||||
|
device = torch.device('cuda:0')
|
||||||
|
tokenizer = ASRTokenizer(workspace_dir / 'config/asr_vocab.json')
|
||||||
|
|
||||||
|
# ============ 创建数据加载器 ============
|
||||||
|
train_loader = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / '.data/train.tsv',
|
||||||
|
audio_dir=workspace_dir / '.data/clips',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=CONFIG['batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = create_dataloader(
|
||||||
|
tsv_path=workspace_dir / '.data/dev.tsv',
|
||||||
|
audio_dir=workspace_dir / '.data/clips',
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
batch_size=CONFIG['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
augment=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ 初始化模型 ============
|
||||||
|
model = ASRModel(
|
||||||
|
vocab_size=tokenizer.vocab_size(),
|
||||||
|
input_dim=CONFIG['input_dim'],
|
||||||
|
num_heads=CONFIG['num_heads'],
|
||||||
|
ffn_dim=CONFIG['ffn_dim'],
|
||||||
|
num_layers=CONFIG['num_layers'],
|
||||||
|
dropout=CONFIG['dropout'],
|
||||||
|
).to(device)
|
||||||
|
print(f"🤖 模型参数量: {model.get_num_params() / 1e6:.2f}M")
|
||||||
|
|
||||||
|
|
||||||
|
criterion = CTCLoss(blank=tokenizer.get_special_token_id('<BLANK>'), zero_infinity=True)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'], foreach=True)
|
||||||
|
|
||||||
|
# 使用 OneCycleLR:自动处理 warmup 和衰减
|
||||||
|
scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=CONFIG['learning_rate'],
|
||||||
|
epochs=CONFIG['num_epochs'],
|
||||||
|
steps_per_epoch=len(train_loader),
|
||||||
|
pct_start=0.1, # 前 10% 步数用于 warmup
|
||||||
|
anneal_strategy='cos',
|
||||||
|
div_factor=25.0, # 初始 lr = max_lr / 25
|
||||||
|
final_div_factor=1e4, # 最终 lr = max_lr / 10000
|
||||||
|
)
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
# ============ TensorBoard ============
|
||||||
|
log_dir = workspace_dir / 'runs' / datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
writer = SummaryWriter(log_dir)
|
||||||
|
|
||||||
|
config_text = '\n'.join([f'{k}: {v}' for k, v in CONFIG.items()])
|
||||||
|
writer.add_text('Config', config_text, 0)
|
||||||
|
print(f"📊 TensorBoard 日志: {log_dir}")
|
||||||
|
|
||||||
|
checkpoint_dir = workspace_dir / '.checkpoints'
|
||||||
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# ============ 训练循环 ============
|
||||||
|
best_cer = float('inf')
|
||||||
|
best_wer = float('inf')
|
||||||
|
patience_counter = 0
|
||||||
|
global_step = 0
|
||||||
|
start_epoch = 0
|
||||||
|
|
||||||
|
latest = find_latest_checkpoint(checkpoint_dir)
|
||||||
|
if latest:
|
||||||
|
start_epoch, global_step, best_cer, best_wer = load_checkpoint(latest, model, optimizer, scheduler)
|
||||||
|
start_epoch += 1 # 从下一个 epoch 开始
|
||||||
|
print(f"✅ 恢复训练: 从 epoch={start_epoch} 开始, global_step={global_step}")
|
||||||
|
else:
|
||||||
|
start_epoch = 0
|
||||||
|
print("🆕 从头开始训练\n")
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, CONFIG['num_epochs']):
|
||||||
|
train_loss, global_step = train_one_epoch(
|
||||||
|
model=model,
|
||||||
|
dataloader=train_loader,
|
||||||
|
criterion=criterion,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
epoch=epoch,
|
||||||
|
writer=writer,
|
||||||
|
global_step=global_step,
|
||||||
|
scaler=scaler,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loss, val_cer, val_wer = validate(model=model, dataloader=val_loader, criterion=criterion, device=device, tokenizer=tokenizer, writer=writer, global_step=global_step)
|
||||||
|
print(f"\n📊 Step {global_step} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f} | Val WER: {val_wer:.4f}")
|
||||||
|
|
||||||
|
# 保存常规 checkpoint
|
||||||
|
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
|
||||||
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=checkpoint_path)
|
||||||
|
|
||||||
|
# 分别保存最佳 CER 和 WER 模型
|
||||||
|
improved = False
|
||||||
|
if val_cer < best_cer:
|
||||||
|
best_cer = val_cer
|
||||||
|
best_cer_path = checkpoint_dir / 'best_cer_model.pt'
|
||||||
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_cer_path)
|
||||||
|
improved = True
|
||||||
|
|
||||||
|
if val_wer < best_wer:
|
||||||
|
best_wer = val_wer
|
||||||
|
best_wer_path = checkpoint_dir / 'best_wer_model.pt'
|
||||||
|
save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, global_step=global_step, train_loss=train_loss, val_loss=val_loss, cer=val_cer, wer=val_wer, save_path=best_wer_path)
|
||||||
|
improved = True
|
||||||
|
|
||||||
|
# Early Stopping 逻辑
|
||||||
|
if improved:
|
||||||
|
patience_counter = 0
|
||||||
|
else:
|
||||||
|
patience_counter += 1
|
||||||
|
print(f"⚠️ 验证指标未改善,patience: {patience_counter}/{CONFIG['early_stopping_patience']}")
|
||||||
|
|
||||||
|
# 删除旧的 checkpoint(保留最近3个)
|
||||||
|
old_checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pt'), key=lambda p: int(p.stem.split('_')[-1]))
|
||||||
|
for old in old_checkpoints[:-3]:
|
||||||
|
old.unlink()
|
||||||
|
|
||||||
|
cuda.empty_cache()
|
||||||
|
|
||||||
|
writer.add_scalar('Val/EpochLoss', val_loss, epoch)
|
||||||
|
writer.add_scalar('Train/EpochLoss', train_loss, epoch)
|
||||||
|
print(f"✅ Epoch {epoch} 完成 | Train Avg Loss: {train_loss:.4f} | Val Avg Loss: {val_loss:.4f} | Best CER: {best_cer:.4f} | Best WER: {best_wer:.4f}\n")
|
||||||
|
|
||||||
|
# Early Stopping 检查
|
||||||
|
if patience_counter >= CONFIG['early_stopping_patience']:
|
||||||
|
print(f"\n🛑 Early Stopping: 验证指标连续 {CONFIG['early_stopping_patience']} 次没有改善")
|
||||||
|
print(f"🏆 最佳 CER: {best_cer:.4f}")
|
||||||
|
print(f"🏆 最佳 WER: {best_wer:.4f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"🎉 训练完成!")
|
||||||
|
print(f"🏆 最佳 CER: {best_cer:.4f}")
|
||||||
|
print(f"🏆 最佳 WER: {best_wer:.4f}")
|
||||||
|
print(f"📊 TensorBoard: tensorboard --logdir={workspace_dir / 'runs'}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user