Skip to content

Commit

Permalink
converting models
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangAoCanada committed May 20, 2022
1 parent db61460 commit f86fb81
Show file tree
Hide file tree
Showing 12 changed files with 1,534 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ testt.py
logs
ckpt
imgs

videos
26 changes: 15 additions & 11 deletions convert2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,26 @@
def preprocessImage(input_img):
# Resizing image in the multiple of 16"
wd_new, ht_new, _ = input_img.shape
if ht_new>wd_new and ht_new>1024:
wd_new = int(np.ceil(wd_new*1024/ht_new))
ht_new = 1024
elif ht_new<=wd_new and wd_new>1024:
ht_new = int(np.ceil(ht_new*1024/wd_new))
wd_new = 1024
if ht_new>wd_new and ht_new>2048:
wd_new = int(np.ceil(wd_new*2048/ht_new))
ht_new = 2048
elif ht_new<=wd_new and wd_new>2048:
ht_new = int(np.ceil(ht_new*2048/wd_new))
wd_new = 2048
wd_new = int(16*np.ceil(wd_new/16.0))
ht_new = int(16*np.ceil(ht_new/16.0))
# input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
input_img = cv2.resize(input_img, (wd_new, ht_new), interpolation=cv2.INTER_AREA)
# input_img = cv2.resize(input_img, (wd_new, ht_new), interpolation=cv2.INTER_AREA)
input_img = cv2.resize(input_img, (ht_new, wd_new), interpolation=cv2.INTER_AREA)

# --- Transform to tensor --- #
transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
input_im = transform_input(input_img)

# input_img = input_img / 255.0
# input_img = (input_img - 0.5) / 0.5
# transform_input = Compose([ToTensor()])
# input_im = transform_input(input_img)
return input_im


Expand All @@ -64,7 +70,7 @@ def preprocessImage(input_img):
model_path = "ckpt/best_psnr+lambda0.01"

video = cv2.VideoCapture(video_path)
video_saving = cv2.VideoWriter(output_video_path,cv2.VideoWriter_fourcc('M','J','P','G'),30,(2040,720))
# video_saving = cv2.VideoWriter(output_video_path,cv2.VideoWriter_fourcc('M','J','P','G'),30,(2040,720))

device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -84,15 +90,13 @@ def preprocessImage(input_img):

net = net.module




sample_img = None
while True:
ret, frame = video.read()
if not ret:
break
sample_image = frame
sample_image = cv2.resize(frame, (960, 540))
break

if sample_image is not None:
Expand Down
120 changes: 120 additions & 0 deletions custom2onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import sys
from tabnanny import verbose
# sys.path.append("/content/drive/MyDrive/DERAIN/TransWeather")

import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from val_data_functions import ValData
from utils import validation, validation_val, calc_psnr, calc_ssim
import os
import numpy as np
import random
from transweather_model_extra import Transweather

from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize
from random import randrange
import torchvision.utils as utils
import cv2
import re
from tqdm import tqdm
from skimage import img_as_ubyte

from torchinfo import summary


def preprocessImage(input_img):
# Resizing image in the multiple of 16"
wd_new, ht_new, _ = input_img.shape
if ht_new>wd_new and ht_new>2048:
wd_new = int(np.ceil(wd_new*2048/ht_new))
ht_new = 2048
elif ht_new<=wd_new and wd_new>2048:
ht_new = int(np.ceil(ht_new*2048/wd_new))
wd_new = 2048
wd_new = int(16*np.ceil(wd_new/16.0))
ht_new = int(16*np.ceil(ht_new/16.0))
# input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
# input_img = cv2.resize(input_img, (wd_new, ht_new), interpolation=cv2.INTER_AREA)
input_img = cv2.resize(input_img, (ht_new, wd_new), interpolation=cv2.INTER_AREA)

# --- Transform to tensor --- #
# transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# input_im = transform_input(input_img)

# input_img = input_img / 255.0
# input_img = (input_img - 0.5) / 0.5
# transform_input = Compose([ToTensor()])
# input_im = transform_input(input_img)

