feat: change waveform
This commit is contained in:
13
src/handle/export_model_state_dict.py
Normal file
13
src/handle/export_model_state_dict.py
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
device = "cuda:0"
|
||||
workspace_dir = Path(__file__).parent.parent.parent
|
||||
|
||||
input_checkpoint = workspace_dir.joinpath('.checkpoints/best_wer_model.pt')
|
||||
output_checkpoint = workspace_dir.joinpath('.checkpoints/prodect_best_wer_model.pt')
|
||||
|
||||
checkpoint = torch.load(input_checkpoint, map_location=device)
|
||||
torch.save(checkpoint['model_state_dict'], output_checkpoint)
|
||||
Reference in New Issue
Block a user