forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinitializers.py
153 lines (125 loc) · 5.25 KB
/
initializers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from caffe2.python.core import DataType, BlobReference, ScopedBlobReference
from caffe2.python.modeling.parameter_info import ParameterInfo
class Initializer(object):
'''
This class abstracts out parameter creation. One can come up with a new
Initializer in order to implement more complex parameter initialization logic
'''
def __init__(self, operator_name=None, **kwargs):
self.operator_name = operator_name
self.operator_kwargs = kwargs
def update(self, operator_name, kwargs):
if self.operator_name is not None:
raise Exception("Operator name overwrites are not allowed")
self.operator_name = operator_name
self.operator_kwargs = kwargs
def create_param(self, param_name, init_net, shape):
param = init_net.__getattr__(self.operator_name)(
[], param_name, shape=shape, **self.operator_kwargs)
return ParameterInfo(
param_id=None,
param=param,
shape=shape,
)
class ExternalInitializer(object):
'''
This class is used in cases when the parameter should not be initialized by
the initializer, but rather provided in the workspace when param_init_net is
executed.
Current version is not doing any real sanity checks to the parameter.
'''
def create_param(self, param_name, init_net, shape):
if isinstance(param_name, BlobReference):
param = BlobReference(str(param_name), init_net)
elif isinstance(param_name, str):
param = ScopedBlobReference(param_name, init_net)
else:
raise TypeError("Unsupported type for param_name")
# TODO(amalevich): Add operator that will check param in the workspace
return ParameterInfo(
param_id=None,
param=param,
shape=shape,
)
class PseudoFP16Initializer(Initializer):
'''
Used in cases when the parameter should be used at half (16-bit) precision
for compute purposes (i.e. on the forward and backward pass) but
needs to be stored and optimized at single (32-bit) precision so tiny
gradients with small learning rates don't underflow FP16 precision.
A 32-bit copy of the 16-bit blob is stored in the ParameterInfo.
This is helpful for mixed-precision training, see
https://arxiv.org/abs/1710.03740 for details.
'''
def update(self, operator_name, kwargs):
if self.operator_name is not None:
raise Exception("Operator name overwrites are not allowed")
self.operator_name = operator_name
self.operator_kwargs = kwargs
def create_param(self, param_name, init_net, shape):
# create master fp32 copy
param_fp32 = init_net.__getattr__(self.operator_name)(
[], param_name + "_fp32", shape=shape,
**self.operator_kwargs)
# cast to fp16 copy
param = init_net.FloatToHalf(
param_fp32, param_name)
return ParameterInfo(
param_id=None,
param=param,
shape=shape,
blob_copy={DataType.FLOAT: param_fp32}
)
class ReversePseudoFP16Initializer(Initializer):
'''
Like PseudoFP16Initializer above, except the primary blob is taken to
be the 32-bit precision parameter, and the 16-bit version of the blob
is stored in blob_copy instead.
'''
def update(self, operator_name, kwargs):
if self.operator_name is not None:
raise Exception("Operator name overwrites are not allowed")
self.operator_name = operator_name
self.operator_kwargs = kwargs
def create_param(self, param_name, init_net, shape):
# create master fp32 copy
param_fp32 = init_net.__getattr__(self.operator_name)(
[], param_name, shape=shape,
**self.operator_kwargs)
# cast to fp16 copy
param_fp16 = init_net.FloatToHalf(
param_fp32, param_name + "_fp16")
return ParameterInfo(
param_id=None,
param=param_fp32,
shape=shape,
blob_copy={DataType.FLOAT16: param_fp16}
)
def update_initializer(initializer_class,
operator_name_and_kwargs,
default_operator_name_and_kwargs):
'''
A helper function to convert from operator_name_and_kwargs to new
object of type initializer_class. This function serves two purposes:
1. Support for custom initialization operators being passed in
2. Allow user to specify a custom Initializer without overwriting
default operators used for initialization
If initializer_class is None, creates a default initializer using
the Initializer class and operator_name_and_kwargs provided
If operator_name_and_kwargs is None, uses default_operator_name_and_kwargs
returns an instantiated Initializer object
'''
def get_initializer_args():
return (
operator_name_and_kwargs or
default_operator_name_and_kwargs
)
if initializer_class is not None:
init = initializer_class(get_initializer_args()[0],
**get_initializer_args()[1])
else:
init = Initializer(
get_initializer_args()[0],
**get_initializer_args()[1]
)
return init