Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plugin/timeouts - Allow ability to configure listening server timeouts #5784

Merged
merged 9 commits into from
Dec 28, 2022
Merged
Next Next commit
CoreDNS server timeouts pass 1
Signed-off-by: Rich <git0@bitservices.io>
  • Loading branch information
rlees85 committed Dec 6, 2022
commit 1fc3a69e48ce8d38e52ff9d1e9715d0b764fffba
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# only add build artifacts concerning coredns - no editor related files
coredns
coredns.exe
Corefile
build/
release/
vendor/
4 changes: 4 additions & 0 deletions core/dnsserver/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"fmt"
"net/http"
"time"

"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin"
Expand Down Expand Up @@ -53,6 +54,9 @@ type Config struct {
// TLSConfig when listening for encrypted connections (gRPC, DNS-over-TLS).
TLSConfig *tls.Config

// Timeouts for TCP, TLS and HTTPS servers.
Timeouts map[string]time.Duration
rlees85 marked this conversation as resolved.
Show resolved Hide resolved

// TSIG secrets, [name]key.
TsigSecret map[string]string

Expand Down
4 changes: 3 additions & 1 deletion core/dnsserver/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (h *dnsContext) MakeServers() ([]caddy.Server, error) {

// Fork TLSConfig for each encrypted connection
c.TLSConfig = c.firstConfigInBlock.TLSConfig.Clone()
c.Timeouts = c.firstConfigInBlock.Timeouts
c.TsigSecret = c.firstConfigInBlock.TsigSecret
}

Expand Down Expand Up @@ -223,7 +224,8 @@ func (c *Config) AddPlugin(m plugin.Plugin) {
}

// registerHandler adds a handler to a site's handler registration. Handlers
// use this to announce that they exist to other plugin.
//
// use this to announce that they exist to other plugin.
func (c *Config) registerHandler(h plugin.Handler) {
if c.registry == nil {
c.registry = make(map[string]plugin.Handler)
Expand Down
41 changes: 36 additions & 5 deletions core/dnsserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ type Server struct {
debug bool // disable recover()
stacktrace bool // enable stacktrace in recover error log
classChaos bool // allow non-INET class queries
idleTimeout time.Duration // Idle timeout for TCP
readTimeout time.Duration // Read timeout for TCP
writeTimeout time.Duration // Write timeout for TCP

tsigSecret map[string]string
}
Expand All @@ -60,6 +63,9 @@ func NewServer(addr string, group []*Config) (*Server, error) {
Addr: addr,
zones: make(map[string][]*Config),
graceTimeout: 5 * time.Second,
idleTimeout: 10 * time.Second,
readTimeout: 3 * time.Second,
writeTimeout: 5 * time.Second,
tsigSecret: make(map[string]string),
}

Expand All @@ -81,6 +87,18 @@ func NewServer(addr string, group []*Config) (*Server, error) {
// append the config to the zone's configs
s.zones[site.Zone] = append(s.zones[site.Zone], site)

// set timeouts
for key, timeout := range site.Timeouts {
switch key {
case "idle":
rlees85 marked this conversation as resolved.
Show resolved Hide resolved
s.idleTimeout = timeout
case "read":
s.readTimeout = timeout
case "write":
s.writeTimeout = timeout
}
}

// copy tsig secrets
rlees85 marked this conversation as resolved.
Show resolved Hide resolved
for key, secret := range site.TsigSecret {
s.tsigSecret[key] = secret
Expand Down Expand Up @@ -130,11 +148,22 @@ var _ caddy.GracefulServer = &Server{}
// This implements caddy.TCPServer interface.
func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
}), TsigSecret: s.tsigSecret}

s.server[tcp] = &dns.Server{Listener: l,
Net: "tcp",
TsigSecret: s.tsigSecret,
MaxTCPQueries: tcpMaxQueries,
ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout,
IdleTimeout: func() time.Duration {
return s.idleTimeout
},
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}

s.m.Unlock()

return s.server[tcp].ActivateAndServe()
Expand Down Expand Up @@ -404,6 +433,8 @@ func errorAndMetricsFunc(server string, w dns.ResponseWriter, r *dns.Msg, rc int
const (
tcp = 0
udp = 1

tcpMaxQueries = -1
)

type (
Expand Down
6 changes: 3 additions & 3 deletions core/dnsserver/server_https.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
}

srv := &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout,
IdleTimeout: s.idleTimeout,
ErrorLog: stdlog.New(&loggerAdapter{}, "", 0),
}
sh := &ServerHTTPS{
Expand Down
46 changes: 41 additions & 5 deletions core/dnsserver/server_tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"fmt"
"net"
"time"

"github.com/coredns/caddy"
"github.com/coredns/coredns/plugin/pkg/reuseport"
Expand Down Expand Up @@ -49,12 +50,43 @@ func (s *ServerTLS) Serve(l net.Listener) error {
l = tls.NewListener(l, s.tlsConfig)
}

var (
TLSIdleTimeout time.Duration
TLSReadTimeout time.Duration
TLSWriteTimeout time.Duration
)

if s.idleTimeout == time.Duration(0) {
TLSIdleTimeout = DefaultTLSIdleTimeout
} else {
TLSIdleTimeout = s.idleTimeout
}
if s.readTimeout == time.Duration(0) {
TLSReadTimeout = DefaultTLSReadTimeout
} else {
TLSReadTimeout = s.readTimeout
}
if s.readTimeout == time.Duration(0) {
TLSWriteTimeout = DefaultTLSWriteTimeout
} else {
TLSWriteTimeout = s.writeTimeout
}

// Only fill out the TCP server for this one.
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp-tls", Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}
s.server[tcp] = &dns.Server{Listener: l,
Net: "tcp-tls",
MaxTCPQueries: TLSMaxQueries,
ReadTimeout: TLSReadTimeout,
WriteTimeout: TLSWriteTimeout,
IdleTimeout: func() time.Duration {
return TLSIdleTimeout
},
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
ctx := context.WithValue(context.Background(), Key{}, s.Server)
ctx = context.WithValue(ctx, LoopKey{}, 0)
s.ServeDNS(ctx, w, r)
})}

s.m.Unlock()

return s.server[tcp].ActivateAndServe()
Expand Down Expand Up @@ -87,3 +119,7 @@ func (s *ServerTLS) OnStartupComplete() {
fmt.Print(out)
}
}

