Skip to content

Commit

Permalink
lazily load API versions (segmentio#375)
Browse files Browse the repository at this point in the history
* lazily load API versions

* fix deadline management in API Version request
  • Loading branch information
Achille authored Nov 8, 2019
1 parent ae89ccc commit 3a767ae
Showing 1 changed file with 87 additions and 51 deletions.
138 changes: 87 additions & 51 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,30 @@ type Conn struct {
correlationID int32

// number of replica acks required when publishing to a partition
requiredAcks int32
apiVersions map[apiKey]ApiVersion
fetchVersion apiVersion
produceVersion apiVersion
requiredAcks int32

// lazily loaded API versions used by this connection
apiVersions atomic.Value // apiVersions

transactionalID *string
}

type apiVersions map[apiKey]ApiVersion

func (v apiVersions) negotiate(key apiKey, sortedSupportedVersions ...apiVersion) apiVersion {
x := v[key]

for i := len(sortedSupportedVersions) - 1; i >= 0; i-- {
s := sortedSupportedVersions[i]

if apiVersion(x.MaxVersion) >= s {
return s
}
}

return -1
}

// ConnConfig is a configuration object used to create new instances of Conn.
type ConnConfig struct {
ClientID string
Expand Down Expand Up @@ -178,41 +194,41 @@ func NewConnWith(conn net.Conn, config ConnConfig) *Conn {
}},
}},
}).size()
c.selectVersions()
c.fetchMaxBytes = math.MaxInt32 - c.fetchMinSize
return c
}

