from pathlib import Path import torch device = "cuda:0" workspace_dir = Path(__file__).parent.parent.parent input_checkpoint = workspace_dir.joinpath('.checkpoints/best_wer_model.pt') output_checkpoint = workspace_dir.joinpath('.checkpoints/prodect_best_wer_model.pt') checkpoint = torch.load(input_checkpoint, map_location=device) torch.save(checkpoint['model_state_dict'], output_checkpoint)