Skip to content

Commit

Permalink
add port to SASL metadata (segmentio#780)
Browse files Browse the repository at this point in the history
  • Loading branch information
Achille authored Nov 5, 2021
1 parent f7dd036 commit e1f5945
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
9 changes: 1 addition & 8 deletions address.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,9 @@ func makeNetAddr(network string, addresses []string) net.Addr {
}

func makeAddr(network, address string) net.Addr {
host, port, _ := net.SplitHostPort(address)
if port == "" {
port = "9092"
}
if host == "" {
host = address
}
return &networkAddress{
network: network,
address: net.JoinHostPort(host, port),
address: canonicalAddress(address),
}
}

Expand Down
6 changes: 2 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -241,11 +240,10 @@ func (c *Conn) loadVersions() (apiVersionMap, error) {
// connection was established to.
func (c *Conn) Broker() Broker {
addr := c.conn.RemoteAddr()
host, port, _ := net.SplitHostPort(addr.String())
portNumber, _ := strconv.Atoi(port)
host, port, _ := splitHostPortNumber(addr.String())
return Broker{
Host: host,
Port: portNumber,
Port: port,
ID: int(c.broker),
Rack: c.rack,
}
Expand Down
26 changes: 21 additions & 5 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kafka
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"strconv"
Expand Down Expand Up @@ -281,8 +282,13 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C
conn := NewConnWith(c, connCfg)

if d.SASLMechanism != nil {
host, port, err := splitHostPortNumber(address)
if err != nil {
return nil, err
}
metadata := &sasl.Metadata{
Host: address,
Host: host,
Port: port,
}
if err := d.authenticateSASL(sasl.WithMetadata(ctx, metadata), conn); err != nil {
_ = conn.Close()
Expand Down Expand Up @@ -435,14 +441,28 @@ func backoff(attempt int, min time.Duration, max time.Duration) time.Duration {
return d
}

func canonicalAddress(s string) string {
return net.JoinHostPort(splitHostPort(s))
}

func splitHostPort(s string) (host string, port string) {
host, port, _ = net.SplitHostPort(s)
if len(host) == 0 && len(port) == 0 {
host = s
port = "9092"
}
return
}

func splitHostPortNumber(s string) (host string, portNumber int, err error) {
host, port := splitHostPort(s)
portNumber, err = strconv.Atoi(port)
if err != nil {
return host, 0, fmt.Errorf("%s: %w", s, err)
}
return host, portNumber, nil
}

func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) {
host, port := splitHostPort(address)

Expand All @@ -468,9 +488,5 @@ func lookupHost(ctx context.Context, address string, resolver Resolver) (string,
}
}

if port == "" {
port = "9092"
}

return net.JoinHostPort(host, port), nil
}
1 change: 1 addition & 0 deletions sasl/sasl.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Metadata struct {
// Host is the address of the broker the authentication will be
// performed on.
Host string
Port int
}

// WithMetadata returns a copy of the context with associated Metadata.
Expand Down
19 changes: 10 additions & 9 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,16 +972,12 @@ func (g *connGroup) grabConnOrConnect(ctx context.Context) (*conn, error) {
broker := g.broker

if broker.ID < 0 {
host, port, err := net.SplitHostPort(addr.String())
host, port, err := splitHostPortNumber(addr.String())
if err != nil {
return nil, fmt.Errorf("%s: %w", addr, err)
}
portNumber, err := strconv.Atoi(port)
if err != nil {
return nil, fmt.Errorf("%s: %w", addr, err)
return nil, err
}
broker.Host = host
broker.Port = portNumber
broker.Port = port
}

ipAddrs, err := rslv.LookupBrokerIPAddr(ctx, broker)
Expand Down Expand Up @@ -1167,7 +1163,7 @@ func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) {

if tlsConfig := g.pool.tls; tlsConfig != nil {
if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify {
host, _, _ := net.SplitHostPort(netAddr.String())
host, _ := splitHostPort(netAddr.String())
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = host
}
Expand Down Expand Up @@ -1197,8 +1193,13 @@ func (g *connGroup) connect(ctx context.Context, addr net.Addr) (*conn, error) {
pc.SetDeadline(time.Time{})

if g.pool.sasl != nil {
host, port, err := splitHostPortNumber(netAddr.String())
if err != nil {
return nil, err
}
metadata := &sasl.Metadata{
Host: netAddr.String(),
Host: host,
Port: port,
}
if err := authenticateSASL(sasl.WithMetadata(ctx, metadata), pc, g.pool.sasl); err != nil {
return nil, err
Expand Down

0 comments on commit e1f5945

Please sign in to comment.