forked from hashicorp/vault
-
Notifications
You must be signed in to change notification settings - Fork 0
/
shamir.go
246 lines (212 loc) · 6.36 KB
/
shamir.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package shamir
import (
"crypto/rand"
"crypto/subtle"
"fmt"
mathrand "math/rand"
"time"
)
const (
// ShareOverhead is the byte size overhead of each share
// when using Split on a secret. This is caused by appending
// a one byte tag to the share.
ShareOverhead = 1
)
// polynomial represents a polynomial of arbitrary degree
type polynomial struct {
coefficients []uint8
}
// makePolynomial constructs a random polynomial of the given
// degree but with the provided intercept value.
func makePolynomial(intercept, degree uint8) (polynomial, error) {
// Create a wrapper
p := polynomial{
coefficients: make([]byte, degree+1),
}
// Ensure the intercept is set
p.coefficients[0] = intercept
// Assign random co-efficients to the polynomial
if _, err := rand.Read(p.coefficients[1:]); err != nil {
return p, err
}
return p, nil
}
// evaluate returns the value of the polynomial for the given x
func (p *polynomial) evaluate(x uint8) uint8 {
// Special case the origin
if x == 0 {
return p.coefficients[0]
}
// Compute the polynomial value using Horner's method.
degree := len(p.coefficients) - 1
out := p.coefficients[degree]
for i := degree - 1; i >= 0; i-- {
coeff := p.coefficients[i]
out = add(mult(out, x), coeff)
}
return out
}
// interpolatePolynomial takes N sample points and returns
// the value at a given x using a lagrange interpolation.
func interpolatePolynomial(x_samples, y_samples []uint8, x uint8) uint8 {
limit := len(x_samples)
var result, basis uint8
for i := 0; i < limit; i++ {
basis = 1
for j := 0; j < limit; j++ {
if i == j {
continue
}
num := add(x, x_samples[j])
denom := add(x_samples[i], x_samples[j])
term := div(num, denom)
basis = mult(basis, term)
}
group := mult(y_samples[i], basis)
result = add(result, group)
}
return result
}
// div divides two numbers in GF(2^8)
func div(a, b uint8) uint8 {
if b == 0 {
// leaks some timing information but we don't care anyways as this
// should never happen, hence the panic
panic("divide by zero")
}
ret := int(mult(a, inverse(b)))
// Ensure we return zero if a is zero but aren't subject to timing attacks
ret = subtle.ConstantTimeSelect(subtle.ConstantTimeByteEq(a, 0), 0, ret)
return uint8(ret)
}
// inverse calculates the inverse of a number in GF(2^8)
func inverse(a uint8) uint8 {
b := mult(a, a)
c := mult(a, b)
b = mult(c, c)
b = mult(b, b)
c = mult(b, c)
b = mult(b, b)
b = mult(b, b)
b = mult(b, c)
b = mult(b, b)
b = mult(a, b)
return mult(b, b)
}
// mult multiplies two numbers in GF(2^8)
func mult(a, b uint8) (out uint8) {
var r uint8 = 0
var i uint8 = 8
for i > 0 {
i--
r = (-(b >> i & 1) & a) ^ (-(r >> 7) & 0x1B) ^ (r + r)
}
return r
}
// add combines two numbers in GF(2^8)
// This can also be used for subtraction since it is symmetric.
func add(a, b uint8) uint8 {
return a ^ b
}
// Split takes an arbitrarily long secret and generates a `parts`
// number of shares, `threshold` of which are required to reconstruct
// the secret. The parts and threshold must be at least 2, and less
// than 256. The returned shares are each one byte longer than the secret
// as they attach a tag used to reconstruct the secret.
func Split(secret []byte, parts, threshold int) ([][]byte, error) {
// Sanity check the input
if parts < threshold {
return nil, fmt.Errorf("parts cannot be less than threshold")
}
if parts > 255 {
return nil, fmt.Errorf("parts cannot exceed 255")
}
if threshold < 2 {
return nil, fmt.Errorf("threshold must be at least 2")
}
if threshold > 255 {
return nil, fmt.Errorf("threshold cannot exceed 255")
}
if len(secret) == 0 {
return nil, fmt.Errorf("cannot split an empty secret")
}
// Generate random list of x coordinates
mathrand.Seed(time.Now().UnixNano())
xCoordinates := mathrand.Perm(255)
// Allocate the output array, initialize the final byte
// of the output with the offset. The representation of each
// output is {y1, y2, .., yN, x}.
out := make([][]byte, parts)
for idx := range out {
out[idx] = make([]byte, len(secret)+1)
out[idx][len(secret)] = uint8(xCoordinates[idx]) + 1
}
// Construct a random polynomial for each byte of the secret.
// Because we are using a field of size 256, we can only represent
// a single byte as the intercept of the polynomial, so we must
// use a new polynomial for each byte.
for idx, val := range secret {
p, err := makePolynomial(val, uint8(threshold-1))
if err != nil {
return nil, fmt.Errorf("failed to generate polynomial: %w", err)
}
// Generate a `parts` number of (x,y) pairs
// We cheat by encoding the x value once as the final index,
// so that it only needs to be stored once.
for i := 0; i < parts; i++ {
x := uint8(xCoordinates[i]) + 1
y := p.evaluate(x)
out[i][idx] = y
}
}
// Return the encoded secrets
return out, nil
}
// Combine is used to reverse a Split and reconstruct a secret
// once a `threshold` number of parts are available.
func Combine(parts [][]byte) ([]byte, error) {
// Verify enough parts provided
if len(parts) < 2 {
return nil, fmt.Errorf("less than two parts cannot be used to reconstruct the secret")
}
// Verify the parts are all the same length
firstPartLen := len(parts[0])
if firstPartLen < 2 {
return nil, fmt.Errorf("parts must be at least two bytes")
}
for i := 1; i < len(parts); i++ {
if len(parts[i]) != firstPartLen {
return nil, fmt.Errorf("all parts must be the same length")
}
}
// Create a buffer to store the reconstructed secret
secret := make([]byte, firstPartLen-1)
// Buffer to store the samples
x_samples := make([]uint8, len(parts))
y_samples := make([]uint8, len(parts))
// Set the x value for each sample and ensure no x_sample values are the same,
// otherwise div() can be unhappy
checkMap := map[byte]bool{}
for i, part := range parts {
samp := part[firstPartLen-1]
if exists := checkMap[samp]; exists {
return nil, fmt.Errorf("duplicate part detected")
}
checkMap[samp] = true
x_samples[i] = samp
}
// Reconstruct each byte
for idx := range secret {
// Set the y value for each sample
for i, part := range parts {
y_samples[i] = part[idx]
}
// Interpolate the polynomial and compute the value at 0
val := interpolatePolynomial(x_samples, y_samples, 0)
// Evaluate the 0th value to get the intercept
secret[idx] = val
}
return secret, nil
}