(Comments)
This tutorial shows a quick recipe to turn a PyTorch checkpoint file trained in Python 2.X into Python 3.x compatible format. It resolves error message similar to this when you try to torch.load()
UnicodeDecodeError: 'ascii' codec can't decode byte 0x8c in position 16: ordinal not in range(128)
Load and save state_dict
In the following example, we use the Kaggle Data science bowl 2017 winner model for demonstration purpose, which can be found at https://github.com/lfz/DSB2017/tree/master/model.
# Download the checkpoint trained in Python 2.X
!wget https://github.com/lfz/DSB2017/blob/master/model/classifier.ckpt?raw=true -O classifier.ckpt
!ls
# Install PyTorch
!pip install torch torchvision
import torch
# Load the checkpoint.
filename = 'classifier.ckpt'
checkpoint = torch.load(filename)
# Only save the `state_dict` object.
torch.save(checkpoint['state_dict'], 'classifier_state_dict.ckpt')
# Dowanload the file if you are using GoogleColab
from google.colab import files
files.download('classifier_state_dict.ckpt')
load_state_dict
casenet
torch.nn.
Module
import torch
import net_classifier as casemodel
# Build the model.
casenet = casemodel.CaseNet(topk=5)
# load_state_dict
state_dict = torch.load('./model/classifier_state_dict.ckpt')
casenet.load_state_dict(state_dict)
Optionally, you can convert the entire checkpoint file to be Python 3.X compatible.
1. Load and pickle the checkpoint file from Python 2.X to binary format.
2. Load the pickled checkpoint in Python 3.X
3. Iteratively decode and convert all binary dictionary keys.
Here is a complete example to show how it is done.
"""
Do this from Python 2.X
"""
import torch
filename = 'classifier.ckpt'
checkpoint = torch.load(filename)
# Pickle the checkpoint file as binary format in Python 2.X
import pickle
with open("classifier_py2.pkl", "wb") as outfile:
pickle.dump(checkpoint, outfile)
"""
Do the following from Python 3.X
"""
# Load the pickle file in Python 3.X
with open('classifier_py2.pkl', 'rb') as f:
data_dict = pickle.load(f, encoding='bytes')
# View the keys, this prints bytes.
print(data_dict.keys())
# Turn OrderedDict to normal dict.
data_dict = dict(data_dict)
# Convert the first level keys.
for key in data_dict:
if type(key) is bytes:
data_dict[key.decode()] = data_dict[key]
data_dict.pop(key)
# This should print strs.
print(data_dict.keys())
# Convert the second level 'state_dict' keys.
data_dict['state_dict'] = dict(data_dict['state_dict'])
for key in data_dict['state_dict']:
if type(key) is bytes:
data_dict['state_dict'][key.decode()] = data_dict['state_dict'][key]
data_dict['state_dict'].pop(key)
torch.save(data_dict, "./model/chcekpoint_py3.ckpt")
# It should have no issue.
checkpoint = torch.load("./model/chcekpoint_py3.ckpt")
Here is the Python 2.x notebook on Google Colab for your convenience.
Share on Twitter Share on Facebook
Comments