Skip to content

Commit

Permalink
add encoding.TextMarshaler and encoding.TextUnmarshaler implementatio…
Browse files Browse the repository at this point in the history
…ns (segmentio#754)

* add encoding.TextMarshaler and encoding.TextUnmarshaler implementations

* support numeric codes as well

* PR feedback
  • Loading branch information
Achille authored Oct 5, 2021
1 parent c03923d commit 2d04c4b
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
50 changes: 50 additions & 0 deletions compress/compress.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package compress

import (
"encoding"
"fmt"
"io"
"strconv"
"strings"

"github.com/segmentio/kafka-go/compress/gzip"
"github.com/segmentio/kafka-go/compress/lz4"
Expand All @@ -13,6 +17,7 @@ import (
type Compression int8

const (
None Compression = 0
Gzip Compression = 1
Snappy Compression = 2
Lz4 Compression = 3
Expand All @@ -33,6 +38,50 @@ func (c Compression) String() string {
return "uncompressed"
}

func (c Compression) MarshalText() ([]byte, error) {
return []byte(c.String()), nil
}

func (c *Compression) UnmarshalText(b []byte) error {
switch string(b) {
case "none", "uncompressed":
*c = None
return nil
}

for _, codec := range Codecs[None+1:] {
if codec.Name() == string(b) {
*c = Compression(codec.Code())
return nil
}
}

i, err := strconv.ParseInt(string(b), 10, 64)
if err == nil && i >= 0 && i < int64(len(Codecs)) {
*c = Compression(i)
return nil
}

s := &strings.Builder{}
s.WriteString("none, uncompressed")

for i, codec := range Codecs[None+1:] {
if i < (len(Codecs) - 1) {
s.WriteString(", ")
} else {
s.WriteString(", or ")
}
s.WriteString(codec.Name())
}

return fmt.Errorf("compression format must be one of %s, not %q", s, b)
}

var (
_ encoding.TextMarshaler = Compression(0)
_ encoding.TextUnmarshaler = (*Compression)(nil)
)

// Codec represents a compression codec to encode and decode the messages.
// See : https://cwiki.apache.org/confluence/display/KAFKA/Compression
//
Expand Down Expand Up @@ -66,6 +115,7 @@ var (

// The global table of compression codecs supported by the kafka protocol.
Codecs = [...]Codec{
None: nil,
Gzip: &GzipCodec,
Snappy: &SnappyCodec,
Lz4: &Lz4Codec,
Expand Down
25 changes: 25 additions & 0 deletions compress/compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec pkg.Codec) {
var r1, r2 []byte
var err error

t.Run("text format of "+codec.Name(), func(t *testing.T) {
c := pkg.Compression(codec.Code())
a := strconv.Itoa(int(c))
x := pkg.Compression(-1)
y := pkg.Compression(-1)
b, err := c.MarshalText()
if err != nil {
t.Fatal(err)
}

if err := x.UnmarshalText([]byte(a)); err != nil {
t.Fatal(err)
}
if err := y.UnmarshalText(b); err != nil {
t.Fatal(err)
}

if x != c {
t.Errorf("compression mismatch after marshal/unmarshal: want=%s got=%s", c, x)
}
if y != c {
t.Errorf("compression mismatch after marshal/unmarshal: want=%s got=%s", c, y)
}
})

t.Run("encode with "+codec.Name(), func(t *testing.T) {
r1, err = compress(codec, m.Value)
if err != nil {
Expand Down
30 changes: 30 additions & 0 deletions produce.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package kafka
import (
"bufio"
"context"
"encoding"
"errors"
"fmt"
"net"
"strconv"
"time"

"github.com/segmentio/kafka-go/protocol"
Expand Down Expand Up @@ -33,6 +35,34 @@ func (acks RequiredAcks) String() string {
}
}

func (acks RequiredAcks) MarshalText() ([]byte, error) {
return []byte(acks.String()), nil
}

func (acks *RequiredAcks) UnmarshalText(b []byte) error {
switch string(b) {
case "none":
*acks = RequireNone
case "one":
*acks = RequireOne
case "all":
*acks = RequireAll
default:
x, err := strconv.ParseInt(string(b), 10, 64)
parsed := RequiredAcks(x)
if err != nil || (parsed != RequireNone && parsed != RequireOne && parsed != RequireAll) {
return fmt.Errorf("required acks must be one of none, one, or all, not %q", b)
}
*acks = parsed
}
return nil
}

var (
_ encoding.TextMarshaler = RequiredAcks(0)
_ encoding.TextUnmarshaler = (*RequiredAcks)(nil)
)

// ProduceRequest represents a request sent to a kafka broker to produce records
// to a topic partition.
type ProduceRequest struct {
Expand Down
33 changes: 33 additions & 0 deletions produce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,45 @@ package kafka

import (
"context"
"strconv"
"testing"
"time"

"github.com/segmentio/kafka-go/compress"
)

func TestRequiredAcks(t *testing.T) {
for _, acks := range []RequiredAcks{
RequireNone,
RequireOne,
RequireAll,
} {
t.Run(acks.String(), func(t *testing.T) {
a := strconv.Itoa(int(acks))
x := RequiredAcks(-2)
y := RequiredAcks(-2)
b, err := acks.MarshalText()
if err != nil {
t.Fatal(err)
}

if err := x.UnmarshalText([]byte(a)); err != nil {
t.Fatal(err)
}
if err := y.UnmarshalText(b); err != nil {
t.Fatal(err)
}

if x != acks {
t.Errorf("required acks mismatch after marshal/unmarshal text: want=%s got=%s", acks, x)
}
if y != acks {
t.Errorf("required acks mismatch after marshal/unmarshal value: want=%s got=%s", acks, y)
}
})
}
}

func TestClientProduce(t *testing.T) {
client, topic, shutdown := newLocalClientAndTopic()
defer shutdown()
Expand Down

0 comments on commit 2d04c4b

Please sign in to comment.