Skip to content

Commit

Permalink
add Transport config options to limit the number of handshakes
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jan 19, 2024
1 parent 594440b commit 6dfc56d
Show file tree
Hide file tree
Showing 4 changed files with 404 additions and 106 deletions.
143 changes: 142 additions & 1 deletion integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"time"

"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/logging"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -301,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
})
})

Context("rate limiting", func() {
Context("queuening and accepting connections", func() {
var (
server *quic.Listener
pconn net.PacketConn
Expand Down Expand Up @@ -448,6 +450,145 @@ var _ = Describe("Handshake tests", func() {
})
})

Context("limiting handshakes", func() {
var conn *net.UDPConn

BeforeEach(func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() { conn.Close() })

It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3
tr := quic.Transport{
Conn: conn,
MaxUnvalidatedHandshakes: limit,
}
defer tr.Close()

// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

const additional = 2
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
for i := 0; i < limit+additional; i++ {
go func(index int) {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
ClosedConnection: func(error) { results[index].closed.Store(true) },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
}(i)
}
numRetries := func() (n int) {
for i := 0; i < limit+additional; i++ {
if results[i].retry.Load() {
n++
}
}
return
}
numClosed := func() (n int) {
for i := 0; i < limit+2; i++ {
if results[i].closed.Load() {
n++
}
}
return
}
Eventually(numRetries).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit+additional; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(numClosed).Should(Equal(limit + additional))
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
})

It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3
tr := quic.Transport{
Conn: conn,
MaxHandshakes: limit,
}
defer tr.Close()

// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

const additional = 2
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
var numSuccessful, numFailed atomic.Int32
for i := 0; i < limit+additional; i++ {
go func() {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
if err != nil {
var transportErr *quic.TransportError
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
}
numFailed.Add(1)
return
}
numSuccessful.Add(1)
conn.CloseWithError(0, "")
}()
}
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))

// make sure that the server is reachable again after these handshakes have completed
go func() { <-handshakes }() // allow this handshake to complete immediately
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
})

Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expand Down
103 changes: 67 additions & 36 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net"
"sync"
"sync/atomic"
"time"

