Skip to content

Commit

Permalink
Merge pull request moby#45314 from corhere/graceful-shutdown
Browse files Browse the repository at this point in the history
cmd/dockerd: gracefully shut down the API server
  • Loading branch information
thaJeztah authored Apr 28, 2023
2 parents dffad6b + 12bf850 commit e22758b
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 161 deletions.
83 changes: 3 additions & 80 deletions api/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package server // import "github.com/docker/docker/api/server"

import (
"context"
"net"
"net/http"
"strings"
"time"

"github.com/docker/docker/api/server/httpstatus"
"github.com/docker/docker/api/server/httputils"
Expand All @@ -23,8 +20,6 @@ const versionMatcher = "/v{version:[0-9.]+}"

// Server contains instance details for the server
type Server struct {
servers []*HTTPServer
routers []router.Router
middlewares []middleware.Middleware
}

Expand All @@ -34,71 +29,6 @@ func (s *Server) UseMiddleware(m middleware.Middleware) {
s.middlewares = append(s.middlewares, m)
}

// Accept sets a listener the server accepts connections into.
func (s *Server) Accept(addr string, listeners ...net.Listener) {
for _, listener := range listeners {
httpServer := &HTTPServer{
srv: &http.Server{
Addr: addr,
ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
},
l: listener,
}
s.servers = append(s.servers, httpServer)
}
}

// Close closes servers and thus stop receiving requests
func (s *Server) Close() {
for _, srv := range s.servers {
if err := srv.Close(); err != nil {
logrus.Error(err)
}
}
}

// Serve starts listening for inbound requests.
func (s *Server) Serve() error {
var chErrors = make(chan error, len(s.servers))
for _, srv := range s.servers {
srv.srv.Handler = s.createMux()
go func(srv *HTTPServer) {
var err error
logrus.Infof("API listen on %s", srv.l.Addr())
if err = srv.Serve(); err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
chErrors <- err
}(srv)
}

for range s.servers {
err := <-chErrors
if err != nil {
return err
}
}
return nil
}

// HTTPServer contains an instance of http server and the listener.
// srv *http.Server, contains configuration to create an http server and a mux router with all api end points.
// l net.Listener, is a TCP or Socket listener that dispatches incoming request to the router.
type HTTPServer struct {
srv *http.Server
l net.Listener
}

// Serve starts listening for inbound requests.
func (s *HTTPServer) Serve() error {
return s.srv.Serve(s.l)
}

// Close closes the HTTPServer from listening for the inbound requests.
func (s *HTTPServer) Close() error {
return s.l.Close()
}

func (s *Server) makeHTTPHandler(handler httputils.APIFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Define the context that we'll pass around to share info
Expand Down Expand Up @@ -130,12 +60,6 @@ func (s *Server) makeHTTPHandler(handler httputils.APIFunc) http.HandlerFunc {
}
}

// InitRouter initializes the list of routers for the server.
// This method also enables the Go profiler.
func (s *Server) InitRouter(routers ...router.Router) {
s.routers = append(s.routers, routers...)
}

type pageNotFoundError struct{}

func (pageNotFoundError) Error() string {
Expand All @@ -144,12 +68,12 @@ func (pageNotFoundError) Error() string {

func (pageNotFoundError) NotFound() {}

// createMux initializes the main router the server uses.
func (s *Server) createMux() *mux.Router {
// CreateMux returns a new mux with all the routers registered.
func (s *Server) CreateMux(routers ...router.Router) *mux.Router {
m := mux.NewRouter()

logrus.Debug("Registering routers")
for _, apiRouter := range s.routers {
for _, apiRouter := range routers {
for _, r := range apiRouter.Routes() {
f := s.makeHTTPHandler(r.Handler())

Expand All @@ -160,7 +84,6 @@ func (s *Server) createMux() *mux.Router {
}

debugRouter := debug.NewRouter()
s.routers = append(s.routers, debugRouter)
for _, r := range debugRouter.Routes() {
f := s.makeHTTPHandler(r.Handler())
m.Path("/debug" + r.Path()).Handler(f)
Expand Down
137 changes: 101 additions & 36 deletions cmd/dockerd/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"crypto/tls"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"sync"
"time"

containerddefaults "github.com/containerd/containerd/defaults"
Expand Down Expand Up @@ -65,14 +67,18 @@ type DaemonCli struct {
configFile *string
flags *pflag.FlagSet

api apiserver.Server
d *daemon.Daemon
authzMiddleware *authorization.Middleware // authzMiddleware enables to dynamically reload the authorization plugins

stopOnce sync.Once
apiShutdown chan struct{}
}

// NewDaemonCli returns a daemon CLI
func NewDaemonCli() *DaemonCli {
return &DaemonCli{}
return &DaemonCli{
apiShutdown: make(chan struct{}),
}
}

func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
Expand Down Expand Up @@ -161,7 +167,7 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
}
}

hosts, err := loadListeners(cli, tlsConfig)
lss, hosts, err := loadListeners(cli.Config, tlsConfig)
if err != nil {
return errors.Wrap(err, "failed to load listeners")
}
Expand All @@ -177,20 +183,51 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
}
defer cancel()

stopc := make(chan bool)
defer close(stopc)

trap.Trap(func() {
cli.stop()
<-stopc // wait for daemonCli.start() to return
}, logrus.StandardLogger())
httpServer := &http.Server{
ReadHeaderTimeout: 5 * time.Minute, // "G112: Potential Slowloris Attack (gosec)"; not a real concern for our use, so setting a long timeout.
}
apiShutdownCtx, apiShutdownCancel := context.WithCancel(context.Background())
apiShutdownDone := make(chan struct{})
trap.Trap(cli.stop)
go func() {
// Block until cli.stop() has been called.
// It may have already been called, and that's okay.
// Any httpServer.Serve() calls made after
// httpServer.Shutdown() will return immediately,
// which is what we want.
<-cli.apiShutdown
err := httpServer.Shutdown(apiShutdownCtx)
if err != nil {
logrus.WithError(err).Error("Error shutting down http server")
}
close(apiShutdownDone)
}()
defer func() {
select {
case <-cli.apiShutdown:
// cli.stop() has been called and the daemon has completed
// shutting down. Give the HTTP server a little more time to
// finish handling any outstanding requests if needed.
tmr := time.AfterFunc(5*time.Second, apiShutdownCancel)
defer tmr.Stop()
<-apiShutdownDone
default:
// cli.start() has returned without cli.stop() being called,
// e.g. because the daemon failed to start.
// Stop the HTTP server with no grace period.
if closeErr := httpServer.Close(); closeErr != nil {
logrus.WithError(closeErr).Error("Error closing http server")
}
}
}()

// Notify that the API is active, but before daemon is set up.
preNotifyReady()

pluginStore := plugin.NewStore()

cli.authzMiddleware = initMiddlewares(&cli.api, cli.Config, pluginStore)
var apiServer apiserver.Server
cli.authzMiddleware = initMiddlewares(&apiServer, cli.Config, pluginStore)

d, err := daemon.NewDaemon(ctx, cli.Config, pluginStore, cli.authzMiddleware)
if err != nil {
Expand Down Expand Up @@ -229,10 +266,9 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
if err != nil {
return err
}
routerOptions.api = &cli.api
routerOptions.cluster = c

initRouter(routerOptions)
httpServer.Handler = apiServer.CreateMux(routerOptions.Build()...)

go d.ProcessClusterNotifications(ctx, c.GetWatchStream())

Expand All @@ -243,10 +279,30 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {

// Daemon is fully initialized. Start handling API traffic
// and wait for serve API to complete.
errAPI := cli.api.Serve()
if errAPI != nil {
logrus.WithError(errAPI).Error("ServeAPI error")
var (
apiWG sync.WaitGroup
errAPI = make(chan error, 1)
)
for _, ls := range lss {
apiWG.Add(1)
go func(ls net.Listener) {
defer apiWG.Done()
logrus.Infof("API listen on %s", ls.Addr())
if err := httpServer.Serve(ls); err != http.ErrServerClosed {
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"listener": ls.Addr(),
}).Error("ServeAPI error")

select {
case errAPI <- err:
default:
}
}
}(ls)
}
apiWG.Wait()
close(errAPI)

c.Cleanup()

Expand All @@ -257,8 +313,8 @@ func (cli *DaemonCli) start(opts *daemonOptions) (err error) {
// Stop notification processing and any background processes
cancel()

if errAPI != nil {
return errors.Wrap(errAPI, "shutting down due to ServeAPI error")
if err, ok := <-errAPI; ok {
return errors.Wrap(err, "shutting down due to ServeAPI error")
}

logrus.Info("Daemon shutdown complete")
Expand All @@ -271,7 +327,6 @@ type routerOptions struct {
features *map[string]bool
buildkit *buildkit.Builder
daemon *daemon.Daemon
api *apiserver.Server
cluster *cluster.Cluster
}

Expand Down Expand Up @@ -357,7 +412,14 @@ func (cli *DaemonCli) reloadConfig() {
}

func (cli *DaemonCli) stop() {
cli.api.Close()
// Signal that the API server should shut down as soon as possible.
// This construct is used rather than directly shutting down the HTTP
// server to avoid any issues if this method is called before the server
// has been instantiated in cli.start(). If this method is called first,
// the HTTP server will be shut down immediately upon instantiation.
cli.stopOnce.Do(func() {
close(cli.apiShutdown)
})
}

// shutdownDaemon just wraps daemon.Shutdown() to handle a timeout in case
Expand Down Expand Up @@ -499,7 +561,7 @@ func normalizeHosts(config *config.Config) error {
return nil
}

func initRouter(opts routerOptions) {
func (opts routerOptions) Build() []router.Router {
decoder := runconfig.ContainerDecoder{
GetSysInfo: func() *sysinfo.SysInfo {
return opts.daemon.RawSysInfo()
Expand Down Expand Up @@ -544,7 +606,7 @@ func initRouter(opts routerOptions) {
}
}

opts.api.InitRouter(routers...)
return routers
}

func initMiddlewares(s *apiserver.Server, cfg *config.Config, pluginStore plugingetter.PluginGetter) *authorization.Middleware {
Expand Down Expand Up @@ -648,17 +710,20 @@ func checkTLSAuthOK(c *config.Config) bool {
return true
}

func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
if len(cli.Config.Hosts) == 0 {
return nil, errors.New("no hosts configured")
func loadListeners(cfg *config.Config, tlsConfig *tls.Config) ([]net.Listener, []string, error) {
if len(cfg.Hosts) == 0 {
return nil, nil, errors.New("no hosts configured")
}
var hosts []string
var (
hosts []string
lss []net.Listener
)

for i := 0; i < len(cli.Config.Hosts); i++ {
protoAddr := cli.Config.Hosts[i]
for i := 0; i < len(cfg.Hosts); i++ {
protoAddr := cfg.Hosts[i]
proto, addr, ok := strings.Cut(protoAddr, "://")
if !ok {
return nil, fmt.Errorf("bad format %s, expected PROTO://ADDR", protoAddr)
return nil, nil, fmt.Errorf("bad format %s, expected PROTO://ADDR", protoAddr)
}

// It's a bad idea to bind to TCP without tlsverify.
Expand All @@ -670,10 +735,10 @@ func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {

// If TLSVerify is explicitly set to false we'll take that as "Please let me shoot myself in the foot"
// We do not want to continue to support a default mode where tls verification is disabled, so we do some extra warnings here and eventually remove support
if !checkTLSAuthOK(cli.Config) {
if !checkTLSAuthOK(cfg) {
ipAddr, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, errors.Wrap(err, "error parsing tcp address")
return nil, nil, errors.Wrap(err, "error parsing tcp address")
}

// shortcut all this extra stuff for literal "localhost"
Expand Down Expand Up @@ -703,19 +768,19 @@ func loadListeners(cli *DaemonCli, tlsConfig *tls.Config) ([]string, error) {
// If we're binding to a TCP port, make sure that a container doesn't try to use it.
if proto == "tcp" {
if err := allocateDaemonPort(addr); err != nil {
return nil, err
return nil, nil, err
}
}
ls, err := listeners.Init(proto, addr, cli.Config.SocketGroup, tlsConfig)
ls, err := listeners.Init(proto, addr, cfg.SocketGroup, tlsConfig)
if err != nil {
return nil, err
return nil, nil, err
}
logrus.Debugf("Listener created for HTTP on %s (%s)", proto, addr)
hosts = append(hosts, addr)
cli.api.Accept(addr, ls...)
lss = append(lss, ls...)
}

return hosts, nil
return lss, hosts, nil
}

func createAndStartCluster(cli *DaemonCli, d *daemon.Daemon) (*cluster.Cluster, error) {
Expand Down
Loading

0 comments on commit e22758b

Please sign in to comment.