Skip to content

Commit

Permalink
UPD: refactor conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
vqdang committed Dec 18, 2020
1 parent 12ed12b commit baef4d1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 727 deletions.
74 changes: 15 additions & 59 deletions convert_chkpt_tf2pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,25 @@
import numpy as np
import torch

model_mode = "Pytorch-Fast" # or 'Pytorch-Fast' vs 'Pytorch-Original'
mapping = pd.read_csv("tf_to_pytorch_variable_mapping.csv", sep="\t", index_col=False)
# mapping = {v["Tensorflow"]: v[model_mode] for k, v in mapping.T.to_dict().items()}
# mapping

# tf_path = "../pretrained/ImageNet-ResNet50-Preact.npz"
# pt_path = "../pretrained/ImageNet-ResNet50-Preact_pytorch.tar"

# tf_path = '../pretrained/hover_seg_Kumar.npz'
# pt_path = "../pretrained/hovernet_original_kumar_pytorch.tar"

# tf_path = '../pretrained/hover_seg_CoNSeP.npz'
# pt_path = "../pretrained/hovernet_original_consep_pytorch.tar"

# tf_path = '../pretrained/hover_seg_CPM17.npz'
# pt_path = "../pretrained/hovernet_original_cpm17_pytorch.tar"

# tf_path = "../pretrained/hover_seg_&_class_CoNSeP.npz"
# pt_path = "../pretrained/hovernet_original_consep_type-pytorch.tar"

# tf_path = "../pretrained/pecan-hover-net.npz"
# pt_path = "../pretrained/hovernet_fast_pannuke_pytorch.tar"

# tf_path = "../pretrained/hovernet_fast_monusac_tf.npz"
# pt_path = "../pretrained/hovernet_fast_monusac_pytorch.tar"
mapping = pd.read_csv("variables_tf2pytorch.csv", sep="\t", index_col=False)
mapping = {v["Tensorflow"]: v['Pytorch'] for k, v in mapping.T.to_dict().items()}

# pt_path = 'dumped_pytorch_chkpt.tar'

# pt = {}
# tf = np.load(tf_path)

# for tf_k, tf_v in tf.items():
# if 'linear' in tf_k: continue # should only be for pretrained model
# pt_k = mapping[tf_k]
# if "conv" in pt_k and "bn" not in pt_k and "bias" not in pt_k:
# tf_v = np.transpose(tf_v, [3, 2, 0, 1])
# if "shortcut" in pt_k:
# tf_v = np.transpose(tf_v, [3, 2, 0, 1])
# pt[pt_k] = torch.from_numpy(tf_v)
# # make compatible with repo structure
# pt["upsample2x.unpool_mat"] = torch.from_numpy(np.ones((2, 2), dtype="float32"))
# pt = {"desc": pt}
# torch.save(pt, pt_path)

# pt_path = "../pretrained/consep_net_epoch=40.npz"

tf_path = "../pretrained/consep_net_epoch=40.tar"
pt_path = "../pretrained/hovernet_original_consep_native.tar"

# tf_path = "../pretrained/kumar_net_epoch=44.tar"
# pt_path = "../pretrained/hovernet_original_kumar_native.tar"

mapping = {v["Pytorch-Original"]: v['Pytorch-Fast'] for k, v in mapping.T.to_dict().items()}
# mapping
tf_path = "" # to original tensorflow chkpt ends with .npz
pt_path = "" # to convert pytorch chkpt ends with .tar

pt = {}
ptx = torch.load(tf_path)['desc']
for tf_k, tf_v in ptx.items():
if 'num_batches_tracked' in tf_k: continue
if tf_k == 'upsample2x.unpool_mat':
pt[tf_k] = tf_v
continue
tf = np.load(tf_path)

for tf_k, tf_v in tf.items():
if 'linear' in tf_k: continue # should only be for pretrained model
pt_k = mapping[tf_k]
pt[pt_k] = tf_v
if "conv" in pt_k and "bn" not in pt_k and "bias" not in pt_k:
tf_v = np.transpose(tf_v, [3, 2, 0, 1])
if "shortcut" in pt_k:
tf_v = np.transpose(tf_v, [3, 2, 0, 1])
pt[pt_k] = torch.from_numpy(tf_v)
# make compatible with repo structure
pt["upsample2x.unpool_mat"] = torch.from_numpy(np.ones((2, 2), dtype="float32"))
pt = {"desc": pt}
torch.save(pt, pt_path)
Loading

0 comments on commit baef4d1

Please sign in to comment.