Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiweiliuiiiiiii authored Oct 9, 2021
1 parent 49f871a commit 549cd3a
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions CIFAR/sparselearning/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ def add_sparse_args(parser):
parser.add_argument('--redistribution', type=str, default='none', help='Redistribution mode. Choose from: momentum, magnitude, nonzeros, or none.')
parser.add_argument('--death-rate', type=float, default=0.50, help='The pruning rate / death rate for Zero-Cost Neuroregeneration.')
parser.add_argument('--pruning-rate', type=float, default=0.50, help='The pruning rate / death rate.')
parser.add_argumeant('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
parser.add_argument('--update-frequency', type=int, default=100, metavar='N', help='how many iterations to train between mask update')
parser.add_argument('--sparse-init', type=str, default='ERK, uniform, uniform_structured for sparse training', help='sparse initialization')
# hyperparameters for gradually pruning
parser.add_argument('--method', type=str, default='GraNet', help='method name: DST, GraNet, GraNet_uniform, GMP, GMO_uniform')
Expand Down Expand Up @@ -435,7 +436,7 @@ def step(self):


def pruning(self, step):
# prune_rate = 1 - self.args.final_density - self.args.ini_density
# prune_rate = 1 - self.args.final_density - self.args.init_density
curr_prune_iter = int(step / self.prune_every_k_steps)
final_iter = int((self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps)
ini_iter = int((self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps)
Expand All @@ -447,7 +448,7 @@ def pruning(self, step):

if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter:
prune_decay = (1 - ((curr_prune_iter - ini_iter) / total_prune_iter)) ** 3
curr_prune_rate = (1 - self.args.ini_density) + (self.args.ini_density - self.args.final_density) * (
curr_prune_rate = (1 - self.args.init_density) + (self.args.init_density - self.args.final_density) * (
1 - prune_decay)

weight_abs = []
Expand All @@ -466,7 +467,7 @@ def pruning(self, step):
for module in self.modules:
for name, weight in module.named_parameters():
if name not in self.masks: continue
self.masks[name] = ((torch.abs(weight)) >= acceptable_score).float()
self.masks[name] = ((torch.abs(weight)) > acceptable_score).float() # must be > to prevent acceptable_score is zero, leading to dense tensors

self.apply_mask()

Expand All @@ -483,7 +484,7 @@ def pruning(self, step):
(total_size-sparse_size) / total_size))

def pruning_uniform(self, step):
# prune_rate = 1 - self.args.final_density - self.args.ini_density
# prune_rate = 1 - self.args.final_density - self.args.init_density
curr_prune_iter = int(step / self.prune_every_k_steps)
final_iter = (self.args.final_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps
ini_iter = (self.args.init_prune_epoch * len(self.loader)*self.args.multiplier) / self.prune_every_k_steps
Expand All @@ -494,7 +495,7 @@ def pruning_uniform(self, step):

if curr_prune_iter >= ini_iter and curr_prune_iter <= final_iter:
prune_decay = (1 - ((curr_prune_iter - ini_iter) / total_prune_iter)) ** 3
curr_prune_rate = (1 - self.args.ini_density) + (self.args.ini_density - self.args.final_density) * (
curr_prune_rate = (1 - self.args.init_density) + (self.args.init_density - self.args.final_density) * (
1 - prune_decay)
# keep the density of the last layer as 0.2 if spasity is larger then 0.8
if curr_prune_rate >= 0.8:
Expand Down Expand Up @@ -554,7 +555,7 @@ def add_module(self, module, sparse_init='ERK', grad_dic=None):
self.masks.pop(name)
print(f"pop out {name}")

self.init(mode=self.args.sparse_init, density=self.args.ini_density, grad_dict=grad_dic)
self.init(mode=self.args.sparse_init, density=self.args.init_density, grad_dict=grad_dic)


def remove_weight(self, name):
Expand Down Expand Up @@ -720,12 +721,7 @@ def print_nonzero_counts(self):
val = '{0}: {1}->{2}, density: {3:.3f}'.format(name, self.name2nonzeros[name], num_nonzeros, num_nonzeros/float(mask.numel()))
print(val)


for module in self.modules:
for name, tensor in module.named_parameters():
if name not in self.masks: continue
print('Death rate: {0}\n'.format(self.name2death_rate[name]))
break
print('Death rate: {0}\n'.format(self.death_rate))

def reset_momentum(self):
"""
Expand Down

0 comments on commit 549cd3a

Please sign in to comment.