Skip to content

Commit

Permalink
Add AddPartitionsToTxn and EndTxn functions to client. (segmentio#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhansen2 authored Sep 14, 2021
1 parent d8cd82a commit a63af32
Show file tree
Hide file tree
Showing 10 changed files with 497 additions and 10 deletions.
108 changes: 108 additions & 0 deletions addpartitionstotxn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package kafka

import (
"context"
"fmt"
"net"
"time"

"github.com/segmentio/kafka-go/protocol/addpartitionstotxn"
)

// AddPartitionToTxn represents a partition to be added
// to a transaction.
type AddPartitionToTxn struct {
// Partition is the ID of a partition to add to the transaction.
Partition int
}

// AddPartitionsToTxnRequest is the request structure fo the AddPartitionsToTxn function.
type AddPartitionsToTxnRequest struct {
// Address of the kafka broker to send the request to.
Addr net.Addr

// The transactional id key
TransactionalID string

// The Producer ID (PID) for the current producer session;
// received from an InitProducerID request.
ProducerID int

// The epoch associated with the current producer session for the given PID
ProducerEpoch int

// Mappings of topic names to lists of partitions.
Topics map[string][]AddPartitionToTxn
}

// AddPartitioinsToTxnResponse is the response structure for the AddPartitioinsToTxn function.
type AddPartitioinsToTxnResponse struct {
// The amount of time that the broker throttled the request.
Throttle time.Duration

// Mappings of topic names to partitions being added to a transactions.
Topics map[string][]AddPartitionToTxnPartition
}

// AddPartitionToTxnPartition represents the state of a single partition
// in response to adding to a transaction.
type AddPartitionToTxnPartition struct {
// The ID of the partition.
Partition int

// An error that may have occured when attempting to add the partition
// to a transaction.
//
// The errors contain the kafka error code. Programs may use the standard
// errors.Is function to test the error against kafka error codes.
Error error
}

// AddPartitionsToTnx sends an add partitions to txn request to a kafka broker and returns the response.
func (c *Client) AddPartitionsToTxn(
ctx context.Context,
req *AddPartitionsToTxnRequest,
) (*AddPartitioinsToTxnResponse, error) {
protoReq := &addpartitionstotxn.Request{
TransactionalID: req.TransactionalID,
ProducerID: int64(req.ProducerID),
ProducerEpoch: int16(req.ProducerEpoch),
}
protoReq.Topics = make([]addpartitionstotxn.RequestTopic, 0, len(req.Topics))

for topic, partitions := range req.Topics {
reqTopic := addpartitionstotxn.RequestTopic{
Name: topic,
Partitions: make([]int32, len(partitions)),
}
for i, partition := range partitions {
reqTopic.Partitions[i] = int32(partition.Partition)
}
protoReq.Topics = append(protoReq.Topics, reqTopic)
}

m, err := c.roundTrip(ctx, req.Addr, protoReq)
if err != nil {
return nil, fmt.Errorf("kafka.(*Client).AddPartitionsToTxn: %w", err)
}

r := m.(*addpartitionstotxn.Response)

res := &AddPartitioinsToTxnResponse{
Throttle: makeDuration(r.ThrottleTimeMs),
Topics: make(map[string][]AddPartitionToTxnPartition, len(r.Results)),
}

for _, result := range r.Results {
partitions := make([]AddPartitionToTxnPartition, 0, len(result.Results))
for _, rp := range result.Results {
partitions = append(partitions, AddPartitionToTxnPartition{
Partition: int(rp.PartitionIndex),
Error: makeError(rp.ErrorCode, ""),
})
}
res.Topics[result.Name] = partitions
}

return res, nil
}
126 changes: 126 additions & 0 deletions addpartitionstotxn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package kafka

import (
"context"
"net"
"strconv"
"testing"
"time"

ktesting "github.com/segmentio/kafka-go/testing"
)

func TestClientAddPartitionsToTxn(t *testing.T) {
if !ktesting.KafkaIsAtLeast("0.11.0") {
t.Skip("Skipping test because kafka version is not high enough.")
}
topic1 := makeTopic()
topic2 := makeTopic()

client, shutdown := newLocalClient()
defer shutdown()

err := clientCreateTopic(client, topic1, 3)
if err != nil {
t.Fatal(err)
}

err = clientCreateTopic(client, topic2, 3)
if err != nil {
t.Fatal(err)
}

transactionalID := makeTransactionalID()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
respc, err := waitForCoordinatorIndefinitely(ctx, client, &FindCoordinatorRequest{
Addr: client.Addr,
Key: transactionalID,
KeyType: CoordinatorKeyTypeTransaction,
})
if err != nil {
t.Fatal(err)
}

transactionCoordinator := TCP(net.JoinHostPort(respc.Coordinator.Host, strconv.Itoa(int(respc.Coordinator.Port))))
client, shutdown = newClient(transactionCoordinator)
defer shutdown()

ipResp, err := client.InitProducerID(ctx, &InitProducerIDRequest{
TransactionalID: transactionalID,
TransactionTimeoutMs: 10000,
})
if err != nil {
t.Fatal(err)
}

if ipResp.Error != nil {
t.Fatal(ipResp.Error)
}

defer func() {
err := clientEndTxn(client, &EndTxnRequest{
TransactionalID: transactionalID,
ProducerID: ipResp.Producer.ProducerID,
ProducerEpoch: ipResp.Producer.ProducerID,
Committed: false,
})
if err != nil {
t.Fatal(err)
}
}()

ctx, cancel = context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
resp, err := client.AddPartitionsToTxn(ctx, &AddPartitionsToTxnRequest{
TransactionalID: transactionalID,
ProducerID: ipResp.Producer.ProducerID,
ProducerEpoch: ipResp.Producer.ProducerEpoch,
Topics: map[string][]AddPartitionToTxn{
topic1: {
{
Partition: 0,
},
{
Partition: 1,
},
{
Partition: 2,
},
},
topic2: {
{
Partition: 0,
},
{
Partition: 2,
},
},
},
})
if err != nil {
t.Fatal(err)
}

if len(resp.Topics) != 2 {
t.Errorf("expected responses for 2 topics; got: %d", len(resp.Topics))
}
for topic, partitions := range resp.Topics {
if topic == topic1 {
if len(partitions) != 3 {
t.Errorf("expected 3 partitions in response for topic %s; got: %d", topic, len(partitions))
}
}
if topic == topic2 {
if len(partitions) != 2 {
t.Errorf("expected 2 partitions in response for topic %s; got: %d", topic, len(partitions))
}
}
for _, partition := range partitions {
if partition.Error != nil {
t.Error(partition.Error)
}
}
}
}
11 changes: 11 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ func clientCreateTopic(client *Client, topic string, partitions int) error {
return nil
}

func clientEndTxn(client *Client, req *EndTxnRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
resp, err := client.EndTxn(ctx, req)
if err != nil {
return err
}

return resp.Error
}

func newLocalClient() (*Client, func()) {
return newClient(TCP("localhost"))
}
Expand Down
18 changes: 10 additions & 8 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ func makeGroupID() string {
return fmt.Sprintf("kafka-go-group-%016x", rand.Int63())
}

func makeTransactionalID() string {
return fmt.Sprintf("kafka-go-transactional-id-%016x", rand.Int63())
}

func TestConn(t *testing.T) {
tests := []struct {
scenario string
Expand Down Expand Up @@ -324,13 +328,13 @@ func TestConn(t *testing.T) {
t.Parallel()

nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
var topic1 = makeTopic()
var topic2 = makeTopic()
topic1 := makeTopic()
topic2 := makeTopic()
var t1Reader *Conn
var t2Reader *Conn
var t1Writer *Conn
var t2Writer *Conn
var dialer = &Dialer{}
dialer := &Dialer{}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand Down Expand Up @@ -378,7 +382,6 @@ func testConnFirstOffset(t *testing.T, conn *Conn) {
func testConnWrite(t *testing.T, conn *Conn) {
b := []byte("Hello World!")
n, err := conn.Write(b)

if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -952,11 +955,10 @@ func testConnFetchAndCommitOffsets(t *testing.T, conn *Conn) {
}

func testConnWriteReadConcurrently(t *testing.T, conn *Conn) {

const N = 1000
var msgs = make([]string, N)
var done = make(chan struct{})
var written = make(chan struct{}, N/10)
msgs := make([]string, N)
done := make(chan struct{})
written := make(chan struct{}, N/10)

for i := 0; i != N; i++ {
msgs[i] = strconv.Itoa(i)
Expand Down
61 changes: 61 additions & 0 deletions endtxn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package kafka

import (
"context"
"fmt"
"net"
"time"

"github.com/segmentio/kafka-go/protocol/endtxn"
)

// EndTxnRequest represets a request sent to a kafka broker to end a transaction.
type EndTxnRequest struct {
// Address of the kafka broker to send the request to.
Addr net.Addr

// The transactional id key.
TransactionalID string

// The Producer ID (PID) for the current producer session
ProducerID int

// The epoch associated with the current producer session for the given PID
ProducerEpoch int

// Committed should be set to true if the transaction was commited, false otherwise.
Committed bool
}

// EndTxnResponse represents a resposne from a kafka broker to an end transaction request.
type EndTxnResponse struct {
// The amount of time that the broker throttled the request.
Throttle time.Duration

// Error is non-nil if an error occureda and contains the kafka error code.
// Programs may use the standard errors.Is function to test the error
// against kafka error codes.
Error error
}

// EndTxn sends an EndTxn request to a kafka broker and returns its response.
func (c *Client) EndTxn(ctx context.Context, req *EndTxnRequest) (*EndTxnResponse, error) {
m, err := c.roundTrip(ctx, req.Addr, &endtxn.Request{
TransactionalID: req.TransactionalID,
ProducerID: int64(req.ProducerID),
ProducerEpoch: int16(req.ProducerEpoch),
Committed: req.Committed,
})
if err != nil {
return nil, fmt.Errorf("kafka.(*Client).EndTxn: %w", err)
}

r := m.(*endtxn.Response)

res := &EndTxnResponse{
Throttle: makeDuration(r.ThrottleTimeMs),
Error: makeError(r.ErrorCode, ""),
}

return res, nil
}
6 changes: 4 additions & 2 deletions initproducerid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestClientInitProducerId(t *testing.T) {
resp, err := client.InitProducerID(context.Background(), &InitProducerIDRequest{
Addr: transactionCoordinator,
TransactionalID: tid,
TransactionTimeoutMs: 3000,
TransactionTimeoutMs: 30000,
})
if err != nil {
t.Fatal(err)
Expand All @@ -57,7 +57,9 @@ func TestClientInitProducerId(t *testing.T) {
resp, err = client.InitProducerID(context.Background(), &InitProducerIDRequest{
Addr: transactionCoordinator,
TransactionalID: tid,
TransactionTimeoutMs: 3000,
TransactionTimeoutMs: 30000,
ProducerID: pid1,
ProducerEpoch: epoch1,
})
if err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit a63af32

Please sign in to comment.