-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathapi.go
173 lines (153 loc) · 5.36 KB
/
api.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
/*
Package api contains the API server.
HTTP is the only transport supported at the moment.
The design package is the Goa design package while the gen package contains all
the generated code produced with goa gen.
*/
package api
import (
"context"
"encoding/json"
"net/http"
"net/url"
"os"
"time"
"unicode/utf8"
"github.com/go-logr/logr"
"github.com/gorilla/websocket"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/trace"
goahttp "goa.design/goa/v3/http"
goahttpmwr "goa.design/goa/v3/http/middleware"
"goa.design/goa/v3/middleware"
"github.com/artefactual-labs/enduro/internal/api/gen/batch"
"github.com/artefactual-labs/enduro/internal/api/gen/collection"
batchsvr "github.com/artefactual-labs/enduro/internal/api/gen/http/batch/server"
collectionsvr "github.com/artefactual-labs/enduro/internal/api/gen/http/collection/server"
pipelinesvr "github.com/artefactual-labs/enduro/internal/api/gen/http/pipeline/server"
swaggersvr "github.com/artefactual-labs/enduro/internal/api/gen/http/swagger/server"
"github.com/artefactual-labs/enduro/internal/api/gen/pipeline"
intbatch "github.com/artefactual-labs/enduro/internal/batch"
intcol "github.com/artefactual-labs/enduro/internal/collection"
intpipe "github.com/artefactual-labs/enduro/internal/pipeline"
"github.com/artefactual-labs/enduro/ui"
)
func HTTPServer(
logger logr.Logger,
tp trace.TracerProvider,
config *Config,
pipesvc intpipe.Service,
batchsvc intbatch.Service,
colsvc intcol.Service,
) *http.Server {
dec := goahttp.RequestDecoder
enc := goahttp.ResponseEncoder
var mux goahttp.Muxer = goahttp.NewMuxer()
websocketUpgrader := &websocket.Upgrader{
HandshakeTimeout: time.Second,
CheckOrigin: sameOriginChecker(logger),
}
// Pipeline service.
var pipelineEndpoints *pipeline.Endpoints = pipeline.NewEndpoints(pipesvc)
pipelineErrorHandler := errorHandler(logger, "Pipeline error.")
var pipelineServer *pipelinesvr.Server = pipelinesvr.New(pipelineEndpoints, mux, dec, enc, pipelineErrorHandler, nil)
pipelinesvr.Mount(mux, pipelineServer)
// Batch service.
var batchEndpoints *batch.Endpoints = batch.NewEndpoints(batchsvc)
batchErrorHandler := errorHandler(logger, "Batch error.")
var batchServer *batchsvr.Server = batchsvr.New(batchEndpoints, mux, dec, enc, batchErrorHandler, nil)
batchsvr.Mount(mux, batchServer)
// Collection service.
var collectionEndpoints *collection.Endpoints = collection.NewEndpoints(colsvc.Goa())
collectionErrorHandler := errorHandler(logger, "Collection error.")
var collectionServer *collectionsvr.Server = collectionsvr.New(collectionEndpoints, mux, dec, enc, collectionErrorHandler, nil, websocketUpgrader, nil)
collectionServer.Download = writeTimeout(collectionServer.Download, 0)
collectionsvr.Mount(mux, collectionServer)
// Swagger service.
var swaggerService *swaggersvr.Server = swaggersvr.New(nil, nil, nil, nil, nil, nil, nil)
swaggersvr.Mount(mux, swaggerService)
// Web handler.
web := ui.SPAHandler()
mux.Handle("GET", "/", web)
mux.Handle("GET", "/{*filename}", web)
// Global middlewares.
var handler http.Handler = mux
handler = otelhttp.NewHandler(handler, "enduro/internal/api", otelhttp.WithTracerProvider(tp))
handler = goahttpmwr.RequestID()(handler)
handler = versionHeaderMiddleware(config.AppVersion)(handler)
if config.Debug {
handler = goahttpmwr.Log(loggerAdapter(logger))(handler)
handler = goahttpmwr.Debug(mux, os.Stdout)(handler)
}
return &http.Server{
Addr: config.Listen,
Handler: handler,
ReadTimeout: time.Second * 5,
WriteTimeout: time.Second * 5,
IdleTimeout: time.Second * 120,
}
}
type errorMessage struct {
RequestID string
Error error
}
// errorHandler returns a function that writes and logs the given error.
// The function also writes and logs the error unique ID so that it's possible
// to correlate.
func errorHandler(logger logr.Logger, msg string) func(context.Context, http.ResponseWriter, error) {
return func(ctx context.Context, w http.ResponseWriter, err error) {
reqID, ok := ctx.Value(middleware.RequestIDKey).(string)
if !ok {
reqID = "unknown"
}
// Only write the error if the connection is not hijacked.
var ws bool
if _, err := w.Write(nil); err == http.ErrHijacked {
ws = true
} else {
_ = json.NewEncoder(w).Encode(&errorMessage{RequestID: reqID})
}
logger.Error(err, "Service error.", "reqID", reqID, "ws", ws)
}
}
func sameOriginChecker(logger logr.Logger) func(r *http.Request) bool {
return func(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
logger.V(1).Info("WebSocket client rejected (origin parse error)", "err", err)
return false
}
eq := equalASCIIFold(u.Host, r.Host)
if !eq {
logger.V(1).Info("WebSocket client rejected (origin and host not equal)", "origin-host", u.Host, "request-host", r.Host)
}
return eq
}
}
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}