const (
tlsMaxQueries = -1
rlees85 marked this conversation as resolved.
Show resolved Hide resolved
)
1 change: 1 addition & 0 deletions core/dnsserver/zdirectives.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ var Directives = []string{
"geoip",
"cancel",
"tls",
"timeouts",
"reload",
"nsid",
"bufsize",
Expand Down
1 change: 1 addition & 0 deletions core/plugin/zplugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
_ "github.com/coredns/coredns/plugin/secondary"
_ "github.com/coredns/coredns/plugin/sign"
_ "github.com/coredns/coredns/plugin/template"
_ "github.com/coredns/coredns/plugin/timeouts"
_ "github.com/coredns/coredns/plugin/tls"
_ "github.com/coredns/coredns/plugin/trace"
_ "github.com/coredns/coredns/plugin/transfer"
Expand Down
1 change: 1 addition & 0 deletions plugin.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ metadata:metadata
geoip:geoip
cancel:cancel
tls:tls
timeouts:timeouts
reload:reload
nsid:nsid
bufsize:bufsize
Expand Down
45 changes: 45 additions & 0 deletions plugin/pkg/timeouts/timeouts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package timeouts

import (
"fmt"
"strconv"
"time"
)

func NewTimeoutsConfigFromArgs(args ...string) (map[string]time.Duration, error) {
rlees85 marked this conversation as resolved.
Show resolved Hide resolved
c := make(map[string]time.Duration)

for i := 0; i < len(args); i++ {
t, err := validateTimeout(args[i])

if err != nil {
return c, err
}

switch i {
case 0:
c["read"] = t
case 1:
c["write"] = t
case 2:
c["idle"] = t
default:
return c, fmt.Errorf("maximum of three arguments allowed for timeouts config, found %d", len(args))
}
}

return c, nil
}

func validateTimeout(t string) (time.Duration, error) {
i, err := strconv.Atoi(t)
if err != nil {
return time.Duration(0), fmt.Errorf("timeout provided '%s' does not appear to be numeric", t)
}

if i < 1 || i > 86400 {
return time.Duration(0), fmt.Errorf("timeout provided '%d' needs to be between 1 and 86400 second(s)", i)
}

return time.Duration(i) * time.Second, nil
}
101 changes: 101 additions & 0 deletions plugin/pkg/timeouts/timeouts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package timeouts

import (
"testing"
)

func TestNewTimeoutsConfigFromArgs(t *testing.T) {
var validIdleTimeoutArg = "300" // Args from configuration are always strings
var validReadTimeoutArg = "30"
var validWriteTimeoutArg = "60"

var invalidTimeoutString = "twenty seconds"
var invalidTimeoutLow = "0"
var invalidTimeoutHigh = "86401"

// No Arguments
to, err := NewTimeoutsConfigFromArgs()
if err != nil {
t.Errorf("Failed to create timeouts map when no arguments specified: %s", err)
}

if len(to) != 0 {
t.Error("Timeouts map with no arguments should be empty")
}

// Read Timeout only
to, err = NewTimeoutsConfigFromArgs(validReadTimeoutArg)
if err != nil {
t.Errorf("Failed to create timeouts map given just a read timeout: %s", err)
}

if _, ok := to["read"]; !ok {
t.Error("Timeouts map given just a read timeout did not return a read timeout")
}

if _, ok := to["write"]; ok {
t.Error("Timeouts map given just a read timeout also returned a write timeout")
}

if _, ok := to["idle"]; ok {
t.Error("Timeouts map given just a read timeout also returned an idle timeout")
}

// Read and Write Timeouts (no Idle)
to, err = NewTimeoutsConfigFromArgs(validReadTimeoutArg, validWriteTimeoutArg)
if err != nil {
t.Errorf("Failed to create timeouts map given a read and write timeout: %s", err)
}

if _, ok := to["read"]; !ok {
t.Error("Timeouts map given a read and write timeout did not return a read timeout")
}

if _, ok := to["write"]; !ok {
t.Error("Timeouts map given a read and write timeout did not return a write timeout")
}

if _, ok := to["idle"]; ok {
t.Error("Timeouts map given a read and write timeout also returned an idle timeout")
}

// All Timeouts
to, err = NewTimeoutsConfigFromArgs(validReadTimeoutArg, validWriteTimeoutArg, validIdleTimeoutArg)
if err != nil {
t.Errorf("Failed to create timeouts map given all timeouts: %s", err)
}

if _, ok := to["read"]; !ok {
t.Error("Timeouts map given all timeouts did not return a read timeout")
}

if _, ok := to["write"]; !ok {
t.Error("Timeouts map given all timeouts did not return a write timeout")
}

if _, ok := to["idle"]; !ok {
t.Error("Timeouts map given all timeouts did not return an idle idle timeout")
}

// Too Many Timeouts
to, err = NewTimeoutsConfigFromArgs(validReadTimeoutArg, validWriteTimeoutArg, validIdleTimeoutArg, "100")
if err == nil {
t.Error("Attempt to create timeouts with too many arguments was successful")
}

// Timeout Validation
to, err = NewTimeoutsConfigFromArgs(invalidTimeoutString)
if err == nil {
t.Error("Attempt to create timeouts with non-numeric value was successful")
}

to, err = NewTimeoutsConfigFromArgs(invalidTimeoutLow)
if err == nil {
t.Error("Attempt to create a timeout of less than 1 second was successful")
}

to, err = NewTimeoutsConfigFromArgs(invalidTimeoutHigh)
if err == nil {
t.Error("Attempt to create a timeout of more than 1 day was successful")
}
}
Loading