How to load Python 2 PyTorch checkpoint in Python 3

(Comments)

checkpoint

This tutorial shows a quick recipe to turn a PyTorch checkpoint file trained in Python 2.X into Python 3.x compatible format. It resolves error message similar to this when you try to call torch.load().

UnicodeDecodeError: 'ascii' codec can't decode byte 0x8c in position 16: ordinal not in range(128)

Step 1

Load and save the state_dict in Python 2.X.

In the following example, we use the Kaggle Data science bowl 2017 winner model for demonstration purpose, which can be found at https://github.com/lfz/DSB2017/tree/master/model.

# Download the checkpoint trained in Python 2.X
!wget https://github.com/lfz/DSB2017/blob/master/model/classifier.ckpt?raw=true -O classifier.ckpt
!ls
# Install PyTorch
!pip install torch torchvision

import torch
# Load the checkpoint.
filename = 'classifier.ckpt'
checkpoint = torch.load(filename)
# Only save the `state_dict` object.
torch.save(checkpoint['state_dict'], 'classifier_state_dict.ckpt')

# Dowanload the file if you are using GoogleColab
from google.colab import files
files.download('classifier_state_dict.ckpt')

Step 2

"load_state_dict" in Python 3.X.

"casenet" is a subclass instance of the torch.nn.Module.

import torch
import net_classifier as casemodel

# Build the model.
casenet = casemodel.CaseNet(topk=5)
# load_state_dict
state_dict = torch.load('./model/classifier_state_dict.ckpt')

casenet.load_state_dict(state_dict)

Optional step 3

Optionally, you can convert the entire checkpoint file to be Python 3.X compatible.

1. Load and pickle the checkpoint file from Python 2.X to binary format.

2. Load the pickled checkpoint in Python 3.X

3. Iteratively decode and convert all binary dictionary keys.

Here is a complete example to show how it is done.

"""
Do this from Python 2.X

"""

import torch
filename = 'classifier.ckpt'
checkpoint = torch.load(filename)

# Pickle the checkpoint file as binary format in Python 2.X
import pickle
with open("classifier_py2.pkl", "wb") as outfile:
    pickle.dump(checkpoint, outfile)

"""
Do the following from Python 3.X

"""

# Load the pickle file in Python 3.X
with open('classifier_py2.pkl', 'rb') as f:
    data_dict = pickle.load(f, encoding='bytes')

# View the keys, this prints bytes.
print(data_dict.keys())

# Turn OrderedDict to normal dict.
data_dict = dict(data_dict)

# Convert the first level keys.
for key in data_dict:
    if type(key) is bytes:
        data_dict[key.decode()] = data_dict[key]
        data_dict.pop(key)

# This should print strs.
print(data_dict.keys())

# Convert the second level 'state_dict' keys.
data_dict['state_dict'] = dict(data_dict['state_dict'])
for key in data_dict['state_dict']:
    if type(key) is bytes:
        data_dict['state_dict'][key.decode()] = data_dict['state_dict'][key]
        data_dict['state_dict'].pop(key)

torch.save(data_dict, "./model/chcekpoint_py3.ckpt")

# It should have no issue.
checkpoint = torch.load("./model/chcekpoint_py3.ckpt")

Here is the Python 2.x notebook on Google Colab for your convenience.

Currently unrated

Comments