Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangAoCanada committed Aug 23, 2022
1 parent ce2491e commit eb86fd7
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 5 deletions.
Empty file modified test_data_functions.py
100755 → 100644
Empty file.
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import time
import torch
import argparse
Expand Down
Empty file modified train_data_functions.py
100755 → 100644
Empty file.
19 changes: 14 additions & 5 deletions train_distillation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2"

import time
import torch
import argparse
Expand Down Expand Up @@ -31,7 +35,7 @@
parser = argparse.ArgumentParser(description='Hyper-parameters for network')
parser.add_argument('-learning_rate', help='Set the learning rate', default=2e-4, type=float)
parser.add_argument('-crop_size', help='Set the crop_size', default=[512, 512], nargs='+', type=int)
parser.add_argument('-train_batch_size', help='Set the training batch size', default=8, type=int)
parser.add_argument('-train_batch_size', help='Set the training batch size', default=10, type=int)
parser.add_argument('-epoch_start', help='Starting epoch number of the training', default=0, type=int)
# parser.add_argument('-lambda_loss', help='Set the lambda in loss function', default=0.04, type=float)
parser.add_argument('-lambda_loss', help='Set the lambda in loss function', default=0.01, type=float)
Expand All @@ -40,6 +44,7 @@
parser.add_argument('-seed', help='set random seed', default=19, type=int)
parser.add_argument('-num_epochs', help='number of epochs', default=1000, type=int)
parser.add_argument('-distillation_scale', help='for model distillation', default=0.9, type=int)
parser.add_argument('-logdir', help='for tensorboard', default="StudentModel6.1", type=str)

args = parser.parse_args()

Expand All @@ -52,6 +57,7 @@
exp_name = args.exp_name
num_epochs = args.num_epochs
distillation_scale = args.distillation_scale
tensorboard_logdir = args.logdir

#set seed
seed = args.seed
Expand All @@ -67,12 +73,15 @@


##################### NOTE: Change the path to the dataset #####################
train_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220617/train"
validate_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220617/validate"
test_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220617/test"
# train_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220617/train"
# validate_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220617/validate"
# test_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220617/test"
# train_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220531/train"
# validate_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220531/validate"
# test_data_dir = "/content/drive/MyDrive/DERAIN/DATA_20220531/test"
train_data_dir = "/home/zhangao/DATASET/DATA_20220617/train"
validate_data_dir = "/home/zhangao/DATASET/DATA_20220617/validate"
test_data_dir = "/home/zhangao/DATASET/DATA_20220617/test_specific"
rain_L_dir = "rain_L"
rain_H_dir = "rain_H"
gt_dir = "gt"
Expand Down Expand Up @@ -172,7 +181,7 @@
print("[INFO] Teacher model encoder depths: {}, decoder depths: {}.".format(net_teacher.module.Tenc.depths, net_teacher.module.Tdec.depths[0]))
print("[INFO] Student model encoder depths: {}, decoder depths: {}.".format(net_student.module.Tenc.depths, net_student.module.Tdec.depths[0]))

log_dir = "./logs/images"
log_dir = os.path.join("./logs", tensorboard_logdir)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
writer = SummaryWriter(log_dir)
Expand Down
Empty file modified val_data_functions.py
100755 → 100644
Empty file.

0 comments on commit eb86fd7

Please sign in to comment.