Skip to content

Commit

Permalink
Add SkipTLSHandshake small interface to CustomDialer
Browse files Browse the repository at this point in the history
Signed-off-by: Waldemar Quevedo <wally@nats.io>

Co-authored-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
wallyqs and piotrpio committed Nov 29, 2022
1 parent c9b2fd8 commit 65b7870
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 19 deletions.
35 changes: 35 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"context"
"fmt"
"log"
"net"
"time"

"github.com/nats-io/nats.go"
Expand Down Expand Up @@ -44,6 +45,40 @@ func ExampleConnect() {
nc.Close()
}

type skipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}

func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}

func (sd *skipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}

func ExampleCustomDialer() {
// Given the following CustomDialer implementation:
//
// type skipTLSDialer struct {
// dialer *net.Dialer
// skipTLS bool
// }
//
// func (sd *skipTLSDialer) Dial(network, address string) (net.Conn, error) {
// return sd.dialer.Dial(network, address)
// }
//
// func (sd *skipTLSDialer) SkipTLSHandshake() bool {
// return true
// }
//
sd := &skipTLSDialer{dialer: &net.Dialer{Timeout: 2 * time.Second}, skipTLS: true}
nc, _ := nats.Connect("demo.nats.io", nats.SetCustomDialer(sd))
defer nc.Close()
}

// This Example shows an asynchronous subscriber.
func ExampleConn_Subscribe() {
nc, _ := nats.Connect(nats.DefaultURL)
Expand Down
30 changes: 12 additions & 18 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ type asyncCallbacksHandler struct {
// Option is a function on the options for a connection.
type Option func(*Options) error

// CustomDialer can be used to specify any dialer, not necessarily
// a *net.Dialer.
// CustomDialer can be used to specify any dialer, not necessarily a
// *net.Dialer. A CustomDialer may also implement `SkipTLSHandshake() bool`
// in order to skip the TLS handshake in case not required.
type CustomDialer interface {
Dial(network, address string) (net.Conn, error)
}
Expand Down Expand Up @@ -303,10 +304,6 @@ type Options struct {
// transports.
TLSConfig *tls.Config

// SkipTLSWrapper does not upgrade the connection to TLS and is
// meant to be used if the custom dialer does handle TLS itself
SkipTLSWrapper bool

// AllowReconnect enables reconnection logic to be used when we
// encounter a disconnect from the current server.
AllowReconnect bool
Expand Down Expand Up @@ -1189,16 +1186,6 @@ func SetCustomDialer(dialer CustomDialer) Option {
}
}

// SetSkipTLSWrapper is an Option to be used with the CustomDialer which
// will not wrap the connection with TLS. Use it if the CustomDialer did
// already handle TLS
func SetSkipTLSWrapper(skip bool) Option {
return func(o *Options) error {
o.SkipTLSWrapper = skip
return nil
}
}

// UseOldRequestStyle is an Option to force usage of the old Request style.
func UseOldRequestStyle() Option {
return func(o *Options) error {
Expand Down Expand Up @@ -1906,11 +1893,18 @@ func (nc *Conn) createConn() (err error) {
return nil
}

type skipTLSDialer interface {
SkipTLSHandshake() bool
}

// makeTLSConn will wrap an existing Conn using TLS
func (nc *Conn) makeTLSConn() error {
if nc.Opts.SkipTLSWrapper {
if nc.Opts.CustomDialer != nil {
// we do nothing when asked to skip the TLS wrapper
return nil
sd, ok := nc.Opts.CustomDialer.(skipTLSDialer)
if ok && sd.SkipTLSHandshake() {
return nil
}
}
// Allow the user to configure their own tls.Config structure.
var tlsCopy *tls.Config
Expand Down
2 changes: 1 addition & 1 deletion services/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

type (

// Service is an interface for sevice management.
// Service is an interface for service management.
// It exposes methods to stop/reset a service, as well as get information on a service.
Service interface {
ID() string
Expand Down
54 changes: 54 additions & 0 deletions ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,60 @@ func TestWSWithTLS(t *testing.T) {
}
}

type testSkipTLSDialer struct {
dialer *net.Dialer
skipTLS bool
}

func (sd *testSkipTLSDialer) Dial(network, address string) (net.Conn, error) {
return sd.dialer.Dial(network, address)
}

func (sd *testSkipTLSDialer) SkipTLSHandshake() bool {
return sd.skipTLS
}

func TestWSWithTLSCustomDialer(t *testing.T) {
sopts := testWSGetDefaultOptions(t, true)
s := RunServerWithOptions(sopts)
defer s.Shutdown()

sd := &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: true,
}

// Connect with CustomDialer that fails since TLSHandshake is disabled.
copts := make([]Option, 0)
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
_, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err == nil {
t.Fatalf("Expected error on connect: %v", err)
}
if err.Error() != `invalid websocket connection` {
t.Logf("Expected invalid websocket connection: %v", err)
}

// Retry with the dialer.
copts = make([]Option, 0)
sd = &testSkipTLSDialer{
dialer: &net.Dialer{
Timeout: 2 * time.Second,
},
skipTLS: false,
}
copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true}))
copts = append(copts, SetCustomDialer(sd))
nc, err := Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...)
if err != nil {
t.Fatalf("Unexpected error on connect: %v", err)
}
defer nc.Close()
}

func TestWSTlsNoConfig(t *testing.T) {
opts := GetDefaultOptions()
opts.Servers = []string{"wss://localhost:443"}
Expand Down

0 comments on commit 65b7870

Please sign in to comment.