Skip to content

Commit

Permalink
organizational changes
Browse files Browse the repository at this point in the history
  • Loading branch information
3283 committed Jul 7, 2023
1 parent f1ce8bd commit 95ada13
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 45 deletions.
150 changes: 114 additions & 36 deletions pytorch/neurops/metrics.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,12 @@
import torch
import torch.nn as nn

"""
Measure l1 or l2 (etc.) sum of weights for each neuron in a layer
L1 sum of weights is used by Li et al (2019) to measure neuron importance
weights: weights of a layer
p: p-norm to use (int, inf, -inf, "fro", "nuc")
fanin: whether to measure w.r.t. fan-in weights (so output length is # of output neurons) or fan-out weights
Whole layer metrics
"""
def weight_sum(weights: torch.Tensor = None, p = 1, fanin: bool = True, conversion_factor: int = -1):
if weights is None:
return None
if not fanin:
weights = torch.transpose(weights, 0, 1)
if len(weights.shape) > 2:
weights = weights.reshape(weights.shape[0], -1)
if conversion_factor != -1:
weights = weights.reshape(-1, conversion_factor, *weights.shape[1:])
weights = weights.reshape(weights.shape[0], -1)
return torch.norm(weights, p=p, dim=1)

"""
Measure variance of activations for each neuron in a layer, used by Polyak
and Wolf (2015) to measure neuron importance
"""
def activation_variance(activations: torch.Tensor = None):
if activations is None:
return None
if len(activations.shape) > 2:
activations = torch.transpose(torch.transpose(activations, 0, 1).reshape(activations.shape[1], -1), 0, 1)
return torch.var(activations, dim=0)