input_im = torch.from_numpy(input_img.astype(np.float32))
return input_im




val_batch_size = 1
exp_name = "ckpt"
#set seed
seed = 19
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
random.seed(seed)
print('Seed:\t{}'.format(seed))

video_path = "/home/ao/tmp/clip_videos/h97cam_water_video.mp4"
output_video_path = "./videos/h97cam_water_lambda00_video.avi"
model_path = "ckpt/best_psnr+lambda0.01"

video = cv2.VideoCapture(video_path)
# video_saving = cv2.VideoWriter(output_video_path,cv2.VideoWriter_fourcc('M','J','P','G'),30,(2040,720))

device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = Transweather()
net = nn.DataParallel(net)

if device == torch.device("cpu"):
net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
print("====> model ", model_path, " loaded")
else:
net.load_state_dict(torch.load(model_path))
net.to(device)
print("====> model ", model_path, " loaded")

net.eval()

net = net.module

sample_img = None
while True:
ret, frame = video.read()
if not ret:
break
sample_image = frame
sample_image = cv2.resize(frame, (960, 540))
# sample_image = cv2.resize(frame, (640, 360))
break

if sample_image is not None:
print("[INFO] image shape: ", sample_image.shape)
else:
print("[INFO] image is None")


input_img = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)
input_img = preprocessImage(input_img)
input_img = input_img.unsqueeze(0)
# input_img = input_img.to(device)

# torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'], opset_version=11)
torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'], opset_version=11, dynamic_axes={'input': {0, 'batch_size'}, 'output': {0, 'batch_size'}})

print("[FINISHED] onnx model exported")

6 changes: 3 additions & 3 deletions inference_video_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def preprocessImage(input_img):
# video_path = "/content/drive/MyDrive/DERAIN/DATA_captured/something_else/dusty_video1.mp4"
# output_video_path = "./videos/dusty_video1_result.avi"
# model_path = "ckpt/best_512"
video_path = "/content/drive/MyDrive/DERAIN/DATA_captured/something_else/dust_with_water2_video.mp4"
output_video_path = "./videos/dust_with_water2_video_result.avi"
model_path = "ckpt/best_aug3_CutMix"
video_path = "/content/drive/MyDrive/DERAIN/video_data/h97cam_water_video.mp4"
output_video_path = "./videos/h97cam_water_lambda00_video.avi"
model_path = "ckpt/best_lambda_0.0"

video = cv2.VideoCapture(video_path)
video_saving = cv2.VideoWriter(output_video_path,cv2.VideoWriter_fourcc('M','J','P','G'),30,(2040,720))
Expand Down
142 changes: 142 additions & 0 deletions jitTransfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import sys
from tabnanny import verbose
# sys.path.append("/content/drive/MyDrive/DERAIN/TransWeather")

import time
import torch
import argparse
import torch.nn as nn
from torch.utils.data import DataLoader
from val_data_functions import ValData
from utils import validation, validation_val, calc_psnr, calc_ssim
import os
import numpy as np
import random
from transweather_model import Transweather

from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize
from random import randrange
import torchvision.utils as utils
import cv2
import re
from tqdm import tqdm
from skimage import img_as_ubyte

from torchinfo import summary


def preprocessImage(input_img):
# Resizing image in the multiple of 16"
wd_new, ht_new, _ = input_img.shape
if ht_new>wd_new and ht_new>1024:
wd_new = int(np.ceil(wd_new*1024/ht_new))
ht_new = 1024
elif ht_new<=wd_new and wd_new>1024:
ht_new = int(np.ceil(ht_new*1024/wd_new))
wd_new = 1024
wd_new = int(16*np.ceil(wd_new/16.0))
ht_new = int(16*np.ceil(ht_new/16.0))
# input_img = input_img.resize((wd_new,ht_new), Image.ANTIALIAS)
input_img = cv2.resize(input_img, (wd_new, ht_new), interpolation=cv2.INTER_AREA)

