Skip to content

Commit

Permalink
adjusting
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangAoCanada committed May 23, 2022
1 parent f86fb81 commit 9a8f4e1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
8 changes: 4 additions & 4 deletions custom2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def preprocessImage(input_img):
if not ret:
break
sample_image = frame
sample_image = cv2.resize(frame, (960, 540))
# sample_image = cv2.resize(frame, (640, 360))
# sample_image = cv2.resize(frame, (960, 540))
sample_image = cv2.resize(frame, (640, 360))
break

if sample_image is not None:
Expand All @@ -113,8 +113,8 @@ def preprocessImage(input_img):
input_img = input_img.unsqueeze(0)
# input_img = input_img.to(device)

# torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'], opset_version=11)
torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'], opset_version=11, dynamic_axes={'input': {0, 'batch_size'}, 'output': {0, 'batch_size'}})
torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'], opset_version=11)
# torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'], opset_version=11, dynamic_axes={'input': {0, 'batch_size'}, 'output': {0, 'batch_size'}})

print("[FINISHED] onnx model exported")

4 changes: 2 additions & 2 deletions optimize_simplied.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import onnxruntime as ort


model = onnx.load("./ckpt/transweather.onnx")
model = onnx.load("./ckpt/transweather_sim.onnx")
model_opt = onnxoptimizer.optimize(model)
onnx.save(model_opt, "./ckpt/transweather_opt.onnx")
onnx.save(model_opt, "./ckpt/transweather.onnx")
print("[INFO] finished.")
7 changes: 4 additions & 3 deletions test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def preprocessImage(input_img):
cap = cv2.VideoCapture(video_path)

# ort_session = ort.InferenceSession("./ckpt/transweather.onnx")
ort_session = ort.InferenceSession("./ckpt/transweather_quant.onnx")
# ort_session = ort.InferenceSession("./ckpt/transweather_quant.onnx")
ort_session = ort.InferenceSession("./ckpt/transweather.quant.onnx")

while True:
ret, frame = cap.read()
if not ret:
break
frame = cv2.resize(frame, (960, 540))
# frame = cv2.resize(frame, (640, 360))
# frame = cv2.resize(frame, (960, 540))
frame = cv2.resize(frame, (640, 360))
input_img = preprocessImage(frame)
outputs = ort_session.run(
None,
Expand Down

0 comments on commit 9a8f4e1

Please sign in to comment.