-
Notifications
You must be signed in to change notification settings - Fork 2
/
export.py
114 lines (97 loc) · 3.08 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import warnings
warnings.filterwarnings("ignore", module="onnxconverter_common.float16")
import onnx
import torch
from onnxconverter_common import float16
from DocShadow.models import DocShadow
from DocShadow.utils import load_checkpoint
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_size",
nargs=2,
type=int,
default=[256, 256],
required=False,
help="Sample image size for ONNX tracing. Please provide two integers (height width). Ensure that you have enough memory to run the export.",
)
parser.add_argument(
"--weights",
type=str,
default="sd7k",
required=False,
help="DocShadow has been trained on these datasets: ['sd7k', 'jung', 'kliger']. Defaults to 'sd7k' weights. You can also specify a local path to the weights.",
)
parser.add_argument(
"--onnx_path",
type=str,
default=None,
required=False,
help="Path to save the exported ONNX model.",
)
parser.add_argument(
"--dynamic_img_size",
action="store_true",
help="Whether to allow dynamic image sizes.",
)
parser.add_argument(
"--dynamic_batch",
action="store_true",
help="Whether to allow dynamic batch size.",
)
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to also export float16 (half) ONNX model (CUDA only).",
)
return parser.parse_args()
def export_onnx(
img_size=[256, 256],
weights="sd7k",
onnx_path=None,
dynamic_img_size=False,
dynamic_batch=False,
fp16=False,
):
# Handle args.
H, W = img_size
if onnx_path is None:
onnx_path = (
f"weights/docshadow_{weights}"
f"{f'_{H}x{W}' if not dynamic_img_size else ''}"
".onnx"
)
# Load inputs and models.
device = torch.device("cpu") # Device on which to export.
img = torch.rand(1, 3, H, W, dtype=torch.float32, device=device)
docshadow = DocShadow()
load_checkpoint(docshadow, weights, device)
docshadow.eval().to(device)
# Export.
opset_version = 12
dynamic_axes = {"image": {}, "result": {}}
if dynamic_batch:
dynamic_axes["image"].update({0: "batch_size"})
dynamic_axes["result"].update({0: "batch_size"})
if dynamic_img_size:
dynamic_axes["image"].update({2: "height", 3: "width"})
dynamic_axes["result"].update({2: "height", 3: "width"})
torch.onnx.export(
docshadow,
img,
onnx_path,
input_names=["image"],
output_names=["result"],
opset_version=opset_version,
dynamic_axes=dynamic_axes,
)
if fp16:
convert_fp16(onnx_path)
def convert_fp16(onnx_model_path: str):
onnx_model = onnx.load(onnx_model_path)
fp16_model = float16.convert_float_to_float16(onnx_model)
onnx.save(fp16_model, onnx_model_path.replace(".onnx", "_fp16.onnx"))
if __name__ == "__main__":
args = parse_args()
export_onnx(**vars(args))