Skip to content

Commit

Permalink
add Transport config options to limit the number of handshakes
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jan 19, 2024
1 parent 594440b commit 7951e15
Show file tree
Hide file tree
Showing 4 changed files with 427 additions and 117 deletions.
177 changes: 165 additions & 12 deletions integrationtests/self/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"time"

"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/logging"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -301,7 +303,7 @@ var _ = Describe("Handshake tests", func() {
})
})

Context("rate limiting", func() {
Context("queuening and accepting connections", func() {
var (
server *quic.Listener
pconn net.PacketConn
Expand Down Expand Up @@ -343,8 +345,11 @@ var _ = Describe("Handshake tests", func() {
}
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued

_, err := dial()
Expect(err).To(HaveOccurred())
conn, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn.AcceptStream(ctx)
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
Expand All @@ -353,18 +358,21 @@ var _ = Describe("Handshake tests", func() {
_, err = server.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
// dial again, and expect that this dial succeeds
conn, err := dial()
conn2, err := dial()
Expect(err).ToNot(HaveOccurred())
defer conn.CloseWithError(0, "")
defer conn2.CloseWithError(0, "")
time.Sleep(25 * time.Millisecond) // wait a bit for the connection to be queued

_, err = dial()
Expect(err).To(HaveOccurred())
conn3, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn3.AcceptStream(ctx)
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
})

It("removes closed connections from the accept queue", func() {
It("also returns closed connections from the accept queue", func() {
firstConn, err := dial()
Expect(err).ToNot(HaveOccurred())

Expand All @@ -375,8 +383,11 @@ var _ = Describe("Handshake tests", func() {
}
time.Sleep(scaleDuration(20 * time.Millisecond)) // wait a bit for the connection to be queued

_, err = dial()
Expect(err).To(HaveOccurred())
conn, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn.AcceptStream(ctx)
var transportErr *quic.TransportError
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))
Expand All @@ -388,8 +399,11 @@ var _ = Describe("Handshake tests", func() {
time.Sleep(scaleDuration(200 * time.Millisecond))

// dial again, and expect that this fails again
_, err = dial()
Expect(err).To(HaveOccurred())
conn2, err := dial()
Expect(err).ToNot(HaveOccurred())
ctx, cancel = context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
_, err = conn2.AcceptStream(ctx)
Expect(errors.As(err, &transportErr)).To(BeTrue())
Expect(transportErr.ErrorCode).To(Equal(quic.ConnectionRefused))

Expand Down Expand Up @@ -448,6 +462,145 @@ var _ = Describe("Handshake tests", func() {
})
})

Context("limiting handshakes", func() {
var conn *net.UDPConn

BeforeEach(func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() { conn.Close() })

It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3
tr := quic.Transport{
Conn: conn,
MaxUnvalidatedHandshakes: limit,
}
defer tr.Close()

// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

const additional = 2
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
for i := 0; i < limit+additional; i++ {
go func(index int) {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
ClosedConnection: func(error) { results[index].closed.Store(true) },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
}(i)
}
numRetries := func() (n int) {
for i := 0; i < limit+additional; i++ {
if results[i].retry.Load() {
n++
}
}
return
}
numClosed := func() (n int) {
for i := 0; i < limit+2; i++ {
if results[i].closed.Load() {
n++
}
}
return
}
Eventually(numRetries).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit+additional; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(numClosed).Should(Equal(limit + additional))
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
})

It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3
tr := quic.Transport{
Conn: conn,
MaxHandshakes: limit,
}
defer tr.Close()

// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

const additional = 2
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
var numSuccessful, numFailed atomic.Int32
for i := 0; i < limit+additional; i++ {
go func() {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
if err != nil {
var transportErr *quic.TransportError
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
}
numFailed.Add(1)
return
}
numSuccessful.Add(1)
conn.CloseWithError(0, "")
}()
}
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))

// make sure that the server is reachable again after these handshakes have completed
go func() { <-handshakes }() // allow this handshake to complete immediately
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
})

Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expand Down
Loading

0 comments on commit 7951e15

Please sign in to comment.