Skip to content

Commit

Permalink
Expand messages and add tests (segmentio#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhansen2 authored Sep 14, 2021
1 parent 433e383 commit d8cd82a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
14 changes: 13 additions & 1 deletion initproducerid.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ type InitProducerIDRequest struct {

// Time after which a transaction should time out
TransactionTimeoutMs int

// The Producer ID (PID).
// This is used to disambiguate requests if a transactional id is reused following its expiration.
// Only supported in version >=3 of the request, will be ignore otherwise.
ProducerID int

// The producer's current epoch.
// This will be checked against the producer epoch on the broker,
// and the request will return an error if they do not match.
// Only supported in version >=3 of the request, will be ignore otherwise.
ProducerEpoch int
}

// ProducerSession contains useful information about the producer session from the broker's response
Expand Down Expand Up @@ -48,10 +59,11 @@ type InitProducerIDResponse struct {
// InitProducerID sends a initProducerId request to a kafka broker and returns the
// response.
func (c *Client) InitProducerID(ctx context.Context, req *InitProducerIDRequest) (*InitProducerIDResponse, error) {

m, err := c.roundTrip(ctx, req.Addr, &initproducerid.Request{
TransactionalID: req.TransactionalID,
TransactionTimeoutMs: int32(req.TransactionTimeoutMs),
ProducerID: int64(req.ProducerID),
ProducerEpoch: int16(req.ProducerEpoch),
})
if err != nil {
return nil, fmt.Errorf("kafka.(*Client).InitProducerId: %w", err)
Expand Down
11 changes: 10 additions & 1 deletion initproducerid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ func TestClientInitProducerId(t *testing.T) {
t.Fatal(err)
}

if resp.Error != nil {
t.Fatal(resp.Error)
}

epoch1 := resp.Producer.ProducerEpoch
pid1 := resp.Producer.ProducerID

Expand All @@ -58,11 +62,16 @@ func TestClientInitProducerId(t *testing.T) {
if err != nil {
t.Fatal(err)
}

if resp.Error != nil {
t.Fatal(resp.Error)
}

epoch2 := resp.Producer.ProducerEpoch
pid2 := resp.Producer.ProducerID

if pid1 != pid2 {
t.Fatal("PID should stay the same across producer sessions")
t.Fatalf("PID should stay the same across producer sessions; expected: %v got: %v", pid1, pid2)
}

if epoch2-epoch1 <= 0 {
Expand Down
40 changes: 40 additions & 0 deletions protocol/initproducerid/initproducerid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package initproducerid_test

import (
"testing"

"github.com/segmentio/kafka-go/protocol/initproducerid"
"github.com/segmentio/kafka-go/protocol/prototest"
)

func TestInitProducerIDRequest(t *testing.T) {
for _, version := range []int16{0, 1, 2} {
prototest.TestRequest(t, version, &initproducerid.Request{
TransactionalID: "transactional-id-0",
TransactionTimeoutMs: 1000,
})
}

// Version 2 added:
// ProducerID
// ProducerEpoch
for _, version := range []int16{3, 4} {
prototest.TestRequest(t, version, &initproducerid.Request{
TransactionalID: "transactional-id-0",
TransactionTimeoutMs: 1000,
ProducerID: 10,
ProducerEpoch: 5,
})
}
}

func TestInitProducerIDResponse(t *testing.T) {
for _, version := range []int16{0, 1, 2, 3, 4} {
prototest.TestResponse(t, version, &initproducerid.Response{
ThrottleTimeMs: 1000,
ErrorCode: 9,
ProducerID: 10,
ProducerEpoch: 1000,
})
}
}

0 comments on commit d8cd82a

Please sign in to comment.