Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiweiliuiiiiiii authored Jan 7, 2022
1 parent 2829c84 commit 1bf6fbe
Showing 1 changed file with 2 additions and 70 deletions.
72 changes: 2 additions & 70 deletions CIFAR/sparselearning/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def add_sparse_args(parser):
parser.add_argument('--sparse', action='store_true', help='Enable sparse mode. Default: True.')
parser.add_argument('--fix', action='store_true', help='Fix topology during training. Default: True.')
parser.add_argument('--update-frequency', type=int, default=100, metavar='N', help='how many iterations to train between mask update')
parser.add_argument('--sparse-init', type=str, default='ERK, uniform, uniform_structured for sparse training', help='sparse initialization')
parser.add_argument('--sparse-init', type=str, default='ERK, uniform distributions for sparse training, global pruning and uniform pruning for pruning', help='sparse initialization')
# hyperparameters for gradually pruning
parser.add_argument('--method', type=str, default='GraNet', help='method name: DST, GraNet, GraNet_uniform, GMP, GMO_uniform')
parser.add_argument('--init-density', type=float, default=0.50, help='The pruning rate / death rate.')
Expand Down Expand Up @@ -243,63 +243,6 @@ def init(self, mode='ER', density=0.05, erk_power_scale=1.0, grad_dict=None):
self.baseline_nonzero += (self.masks[name] != 0).sum().int().item()
self.apply_mask()

# structured pruning
elif self.sparse_init == 'prune_structured':
# uniformly structured pruning
print('initialized by pruning structured')

self.baseline_nonzero = 0
for module in self.modules:
for name, weight in module.named_parameters():
if name not in self.masks: continue
self.masks[name] = (weight != 0).cuda()
nunits = weight.size(0)

criteria_for_layer = weight.data.abs().view(nunits, -1).sum(dim=1)
num_zeros = (criteria_for_layer == 0).sum().item()
num_nonzeros = nunits-num_zeros
num_remove = self.args.pruning_rate * num_nonzeros
k = int(num_zeros + num_remove)
x, idx = torch.sort(criteria_for_layer)
self.masks[name][idx[:k]] = 0.0
self.apply_mask()

elif self.sparse_init == 'prune_and_grow_structured':
# # uniformly structured pruning
print('initialized by prune_and_grow_structured')

self.baseline_nonzero = 0
for module in self.modules:
for name, weight in module.named_parameters():
if name not in self.masks: continue
self.masks[name] = (weight != 0).cuda()
nunits = weight.size(0)

# prune
criteria_for_layer = weight.data.abs().view(nunits, -1).sum(dim=1)
num_zeros = (criteria_for_layer == 0).sum().item()
num_nonzeros = nunits-num_zeros
num_remove = self.args.pruning_rate * num_nonzeros
print(f"number of removed channels is {num_remove}")
k = int(num_zeros + num_remove)
x, idx = torch.sort(criteria_for_layer)
self.masks[name][idx[:k]] = 0.0

# set the pruned weights to zero
weight.data = weight.data * self.masks[name]
if 'momentum_buffer' in self.optimizer.state[weight]:
self.optimizer.state[weight]['momentum_buffer'] = self.optimizer.state[weight][
'momentum_buffer'] * self.masks[name]
# grow
num_remove = num_nonzeros - (weight.data.view(nunits, -1).sum(dim=1) != 0).sum().item()
print(f"number of removed channels is {num_remove}")
grad = grad_dict[name]
grad = grad * (self.masks[name] == 0).float()
grad_criteria_for_layer = grad.data.abs().view(nunits, -1).sum(dim=1)
y, idx = torch.sort(grad_criteria_for_layer, descending=True)
self.masks[name][idx[:num_remove]] = 1.0
self.apply_mask()

elif self.sparse_init == 'uniform':
self.baseline_nonzero = 0
for module in self.modules:
Expand All @@ -310,17 +253,6 @@ def init(self, mode='ER', density=0.05, erk_power_scale=1.0, grad_dict=None):
self.baseline_nonzero += weight.numel() * density
self.apply_mask()

elif self.sparse_init == 'uniform_structured':
self.baseline_nonzero = 0
for module in self.modules:
for name, weight in module.named_parameters():
if name not in self.masks: continue
nunits = weight.size(0)
num_zeros = int(nunits * (1-density))
zero_idx = np.random.choice(range(nunits), num_zeros, replace=False)
self.masks[name][zero_idx] = 0.0
self.apply_mask()

elif self.sparse_init == 'ERK':
print('initialize by ERK')
for name, weight in self.masks.items():
Expand Down Expand Up @@ -763,4 +695,4 @@ def fired_masks_update(self):
print('Layerwise percentage of the fired weights of', name, 'is:', layer_fired_weights[name])
total_fired_weights = ntotal_fired_weights/ntotal_weights
print('The percentage of the total fired weights is:', total_fired_weights)
return layer_fired_weights, total_fired_weights
return layer_fired_weights, total_fired_weights

0 comments on commit 1bf6fbe

Please sign in to comment.