package kafka import ( "context" "crypto/tls" "io" "net" "strconv" "strings" "time" "github.com/segmentio/kafka-go/sasl" ) // The Dialer type mirrors the net.Dialer API but is designed to open kafka // connections instead of raw network connections. type Dialer struct { // Unique identifier for client connections established by this Dialer. ClientID string // Optionally specifies the function that the dialer uses to establish // network connections. If nil, net.(*Dialer).DialContext is used instead. // // When DialFunc is set, LocalAddr, DualStack, FallbackDelay, and KeepAlive // are ignored. DialFunc func(ctx context.Context, network string, address string) (net.Conn, error) // Timeout is the maximum amount of time a dial will wait for a connect to // complete. If Deadline is also set, it may fail earlier. // // The default is no timeout. // // When dialing a name with multiple IP addresses, the timeout may be // divided between them. // // With or without a timeout, the operating system may impose its own // earlier timeout. For instance, TCP timeouts are often around 3 minutes. Timeout time.Duration // Deadline is the absolute point in time after which dials will fail. // If Timeout is set, it may fail earlier. // Zero means no deadline, or dependent on the operating system as with the // Timeout option. Deadline time.Time // LocalAddr is the local address to use when dialing an address. // The address must be of a compatible type for the network being dialed. // If nil, a local address is automatically chosen. LocalAddr net.Addr // DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing when the // network is "tcp" and the destination is a host name with both IPv4 and // IPv6 addresses. This allows a client to tolerate networks where one // address family is silently broken. DualStack bool // FallbackDelay specifies the length of time to wait before spawning a // fallback connection, when DualStack is enabled. // If zero, a default delay of 300ms is used. FallbackDelay time.Duration // KeepAlive specifies the keep-alive period for an active network // connection. // If zero, keep-alives are not enabled. Network protocols that do not // support keep-alives ignore this field. KeepAlive time.Duration // Resolver optionally gives a hook to convert the broker address into an // alternate host or IP address which is useful for custom service discovery. // If a custom resolver returns any possible hosts, the first one will be // used and the original discarded. If a port number is included with the // resolved host, it will only be used if a port number was not previously // specified. If no port is specified or resolved, the default of 9092 will be // used. Resolver Resolver // TLS enables Dialer to open secure connections. If nil, standard net.Conn // will be used. TLS *tls.Config // SASLMechanism configures the Dialer to use SASL authentication. If nil, // no authentication will be performed. SASLMechanism sasl.Mechanism // The transactional id to use for transactional delivery. Idempotent // deliver should be enabled if transactional id is configured. // For more details look at transactional.id description here: http://kafka.apache.org/documentation.html#producerconfigs // Empty string means that the connection will be non-transactional. TransactionalID string } // Dial connects to the address on the named network. func (d *Dialer) Dial(network string, address string) (*Conn, error) { return d.DialContext(context.Background(), network, address) } // DialContext connects to the address on the named network using the provided // context. // // The provided Context must be non-nil. If the context expires before the // connection is complete, an error is returned. Once successfully connected, // any expiration of the context will not affect the connection. // // When using TCP, and the host in the address parameter resolves to multiple // network addresses, any dial timeout (from d.Timeout or ctx) is spread over // each consecutive dial, such that each is given an appropriate fraction of the // time to connect. For example, if a host has 4 IP addresses and the timeout is // 1 minute, the connect to each single address will be given 15 seconds to // complete before trying the next one. func (d *Dialer) DialContext(ctx context.Context, network string, address string) (*Conn, error) { return d.connect( ctx, network, address, ConnConfig{ ClientID: d.ClientID, TransactionalID: d.TransactionalID, }, ) } // DialLeader opens a connection to the leader of the partition for a given // topic. // // The address given to the DialContext method may not be the one that the // connection will end up being established to, because the dialer will lookup // the partition leader for the topic and return a connection to that server. // The original address is only used as a mechanism to discover the // configuration of the kafka cluster that we're connecting to. func (d *Dialer) DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) { p, err := d.LookupPartition(ctx, network, address, topic, partition) if err != nil { return nil, err } return d.DialPartition(ctx, network, address, p) } // DialPartition opens a connection to the leader of the partition specified by partition // descriptor. It's strongly advised to use descriptor of the partition that comes out of // functions LookupPartition or LookupPartitions. func (d *Dialer) DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) { return d.connect(ctx, network, net.JoinHostPort(partition.Leader.Host, strconv.Itoa(partition.Leader.Port)), ConnConfig{ ClientID: d.ClientID, Topic: partition.Topic, Partition: partition.ID, TransactionalID: d.TransactionalID, }) } // LookupLeader searches for the kafka broker that is the leader of the // partition for a given topic, returning a Broker value representing it. func (d *Dialer) LookupLeader(ctx context.Context, network string, address string, topic string, partition int) (Broker, error) { p, err := d.LookupPartition(ctx, network, address, topic, partition) return p.Leader, err } // LookupPartition searches for the description of specified partition id. func (d *Dialer) LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) { c, err := d.DialContext(ctx, network, address) if err != nil { return Partition{}, err } defer c.Close() brkch := make(chan Partition, 1) errch := make(chan error, 1) go func() { for attempt := 0; true; attempt++ { if attempt != 0 { if !sleep(ctx, backoff(attempt, 100*time.Millisecond, 10*time.Second)) { errch <- ctx.Err() return } } partitions, err := c.ReadPartitions(topic) if err != nil { if isTemporary(err) { continue } errch <- err return } for _, p := range partitions { if p.ID == partition { brkch <- p return } } } errch <- UnknownTopicOrPartition }() var prt Partition select { case prt = <-brkch: case err = <-errch: case <-ctx.Done(): err = ctx.Err() } return prt, err } // LookupPartitions returns the list of partitions that exist for the given topic. func (d *Dialer) LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) { conn, err := d.DialContext(ctx, network, address) if err != nil { return nil, err } defer conn.Close() prtch := make(chan []Partition, 1) errch := make(chan error, 1) go func() { if prt, err := conn.ReadPartitions(topic); err != nil { errch <- err } else { prtch <- prt } }() var prt []Partition select { case prt = <-prtch: case err = <-errch: case <-ctx.Done(): err = ctx.Err() } return prt, err } // connectTLS returns a tls.Conn that has already completed the Handshake func (d *Dialer) connectTLS(ctx context.Context, conn net.Conn, config *tls.Config) (tlsConn *tls.Conn, err error) { tlsConn = tls.Client(conn, config) errch := make(chan error) go func() { defer close(errch) errch <- tlsConn.Handshake() }() select { case <-ctx.Done(): conn.Close() tlsConn.Close() <-errch // ignore possible error from Handshake err = ctx.Err() case err = <-errch: } return } // connect opens a socket connection to the broker, wraps it to create a // kafka connection, and performs SASL authentication if configured to do so. func (d *Dialer) connect(ctx context.Context, network, address string, connCfg ConnConfig) (*Conn, error) { if d.Timeout != 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, d.Timeout) defer cancel() } if !d.Deadline.IsZero() { var cancel context.CancelFunc ctx, cancel = context.WithDeadline(ctx, d.Deadline) defer cancel() } c, err := d.dialContext(ctx, network, address) if err != nil { return nil, err } conn := NewConnWith(c, connCfg) if d.SASLMechanism != nil { if err := d.authenticateSASL(ctx, conn); err != nil { _ = conn.Close() return nil, err } } return conn, nil } // authenticateSASL performs all of the required requests to authenticate this // connection. If any step fails, this function returns with an error. A nil // error indicates successful authentication. // // In case of error, this function *does not* close the connection. That is the // responsibility of the caller. func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error { if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil { return err } sess, state, err := d.SASLMechanism.Start(ctx) if err != nil { return err } for completed := false; !completed; { challenge, err := conn.saslAuthenticate(state) switch err { case nil: case io.EOF: // the broker may communicate a failed exchange by closing the // connection (esp. in the case where we're passing opaque sasl // data over the wire since there's no protocol info). return SASLAuthenticationFailed default: return err } completed, state, err = sess.Next(ctx, challenge) if err != nil { return err } } return nil } func (d *Dialer) dialContext(ctx context.Context, network string, addr string) (net.Conn, error) { address, err := lookupHost(ctx, addr, d.Resolver) if err != nil { return nil, err } dial := d.DialFunc if dial == nil { dial = (&net.Dialer{ LocalAddr: d.LocalAddr, DualStack: d.DualStack, FallbackDelay: d.FallbackDelay, KeepAlive: d.KeepAlive, }).DialContext } conn, err := dial(ctx, network, address) if err != nil { return nil, err } if d.TLS != nil { c := d.TLS // If no ServerName is set, infer the ServerName // from the hostname we're connecting to. if c.ServerName == "" { c = d.TLS.Clone() // Copied from tls.go in the standard library. colonPos := strings.LastIndex(address, ":") if colonPos == -1 { colonPos = len(address) } hostname := address[:colonPos] c.ServerName = hostname } return d.connectTLS(ctx, conn, c) } return conn, nil } // DefaultDialer is the default dialer used when none is specified. var DefaultDialer = &Dialer{ Timeout: 10 * time.Second, DualStack: true, } // Dial is a convenience wrapper for DefaultDialer.Dial. func Dial(network string, address string) (*Conn, error) { return DefaultDialer.Dial(network, address) } // DialContext is a convenience wrapper for DefaultDialer.DialContext. func DialContext(ctx context.Context, network string, address string) (*Conn, error) { return DefaultDialer.DialContext(ctx, network, address) } // DialLeader is a convenience wrapper for DefaultDialer.DialLeader. func DialLeader(ctx context.Context, network string, address string, topic string, partition int) (*Conn, error) { return DefaultDialer.DialLeader(ctx, network, address, topic, partition) } // DialPartition is a convenience wrapper for DefaultDialer.DialPartition. func DialPartition(ctx context.Context, network string, address string, partition Partition) (*Conn, error) { return DefaultDialer.DialPartition(ctx, network, address, partition) } // LookupPartition is a convenience wrapper for DefaultDialer.LookupPartition. func LookupPartition(ctx context.Context, network string, address string, topic string, partition int) (Partition, error) { return DefaultDialer.LookupPartition(ctx, network, address, topic, partition) } // LookupPartitions is a convenience wrapper for DefaultDialer.LookupPartitions. func LookupPartitions(ctx context.Context, network string, address string, topic string) ([]Partition, error) { return DefaultDialer.LookupPartitions(ctx, network, address, topic) } func sleep(ctx context.Context, duration time.Duration) bool { if duration == 0 { select { default: return true case <-ctx.Done(): return false } } timer := time.NewTimer(duration) defer timer.Stop() select { case <-timer.C: return true case <-ctx.Done(): return false } } func backoff(attempt int, min time.Duration, max time.Duration) time.Duration { d := time.Duration(attempt*attempt) * min if d > max { d = max } return d } func splitHostPort(s string) (host string, port string) { host, port, _ = net.SplitHostPort(s) if len(host) == 0 && len(port) == 0 { host = s } return } func lookupHost(ctx context.Context, address string, resolver Resolver) (string, error) { host, port := splitHostPort(address) if resolver != nil { resolved, err := resolver.LookupHost(ctx, host) if err != nil { return "", err } // if the resolver doesn't return anything, we'll fall back on the provided // address instead if len(resolved) > 0 { resolvedHost, resolvedPort := splitHostPort(resolved[0]) // we'll always prefer the resolved host host = resolvedHost // in the case of port though, the provided address takes priority, and we // only use the resolved address to set the port when not specified if port == "" { port = resolvedPort } } } if port == "" { port = "9092" } return net.JoinHostPort(host, port), nil }