func (c *Conn) selectVersions() {
var err error
apiVersions, err := c.ApiVersions()
func (c *Conn) negotiateVersion(key apiKey, sortedSupportedVersions ...apiVersion) (apiVersion, error) {
v, err := c.loadVersions()
if err != nil {
c.apiVersions = defaultApiVersions
} else {
c.apiVersions = make(map[apiKey]ApiVersion)
for _, v := range apiVersions {
c.apiVersions[apiKey(v.ApiKey)] = v
}
return -1, err
}
for _, v := range c.apiVersions {
if apiKey(v.ApiKey) == fetchRequest {
switch version := v.MaxVersion; {
case version >= 10:
c.fetchVersion = 10
case version >= 5:
c.fetchVersion = 5
default:
c.fetchVersion = 2
}
}
if apiKey(v.ApiKey) == produceRequest {
if v.MaxVersion >= 7 {
c.produceVersion = 7
} else {
c.produceVersion = 2
}
}
a := v.negotiate(key, sortedSupportedVersions...)
if a < 0 {
return -1, fmt.Errorf("no matching versions were found between the client and the broker for API key %d", key)
}
return a, nil
}

func (c *Conn) loadVersions() (apiVersions, error) {
v, _ := c.apiVersions.Load().(apiVersions)
if v != nil {
return v, nil
}

brokerVersions, err := c.ApiVersions()
if err != nil {
return nil, err
}

v = make(apiVersions, len(brokerVersions))

for _, a := range brokerVersions {
v[apiKey(a.ApiKey)] = a
}

c.apiVersions.Store(v)
return v, nil
}

// Controller requests kafka for the current controller and returns its URL
Expand Down Expand Up @@ -767,10 +783,15 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
return &Batch{err: dontExpectEOF(err)}
}

fetchVersion, err := c.negotiateVersion(fetchRequest, v2, v5, v10)
if err != nil {
return &Batch{err: dontExpectEOF(err)}
}

id, err := c.doRequest(&c.rdeadline, func(deadline time.Time, id int32) error {
now := time.Now()
deadline = adjustDeadlineForRTT(deadline, now, defaultRTT)
switch c.fetchVersion {
switch fetchVersion {
case v10:
return c.wb.writeFetchRequestV10(
id,
Expand Down Expand Up @@ -821,7 +842,7 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
var highWaterMark int64
var remain int

switch c.fetchVersion {
switch fetchVersion {
case v10:
throttle, highWaterMark, remain, err = readFetchResponseHeaderV10(&c.rbuf, size)
case v5:
Expand Down Expand Up @@ -1033,7 +1054,6 @@ func (c *Conn) WriteCompressedMessagesAt(codec CompressionCodec, msgs ...Message
}

func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message) (nbytes int, partition int32, offset int64, appendTime time.Time, err error) {

if len(msgs) == 0 {
return
}
Expand All @@ -1058,12 +1078,17 @@ func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message)
nbytes += len(msg.Key) + len(msg.Value)
}

var produceVersion apiVersion
if produceVersion, err = c.negotiateVersion(produceRequest, v2, v3, v7); err != nil {
return
}

err = c.writeOperation(
func(deadline time.Time, id int32) error {
now := time.Now()
deadline = adjustDeadlineForRTT(deadline, now, defaultRTT)
switch version := c.apiVersions[produceRequest].MaxVersion; {
case version >= 7:
switch produceVersion {
case v7:
recordBatch, err :=
newRecordBatch(
codec,
Expand All @@ -1082,7 +1107,7 @@ func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message)
c.transactionalID,
recordBatch,
)
case version >= 3:
case v3:
recordBatch, err :=
newRecordBatch(
codec,
Expand Down Expand Up @@ -1126,7 +1151,7 @@ func (c *Conn) writeCompressedMessages(codec CompressionCodec, msgs ...Message)
// Read the list of partitions, there should be only one since
// we've produced a message to a single partition.
size, err = readArrayWith(r, size, func(r *bufio.Reader, size int) (int, error) {
switch c.produceVersion {
switch produceVersion {
case v7:
var p produceResponsePartitionV7
size, err := p.readFrom(r, size)
Expand Down Expand Up @@ -1373,26 +1398,33 @@ var defaultApiVersions map[apiKey]ApiVersion = map[apiKey]ApiVersion{
}

func (c *Conn) ApiVersions() ([]ApiVersion, error) {
id, err := c.doRequest(&c.rdeadline, func(deadline time.Time, id int32) error {
now := time.Now()
deadline = adjustDeadlineForRTT(deadline, now, defaultRTT)
deadline := &c.rdeadline

if deadline.deadline().IsZero() {
// ApiVersions is called automatically when API version negotiation
// needs to happen, so we are not garanteed that a read deadline has
// been set yet. Fallback to use the write deadline in case it was
// set, for example when version negotiation is initiated during a
// produce request.
deadline = &c.wdeadline
}

id, err := c.doRequest(deadline, func(_ time.Time, id int32) error {
h := requestHeader{
ApiKey: int16(apiVersionsRequest),
ApiVersion: int16(v0),
CorrelationID: id,
ClientID: c.clientID,
}
h.Size = (h.size() - 4)

h.writeTo(&c.wb)
return c.wbuf.Flush()
})
if err != nil {
return nil, err
}

_, size, lock, err := c.waitResponse(&c.rdeadline, id)
_, size, lock, err := c.waitResponse(deadline, id)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1503,12 +1535,13 @@ func (c *Conn) saslHandshake(mechanism string) error {
// number will affect how the SASL authentication
// challenge/responses are sent
var resp saslHandshakeResponseV0
version := v0
if c.apiVersions[saslHandshakeRequest].MaxVersion >= 1 {
version = v1

version, err := c.negotiateVersion(saslHandshakeRequest, v0, v1)
if err != nil {
return err
}

err := c.writeOperation(
err = c.writeOperation(
func(deadline time.Time, id int32) error {
return c.writeRequest(saslHandshakeRequest, version, id, &saslHandshakeRequestV0{Mechanism: mechanism})
},
Expand All @@ -1532,7 +1565,11 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) {
// if we sent a v1 handshake, then we must encapsulate the authentication
// request in a saslAuthenticateRequest. otherwise, we read and write raw
// bytes.
if c.apiVersions[saslHandshakeRequest].MaxVersion >= 1 {
version, err := c.negotiateVersion(saslHandshakeRequest, v0, v1)
if err != nil {
return nil, err
}
if version == v1 {
var request = saslAuthenticateRequestV0{Data: data}
var response saslAuthenticateResponseV0

Expand Down Expand Up @@ -1563,8 +1600,7 @@ func (c *Conn) saslAuthenticate(data []byte) ([]byte, error) {
}

var respLen int32
_, err := readInt32(&c.rbuf, 4, &respLen)
if err != nil {
if _, err := readInt32(&c.rbuf, 4, &respLen); err != nil {
return nil, err
}

Expand Down

0 comments on commit 3a767ae

Please sign in to comment.