Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiweiliuiiiiiii authored Feb 1, 2022
1 parent 83cae5a commit 7260d67
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions ImageNet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _worker_init_fn(id):

# add binary masks to the dense model
if args.sparse:
decay = CosineDecay(args.prune_rate, len(train_loader)*args.epochs*args.multiplier)
decay = CosineDecay(args.prune_rate, int(len(train_loader)*args.epochs*args.multiplier))
mask = Masking(optimizer,train_loader=train_loader, prune_mode=args.prune, prune_rate_decay=decay, growth_mode=args.growth, redistribution_mode=args.redistribution, args=args)
mask.add_module(model_and_loss.model)
model_and_loss.mask = mask
Expand All @@ -248,7 +248,7 @@ def _worker_init_fn(id):
return

logger = logger_cls(train_loader_len, val_loader_len, args)
train_loop(args, model_and_loss, optimizer, adjust_learning_rate(args), train_loader, val_loader, args.epochs*args.multiplier,
train_loop(args, model_and_loss, optimizer, adjust_learning_rate(args), train_loader, val_loader, int(args.epochs*args.multiplier),
args.fp16, logger, should_backup_checkpoint(args),
start_epoch = args.start_epoch, best_prec1 = best_prec1, prof=args.prof)

Expand Down Expand Up @@ -365,7 +365,7 @@ def train_loop(args, model_and_loss, optimizer, lr_scheduler, train_loader, val_
Fore.RESET)

save_path = './save/granet-st/'
save_subfolder = os.path.join(save_path, 'M=' + str(args.multiplier) + '_ini_sparsity' + str(1 - args.ini_density) + '_final_sparsity' + str(1 - args.final_density))
save_subfolder = os.path.join(save_path, 'M=' + str(args.multiplier) + '_ini_sparsity' + str(1 - args.init_density) + '_final_sparsity' + str(1 - args.final_density))
if not os.path.exists(save_subfolder): os.makedirs(save_subfolder)
if should_backup_checkpoint(epoch):
backup_filename = args.save + 'checkpoint-{}.pth.tar'.format(epoch + 1)
Expand All @@ -390,16 +390,17 @@ def fast_collate(batch):
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
tens = torch.from_numpy(nump_array)
# tens = torch.from_numpy(nump_array)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)

tensor[i] += torch.from_numpy(nump_array)
nump_array_copy = np.copy(nump_array)
tensor[i] += torch.from_numpy(nump_array_copy)

return tensor, targets



def prefetched_loader(loader, fp16):
mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
Expand Down Expand Up @@ -433,7 +434,7 @@ def prefetched_loader(loader, fp16):


def get_train_loader(data_path, batch_size, workers=5, _worker_init_fn=None):
traindir = os.path.join(data_path, 'ILSVRC2012_img_train')
traindir = os.path.join(data_path, 'train')
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
Expand All @@ -455,7 +456,7 @@ def get_train_loader(data_path, batch_size, workers=5, _worker_init_fn=None):
return train_loader

def get_val_loader(data_path, batch_size, workers=5, _worker_init_fn=None):
valdir = os.path.join(data_path, 'ILSVRC2012_img_val')
valdir = os.path.join(data_path, 'val')

val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
Expand Down Expand Up @@ -783,11 +784,11 @@ def _alr(optimizer, epoch):
lr = args.lr * (epoch + 1) / (args.warmup + 1)

else:
if epoch < 30*args.multiplier:
if epoch < int(30*args.multiplier):
p = 0
elif epoch < 60*args.multiplier:
elif epoch < int(60*args.multiplier):
p = 1
elif epoch < 90*args.multiplier:
elif epoch < int(90*args.multiplier):
p = 2
else:
p = 3
Expand Down

0 comments on commit 7260d67

Please sign in to comment.