"""
Measure effective rank of whole layer via thresholding singular values of
activations (or weights)
activations (or weights) (high score = low redundancy)
"""
def effective_rank(tensor: torch.Tensor = None, threshold: float = 0.01,
partial: bool = False, scale: bool = True, limit_ratio = -1):
Expand All @@ -43,7 +15,8 @@ def effective_rank(tensor: torch.Tensor = None, threshold: float = 0.01,
if len(tensor.shape) > 2:
tensor = torch.transpose(torch.transpose(tensor, 0, 1).reshape(tensor.shape[1], -1), 0, 1)
if limit_ratio > 0 and tensor.shape[0]/tensor.shape[1] > limit_ratio:
tensor = tensor[:tensor.shape[1]*limit_ratio]
sampleindices = torch.randperm(tensor.shape[0])[:tensor.shape[1]*limit_ratio]
tensor = tensor[sampleindices]
if scale:
tensor = tensor.clone() / tensor.shape[1]**0.5
_, S, _ = torch.svd(tensor, compute_uv=False)
Expand All @@ -54,21 +27,61 @@ def effective_rank(tensor: torch.Tensor = None, threshold: float = 0.01,

"""
Measure orthogonality gap of activations. Score of 0 means completely orthogonal, score of 1 means completely redundant
Used by Daneshmand et al. (2021)
Used by Daneshmand et al. (2021) (theoretical version is covariance-based, implementation is SVD-based)
"""
def orthogonality_gap(activations: torch.Tensor = None):
def orthogonality_gap(activations: torch.Tensor = None, svd = True, norm_neurons: bool = False):
if activations is None:
return None
if norm_neurons:
activations = activations / torch.norm(activations, dim=1, keepdim=True) #TEST
if len(activations.shape) > 2:
activations = activations.reshape(activations.shape[0], -1)
cov = activations @ activations.t()
return torch.norm(cov/(torch.norm(activations)**2) - torch.eye(activations.shape[0]).to(cov.device)/activations.shape[0], p='fro')





"""
Per-neuron metrics
"""

"""
Measure l1 or l2 (etc.) sum of weights for each neuron in a layer
L1 sum of weights is used by Li et al (2019) to measure neuron importance (high score = high importance)
weights: weights of a layer
p: p-norm to use (int, inf, -inf, "fro", "nuc")
fanin: whether to measure w.r.t. fan-in weights (so output length is # of output neurons) or fan-out weights
"""
def weight_sum(weights: torch.Tensor = None, p = 1, fanin: bool = True, conversion_factor: int = -1):
if weights is None:
return None
if not fanin:
weights = torch.transpose(weights, 0, 1)
if len(weights.shape) > 2:
weights = weights.reshape(weights.shape[0], -1)
if conversion_factor != -1:
weights = weights.reshape(-1, conversion_factor, *weights.shape[1:])
weights = weights.reshape(weights.shape[0], -1)
return torch.norm(weights, p=p, dim=1)

"""
Measure variance of activations for each neuron in a layer, used by Polyak
and Wolf (2015) to measure neuron importance (high score = high importance)
"""
def activation_variance(activations: torch.Tensor = None):
if activations is None:
return None
if len(activations.shape) > 2:
activations = torch.transpose(torch.transpose(activations, 0, 1).reshape(activations.shape[1], -1), 0, 1)
return torch.var(activations, dim=0)

"""
Measure effective rank per neuron when that neuron is left out of the
computation
Used by Maile et al. (2022) for selection of neurogenesis initialization candidates
computation (high score = high redundancy)
Variation of method used by Maile et al. (2022) for selection of neurogenesis initialization candidates
"""
def svd_score(tensor: torch.Tensor = None, threshold: float = 0.01, addwhole: bool = False,
scale: bool = True, difference: bool = False, limit_ratio = -1):
Expand All @@ -87,7 +100,8 @@ def svd_score(tensor: torch.Tensor = None, threshold: float = 0.01, addwhole: bo
if len(prunedtensor.shape) > 2:
prunedtensor = prunedtensor.reshape(tensor.shape[0], -1)
if limit_ratio > 0 and prunedtensor.shape[1]/tensor.shape[0] > limit_ratio:
prunedtensor = prunedtensor[:, :tensor.shape[0]*limit_ratio]
sampleindices = torch.randperm(prunedtensor.shape[1])[:prunedtensor.shape[0]*limit_ratio]
prunedtensor = prunedtensor[:, sampleindices]
if scale:
prunedtensor /= prunedtensor.shape[1]**0.5
_, S, _ = torch.svd(prunedtensor, compute_uv=False)
Expand Down Expand Up @@ -118,6 +132,70 @@ def nuclear_score(activations: torch.Tensor = None, average: bool = False):
scores[neuron] = torch.mean(torch.norm(pruned_activations, p='nuc', dim=(1,2)))
return scores

"""
Measure correlation between activations of each neuron (dim 1) and the rest of the layer
(high score = high redundance)
Used by Suau et al 2018
"""
def correlation_score(activations: torch.Tensor = None, crosscorr: bool = True):
if activations is None:
return None
if len(activations.shape) > 2:
activations = torch.transpose(torch.transpose(activations, 0, 1).reshape(activations.shape[1], -1), 0, 1)
if crosscorr:
corr = torch.nan_to_num(torch.corrcoef(activations.t()),0)
else:
corr = activations.t() @ activations
return torch.sum(torch.square(corr), dim=1)

"""
Measure correlation between activations when each neuron is left out of the computation:
(high score = low redundance)
"""
def dropped_corr_score(activations: torch.Tensor = None):
if activations is None:
return None
if len(activations.shape) > 2:
activations = torch.transpose(torch.transpose(activations, 0, 1).reshape(activations.shape[1], -1), 0, 1)
scores = torch.zeros(activations.shape[1])
for neuron in range(activations.shape[1]):
pruned_activations = torch.cat((activations[:, :neuron], activations[:, neuron+1:]), dim=1)
corr = pruned_activations.t() @ pruned_activations
scores[neuron] = torch.sum(torch.sum(corr,dim=1)/(torch.diagonal(corr)+1e-8))
return scores

"""
Measure Average Percentage of Zeros (APoZ) per neurons (high score = low importance)
Used by Hu et al 2016
"""
def apoz_score(activations: torch.Tensor = None):
if activations is None:
return None
if len(activations.shape) > 2:
activations = torch.transpose(torch.transpose(activations, 0, 1).reshape(activations.shape[1], -1), 0, 1)
return torch.mean((torch.abs(activations) < 1e-8).float(), dim=0)

"""
Measure reconstruction error of each neuron when that neuron is projected onto the rest of the layer.
(high score = low redundancy)
Used by Berg 2022
"""
def reconstruction_score(activations: torch.Tensor = None, limit_ratio = -1):
if activations is None:
return None
if len(activations.shape) > 2:
activations = torch.transpose(torch.transpose(activations, 0, 1).reshape(activations.shape[1], -1), 0, 1)
if limit_ratio > 0 and activations.shape[0]/activations.shape[1] > limit_ratio:
sampleindices = torch.randperm(activations.shape[0])[:activations.shape[1]*limit_ratio]
activations = activations[sampleindices]
scores = torch.zeros(activations.shape[1])
for neuron in range(activations.shape[1]):
pruned_activations = torch.cat((activations[:, :neuron], activations[:, neuron+1:]), dim=1)
scores[neuron] = torch.norm(activations[:, neuron] - pruned_activations @ torch.pinverse(pruned_activations) @ activations[:, neuron])
return scores


"""
Measure fisher information of mask gradients: assume 0th dim is batch dim and rest are weight dims
Expand Down
36 changes: 29 additions & 7 deletions pytorch/neurops/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(self, *args, track_activations: False, track_auxiliary_gradients:

super(ModSequential, self).__init__(*args)

self.test = False

self.track_activations = track_activations
self.track_auxiliary_gradients = track_auxiliary_gradients

Expand Down Expand Up @@ -69,17 +71,19 @@ def parameters(self, recurse: bool = True, include_mask = False):
Saves the activations of a layer to the activations dictionary.
"""
def _act_hook(self, name, module, input, output):
self.activations[name] = torch.cat((self.activations[name], output.cpu()), dim=0)
if self.activations[name].shape[0] > self.multiplier*self.activations[name].shape[1]:
self.activations[name] = self.activations[name][-self.multiplier*self.activations[name].shape[1]:]
if not self.test:
self.activations[name] = torch.cat((self.activations[name], output.cpu()), dim=0)
if self.activations[name].shape[0] > self.multiplier*self.activations[name].shape[1]:
self.activations[name] = self.activations[name][-self.multiplier*self.activations[name].shape[1]:]

"""
Saves the input to the first layer to the activations dictionary.
"""
def _input_hook(self, module, input):
self.activations["-1"] = torch.cat((self.activations["-1"], input[0].cpu()), dim=0)
if self.activations["-1"].shape[0] > self.multiplier*self.activations["-1"].shape[1]:
self.activations["-1"] = self.activations["-1"][-self.multiplier*self.activations["-1"].shape[1]:]
if not self.test:
self.activations["-1"] = torch.cat((self.activations["-1"], input[0].cpu()), dim=0)
if self.activations["-1"].shape[0] > self.multiplier*self.activations["-1"].shape[1]:
self.activations["-1"] = self.activations["-1"][-self.multiplier*self.activations["-1"].shape[1]:]

def _act_shape_hook(self, module, input, output):
self.conv_output_shape = output.shape[1:]
Expand All @@ -97,6 +101,24 @@ def parameter_count(self, masked: bool = False):
else:
count += sum(p.numel() for p in self[i].parameters())
return count

def FLOPs_count(self, input = None, masked: bool = False, verbose: bool = False):
count = 0
x = torch.zeros_like(input)
for i in range(len(self)):
if isinstance(self[i], ModLinear) or isinstance(self[i], ModConv2d):
FLOPs, x = self[i].FLOPs_count(x, masked=masked, previous_mask = None if i == 0 or not self[i-1].masked else
self[i-1].mask_vector if i-1 != self.conversion_layer else
self[i-1].mask_vector.view(1,-1).tile(self.conversion_factor,1).view(-1))
count += FLOPs
if verbose:
print(f"Layer {i}: {FLOPs} FLOPs")
else:
x = self[i](x)
return count

def clear_activations(self):
self.activations = defaultdict(torch.Tensor)

def forward(self, x, auxiliaries: list = None, layer_index: int = -1):
old_x = x
Expand Down Expand Up @@ -160,7 +182,7 @@ def prune(self, layer_index: int, neurons: list = [], optimizer=None, clear_acti
elif index == layer_index and self.track_activations and len(self.activations[str(index)].shape) >= 2:
neurons_to_keep = range(self.activations[str(index)].shape[1])
neurons_to_keep = [
ntk for ntk in neurons if ntk not in neurons]
ntk for ntk in neurons_to_keep if ntk not in neurons]
self.activations[str(index)] = self.activations[str(index)][:, neurons_to_keep]
if self.track_auxiliary_gradients:
neurons_to_keep = range(self.auxiliaries[layer_index-1].shape[1])
Expand Down
4 changes: 2 additions & 2 deletions tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
"for iter in range(5):\n",
" for i in range(len(modded_model_grow)-1):\n",
" #score = orthogonality_gap(modded_model_grow.activations[str(i)])\n",
" max_rank = modded_model_grow[i].out_features if i > modded_model_grow.conversion_layer else modded_model_grow[i].out_channels\n",
" max_rank = modded_model_grow[i].width()\n",
" score = effective_rank(modded_model_grow.activations[str(i)])\n",
" to_add = max(score-int(0.95*max_rank), 0)\n",
" print(\"Layer {} score: {}/{}, neurons to add: {}\".format(i, score, max_rank, to_add))\n",
Expand Down Expand Up @@ -297,7 +297,7 @@
"modded_optimizer_masked.load_state_dict(optimizer.state_dict())\n",
"\n",
"for i in range(len(modded_model_masked)-1):\n",
" neurons = modded_model_masked[i].out_features if i > modded_model_masked.conversion_layer else modded_model_masked[i].out_channels\n",
" neurons = modded_model_masked[i].width()\n",
" modded_model_masked.grow(i, neurons, fanin_weights=\"kaiming\", fanout_weights=\"kaiming\", optimizer=modded_optimizer_masked)\n",
" modded_model_masked.mask(i, list(range(neurons, 2*neurons)))\n",
"\n",
Expand Down

0 comments on commit 95ada13

Please sign in to comment.