-
-
Notifications
You must be signed in to change notification settings - Fork 436
/
Copy pathconst.go
134 lines (119 loc) · 4.2 KB
/
const.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package gorgonia
import (
"math"
"gorgonia.org/tensor"
)
const (
// graphviz name for a full graph
fullGraphName = "fullGraph"
// group names
exprgraphClust = "expressionGraph"
constantsClust = "constants"
inputsClust = "inputs"
gradClust = "gradients"
strayClust = "undifferentiated nodes"
// subgraphs to rank the same
outsideSubG = "outsides"
inputConsts = "inputConsts"
// special nodes for graphviz hacking
outsideRoot = "outsideRoot"
outsideInputs = "outsideInputs"
insideInputs = "insideInputs"
outsideConsts = "outsideConsts"
insideConsts = "insideConsts"
outsideExprG = "outsideExprG"
insideExprG = "insideExprG"
outsideGrads = "outsideGrads"
insideGrads = "insideGrads"
// error messages
sortFail = "Failed to sort"
cloneFail = "Failed to carry clone(%v)"
clone0Fail = "Failed to carry clone0()"
nyiTypeFail = "%s not yet implemented for %T"
nyiFail = "%s not yet implemented for %v"
dtypeOfFail = "Failed to carry dtypeOf()"
mulFail = "Failed to carry Mul()"
applyOpFail = "Failed to carryApplyOp()"
opDoFail = "Failed to carry op.Do()"
binOpDoFail = "Failed to carry binOp.Do()"
binOpNodeFail = "Failed to carry binary operation %T"
applyFail = "Failed to carry Apply()"
binOpFail = "Binary operator received %d arguments"
hadamardProdFail = "Failed to carry hadamardProd()"
hadamardDivFail = "Failed to carry hadamardDiv()"
cubeFail = "Failed to carry cube()"
negFail = "Failed to carry Neg()"
invFail = "Failed to carry Inv()"
pointWiseMulFail = "Failed to carry PointWiseMul()"
pointWiseSquareFail = "Failed to carry PointWiseSquare()"
clampFail = "Failed to carry Clamp()"
invSqrtFail = "Failed to carry InvSqrt()"
subFail = "Failed to carry Sub()"
addFail = "Failed to carry Add()"
signFail = "Failed to carry Sign()"
softplusFail = "Failed to carry Softplus()"
incrErr = "increment couldn't be done. Safe op was performed instead"
bindFail = "Failed to bind"
anyToValueFail = "Failed to convert %v(%T) into a Value"
dtypeExtractionFail = "Failed to extract dtype from %v"
operationError = "Operation failed"
doFail = "Doing %v failed"
unsafeDoFail = "UnsafeDoing %v failed."
tFail = "Failed to transpose Tensor"
repFail = "Failed to repeat Tensor along %d %d times"
reshapeFail = "Failed to reshape Tensor into %v. DataSize was: %d"
sliceFail = "Failed to slice Tensor with %v"
execFail = "Failed to execute %v in node %v"
autodiffFail = "Failed to differentiate %v"
undefinedOnShape = "%v undefined on shape %v"
unsupportedDtype = "dtype %v is not yet supported"
gradOnDeviceFail = "Cannot get gradient of %v on %v"
makeValueFail = "Unable to make value of %v with shape %v"
allocFail = "Unable to allocate %v bytes on %v"
shapeMismatchErr = "Shape Mismatch. Expected %v. Got %v instead."
)
var empty struct{}
var (
onef32 = NewConstant(float32(1.0))
onef64 = NewConstant(float64(1.0))
oneMoref32 = NewConstant(float32(1.0 + 1e-16))
oneMoref64 = NewConstant(float64(1.0 + 1e-16))
zerof32 = NewConstant(float32(0.0))
zerof64 = NewConstant(float64(0.0))
twof64 = NewConstant(float64(2.0))
twof32 = NewConstant(float32(2.0))
threef64 = NewConstant(float64(3.0))
threef32 = NewConstant(float32(3.0))
ln2f64 = NewConstant(math.Ln2)
ln2f32 = NewConstant(float32(math.Ln2))
onef32ConstOp = onef32.op.(constant)
onef64ConstOp = onef64.op.(constant)
zerof32ConstOp = zerof32.op.(constant)
zerof64ConstOp = zerof64.op.(constant)
constmap map[string]map[tensor.Dtype]*Node
)
var oneone = tensor.Shape{1, 1}
func init() {
constmap = map[string]map[tensor.Dtype]*Node{
"zero": {
Float32: zerof32,
Float64: zerof64,
},
"one": {
Float32: onef32,
Float64: onef64,
},
"two": {
Float32: twof32,
Float64: twof64,
},
"three": {
Float32: threef32,
Float64: threef64,
},
"log2": {
Float32: ln2f32,
Float64: ln2f64,
},
}
}