# --- Transform to tensor --- #
transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
input_im = transform_input(input_img)

# input_img = input_img / 255.0
# input_img = (input_img - 0.5) / 0.5
# transform_input = Compose([ToTensor()])
# input_im = transform_input(input_img)
return input_im




val_batch_size = 1
exp_name = "ckpt"
#set seed
seed = 19
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
random.seed(seed)
print('Seed:\t{}'.format(seed))

video_path = "/home/ao/tmp/clip_videos/h97cam_water_video.mp4"
output_video_path = "./videos/h97cam_water_lambda00_video.avi"
model_path = "ckpt/best_psnr+lambda0.01"

video = cv2.VideoCapture(video_path)
# video_saving = cv2.VideoWriter(output_video_path,cv2.VideoWriter_fourcc('M','J','P','G'),30,(2040,720))

device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = Transweather()
net = nn.DataParallel(net)

if device == torch.device("cpu"):
net.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
print("====> model ", model_path, " loaded")
else:
net.load_state_dict(torch.load(model_path))
net.to(device)
print("====> model ", model_path, " loaded")

net.eval()

net = net.module

sample_img = None
while True:
ret, frame = video.read()
if not ret:
break
# sample_image = frame
sample_image = cv2.resize(frame, (640, 480))
break

if sample_image is not None:
print("[INFO] image shape: ", sample_image.shape)
else:
print("[INFO] image is None")


input_img = cv2.cvtColor(sample_image, cv2.COLOR_BGR2RGB)
input_img = preprocessImage(input_img)
input_img = input_img.unsqueeze(0)
# input_img = input_img.to(device)

torch.onnx.export(net, input_img, "./ckpt/transweather.onnx", verbose=True, input_names=['input'], output_names=['output'])

print("[FINISHED] onnx model exported")




# ### NOTE: start evaluation ###
# with torch.no_grad():
# while True:
# ret, frame = video.read()
# if not ret:
# break
# frame = frame[:, 180:1200, :]
# # pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# input_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# input_img = preprocessImage(input_img)
# input_img = input_img.to(device)
# input_img = input_img.unsqueeze(0)
# print("[INFO] ", input_img.shape)
# pred_image = net(input_img)
# pred_image_cpu = pred_image[0].permute(1,2,0).cpu().numpy()
# pred_image_cpu = img_as_ubyte(pred_image_cpu)
# pred_image_cpu = cv2.resize(pred_image_cpu, (frame.shape[1],frame.shape[0]))
# image = np.concatenate((frame, pred_image_cpu[..., ::-1]), axis=1)
# # video_saving.write(image)
# cv2.imshow("image", image)
# if cv2.waitKey(1) == 27:
# break


22 changes: 22 additions & 0 deletions onnxQuantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# import onnx
# from onnxruntime.quantization import quantize_dynamic, QuantType

# model_fp32 = './ckpt/transweather.onnx'
# model_quant = './ckpt/transweather.quant.onnx'
# quantized_model = quantize_dynamic(model_fp32, model_quant)

# model = onnx.load("./ckpt/transweather.quant.onnx")
# onnx.checker.check_model(model)
# print(onnx.helper.printable_graph(model.graph))


import onnx
from onnxruntime.quantization import quantize, QuantizationMode

model = onnx.load('./ckpt/transweather.onnx')
quantized_model = quantize(model, quantization_mode=QuantizationMode.IntegerOps)
onnx.save(quantized_model, './ckpt/transweather.quant.onnx')

model = onnx.load("./ckpt/transweather.quant.onnx")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
9 changes: 9 additions & 0 deletions optimize_simplied.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import onnx
import onnxoptimizer
import onnxruntime as ort


model = onnx.load("./ckpt/transweather.onnx")
model_opt = onnxoptimizer.optimize(model)
onnx.save(model_opt, "./ckpt/transweather_opt.onnx")
print("[INFO] finished.")
Loading

0 comments on commit f86fb81

Please sign in to comment.