Skip to content

Commit

Permalink
Refactor using buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
Pryz committed Jul 3, 2018
1 parent 1b8bebd commit 262746a
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 51 deletions.
12 changes: 8 additions & 4 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ const (
// the messages.
type CompressionCodec struct {
str func() string
encode func(src []byte, level int) ([]byte, error)
decode func(src []byte) ([]byte, error)
encode func(dst, src []byte) (int, error)
decode func(dst, src []byte) (int, error)
}

const compressionCodecMask int8 = 0x03
Expand All @@ -21,8 +21,12 @@ const defaultCompressionLevel int = -1
func init() {
RegisterCompressionCodec(0,
func() string { return "none" },
func(src []byte, level int) ([]byte, error) { return src, nil },
func(src []byte) ([]byte, error) { return src, nil },
func(dst, src []byte) (int, error) {
return copy(dst, src), nil
},
func(dst, src []byte) (int, error) {
return copy(dst, src), nil
},
)
}

Expand Down
5 changes: 2 additions & 3 deletions compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec int8) {
t.Error(err)
}
})

t.Run("encode with "+codecToStr(codec), func(t *testing.T) {
r2, err = r1.Decode()
if err != nil {
t.Error(err)
}
if string(r2.Value) != "message" {
t.Error("bad message")
t.Log("got: ", string(m.Value))
t.Log("expected: message")
t.Log("got: ", r2.Value)
t.Log("expected: ", []byte("message"))
}
})
}
Expand Down
39 changes: 32 additions & 7 deletions gzip/gzip.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,49 @@ func String() string {
return "gzip"
}