"github.com/quic-go/quic-go/internal/handshake"
Expand Down Expand Up @@ -110,6 +111,11 @@ type baseServer struct {
connectionRefusedQueue chan rejectedPacket
retryQueue chan rejectedPacket

maxNumHandshakesUnvalidated int
maxNumHandshakesTotal int
numHandshakesUnvalidated atomic.Int64
numHandshakesValidated atomic.Int64

connQueue chan quicConn

tracer *logging.Tracer
Expand Down Expand Up @@ -238,31 +244,34 @@ func newServer(
onClose func(),
tokenGeneratorKey TokenGeneratorKey,
maxTokenAge time.Duration,
maxNumHandshakesUnvalidated, maxNumHandshakesTotal int,
disableVersionNegotiation bool,
acceptEarly bool,
) *baseServer {
s := &baseServer{
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
conn: conn,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
maxTokenAge: maxTokenAge,
maxNumHandshakesUnvalidated: maxNumHandshakesUnvalidated,
maxNumHandshakesTotal: maxNumHandshakesTotal,
connIDGenerator: connIDGenerator,
connHandler: connHandler,
connQueue: make(chan quicConn, protocol.MaxAcceptQueueSize),
errorChan: make(chan struct{}),
running: make(chan struct{}),
receivedPackets: make(chan receivedPacket, protocol.MaxServerUnprocessedPackets),
versionNegotiationQueue: make(chan receivedPacket, 4),
invalidTokenQueue: make(chan rejectedPacket, 4),
connectionRefusedQueue: make(chan rejectedPacket, 4),
retryQueue: make(chan rejectedPacket, 8),
newConn: newConnection,
tracer: tracer,
logger: utils.DefaultLogger.WithPrefix("server"),
acceptEarlyConns: acceptEarly,
disableVersionNegotiation: disableVersionNegotiation,
onClose: onClose,
}
if acceptEarly {
s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{}
Expand Down Expand Up @@ -570,8 +579,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
}

clientAddrIsValid := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrIsValid {
clientAddrValidated := s.validateToken(token, p.remoteAddr)
if token != nil && !clientAddrValidated {
// For invalid and expired non-retry tokens, we don't send an INVALID_TOKEN error.
// We just ignore them, and act as if there was no token on this packet at all.
// This also means we might send a Retry later.
Expand All @@ -590,24 +599,31 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
return nil
}
}
if token == nil && s.config.RequireAddressValidation(p.remoteAddr) {
// Retry invalidates all 0-RTT packets sent.

// Until the next call to handleInitialImpl, these numbers are guaranteed to not increase.
// They might decrease if another connection completes the handshake.
numHandshakesUnvalidated := s.numHandshakesUnvalidated.Load()
numHandshakesValidated := s.numHandshakesValidated.Load()

// Check the total handshake limit first. It's better to reject than to initiate a retry.
if total := numHandshakesUnvalidated + numHandshakesValidated; total >= int64(s.maxNumHandshakesTotal) {
s.logger.Debugf("Rejecting new connection. Server currently busy. Currently handshaking: %d (max %d)", total, s.maxNumHandshakesTotal)
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.retryQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out Retry packets fast enough
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
p.buffer.Release()
}
return nil
}

if queueLen := len(s.connQueue); queueLen >= protocol.MaxAcceptQueueSize {
s.logger.Debugf("Rejecting new connection. Server currently busy. Accept queue length: %d (max %d)", queueLen, protocol.MaxAcceptQueueSize)
if token == nil && (s.config.RequireAddressValidation(p.remoteAddr) || numHandshakesUnvalidated >= int64(s.maxNumHandshakesUnvalidated)) {
// Retry invalidates all 0-RTT packets sent.
delete(s.zeroRTTQueues, hdr.DestConnectionID)
select {
case s.connectionRefusedQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
case s.retryQueue <- rejectedPacket{receivedPacket: p, hdr: hdr}:
default:
// drop packet if we can't send out the CONNECTION_REFUSED fast enough
// drop packet if we can't send out Retry packets fast enough
p.buffer.Release()
}
return nil
Expand Down Expand Up @@ -652,7 +668,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
config,
s.tlsConf,
s.tokenGenerator,
clientAddrIsValid,
clientAddrValidated,
tracer,
tracingID,
s.logger,
Expand All @@ -677,16 +693,21 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
return nil
}
if clientAddrValidated {
s.numHandshakesValidated.Add(1)
} else {
s.numHandshakesUnvalidated.Add(1)
}
go conn.run()
go s.handleNewConn(conn)
go s.handleNewConn(conn, clientAddrValidated)
if conn == nil {
p.buffer.Release()
return nil
}
return nil
}

func (s *baseServer) handleNewConn(conn quicConn) {
func (s *baseServer) handleNewConn(conn quicConn, clientAddrValidated bool) {
connCtx := conn.Context()
if s.acceptEarlyConns {
// wait until the early connection is ready, the handshake fails, or the server is closed
Expand All @@ -710,10 +731,20 @@ func (s *baseServer) handleNewConn(conn quicConn) {
}
}

if clientAddrValidated {
if s.numHandshakesValidated.Add(-1) < 0 {
panic("server BUG: number of validated handshakes negative")
}
} else {
if s.numHandshakesUnvalidated.Add(-1) < 0 {
panic("server BUG: number of unvalidated handshakes negative")
}
}

select {
case s.connQueue <- conn:
default:
conn.destroy(&qerr.TransportError{ErrorCode: ConnectionRefused})
conn.closeWithTransportError(ConnectionRefused)
}
}

Expand Down
Loading

0 comments on commit 6dfc56d

Please sign in to comment.