13 lines
398 B
Python
13 lines
398 B
Python
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
device = "cuda:0"
|
|
workspace_dir = Path(__file__).parent.parent.parent
|
|
|
|
input_checkpoint = workspace_dir.joinpath('.checkpoints/best_wer_model.pt')
|
|
output_checkpoint = workspace_dir.joinpath('.checkpoints/prodect_best_wer_model.pt')
|
|
|
|
checkpoint = torch.load(input_checkpoint, map_location=device)
|
|
torch.save(checkpoint['model_state_dict'], output_checkpoint) |