Skip to content

Commit

Permalink
support saslAuthenticate v0 in protocol package (segmentio#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhansen2 authored Apr 1, 2022
1 parent ae86f55 commit 4296f73
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
4 changes: 4 additions & 0 deletions protocol/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func (c *Conn) RoundTrip(msg Message) (Message, error) {
p.Prepare(apiVersion)
}

if raw, ok := msg.(RawExchanger); ok && raw.Required(versions) {
return raw.RawExchange(c)
}

return RoundTrip(c, apiVersion, correlationID, c.clientID, msg)
}

Expand Down
16 changes: 16 additions & 0 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ type Partition struct {
Offline []int32
}

// RawExchanger is an extention to the Message interface to allow messages
// to control the request response cycle for the message. This is currently
// only used to facilitate v0 SASL Authenticate requests being written in
// a non-standard fashion when the SASL Handshake was done at v0 but not
// when done at v1.
type RawExchanger interface {
// Required should return true when a RawExchange is needed.
// The passed in versions are the negotiated versions for the connection
// performing the request.
Required(versions map[ApiKey]int16) bool
// RawExchange is given the raw connection to the broker and the Message
// is responsible for writing itself to the connection as well as reading
// the response.
RawExchange(rw io.ReadWriter) (Message, error)
}

// BrokerMessage is an extension of the Message interface implemented by some
// request types to customize the broker assignment logic.
type BrokerMessage interface {
Expand Down
46 changes: 45 additions & 1 deletion protocol/saslauthenticate/saslauthenticate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package saslauthenticate

import "github.com/segmentio/kafka-go/protocol"
import (
"encoding/binary"
"io"

"github.com/segmentio/kafka-go/protocol"
)

func init() {
protocol.Register(&Request{}, &Response{})
Expand All @@ -10,6 +15,43 @@ type Request struct {
AuthBytes []byte `kafka:"min=v0,max=v1"`
}

func (r *Request) RawExchange(rw io.ReadWriter) (protocol.Message, error) {
if err := r.writeTo(rw); err != nil {
return nil, err
}
return r.readResp(rw)
}

func (*Request) Required(versions map[protocol.ApiKey]int16) bool {
const v0 = 0
return versions[protocol.SaslHandshake] == v0
}

func (r *Request) writeTo(w io.Writer) error {
size := len(r.AuthBytes) + 4
buf := make([]byte, size, size)
binary.BigEndian.PutUint32(buf[:4], uint32(len(r.AuthBytes)))
copy(buf[4:], r.AuthBytes)
_, err := w.Write(buf)
return err
}

func (r *Request) readResp(read io.Reader) (protocol.Message, error) {
var lenBuf [4]byte
if _, err := io.ReadFull(read, lenBuf[:]); err != nil {
return nil, err
}
respLen := int32(binary.BigEndian.Uint32(lenBuf[:]))
data := make([]byte, respLen)

if _, err := io.ReadFull(read, data[:]); err != nil {
return nil, err
}
return &Response{
AuthBytes: data,
}, nil
}

func (r *Request) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate }

type Response struct {
Expand All @@ -20,3 +62,5 @@ type Response struct {
}

func (r *Response) ApiKey() protocol.ApiKey { return protocol.SaslAuthenticate }

var _ protocol.RawExchanger = (*Request)(nil)
50 changes: 50 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"sync"
"testing"
"time"

"github.com/segmentio/kafka-go/sasl/plain"
)

func TestBatchQueue(t *testing.T) {
Expand Down Expand Up @@ -164,6 +166,10 @@ func TestWriter(t *testing.T) {
scenario: "terminates on an attempt to write a message to a nonexistent topic",
function: testWriterTerminateMissingTopic,
},
{
scenario: "writing a message with SASL Plain authentication",
function: testWriterSasl,
},
}

for _, test := range tests {
Expand Down Expand Up @@ -766,6 +772,50 @@ func testWriterTerminateMissingTopic(t *testing.T) {
}
}

func testWriterSasl(t *testing.T) {
topic := makeTopic()
defer deleteTopic(t, topic)
dialer := &Dialer{
Timeout: 10 * time.Second,
SASLMechanism: plain.Mechanism{
Username: "adminplain",
Password: "admin-secret",
},
}

w := newTestWriter(WriterConfig{
Dialer: dialer,
Topic: topic,
Brokers: []string{"localhost:9093"},
})

w.AllowAutoTopicCreation = true

defer w.Close()

msg := Message{Key: []byte("key"), Value: []byte("Hello World")}

var err error
const retries = 5
for i := 0; i < retries; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err = w.WriteMessages(ctx, msg)
if errors.Is(err, LeaderNotAvailable) || errors.Is(err, context.DeadlineExceeded) {
time.Sleep(time.Millisecond * 250)
continue
}

if err != nil {
t.Errorf("unexpected error %v", err)
return
}
}
if err != nil {
t.Errorf("unable to create topic %v", err)
}
}

type staticBalancer struct {
partition int
}
Expand Down

0 comments on commit 4296f73

Please sign in to comment.