func Encode(src []byte, level int) ([]byte, error) {
type buffer struct {
data []byte
size int
}

func (buf *buffer) Write(b []byte) (int, error) {
n := copy(buf.data[buf.size:], b)
buf.size += n
if n != len(b) {
return n, bytes.ErrTooLarge
}
return n, nil
}

func Encode(dst, src []byte) (int, error) {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
_, err := writer.Write(src)
if err != nil {
return nil, err
return 0, err
}
err = writer.Close()
if err != nil {
return nil, err
return 0, err
}
return buf.Bytes(), nil

n, err := buf.WriteTo(&buffer{
data: dst,
})
return int(n), err
}

func Decode(src []byte) ([]byte, error) {
func Decode(dst, src []byte) (int, error) {
reader, err := gzip.NewReader(bytes.NewReader(src))
if err != nil {
return nil, err
return 0, err
}
data, err := ioutil.ReadAll(reader)
if err != nil {
return 0, err
}
buf := buffer{
data: dst,
}
return ioutil.ReadAll(reader)
return buf.Write(data)
}
10 changes: 6 additions & 4 deletions gzip/gzip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,28 @@ import (
)

func TestGzip(t *testing.T) {
var r1, r2 []byte
var err error
payload := []byte("message")
r1 := make([]byte, 6*len(payload))
r2 := make([]byte, len(payload))

t.Run("encode", func(t *testing.T) {
r1, err = Encode(payload, 1)
n, err := Encode(r1, payload)
if err != nil {
t.Error(err)
}
r1 = r1[:n]
if bytes.Equal(payload, r1) {
t.Error("failed to encode payload")
t.Log("got: ", r1)
}
})

t.Run("decode", func(t *testing.T) {
r2, err = Decode(r1)
n, err := Decode(r2, r1)
if err != nil {
t.Error(err)
}
r2 = r2[:n]
if !bytes.Equal(payload, r2) {
t.Error("failed to decode payload")
t.Log("expected: ", payload)
Expand Down
37 changes: 31 additions & 6 deletions lz4/lz4.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,47 @@ func String() string {
return "lz4"
}

func Encode(src []byte, level int) ([]byte, error) {
type buffer struct {
data []byte
size int
}

func (buf *buffer) Write(b []byte) (int, error) {
n := copy(buf.data[buf.size:], b)
buf.size += n
if n != len(b) {
return n, bytes.ErrTooLarge
}
return n, nil
}

func Encode(dst, src []byte) (int, error) {
var buf bytes.Buffer
writer := lz4.NewWriter(&buf)
_, err := writer.Write(src)
if err != nil {
return nil, err
return 0, err
}
err = writer.Close()
if err != nil {
return nil, err
return 0, err
}

return buf.Bytes(), nil
n, err := buf.WriteTo(&buffer{
data: dst,
})

return int(n), err
}

func Decode(src []byte) ([]byte, error) {
func Decode(dst, src []byte) (int, error) {
reader := lz4.NewReader(bytes.NewReader(src))
return ioutil.ReadAll(reader)
data, err := ioutil.ReadAll(reader)
if err != nil {
return 0, err
}
buf := buffer{
data: dst,
}
return buf.Write(data)
}
10 changes: 6 additions & 4 deletions lz4/lz4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,28 @@ import (
)

func TestLZ4(t *testing.T) {
var r1, r2 []byte
var err error
payload := []byte("message")
r1 := make([]byte, 6*len(payload))
r2 := make([]byte, len(payload))

t.Run("encode", func(t *testing.T) {
r1, err = Encode(payload, 1)
n, err := Encode(r1, payload)
if err != nil {
t.Error(err)
}
r1 = r1[:n]
if bytes.Equal(payload, r1) {
t.Error("failed to encode payload")
t.Log("got: ", r1)
}
})

t.Run("decode", func(t *testing.T) {
r2, err = Decode(r1)
n, err := Decode(r2, r1)
if err != nil {
t.Error(err)
}
r2 = r2[:n]
if !bytes.Equal(payload, r2) {
t.Error("failed to decode payload")
t.Log("expected: ", payload)
Expand Down
35 changes: 20 additions & 15 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kafka

import (
"bufio"
"bytes"
"fmt"
"time"
)
Expand All @@ -11,7 +12,7 @@ import (
var codecs map[int8]CompressionCodec

// RegisterCompressionCodec registers a compression codec so it can be used by a Writer.
func RegisterCompressionCodec(code int8, str func() string, encode func(src []byte, level int) ([]byte, error), decode func(src []byte) ([]byte, error)) error {
func RegisterCompressionCodec(code int8, str func() string, encode, decode func(dst, src []byte) (int, error)) error {
if codecs == nil {
codecs = make(map[int8]CompressionCodec)
}
Expand Down Expand Up @@ -69,35 +70,39 @@ func (msg Message) message() message {

// Encode encodes the Message using the CompressionCodec and CompressionLevel.
func (msg Message) Encode() (Message, error) {
var err error
codec, ok := codecs[msg.CompressionCodec]
if !ok {
return msg, fmt.Errorf("codec %s not imported.", codecToStr(msg.CompressionCodec))
}

encodedValue, err := codec.encode(msg.Value, msg.CompressionLevel)
if err != nil {
return msg, err
}

msg.Value = encodedValue
return msg, nil
msg.Value, err = transform(msg.Value, codec.encode)
return msg, err
}

// Decode decodes the Message using the CompressionCodec.
func (msg Message) Decode() (Message, error) {
var err error
c := msg.message().Attributes & compressionCodecMask
codec, ok := codecs[c]
if !ok {
return msg, fmt.Errorf("codec %s not imported.", codecToStr(msg.CompressionCodec))
}
msg.Value, err = transform(msg.Value, codec.decode)
return msg, err
}

decodedValue, err := codec.decode(msg.Value)
if err != nil {
return msg, err
func transform(value []byte, fn func(dst, src []byte) (int, error)) ([]byte, error) {
res := make([]byte, len(value))
n, err := fn(res, value)
for ; err != nil; n, err = fn(res, value) {
switch err {
case bytes.ErrTooLarge:
res = make([]byte, 2*len(res))
default:
return value, err
}
}

msg.Value = decodedValue
return msg, nil
return res[:n], nil
}

type message struct {
Expand Down
30 changes: 26 additions & 4 deletions snappy/snappy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package snappy

import (
"bytes"

"github.com/eapache/go-xerial-snappy"
kafka "github.com/segmentio/kafka-go"
)
Expand All @@ -13,10 +15,30 @@ func String() string {
return "snappy"
}

func Encode(src []byte, level int) ([]byte, error) {
return snappy.Encode(src), nil
type buffer struct {
data []byte
size int
}

func (buf *buffer) Write(b []byte) (int, error) {
n := copy(buf.data[buf.size:], b)
buf.size += n
if n != len(b) {
return n, bytes.ErrTooLarge
}
return n, nil
}

func Encode(dst, src []byte) (int, error) {
buf := buffer{data: dst}
return buf.Write(snappy.Encode(src))
}

func Decode(src []byte) ([]byte, error) {
return snappy.Decode(src)
func Decode(dst, src []byte) (int, error) {
buf := buffer{data: dst}
data, err := snappy.Decode(src)
if err != nil {
return 0, err
}
return buf.Write(data)
}
10 changes: 6 additions & 4 deletions snappy/snappy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,28 @@ import (
)

func TestSnappy(t *testing.T) {
var r1, r2 []byte
var err error
payload := []byte("message")
r1 := make([]byte, 6*len(payload))
r2 := make([]byte, len(payload))

t.Run("encode", func(t *testing.T) {
r1, err = Encode(payload, 1)
n, err := Encode(r1, payload)
if err != nil {
t.Error(err)
}
r1 = r1[:n]
if bytes.Equal(payload, r1) {
t.Error("failed to encode payload")
t.Log("got: ", r1)
}
})

t.Run("decode", func(t *testing.T) {
r2, err = Decode(r1)
n, err := Decode(r2, r1)
if err != nil {
t.Error(err)
}
r2 = r2[:n]
if !bytes.Equal(payload, r2) {
t.Error("failed to decode payload")
t.Log("expected: ", payload)
Expand Down

0 comments on commit 262746a

Please sign in to comment.