diff --git a/compress/compress.go b/compress/compress.go index 52006077e..6e92968f2 100644 --- a/compress/compress.go +++ b/compress/compress.go @@ -1,7 +1,11 @@ package compress import ( + "encoding" + "fmt" "io" + "strconv" + "strings" "github.com/segmentio/kafka-go/compress/gzip" "github.com/segmentio/kafka-go/compress/lz4" @@ -13,6 +17,7 @@ import ( type Compression int8 const ( + None Compression = 0 Gzip Compression = 1 Snappy Compression = 2 Lz4 Compression = 3 @@ -33,6 +38,50 @@ func (c Compression) String() string { return "uncompressed" } +func (c Compression) MarshalText() ([]byte, error) { + return []byte(c.String()), nil +} + +func (c *Compression) UnmarshalText(b []byte) error { + switch string(b) { + case "none", "uncompressed": + *c = None + return nil + } + + for _, codec := range Codecs[None+1:] { + if codec.Name() == string(b) { + *c = Compression(codec.Code()) + return nil + } + } + + i, err := strconv.ParseInt(string(b), 10, 64) + if err == nil && i >= 0 && i < int64(len(Codecs)) { + *c = Compression(i) + return nil + } + + s := &strings.Builder{} + s.WriteString("none, uncompressed") + + for i, codec := range Codecs[None+1:] { + if i < (len(Codecs) - 1) { + s.WriteString(", ") + } else { + s.WriteString(", or ") + } + s.WriteString(codec.Name()) + } + + return fmt.Errorf("compression format must be one of %s, not %q", s, b) +} + +var ( + _ encoding.TextMarshaler = Compression(0) + _ encoding.TextUnmarshaler = (*Compression)(nil) +) + // Codec represents a compression codec to encode and decode the messages. // See : https://cwiki.apache.org/confluence/display/KAFKA/Compression // @@ -66,6 +115,7 @@ var ( // The global table of compression codecs supported by the kafka protocol. Codecs = [...]Codec{ + None: nil, Gzip: &GzipCodec, Snappy: &SnappyCodec, Lz4: &Lz4Codec, diff --git a/compress/compress_test.go b/compress/compress_test.go index 77f70f663..2fb3d76b9 100644 --- a/compress/compress_test.go +++ b/compress/compress_test.go @@ -85,6 +85,31 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec pkg.Codec) { var r1, r2 []byte var err error + t.Run("text format of "+codec.Name(), func(t *testing.T) { + c := pkg.Compression(codec.Code()) + a := strconv.Itoa(int(c)) + x := pkg.Compression(-1) + y := pkg.Compression(-1) + b, err := c.MarshalText() + if err != nil { + t.Fatal(err) + } + + if err := x.UnmarshalText([]byte(a)); err != nil { + t.Fatal(err) + } + if err := y.UnmarshalText(b); err != nil { + t.Fatal(err) + } + + if x != c { + t.Errorf("compression mismatch after marshal/unmarshal: want=%s got=%s", c, x) + } + if y != c { + t.Errorf("compression mismatch after marshal/unmarshal: want=%s got=%s", c, y) + } + }) + t.Run("encode with "+codec.Name(), func(t *testing.T) { r1, err = compress(codec, m.Value) if err != nil { diff --git a/produce.go b/produce.go index bbf34b7fa..63882b4fc 100644 --- a/produce.go +++ b/produce.go @@ -3,9 +3,11 @@ package kafka import ( "bufio" "context" + "encoding" "errors" "fmt" "net" + "strconv" "time" "github.com/segmentio/kafka-go/protocol" @@ -33,6 +35,34 @@ func (acks RequiredAcks) String() string { } } +func (acks RequiredAcks) MarshalText() ([]byte, error) { + return []byte(acks.String()), nil +} + +func (acks *RequiredAcks) UnmarshalText(b []byte) error { + switch string(b) { + case "none": + *acks = RequireNone + case "one": + *acks = RequireOne + case "all": + *acks = RequireAll + default: + x, err := strconv.ParseInt(string(b), 10, 64) + parsed := RequiredAcks(x) + if err != nil || (parsed != RequireNone && parsed != RequireOne && parsed != RequireAll) { + return fmt.Errorf("required acks must be one of none, one, or all, not %q", b) + } + *acks = parsed + } + return nil +} + +var ( + _ encoding.TextMarshaler = RequiredAcks(0) + _ encoding.TextUnmarshaler = (*RequiredAcks)(nil) +) + // ProduceRequest represents a request sent to a kafka broker to produce records // to a topic partition. type ProduceRequest struct { diff --git a/produce_test.go b/produce_test.go index 68347a1d7..49437781d 100644 --- a/produce_test.go +++ b/produce_test.go @@ -2,12 +2,45 @@ package kafka import ( "context" + "strconv" "testing" "time" "github.com/segmentio/kafka-go/compress" ) +func TestRequiredAcks(t *testing.T) { + for _, acks := range []RequiredAcks{ + RequireNone, + RequireOne, + RequireAll, + } { + t.Run(acks.String(), func(t *testing.T) { + a := strconv.Itoa(int(acks)) + x := RequiredAcks(-2) + y := RequiredAcks(-2) + b, err := acks.MarshalText() + if err != nil { + t.Fatal(err) + } + + if err := x.UnmarshalText([]byte(a)); err != nil { + t.Fatal(err) + } + if err := y.UnmarshalText(b); err != nil { + t.Fatal(err) + } + + if x != acks { + t.Errorf("required acks mismatch after marshal/unmarshal text: want=%s got=%s", acks, x) + } + if y != acks { + t.Errorf("required acks mismatch after marshal/unmarshal value: want=%s got=%s", acks, y) + } + }) + } +} + func TestClientProduce(t *testing.T) { client, topic, shutdown := newLocalClientAndTopic() defer shutdown()