Skip to content

Commit

Permalink
Small fix to Dialer to set the ServerName in tls.Config if it is empt…
Browse files Browse the repository at this point in the history
  • Loading branch information
pkedy authored and achille-roussel committed Nov 27, 2018
1 parent 310a8f1 commit 0b3aacc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
20 changes: 17 additions & 3 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"net"
"strconv"
"strings"
"time"
)

Expand Down Expand Up @@ -216,8 +217,8 @@ func (d *Dialer) LookupPartitions(ctx context.Context, network string, address s
}

// connectTLS returns a tls.Conn that has already completed the Handshake
func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn) (tlsConn *tls.Conn, err error) {
tlsConn = tls.Client(conn, d.TLS)
func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) {
tlsConn = tls.Client(conn, config)
errch := make(chan error)

go func() {
Expand Down Expand Up @@ -265,7 +266,20 @@ func (d *Dialer) dialContext(ctx context.Context, network string, address string
}

if d.TLS != nil {
return d.connectTLS(ctx, conn)
c := d.TLS
// If no ServerName is set, infer the ServerName
// from the hostname we're connecting to.
if c.ServerName == "" {
c = d.TLS.Clone()
// Copied from tls.go in the standard library.
colonPos := strings.LastIndex(address, ":")
if colonPos == -1 {
colonPos = len(address)
}
hostname := address[:colonPos]
c.ServerName = hostname
}
return d.connectTLS(ctx, conn, c)
}

return conn, nil
Expand Down
2 changes: 1 addition & 1 deletion dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func TestDialerConnectTLSHonorsContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*25)
defer cancel()

_, err := d.connectTLS(ctx, conn)
_, err := d.connectTLS(ctx, conn, d.TLS)
if context.DeadlineExceeded != err {
t.Errorf("expected err to be %v; got %v", context.DeadlineExceeded, err)
t.FailNow()
Expand Down

0 comments on commit 0b3aacc

Please sign in to comment.