-
Notifications
You must be signed in to change notification settings - Fork 32
/
core.py
706 lines (598 loc) · 26.9 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
# This file is part of Zennit
# Copyright (C) 2019-2021 Christopher J. Anders
#
# zennit/core.py
#
# Zennit is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# Zennit is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
# more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library. If not, see <https://www.gnu.org/licenses/>.
'''Core functions and classes'''
import functools
import weakref
from contextlib import contextmanager
import torch
class Stabilizer:
'''Class to create a stabilizer callable.
Parameters
----------
epsilon: float, optional
Value by which to shift/clip elements of ``input``.
clip: bool, optional
If ``False`` (default), add ``epsilon`` multiplied by each entry's sign (+1 for 0). If ``True``, instead clip
the absolute value of ``input`` and multiply it by each entry's original sign.
norm_scale: bool, optional
If ``False`` (default), ``epsilon`` is added to/used to clip ``input``. If ``True``, scale ``epsilon`` by the
square root of the mean over the squared elements of the specified dimensions ``dim``.
dim: tuple[int], optional
If ``norm_scale`` is ``True``, specifies the dimension over which the scaled norm should be computed (all
except dimension 0 by default).
'''
def __init__(self, epsilon=1e-6, clip=False, norm_scale=False, dim=None):
self.epsilon = epsilon
self.clip = clip
self.norm_scale = norm_scale
self.dim = dim
def __call__(self, input):
'''Stabilize input for safe division. This shifts zero-elements by ``+ epsilon``. For the sake of the
*epsilon rule*, this also shifts positive values by ``+ epsilon`` and negative values by ``- epsilon``.
Parameters
----------
input: :py:obj:`torch.Tensor`
Tensor to stabilize.
Returns
-------
:py:obj:`torch.Tensor`
Stabilized ``input``.
'''
return stabilize(input, self.epsilon, self.clip, self.norm_scale, self.dim)
@classmethod
def ensure(cls, value):
'''Given a value, return a stabilizer. If ``value`` is a float, a Stabilizer with that epsilon ``value`` is
returned. If ``value`` is callable, it will be used directly as a stabilizer. Otherwise a TypeError will be
raised.
Parameters
----------
value: float, int, or callable
The value used to produce a valid stabilizer function.
Returns
-------
callable or Stabilizer
A callable to be used as a stabilizer.
Raises
------
TypeError
If no valid stabilizer could be produced from ``value``.
'''
if isinstance(value, (float, int)):
return cls(epsilon=float(value))
if callable(value):
return value
raise TypeError(f'Value {value} is not a valid stabilizer!')
def stabilize(input, epsilon=1e-6, clip=False, norm_scale=False, dim=None):
'''Stabilize input for safe division.
Parameters
----------
input: :py:obj:`torch.Tensor`
Tensor to stabilize.
epsilon: float, optional
Value by which to shift/clip elements of ``input``.
clip: bool, optional
If ``False`` (default), add ``epsilon`` multiplied by each entry's sign (+1 for 0). If ``True``, instead clip
the absolute value of ``input`` and multiply it by each entry's original sign.
norm_scale: bool, optional
If ``False`` (default), ``epsilon`` is added to/used to clip ``input``. If ``True``, scale ``epsilon`` by the
square root of the mean over the squared elements of the specified dimensions ``dim``.
dim: tuple[int], optional
If ``norm_scale`` is ``True``, specifies the dimension over which the scaled norm should be computed. Defaults
to all except dimension 0.
Returns
-------
:py:obj:`torch.Tensor`
New Tensor copied from `input` with values shifted by epsilon.
'''
sign = ((input == 0.).to(input) + input.sign())
if norm_scale:
if dim is None:
dim = tuple(range(1, input.ndim))
epsilon = epsilon * ((input ** 2).mean(dim=dim, keepdim=True) ** .5)
if clip:
return sign * input.abs().clip(min=epsilon)
return input + sign * epsilon
def expand(tensor, shape, cut_batch_dim=False):
'''Expand a scalar value or tensor to a shape. In addition to torch.Tensor.expand, this will also accept
non-torch.tensor objects, which will be used to create a new tensor. If ``tensor`` has fewer dimensions than
``shape``, singleton dimension will be appended to match the size of ``shape`` before expanding.
Parameters
----------
tensor : int, float or :py:obj:`torch.Tensor`
Scalar or tensor to expand to the size of ``shape``.
shape : tuple[int]
Shape to which ``tensor`` will be expanded.
cut_batch_dim : bool, optional
If True, take only the first ``shape[0]`` entries along dimension 0 of the expanded ``tensor``, if it has more
entries in dimension 0 than ``shape``. Default (False) is not to cut, which will instead cause a
``RuntimeError`` due to the size mismatch.
Returns
-------
:py:obj:`torch.Tensor`
A new tensor expanded from ``tensor`` with shape ``shape``.
Raises
------
RuntimeError
If ``tensor`` could not be expanded to ``shape`` due to incompatible shapes.
'''
if not isinstance(tensor, torch.Tensor):
# cast non-tensor scalar to 0-dim tensor
tensor = torch.tensor(tensor)
if tensor.ndim == 0:
# expand scalar tensors
return tensor.expand(shape)
if tensor.ndim < len(shape) and all(left in (1, right) for left, right in zip(tensor.shape, shape)):
# append singleton dimensions if tensor has fewer dimensions, and the existing ones match with shape
tensor = tensor[(...,) + (None,) * (len(shape) - len(tensor.shape))]
if tensor.ndim == len(shape):
# if the dims match completely (lengths match and zipped match), expand normally
if all(left in (1, right) for left, right in zip(tensor.shape, shape)):
return tensor.expand(shape)
# if `cut_batch_dim` and dims match except first, which is larger than shape, the the first dim and expand
if (
cut_batch_dim
and all(left in (1, right) for left, right in zip(tensor.shape[1:], shape[1:]))
and tensor.shape[0] > shape[0]
):
return tensor[:shape[0]].expand(shape)
raise RuntimeError(f'Invalid shape! Target: {tensor.shape}; Source: {shape}')
def zero_wrap(zero_params):
'''Create a function wrapper factory (i.e. a decorator), which takes a single function argument ``(name, param) ->
tensor`` such that the function is only called if name is not equal to zero_params, if zero_params is a string, or
it is not in zero_params. Otherwise return `torch.zeros_like` of that tensor.
Parameters
----------
zero_params: str or list[str]
String or list of strings compared to `name`.
Returns
-------
function
The function wrapper to be called on the function.
'''
def zero_params_wrapper(modifier):
'''Wrap a function (name, param) -> tensor such that the function is only called if name is not equal to the
closure list zero_params, if zero_params is a string, or it is not in zero_params. Otherwise return
`torch.zeros_like` of that tensor.
Parameters
----------
func: function
Function to wrap.
Returns
-------
function
The wrapped function.
'''
if not zero_params:
return modifier
@functools.wraps(modifier)
def modifier_wrapper(input, name):
'''Wrapped function (name, param) -> tensor, where the original function is only called if name is not
equal to the closure list zero_params, if zero_params is a string, or it is not in zero_params. Otherwise
return `torch.zeros_like` of that tensor.
Parameters
----------
input: :py:obj:`torch.Tensor`
The input tensor modified by the original function.
name: str
The name associated with the input tensor (e.g. the parameter name).
Returns
-------
:py:obj:`torch.Tensor`
The modified tensor.
'''
if isinstance(zero_params, str) and name == zero_params or name in zero_params:
return torch.zeros_like(input)
return modifier(input, name)
return modifier_wrapper
return zero_params_wrapper
class ParamMod:
'''Class to produce a context manager to temporarily modify parameter attributes (all by default) of a module.
Parameters
----------
modifier: function
A function used to modify parameter attributes. If `param_keys` is empty, this is not used.
param_keys: list[str], optional
A list of parameter names that shall be modified. If `None` (default), all parameters are modified (which may
be none). If `[]`, no parameters are modified and `modifier` is ignored.
zero_params: list[str], optional
A list of parameter names that shall set to zero. If `None` (default), no parameters are set to zero.
require_params: bool, optional
Whether existence of `module`'s params is mandatory (True by default). If the attribute exists but is `None`,
it is not considered missing, and the modifier is not applied.
'''
def __init__(self, modifier, param_keys=None, zero_params=None, require_params=True):
self.modifier = modifier
self.param_keys = param_keys
self.zero_params = zero_params
self.require_params = require_params
@classmethod
def ensure(cls, modifier):
'''If ``modifier`` is an instance of ParamMod, return it as-is, if it is callable, create a new instance with
``modifier`` as the ParamMod's function, otherwise raise a TypeError.
Parameters
----------
modifier : :py:obj:`ParamMod` or callable
The modifier which, if necessary, will be used to construct a ParamMod.
Returns
-------
:py:obj:`ParamMod`
Either ``modifier`` as is, or a :py:obj:`ParamMod` constructed using ``modifier``.
Raises
------
TypeError
If ``modifier`` is neither an instance of :py:obj:`ParamMod`, nor callable.
'''
if isinstance(modifier, cls):
return modifier
if callable(modifier):
return cls(modifier)
raise TypeError(f'{modifier} is neither an instance of {cls}, nor callable!')
@contextmanager
def __call__(self, module):
'''Context manager to temporarily modify parameter attributes (all by default) of a module.
Parameters
----------
module: :py:obj:`torch.nn.Module`
Module of which to modify parameters. If `self.requires_params` is `True`, it must have all elements given
in `self.param_keys` as attributes (attributes are allowed to be `None`, in which case they are ignored).
Raises
------
RuntimeError
If `self.require_params` is `True` and `module` is missing an attribute listed in `self.param_keys`.
Yields
------
module: :py:obj:`torch.nn.Module`
The `module` with appropriate parameters temporarily modified.
'''
try:
stored_params = {}
param_keys = self.param_keys
zero_params = self.zero_params
if param_keys is None:
param_keys = [name for name, _ in module.named_parameters(recurse=False)]
if zero_params is None:
zero_params = []
missing = [key for key in param_keys if not hasattr(module, key)]
if self.require_params and missing:
missing_str = '\', \''.join(missing)
raise RuntimeError(f'Module {module} requires missing parameters: \'{missing_str}\'')
modifier = zero_wrap(zero_params)(self.modifier)
for key in param_keys:
if key not in missing:
param = getattr(module, key)
if param is not None:
stored_params[key] = param
object.__setattr__(module, key, modifier(param.data, key))
yield module
finally:
for key, value in stored_params.items():
object.__setattr__(module, key, value)
def collect_leaves(module):
'''Generator function to collect all leaf modules of a module.
Parameters
----------
module: :py:obj:`torch.nn.Module`
A module for which the leaves will be collected.
Yields
------
leaf: :py:obj:`torch.nn.Module`
Either a leaf of the module structure, or the module itself if it has no children.
'''
is_leaf = True
children = module.children()
for child in children:
is_leaf = False
yield from collect_leaves(child)
if is_leaf: # pragma: no branch
yield module
class Identity(torch.autograd.Function):
'''Identity to add a grad_fn to a tensor, so a backward hook can be applied.'''
@staticmethod
def forward(ctx, *inputs):
'''Forward identity.'''
return inputs
@staticmethod
def backward(ctx, *grad_outputs):
'''Backward identity.'''
return grad_outputs
class Hook:
'''Base class for hooks to be used to compute layer-wise attributions.'''
def __init__(self):
self.stored_tensors = {}
self.active = True
self.tensor_handles = RemovableHandleList()
def pre_forward(self, module, input):
'''Apply an Identity to the input before the module to register a backward hook.'''
hook_ref = weakref.ref(self)
@functools.wraps(self.backward)
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
return None
if not isinstance(input, tuple):
input = (input,)
# only if gradient required
if input[0].requires_grad:
# add identity to ensure .grad_fn exists
post_input = Identity.apply(*input)
# register the input tensor gradient hook
self.tensor_handles.append(
post_input[0].grad_fn.register_hook(wrapper)
)
# work around to support in-place operations
post_input = tuple(elem.clone() for elem in post_input)
else:
# no gradient required
post_input = input
return post_input[0] if len(post_input) == 1 else post_input
def post_forward(self, module, input, output):
'''Register a backward-hook to the resulting tensor right after the forward.'''
hook_ref = weakref.ref(self)
@functools.wraps(self.pre_backward)
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.pre_backward(module, grad_input, grad_output)
return None
if not isinstance(output, tuple):
output = (output,)
# only if gradient required
if output[0].grad_fn is not None:
# register the output tensor gradient hook
self.tensor_handles.append(
output[0].grad_fn.register_hook(wrapper)
)
return output[0] if len(output) == 1 else output
def pre_backward(self, module, grad_input, grad_output):
'''Store the grad_output for the backward hook'''
self.stored_tensors['grad_output'] = grad_output
def forward(self, module, input, output):
'''Hook applied during forward-pass'''
def backward(self, module, grad_input, grad_output):
'''Hook applied during backward-pass'''
def copy(self):
'''Return a copy of this hook.
This is used to describe hooks of different modules by a single hook instance.
'''
return self.__class__()
def remove(self):
'''When removing hooks, remove all references to stored tensors'''
self.stored_tensors.clear()
self.tensor_handles.remove()
def register(self, module):
'''Register this instance by registering all hooks to the supplied module.'''
return RemovableHandleList([
RemovableHandle(self),
module.register_forward_pre_hook(self.pre_forward),
module.register_forward_hook(self.post_forward),
module.register_forward_hook(self.forward),
])
class BasicHook(Hook):
'''A hook to compute the layer-wise attribution of the module it is attached to.
A BasicHook instance may only be registered with a single module.
Parameters
----------
input_modifiers: list[callable], optional
A list of functions ``(input: torch.Tensor) -> torch.Tensor`` to produce multiple inputs. Default is a single
input which is the identity.
param_modifiers: list[:py:obj:`~zennit.core.ParamMod` or callable], optional
A list of ParamMod instances or functions ``(obj: torch.Tensor, name: str) -> torch.Tensor``, with parameter
tensor ``obj``, registered in the root model as ``name``, to temporarily modify the parameters of the attached
module for each input produced with `input_modifiers`. Default is unmodified parameters for each input. Use a
:py:obj:`~zennit.core.ParamMod` instance to specify which parameters should be modified, whether they are
required, and which should be set to zero.
output_modifiers: list[callable], optional
A list of functions ``(input: torch.Tensor) -> torch.Tensor`` to modify the module's output computed using the
modified parameters before gradient computation for each input produced with `input_modifier`. Default is the
identity for each output.
gradient_mapper: callable, optional
Function ``(out_grad: torch.Tensor, outputs: list[torch.Tensor]) -> list[torch.Tensor]`` to modify upper
relevance. A list or tuple of the same size as ``outputs`` is expected to be returned. ``outputs`` has the same
size as ``input_modifiers`` and ``param_modifiers``. Default is a stabilized normalization by each of the
outputs, multiplied with the output gradient.
reducer: callable
Function ``(inputs: list[torch.Tensor], gradients: list[torch.Tensor]) -> torch.Tensor`` to reduce all the
inputs and gradients produced through ``input_modifiers`` and ``param_modifiers``. ``inputs`` and ``gradients``
have the same as ``input_modifiers`` and ``param_modifiers``. Default is the sum of the multiplications of each
input and its corresponding gradient.
'''
def __init__(
self,
input_modifiers=None,
param_modifiers=None,
output_modifiers=None,
gradient_mapper=None,
reducer=None,
stabilizer=1e-6,
):
super().__init__()
modifiers = {
'in': input_modifiers,
'param': param_modifiers,
'out': output_modifiers,
}
supplied = {key for key, val in modifiers.items() if val is not None}
num_mods = len(modifiers[next(iter(supplied))]) if supplied else 1
modifiers.update({key: (self._default_modifier,) * num_mods for key in set(modifiers) - supplied})
if gradient_mapper is None:
gradient_mapper = self._default_gradient_mapper
if reducer is None:
reducer = self._default_reducer
self.input_modifiers = modifiers['in']
self.param_modifiers = modifiers['param']
self.output_modifiers = modifiers['out']
self.gradient_mapper = gradient_mapper
self.reducer = reducer
def forward(self, module, input, output):
'''Forward hook to save module in-/outputs.'''
self.stored_tensors['input'] = input
def backward(self, module, grad_input, grad_output):
'''Backward hook to compute LRP based on the class attributes.'''
original_input = self.stored_tensors['input'][0].clone()
inputs = []
outputs = []
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
input = in_mod(original_input).requires_grad_()
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
output = modified.forward(input)
output = out_mod(output)
inputs.append(input)
outputs.append(output)
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
gradients = torch.autograd.grad(
outputs,
inputs,
grad_outputs=grad_outputs,
create_graph=grad_output[0].requires_grad
)
relevance = self.reducer(inputs, gradients)
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
def copy(self):
'''Return a copy of this hook.
This is used to describe hooks of different modules by a single hook instance.
'''
copy = BasicHook.__new__(type(self))
BasicHook.__init__(
copy,
self.input_modifiers,
self.param_modifiers,
self.output_modifiers,
self.gradient_mapper,
self.reducer,
)
return copy
@staticmethod
def _default_modifier(obj, name=None):
return obj
@staticmethod
def _default_gradient_mapper(out_grad, outputs):
return tuple(out_grad / stabilize(output) for output in outputs)
@staticmethod
def _default_reducer(inputs, gradients):
return sum(input * gradient for input, gradient in zip(inputs, gradients))
class RemovableHandle:
'''Create weak reference to call .remove on some instance.'''
def __init__(self, instance):
self.instance_ref = weakref.ref(instance)
def remove(self):
'''Call remove on weakly reference instance if it still exists.'''
instance = self.instance_ref()
if instance is not None:
instance.remove()
class RemovableHandleList(list):
'''A list to hold handles, with the ability to call remove on all of its members.'''
def remove(self):
'''Call remove on all members, effectively removing handles from modules, or reverting canonizers.'''
for handle in self:
handle.remove()
self.clear()
class CompositeContext:
'''A context object to register a composite in a context and remove the associated hooks and canonizers afterwards.
Parameters
----------
module: :py:class:`torch.nn.Module`
The module to which `composite` should be registered.
composite: :py:class:`zennit.core.Composite`
The composite which shall be registered to `module`.
'''
def __init__(self, module, composite):
self.module = module
self.composite = composite
def __enter__(self):
self.composite.register(self.module)
return self.module
def __exit__(self, exc_type, exc_value, traceback):
self.composite.remove()
return False
class Composite:
'''A Composite to apply canonizers and register hooks to modules.
One Composite instance may only be applied to a single module at a time.
Parameters
----------
module_map: callable, optional
A function ``(ctx: dict, name: str, module: torch.nn.Module) -> Hook or None`` which maps a context, name and
module to a matching :py:class:`~zennit.core.Hook`, or ``None`` if there is no matchin
:py:class:`~zennit.core.Hook`.
canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional
List of canonizer instances to be applied before applying hooks.
'''
def __init__(self, module_map=None, canonizers=None):
if module_map is None:
module_map = self._empty_module_map
if canonizers is None:
canonizers = []
self.module_map = module_map
self.canonizers = canonizers
self.handles = RemovableHandleList()
self.hook_refs = weakref.WeakSet()
def register(self, module):
'''Apply all canonizers and register all hooks to a module (and its recursive children).
Previous canonizers of this composite are reverted and all hooks registered by this composite are removed.
The module or any of its children (recursively) may still have other hooks attached.
Parameters
----------
module: :py:class:`torch.nn.Module`
Hooks and canonizers will be applied to this module recursively according to ``module_map`` and
``canonizers``.
'''
self.remove()
for canonizer in self.canonizers:
self.handles += canonizer.apply(module)
ctx = {}
for name, child in module.named_modules():
template = self.module_map(ctx, name, child)
if template is not None:
hook = template.copy()
self.hook_refs.add(hook)
self.handles.append(hook.register(child))
def remove(self):
'''Remove all handles for hooks and canonizers.
Hooks will simply be removed from their corresponding Modules.
Canonizers will revert the state of the modules they changed.
'''
self.handles.remove()
self.hook_refs.clear()
def context(self, module):
'''Return a CompositeContext object with this instance and the supplied module.
Parameters
----------
module: :py:class:`torch.nn.Module`
Module for which to register this composite in the context.
Returns
-------
:py:class:`zennit.core.CompositeContext`
A context object which registers the composite to ``module`` on entering, and removes it on exiting.
'''
return CompositeContext(module, self)
@contextmanager
def inactive(self):
'''Context manager to temporarily deactivate the gradient modification. This can be used to compute the
gradient of the modified gradient.
'''
try:
for hook in self.hook_refs:
hook.active = False
yield self
finally:
for hook in self.hook_refs:
hook.active = True
@staticmethod
def _empty_module_map(ctx, name, module):
'''Empty module_map, does not assign any rules.'''
return None