-
Notifications
You must be signed in to change notification settings - Fork 1
/
client.go
84 lines (74 loc) · 1.96 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package bifrost
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net/http"
"sync/atomic"
"time"
)
// HTTPClient returns a http.Client set up for TLS Client Authentication (mTLS).
// The client will request a new certificate from the bifrost caUrl when needed.
// If roots is not nil, then only those Root CAs are used to authenticate server certs.
// If ssllog is not nil, the client will log TLS key material to it.
func HTTPClient(
caUrl string,
privkey *PrivateKey,
roots *x509.CertPool,
ssllog io.Writer,
) (*http.Client, error) {
cr := &certRefresher{
url: caUrl,
privkey: privkey,
}
if _, err := cr.GetClientCertificate(nil); err != nil {
return nil, err
}
tlsConfig := &tls.Config{
GetClientCertificate: cr.GetClientCertificate,
RootCAs: roots,
KeyLogWriter: ssllog,
}
tlsTransport := http.DefaultTransport.(*http.Transport).Clone()
tlsTransport.TLSClientConfig = tlsConfig
return &http.Client{
Transport: tlsTransport,
}, nil
}
type certRefresher struct {
url string
privkey *PrivateKey
cert atomic.Pointer[Certificate]
}
func (cr *certRefresher) GetClientCertificate(
info *tls.CertificateRequestInfo,
) (*tls.Certificate, error) {
ctx := context.Background()
if info != nil {
ctx = info.Context()
}
// If the certificate is nil or is going to expire soon, request a new one.
if cert := cr.cert.Load(); cert == nil ||
cert.NotAfter.Before(time.Now().Add(-time.Minute*10)) {
Logger().DebugContext(ctx, "refreshing client certificate")
cert, err := RequestCertificate(ctx, cr.url, cr.privkey)
if err != nil {
return nil, err
}
for {
oldCert := cr.cert.Load()
if cr.cert.CompareAndSwap(oldCert, cert) {
break
}
}
Logger().InfoContext(ctx, "got new client certificate")
}
tlsCert := X509ToTLSCertificate(cr.cert.Load().Certificate, cr.privkey.PrivateKey)
if info != nil {
if err := info.SupportsCertificate(tlsCert); err != nil {
return nil, err
}
}
return tlsCert, nil
}