From 59f58f09c9d478be8062bbafc95265c199280b24 Mon Sep 17 00:00:00 2001 From: Achille Date: Sun, 21 Jul 2019 23:13:53 -0700 Subject: [PATCH] Stream compression (#306) * implement stream compression * fix zstd + pass benchmarks * add comments and remove old API * add kafka.CompressionCodec.Name() * avoid a copy when the output buffer for decoding snappy compressed data is large enough * fix xerial framing + validate against go-xerial-snappy * remove unused function * PR feedback * cleanup APIs + improve benchmarks * optimize gzip reader * fix snappy compression codec * support running decompression benchmarks alone --- compression.go | 35 +++-- compression_test.go | 215 +++++++++++++++++---------- go.mod | 1 + go.sum | 2 + gzip/gzip.go | 143 ++++++++++-------- lz4/lz4.go | 103 +++++++------ message.go | 57 +++++--- snappy/snappy.go | 110 ++++++++------ snappy/xerial.go | 327 ++++++++++++++++++++++++++++++++++++++++++ snappy/xerial_test.go | 165 +++++++++++++++++++++ write.go | 126 ++++++++-------- zstd/zstd.go | 198 +++++++++++++++++++++---- 12 files changed, 1129 insertions(+), 353 deletions(-) create mode 100644 snappy/xerial.go create mode 100644 snappy/xerial_test.go diff --git a/compression.go b/compression.go index dd62c1309..db89e7f39 100644 --- a/compression.go +++ b/compression.go @@ -2,19 +2,28 @@ package kafka import ( "errors" + "io" "sync" ) -var errUnknownCodec = errors.New("the compression code is invalid or its codec has not been imported") +const ( + CompressionNoneCode = 0 -var codecs = make(map[int8]CompressionCodec) -var codecsMutex sync.RWMutex + compressionCodecMask = 0x07 +) + +var ( + errUnknownCodec = errors.New("the compression code is invalid or its codec has not been imported") + + codecs = make(map[int8]CompressionCodec) + codecsMutex sync.RWMutex +) // RegisterCompressionCodec registers a compression codec so it can be used by a Writer. -func RegisterCompressionCodec(codec func() CompressionCodec) { - c := codec() +func RegisterCompressionCodec(codec CompressionCodec) { + code := codec.Code() codecsMutex.Lock() - codecs[c.Code()] = c + codecs[code] = codec codecsMutex.Unlock() } @@ -40,12 +49,12 @@ type CompressionCodec interface { // Code returns the compression codec code Code() int8 - // Encode encodes the src data - Encode(src []byte) ([]byte, error) + // Human-readable name for the codec. + Name() string - // Decode decodes the src data - Decode(src []byte) ([]byte, error) -} + // Constructs a new reader which decompresses data from r. + NewReader(r io.Reader) io.ReadCloser -const compressionCodecMask int8 = 0x07 -const CompressionNoneCode = 0 + // Constructs a new writer which writes compressed data to w. + NewWriter(w io.Writer) io.WriteCloser +} diff --git a/compression_test.go b/compression_test.go index 4bb017793..a006e70e4 100644 --- a/compression_test.go +++ b/compression_test.go @@ -1,14 +1,20 @@ package kafka_test import ( + "bytes" + compressGzip "compress/gzip" "context" "fmt" - "math/rand" + "io" + "io/ioutil" + "os" + "path/filepath" "strconv" "testing" + "text/tabwriter" "time" - "github.com/segmentio/kafka-go" + kafka "github.com/segmentio/kafka-go" "github.com/segmentio/kafka-go/gzip" "github.com/segmentio/kafka-go/lz4" "github.com/segmentio/kafka-go/snappy" @@ -29,24 +35,46 @@ func TestCompression(t *testing.T) { } } +func compress(codec kafka.CompressionCodec, src []byte) ([]byte, error) { + b := new(bytes.Buffer) + r := bytes.NewReader(src) + w := codec.NewWriter(b) + if _, err := io.Copy(w, r); err != nil { + w.Close() + return nil, err + } + if err := w.Close(); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +func decompress(codec kafka.CompressionCodec, src []byte) ([]byte, error) { + b := new(bytes.Buffer) + r := codec.NewReader(bytes.NewReader(src)) + if _, err := io.Copy(b, r); err != nil { + r.Close() + return nil, err + } + if err := r.Close(); err != nil { + return nil, err + } + return b.Bytes(), nil +} + func testEncodeDecode(t *testing.T, m kafka.Message, codec kafka.CompressionCodec) { var r1, r2 []byte var err error - var code int8 - - if codec != nil { - code = codec.Code() - } - t.Run("encode with "+codecToStr(code), func(t *testing.T) { - r1, err = codec.Encode(m.Value) + t.Run("encode with "+codec.Name(), func(t *testing.T) { + r1, err = compress(codec, m.Value) if err != nil { t.Error(err) } }) - t.Run("decode with "+codecToStr(code), func(t *testing.T) { - r2, err = codec.Decode(r1) + t.Run("decode with "+codec.Name(), func(t *testing.T) { + r2, err = decompress(codec, r1) if err != nil { t.Error(err) } @@ -58,23 +86,6 @@ func testEncodeDecode(t *testing.T, m kafka.Message, codec kafka.CompressionCode }) } -func codecToStr(codec int8) string { - switch codec { - case kafka.CompressionNoneCode: - return "none" - case gzip.Code: - return "gzip" - case snappy.Code: - return "snappy" - case lz4.Code: - return "lz4" - case zstd.Code: - return "zstd" - default: - return "unknown" - } -} - func TestCompressedMessages(t *testing.T) { testCompressedMessages(t, gzip.NewCompressionCodec()) testCompressedMessages(t, snappy.NewCompressionCodec()) @@ -86,7 +97,7 @@ func TestCompressedMessages(t *testing.T) { } func testCompressedMessages(t *testing.T, codec kafka.CompressionCodec) { - t.Run("produce/consume with"+codecToStr(codec.Code()), func(t *testing.T) { + t.Run("produce/consume with"+codec.Name(), func(t *testing.T) { t.Parallel() topic := kafka.CreateTopic(t, 1) @@ -232,98 +243,148 @@ func (noopCodec) Code() int8 { return 0 } -func (noopCodec) Encode(src []byte) ([]byte, error) { - return src, nil +func (noopCodec) Name() string { + return "none" } -func (noopCodec) Decode(src []byte) ([]byte, error) { - return src, nil +func (noopCodec) NewReader(r io.Reader) io.ReadCloser { + return ioutil.NopCloser(r) } +func (noopCodec) NewWriter(w io.Writer) io.WriteCloser { + return nopWriteCloser{w} +} + +type nopWriteCloser struct{ io.Writer } + +func (nopWriteCloser) Close() error { return nil } + func BenchmarkCompression(b *testing.B) { benchmarks := []struct { - scenario string codec kafka.CompressionCodec - function func(*testing.B, kafka.CompressionCodec, int, map[int][]byte) + function func(*testing.B, kafka.CompressionCodec, *bytes.Buffer, []byte) float64 }{ { - scenario: "None", codec: &noopCodec{}, function: benchmarkCompression, }, { - scenario: "GZIP", codec: gzip.NewCompressionCodec(), function: benchmarkCompression, }, { - scenario: "Snappy", codec: snappy.NewCompressionCodec(), function: benchmarkCompression, }, { - scenario: "LZ4", codec: lz4.NewCompressionCodec(), function: benchmarkCompression, }, { - scenario: "zstd", codec: zstd.NewCompressionCodec(), function: benchmarkCompression, }, } - payload := map[int][]byte{ - 1024: randomPayload(1024), - 4096: randomPayload(4096), - 8192: randomPayload(8192), - 16384: randomPayload(16384), + f, err := os.Open(filepath.Join(os.Getenv("GOROOT"), "src/encoding/json/testdata/code.json.gz")) + if err != nil { + b.Fatal(err) } + defer f.Close() - for _, benchmark := range benchmarks { - b.Run(benchmark.scenario+"1024", func(b *testing.B) { - benchmark.function(b, benchmark.codec, 1024, payload) - }) - b.Run(benchmark.scenario+"4096", func(b *testing.B) { - benchmark.function(b, benchmark.codec, 4096, payload) - }) - b.Run(benchmark.scenario+"8192", func(b *testing.B) { - benchmark.function(b, benchmark.codec, 8192, payload) - }) - b.Run(benchmark.scenario+"16384", func(b *testing.B) { - benchmark.function(b, benchmark.codec, 16384, payload) - }) + z, err := compressGzip.NewReader(f) + if err != nil { + b.Fatal(err) } -} + payload, err := ioutil.ReadAll(z) + if err != nil { + b.Fatal(err) + } -func benchmarkCompression(b *testing.B, codec kafka.CompressionCodec, payloadSize int, payload map[int][]byte) { - msg := kafka.Message{ - Value: payload[payloadSize], + buffer := bytes.Buffer{} + buffer.Grow(len(payload)) + + ts := &bytes.Buffer{} + tw := tabwriter.NewWriter(ts, 0, 8, 0, '\t', 0) + defer func() { + tw.Flush() + fmt.Printf("input => %.2f MB\n", float64(len(payload))/(1024*1024)) + fmt.Println(ts) + }() + + b.ResetTimer() + + for i := range benchmarks { + benchmark := &benchmarks[i] + ratio := 0.0 + + b.Run(fmt.Sprintf("%s", benchmark.codec.Name()), func(b *testing.B) { + ratio = benchmark.function(b, benchmark.codec, &buffer, payload) + }) + + fmt.Fprintf(tw, " %s:\t%.2f%%\n", benchmark.codec.Name(), 100*ratio) } +} - for i := 0; i < b.N; i++ { - m1, err := codec.Encode(msg.Value) - if err != nil { - b.Fatal(err) +func benchmarkCompression(b *testing.B, codec kafka.CompressionCodec, buf *bytes.Buffer, payload []byte) float64 { + // In case only the decompression benchmark are run, we use this flags to + // detect whether we have to compress the payload before the decompression + // benchmarks. + compressed := false + + b.Run("compress", func(b *testing.B) { + compressed = true + r := bytes.NewReader(payload) + + for i := 0; i < b.N; i++ { + buf.Reset() + r.Reset(payload) + w := codec.NewWriter(buf) + + _, err := io.Copy(w, r) + if err != nil { + b.Fatal(err) + } + if err := w.Close(); err != nil { + b.Fatal(err) + } } - b.SetBytes(int64(len(m1))) + b.SetBytes(int64(buf.Len())) + }) + + if !compressed { + r := bytes.NewReader(payload) + w := codec.NewWriter(buf) - _, err = codec.Decode(m1) + _, err := io.Copy(w, r) if err != nil { b.Fatal(err) } - + if err := w.Close(); err != nil { + b.Fatal(err) + } } -} -const dataset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" + b.Run("decompress", func(b *testing.B) { + c := bytes.NewReader(buf.Bytes()) -func randomPayload(n int) []byte { - b := make([]byte, n) - for i := range b { - b[i] = dataset[rand.Intn(len(dataset))] - } - return b + for i := 0; i < b.N; i++ { + c.Reset(buf.Bytes()) + r := codec.NewReader(c) + + n, err := io.Copy(ioutil.Discard, r) + if err != nil { + b.Fatal(err) + } + if err := r.Close(); err != nil { + b.Fatal(err) + } + + b.SetBytes(n) + } + }) + + return 1 - (float64(buf.Len()) / float64(len(payload))) } diff --git a/go.mod b/go.mod index e08dded80..a5557c338 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.11 require ( github.com/DataDog/zstd v1.4.0 + github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 github.com/golang/snappy v0.0.1 github.com/pierrec/lz4 v2.0.5+incompatible github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c diff --git a/go.sum b/go.sum index 651d75197..14f3034f2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/DataDog/zstd v1.4.0 h1:vhoV+DUHnRZdKW1i5UMjAk2G4JY8wN4ayRfYDNdEhwo= github.com/DataDog/zstd v1.4.0/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw= +github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I= diff --git a/gzip/gzip.go b/gzip/gzip.go index ddc414a4c..5743d963e 100644 --- a/gzip/gzip.go +++ b/gzip/gzip.go @@ -3,108 +3,129 @@ package gzip import ( "bytes" "compress/gzip" + "io" "io/ioutil" "sync" - "github.com/segmentio/kafka-go" + kafka "github.com/segmentio/kafka-go" ) var ( - readerPool sync.Pool - writerPool sync.Pool - // emptyGzipBytes is the binary value for an empty file that has been // gzipped. It is used to initialize gzip.Reader before adding it to the // readerPool. - emptyGzipBytes = []byte{ + emptyGzipBytes = [...]byte{ 0x1f, 0x8b, 0x08, 0x08, 0x0d, 0x0c, 0x67, 0x5c, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } -) -func init() { readerPool = sync.Pool{ New: func() interface{} { // if the reader doesn't get valid gzip at initialization time, // it will not be valid and will fail on Reset. - reader, _ := gzip.NewReader(bytes.NewBuffer(emptyGzipBytes)) + reader := &gzipReader{} + reader.Reset(nil) return reader }, } - writerPool = sync.Pool{ - New: func() interface{} { - return gzip.NewWriter(bytes.NewBuffer(nil)) - }, - } +) + +type gzipReader struct { + gzip.Reader + emptyGzipFile bytes.Reader +} - kafka.RegisterCompressionCodec(func() kafka.CompressionCodec { - return NewCompressionCodec() - }) +func (z *gzipReader) Reset(r io.Reader) { + if r == nil { + z.emptyGzipFile.Reset(emptyGzipBytes[:]) + r = &z.emptyGzipFile + } + z.Reader.Reset(r) } -type CompressionCodec struct { - // CompressionLevel is the level of compression to use on messages. - CompressionLevel int +func init() { + kafka.RegisterCompressionCodec(NewCompressionCodec()) } const ( - Code int8 = 1 - DefaultCompressionLevel int = -1 + Code = 1 + + DefaultCompressionLevel = gzip.DefaultCompression ) -func NewCompressionCodec() CompressionCodec { - return NewCompressionCodecWith(DefaultCompressionLevel) +type CompressionCodec struct{ writerPool sync.Pool } + +func NewCompressionCodec() *CompressionCodec { + return NewCompressionCodecLevel(DefaultCompressionLevel) } -func NewCompressionCodecWith(level int) CompressionCodec { - return CompressionCodec{ - CompressionLevel: level, +func NewCompressionCodecLevel(level int) *CompressionCodec { + return &CompressionCodec{ + writerPool: sync.Pool{ + New: func() interface{} { + w, err := gzip.NewWriterLevel(ioutil.Discard, level) + if err != nil { + return err + } + return w + }, + }, } } // Code implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Code() int8 { - return Code +func (c *CompressionCodec) Code() int8 { return Code } + +// Name implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) Name() string { return "gzip" } + +// NewReader implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser { + z := readerPool.Get().(*gzipReader) + z.Reset(r) + return &reader{z} } -// Encode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Encode(src []byte) ([]byte, error) { - buf := bytes.Buffer{} - buf.Grow(len(src)) // guess a size to avoid repeat allocations. - writer := writerPool.Get().(*gzip.Writer) - writer.Reset(&buf) - - _, err := writer.Write(src) - if err != nil { - // don't return writer to pool on error. - return nil, err +// NewWriter implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { + x := c.writerPool.Get() + z, _ := x.(*gzip.Writer) + if z == nil { + return errorWriter{err: x.(error)} } + z.Reset(w) + return &writer{c, z} +} - // note that the gzip reader must be closed in order for it to write - // out trailing contents. Flush is insufficient. it is okay to re-use - // the writer even after it's closed by Resetting it. - err = writer.Close() - if err != nil { - // don't return writer to pool on error. - return nil, err - } +type reader struct{ *gzipReader } - writerPool.Put(writer) +func (r *reader) Close() (err error) { + if z := r.gzipReader; z != nil { + r.gzipReader = nil + err = z.Close() + z.Reset(nil) + readerPool.Put(z) + } + return +} - return buf.Bytes(), err +type writer struct { + c *CompressionCodec + *gzip.Writer } -// Decode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Decode(src []byte) ([]byte, error) { - reader := readerPool.Get().(*gzip.Reader) - err := reader.Reset(bytes.NewReader(src)) - if err != nil { - return nil, err - } - res, err := ioutil.ReadAll(reader) - // only return the reader to pool if the read was a success. - if err == nil { - readerPool.Put(reader) +func (w *writer) Close() (err error) { + if z := w.Writer; z != nil { + w.Writer = nil + err = z.Close() + z.Reset(nil) + w.c.writerPool.Put(z) } - return res, err + return } + +type errorWriter struct{ err error } + +func (w errorWriter) Close() error { return w.err } + +func (w errorWriter) Write(b []byte) (int, error) { return 0, w.err } diff --git a/lz4/lz4.go b/lz4/lz4.go index e7ab05fa4..140c7f9cb 100644 --- a/lz4/lz4.go +++ b/lz4/lz4.go @@ -1,81 +1,74 @@ package lz4 import ( - "bytes" - "io/ioutil" + "io" "sync" "github.com/pierrec/lz4" - "github.com/segmentio/kafka-go" + kafka "github.com/segmentio/kafka-go" ) -var ( - readerPool sync.Pool - writerPool sync.Pool +func init() { + kafka.RegisterCompressionCodec(NewCompressionCodec()) +} + +const ( + Code = 3 ) -func init() { - readerPool = sync.Pool{ - New: func() interface{} { - return lz4.NewReader(nil) - }, - } - writerPool = sync.Pool{ - New: func() interface{} { - return lz4.NewWriter(nil) - }, - } +type CompressionCodec struct{} - kafka.RegisterCompressionCodec(func() kafka.CompressionCodec { - return NewCompressionCodec() - }) +func NewCompressionCodec() *CompressionCodec { + return &CompressionCodec{} } -type CompressionCodec struct{} +// Code implements the kafka.CompressionCodec interface. +func (CompressionCodec) Code() int8 { return Code } -const Code = 3 +// Name implements the kafka.CompressionCodec interface. +func (CompressionCodec) Name() string { return "lz4" } -func NewCompressionCodec() CompressionCodec { - return CompressionCodec{} +// NewReader implements the kafka.CompressionCodec interface. +func (CompressionCodec) NewReader(r io.Reader) io.ReadCloser { + z := readerPool.Get().(*lz4.Reader) + z.Reset(r) + return &reader{z} } -// Code implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Code() int8 { - return Code +// NewWriter implements the kafka.CompressionCodec interface. +func (CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { + z := writerPool.Get().(*lz4.Writer) + z.Reset(w) + return &writer{z} } -// Encode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Encode(src []byte) ([]byte, error) { - buf := bytes.Buffer{} - buf.Grow(len(src)) // guess a size to avoid repeat allocations. - writer := writerPool.Get().(*lz4.Writer) - writer.Reset(&buf) - - _, err := writer.Write(src) - if err != nil { - // don't return writer to pool on error. - return nil, err - } +type reader struct{ *lz4.Reader } - err = writer.Close() - if err != nil { - // don't return writer to pool on error. - return nil, err +func (r *reader) Close() (err error) { + if z := r.Reader; z != nil { + r.Reader = nil + z.Reset(nil) + readerPool.Put(z) } + return +} - writerPool.Put(writer) +type writer struct{ *lz4.Writer } - return buf.Bytes(), err +func (w *writer) Close() (err error) { + if z := w.Writer; z != nil { + w.Writer = nil + err = z.Close() + z.Reset(nil) + writerPool.Put(z) + } + return } -// Decode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Decode(src []byte) ([]byte, error) { - reader := readerPool.Get().(*lz4.Reader) - reader.Reset(bytes.NewReader(src)) - res, err := ioutil.ReadAll(reader) - // only return the reader to pool if the read was a success. - if err == nil { - readerPool.Put(reader) - } - return res, err +var readerPool = sync.Pool{ + New: func() interface{} { return lz4.NewReader(nil) }, +} + +var writerPool = sync.Pool{ + New: func() interface{} { return lz4.NewWriter(nil) }, } diff --git a/message.go b/message.go index 0d083475b..c7b0b1fe2 100644 --- a/message.go +++ b/message.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "fmt" + "io" "time" ) @@ -237,13 +238,19 @@ func (r *messageSetReaderV1) readMessage(min int64, } // read and decompress the contained message set. - var decompressed []byte + var decompressed bytes.Buffer + if r.remain, err = readBytesWith(r.reader, r.remain, func(r *bufio.Reader, sz, n int) (remain int, err error) { - var value []byte - if value, remain, err = readNewBytes(r, sz, n); err != nil { - return - } - decompressed, err = codec.Decode(value) + // x4 as a guess that the average compression ratio is near 75% + decompressed.Grow(4 * n) + + l := io.LimitedReader{R: r, N: int64(n)} + d := codec.NewReader(&l) + + _, err = decompressed.ReadFrom(d) + remain = sz - (n - int(l.N)) + + d.Close() return }); err != nil { return @@ -256,13 +263,16 @@ func (r *messageSetReaderV1) readMessage(min int64, // messages at offsets 10-13, then the container message will have // offset 13 and the contained messages will be 0,1,2,3. the base // offset for the container, then is 13-3=10. - if offset, err = extractOffset(offset, decompressed); err != nil { + if offset, err = extractOffset(offset, decompressed.Bytes()); err != nil { return } r.readerStack = &readerStack{ - reader: bufio.NewReader(bytes.NewReader(decompressed)), - remain: len(decompressed), + // Allocate a buffer of size 0, which gets capped at 16 bytes + // by the bufio package. We are already reading buffered data + // here, no need to reserve another 4KB buffer. + reader: bufio.NewReaderSize(&decompressed, 0), + remain: decompressed.Len(), base: offset, parent: r.readerStack, } @@ -458,33 +468,40 @@ func (r *messageSetReaderV2) readMessage(min int64, r.readerStack = r.parent } } + if err = r.readHeader(); err != nil { return } - code := r.header.compression() - var decompressed []byte - if code != 0 { + + if code := r.header.compression(); code != 0 { var codec CompressionCodec if codec, err = resolveCodec(code); err != nil { return } - batchRemain := int(r.header.length - 49) + + var batchRemain = int(r.header.length - 49) if batchRemain > r.remain { err = errShortRead return } - var compressed []byte - compressed, r.remain, err = readNewBytes(r.reader, r.remain, batchRemain) + + var decompressed bytes.Buffer + decompressed.Grow(4 * batchRemain) + + l := io.LimitedReader{R: r.reader, N: int64(batchRemain)} + d := codec.NewReader(&l) + + _, err = decompressed.ReadFrom(d) + r.remain = r.remain - (batchRemain - int(l.N)) + d.Close() + if err != nil { return } - if decompressed, err = codec.Decode(compressed); err != nil { - return - } r.readerStack = &readerStack{ - reader: bufio.NewReader(bytes.NewReader(decompressed)), - remain: len(decompressed), + reader: bufio.NewReaderSize(&decompressed, 0), + remain: decompressed.Len(), base: -1, // base is unused here parent: r.readerStack, } diff --git a/snappy/snappy.go b/snappy/snappy.go index a6ca858cc..ff77cb228 100644 --- a/snappy/snappy.go +++ b/snappy/snappy.go @@ -1,68 +1,92 @@ package snappy import ( - "bytes" - "encoding/binary" + "io" + "sync" "github.com/golang/snappy" - "github.com/segmentio/kafka-go" + kafka "github.com/segmentio/kafka-go" ) func init() { - kafka.RegisterCompressionCodec(func() kafka.CompressionCodec { - return NewCompressionCodec() - }) + kafka.RegisterCompressionCodec(NewCompressionCodec()) } -type CompressionCodec struct{} +// Framing is an enumeration type used to enable or disable xerial framing of +// snappy messages. +type Framing int -const Code = 2 +const ( + Framed Framing = iota + Unframed +) + +const ( + Code = 2 +) + +type CompressionCodec struct{ framing Framing } -func NewCompressionCodec() CompressionCodec { - return CompressionCodec{} +func NewCompressionCodec() *CompressionCodec { + return NewCompressionCodecFraming(Framed) } -// Code implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Code() int8 { - return Code +func NewCompressionCodecFraming(framing Framing) *CompressionCodec { + return &CompressionCodec{framing} } -// Encode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Encode(src []byte) ([]byte, error) { - // NOTE : passing a nil dst means snappy will allocate it. - return snappy.Encode(nil, src), nil +// Code implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) Code() int8 { return Code } + +// Name implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) Name() string { return "snappy" } + +// NewReader implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser { + x := readerPool.Get().(*xerialReader) + x.Reset(r) + return &reader{x} } -// Decode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Decode(src []byte) ([]byte, error) { - return decode(src) +// NewWriter implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { + x := writerPool.Get().(*xerialWriter) + x.Reset(w) + x.framed = c.framing == Framed + return &writer{x} } -var xerialHeader = []byte{130, 83, 78, 65, 80, 80, 89, 0} +type reader struct{ *xerialReader } -// From github.com/eapache/go-xerial-snappy -func decode(src []byte) ([]byte, error) { - if !bytes.Equal(src[:8], xerialHeader) { - return snappy.Decode(nil, src) +func (r *reader) Close() (err error) { + if x := r.xerialReader; x != nil { + r.xerialReader = nil + x.Reset(nil) + readerPool.Put(x) } + return +} - var ( - pos = uint32(16) - max = uint32(len(src)) - dst = make([]byte, 0, len(src)) - chunk []byte - err error - ) - for pos < max { - size := binary.BigEndian.Uint32(src[pos : pos+4]) - pos += 4 - - chunk, err = snappy.Decode(chunk, src[pos:pos+size]) - if err != nil { - return nil, err - } - pos += size - dst = append(dst, chunk...) +type writer struct{ *xerialWriter } + +func (w *writer) Close() (err error) { + if x := w.xerialWriter; x != nil { + w.xerialWriter = nil + err = x.Flush() + x.Reset(nil) + writerPool.Put(x) } - return dst, nil + return +} + +var readerPool = sync.Pool{ + New: func() interface{} { + return &xerialReader{decode: snappy.Decode} + }, +} + +var writerPool = sync.Pool{ + New: func() interface{} { + return &xerialWriter{encode: snappy.Encode} + }, } diff --git a/snappy/xerial.go b/snappy/xerial.go new file mode 100644 index 000000000..23d523ccf --- /dev/null +++ b/snappy/xerial.go @@ -0,0 +1,327 @@ +package snappy + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/golang/snappy" +) + +const defaultBufferSize = 32 * 1024 + +// An implementation of io.Reader which consumes a stream of xerial-framed +// snappy-encoeded data. The framing is optional, if no framing is detected +// the reader will simply forward the bytes from its underlying stream. +type xerialReader struct { + reader io.Reader + header [16]byte + input []byte + output []byte + offset int64 + nbytes int64 + decode func([]byte, []byte) ([]byte, error) +} + +func (x *xerialReader) Reset(r io.Reader) { + x.reader = r + x.input = x.input[:0] + x.output = x.output[:0] + x.offset = 0 + x.nbytes = 0 +} + +func (x *xerialReader) Read(b []byte) (int, error) { + for { + if x.offset < int64(len(x.output)) { + n := copy(b, x.output[x.offset:]) + x.offset += int64(n) + return n, nil + } + + n, err := x.readChunk(b) + if err != nil { + return 0, err + } + if n > 0 { + return n, nil + } + } +} + +func (x *xerialReader) WriteTo(w io.Writer) (int64, error) { + wn := int64(0) + + for { + for x.offset < int64(len(x.output)) { + n, err := w.Write(x.output[x.offset:]) + wn += int64(n) + x.offset += int64(n) + if err != nil { + return wn, err + } + } + + if _, err := x.readChunk(nil); err != nil { + if err == io.EOF { + err = nil + } + return wn, err + } + } +} + +func (x *xerialReader) readChunk(dst []byte) (int, error) { + x.offset = 0 + prefix := 0 + + if x.nbytes == 0 { + n, err := x.readFull(x.header[:]) + if err != nil && n == 0 { + return 0, err + } + prefix = n + } + + if isXerialHeader(x.header[:]) { + if cap(x.input) < 4 { + x.input = make([]byte, 4, defaultBufferSize) + } else { + x.input = x.input[:4] + } + + _, err := x.readFull(x.input) + if err != nil { + return 0, err + } + + frame := int(binary.BigEndian.Uint32(x.input)) + if cap(x.input) < frame { + x.input = make([]byte, frame, align(frame, defaultBufferSize)) + } else { + x.input = x.input[:frame] + } + + if _, err := x.readFull(x.input); err != nil { + return 0, err + } + } else { + if cap(x.input) == 0 { + x.input = make([]byte, 0, defaultBufferSize) + } else { + x.input = x.input[:0] + } + + if prefix > 0 { + x.input = append(x.input, x.header[:prefix]...) + } + + for { + if len(x.input) == cap(x.input) { + b := make([]byte, len(x.input), 2*cap(x.input)) + copy(b, x.input) + x.input = b + } + + n, err := x.read(x.input[len(x.input):cap(x.input)]) + x.input = x.input[:len(x.input)+n] + if err != nil { + if err == io.EOF && len(x.input) > 0 { + break + } + return 0, err + } + } + } + + var n int + var err error + + if x.decode == nil { + x.output, x.input, err = x.input, x.output, nil + } else if n, err = snappy.DecodedLen(x.input); n <= len(dst) && err == nil { + // If the output buffer is large enough to hold the decode value, + // write it there directly instead of using the intermediary output + // buffer. + _, err = x.decode(dst, x.input) + } else { + var b []byte + n = 0 + b, err = x.decode(x.output[:cap(x.output)], x.input) + if err == nil { + x.output = b + } + } + + return n, err +} + +func (x *xerialReader) read(b []byte) (int, error) { + n, err := x.reader.Read(b) + x.nbytes += int64(n) + return n, err +} + +func (x *xerialReader) readFull(b []byte) (int, error) { + n, err := io.ReadFull(x.reader, b) + x.nbytes += int64(n) + return n, err +} + +// An implementation of a xerial-framed snappy-encoded output stream. +// Each Write made to the writer is framed with a xerial header. +type xerialWriter struct { + writer io.Writer + header [16]byte + input []byte + output []byte + nbytes int64 + framed bool + encode func([]byte, []byte) []byte +} + +func (x *xerialWriter) Reset(w io.Writer) { + x.writer = w + x.input = x.input[:0] + x.output = x.output[:0] + x.nbytes = 0 +} + +func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) { + wn := int64(0) + + if cap(x.input) == 0 { + x.input = make([]byte, 0, defaultBufferSize) + } + + for { + if x.full() { + x.grow() + } + + n, err := r.Read(x.input[len(x.input):cap(x.input)]) + wn += int64(n) + x.input = x.input[:len(x.input)+n] + + if x.fullEnough() { + if err := x.Flush(); err != nil { + return wn, err + } + } + + if err != nil { + if err == io.EOF { + err = nil + } + return wn, err + } + } +} + +func (x *xerialWriter) Write(b []byte) (int, error) { + wn := 0 + + if cap(x.input) == 0 { + x.input = make([]byte, 0, defaultBufferSize) + } + + for len(b) > 0 { + if x.full() { + x.grow() + } + + n := copy(x.input[len(x.input):cap(x.input)], b) + b = b[n:] + wn += n + x.input = x.input[:len(x.input)+n] + + if x.fullEnough() { + if err := x.Flush(); err != nil { + return wn, err + } + } + } + + return wn, nil +} + +func (x *xerialWriter) Flush() error { + if len(x.input) == 0 { + return nil + } + + var b []byte + if x.encode == nil { + b = x.input + } else { + x.output = x.encode(x.output[:cap(x.output)], x.input) + b = x.output + } + + x.input = x.input[:0] + x.output = x.output[:0] + + if x.framed && x.nbytes == 0 { + writeXerialHeader(x.header[:]) + _, err := x.write(x.header[:]) + if err != nil { + return err + } + } + + if x.framed { + writeXerialFrame(x.header[:4], len(b)) + _, err := x.write(x.header[:4]) + if err != nil { + return err + } + } + + _, err := x.write(b) + return err +} + +func (x *xerialWriter) write(b []byte) (int, error) { + n, err := x.writer.Write(b) + x.nbytes += int64(n) + return n, err +} + +func (x *xerialWriter) full() bool { + return len(x.input) == cap(x.input) +} + +func (x *xerialWriter) fullEnough() bool { + return x.framed && (cap(x.input)-len(x.input)) < 1024 +} + +func (x *xerialWriter) grow() { + tmp := make([]byte, len(x.input), 2*cap(x.input)) + copy(tmp, x.input) + x.input = tmp +} + +func align(n, a int) int { + if (n % a) == 0 { + return n + } + return ((n / a) + 1) * a +} + +var ( + xerialHeader = [...]byte{130, 83, 78, 65, 80, 80, 89, 0} + xerialVersionInfo = [...]byte{0, 0, 0, 1, 0, 0, 0, 1} +) + +func isXerialHeader(src []byte) bool { + return len(src) >= 16 && bytes.Equal(src[:8], xerialHeader[:]) +} + +func writeXerialHeader(b []byte) { + copy(b[:8], xerialHeader[:]) + copy(b[8:], xerialVersionInfo[:]) +} + +func writeXerialFrame(b []byte, n int) { + binary.BigEndian.PutUint32(b, uint32(n)) +} diff --git a/snappy/xerial_test.go b/snappy/xerial_test.go new file mode 100644 index 000000000..4545bfe2e --- /dev/null +++ b/snappy/xerial_test.go @@ -0,0 +1,165 @@ +package snappy + +import ( + "bytes" + "crypto/rand" + "io" + "testing" + + goxerialsnappy "github.com/eapache/go-xerial-snappy" + "github.com/golang/snappy" +) + +// Wrap an io.Reader or io.Writer to disable all copy optimizations like +// io.WriterTo or io.ReaderFrom. +// We use this to ensure writes are chunked by io.Copy's internal buffer +// in the tests. +type simpleReader struct{ io.Reader } +type simpleWriter struct{ io.Writer } + +func TestXerialReaderSnappy(t *testing.T) { + rawData := new(bytes.Buffer) + rawData.Grow(1024 * 1024) + io.CopyN(rawData, rand.Reader, 1024*1024) + + compressedRawData := bytes.NewReader(snappy.Encode(nil, rawData.Bytes())) + + decompressedData := new(bytes.Buffer) + io.Copy(decompressedData, + &xerialReader{reader: compressedRawData, decode: snappy.Decode}) + + b0 := rawData.Bytes() + b1 := decompressedData.Bytes() + + if !bytes.Equal(b0, b1) { + t.Error("data mismatch") + } +} + +func TestXerialReaderWriter(t *testing.T) { + rawData := new(bytes.Buffer) + rawData.Grow(1024 * 1024) + io.CopyN(rawData, rand.Reader, 1024*1024) + + framedData := new(bytes.Buffer) + framedData.Grow(rawData.Len() + 1024) + w := simpleWriter{&xerialWriter{writer: framedData}} + r := simpleReader{bytes.NewReader(rawData.Bytes())} + io.Copy(w, r) + w.Writer.(*xerialWriter).Flush() + + unframedData := new(bytes.Buffer) + unframedData.Grow(rawData.Len()) + io.Copy(unframedData, &xerialReader{reader: framedData}) + + b0 := rawData.Bytes() + b1 := unframedData.Bytes() + + if !bytes.Equal(b0, b1) { + t.Error("data mismatch") + } +} + +func TestXerialFramedCompression(t *testing.T) { + rawData := new(bytes.Buffer) + rawData.Grow(1024 * 1024) + io.CopyN(rawData, rand.Reader, 1024*1024) + + framedAndCompressedData := new(bytes.Buffer) + framedAndCompressedData.Grow(rawData.Len()) + w := simpleWriter{&xerialWriter{writer: framedAndCompressedData, framed: true, encode: snappy.Encode}} + r := simpleReader{bytes.NewReader(rawData.Bytes())} + io.Copy(w, r) + w.Writer.(*xerialWriter).Flush() + + unframedAndDecompressedData := new(bytes.Buffer) + unframedAndDecompressedData.Grow(rawData.Len()) + io.Copy(unframedAndDecompressedData, + simpleReader{&xerialReader{reader: framedAndCompressedData, decode: snappy.Decode}}) + + b0 := rawData.Bytes() + b1 := unframedAndDecompressedData.Bytes() + + if !bytes.Equal(b0, b1) { + t.Error("data mismatch") + } +} + +func TestXerialFramedCompressionOptimized(t *testing.T) { + rawData := new(bytes.Buffer) + rawData.Grow(1024 * 1024) + io.CopyN(rawData, rand.Reader, 1024*1024) + + framedAndCompressedData := new(bytes.Buffer) + framedAndCompressedData.Grow(rawData.Len()) + w := &xerialWriter{writer: framedAndCompressedData, framed: true, encode: snappy.Encode} + r := simpleReader{bytes.NewReader(rawData.Bytes())} + io.Copy(w, r) + w.Flush() + + unframedAndDecompressedData := new(bytes.Buffer) + unframedAndDecompressedData.Grow(rawData.Len()) + io.Copy(unframedAndDecompressedData, + &xerialReader{reader: framedAndCompressedData, decode: snappy.Decode}) + + b0 := rawData.Bytes() + b1 := unframedAndDecompressedData.Bytes() + + if !bytes.Equal(b0, b1) { + t.Error("data mismatch") + } +} + +func TestXerialReaderAgainstGoXerialSnappy(t *testing.T) { + rawData := new(bytes.Buffer) + rawData.Grow(1024 * 1024) + io.CopyN(rawData, rand.Reader, 1024*1024) + rawBytes := rawData.Bytes() + + framedAndCompressedData := []byte{} + const chunkSize = 999 + for i := 0; i < len(rawBytes); i += chunkSize { + j := i + chunkSize + if j > len(rawBytes) { + j = len(rawBytes) + } + framedAndCompressedData = goxerialsnappy.EncodeStream(framedAndCompressedData, rawBytes[i:j]) + } + + unframedAndDecompressedData := new(bytes.Buffer) + unframedAndDecompressedData.Grow(rawData.Len()) + io.Copy(unframedAndDecompressedData, + &xerialReader{reader: bytes.NewReader(framedAndCompressedData), decode: snappy.Decode}) + + b0 := rawBytes + b1 := unframedAndDecompressedData.Bytes() + + if !bytes.Equal(b0, b1) { + t.Error("data mismatch") + } +} + +func TestXerialWriterAgainstGoXerialSnappy(t *testing.T) { + rawData := new(bytes.Buffer) + rawData.Grow(1024 * 1024) + io.CopyN(rawData, rand.Reader, 1024*1024) + + framedAndCompressedData := new(bytes.Buffer) + framedAndCompressedData.Grow(rawData.Len()) + w := &xerialWriter{writer: framedAndCompressedData, framed: true, encode: snappy.Encode} + r := simpleReader{bytes.NewReader(rawData.Bytes())} + io.Copy(w, r) + w.Flush() + + unframedAndDecompressedData, err := goxerialsnappy.Decode(framedAndCompressedData.Bytes()) + if err != nil { + t.Error(err) + } + + b0 := rawData.Bytes() + b1 := unframedAndDecompressedData + + if !bytes.Equal(b0, b1) { + t.Error("data mismatch") + } +} diff --git a/write.go b/write.go index 2027857c4..4eb847500 100644 --- a/write.go +++ b/write.go @@ -307,7 +307,7 @@ func writeProduceRequestV2(w *bufio.Writer, codec CompressionCodec, correlationI attributes := int8(CompressionNoneCode) if codec != nil { - if msgs, err = compress(codec, msgs...); err != nil { + if msgs, err = compressMessageSet(codec, msgs...); err != nil { return err } attributes = codec.Code() @@ -354,24 +354,13 @@ func writeProduceRequestV3(w *bufio.Writer, codec CompressionCodec, correlationI var compressed []byte var attributes int16 - if codec != nil { - attributes = int16(codec.Code()) - recordBuf := &bytes.Buffer{} - recordBuf.Grow(int(recordBatchSize(msgs...))) - compressedWriter := bufio.NewWriter(recordBuf) - for i, msg := range msgs { - writeRecord(compressedWriter, 0, msgs[0].Time, int64(i), msg) - } - compressedWriter.Flush() - - compressed, err = codec.Encode(recordBuf.Bytes()) + if codec == nil { + size = recordBatchSize(msgs...) + } else { + compressed, attributes, size, err = compressRecordBatch(codec, msgs...) if err != nil { return } - attributes = int16(codec.Code()) - size = recordBatchHeaderSize() + int32(len(compressed)) - } else { - size = recordBatchSize(msgs...) } h := requestHeader{ @@ -380,6 +369,7 @@ func writeProduceRequestV3(w *bufio.Writer, codec CompressionCodec, correlationI CorrelationID: correlationID, ClientID: clientID, } + h.Size = (h.size() - 4) + sizeofNullableString(transactionalID) + 2 + // required acks @@ -424,28 +414,17 @@ func writeProduceRequestV3(w *bufio.Writer, codec CompressionCodec, correlationI } func writeProduceRequestV7(w *bufio.Writer, codec CompressionCodec, correlationID int32, clientID, topic string, partition int32, timeout time.Duration, requiredAcks int16, transactionalID *string, msgs ...Message) (err error) { - var size int32 var compressed []byte var attributes int16 - if codec != nil { - attributes = int16(codec.Code()) - recordBuf := &bytes.Buffer{} - recordBuf.Grow(int(recordBatchSize(msgs...))) - compressedWriter := bufio.NewWriter(recordBuf) - for i, msg := range msgs { - writeRecord(compressedWriter, 0, msgs[0].Time, int64(i), msg) - } - compressedWriter.Flush() - compressed, err = codec.Encode(recordBuf.Bytes()) + if codec == nil { + size = recordBatchSize(msgs...) + } else { + compressed, attributes, size, err = compressRecordBatch(codec, msgs...) if err != nil { return } - attributes = int16(codec.Code()) - size = recordBatchHeaderSize() + int32(len(compressed)) - } else { - size = recordBatchSize(msgs...) } h := requestHeader{ @@ -511,33 +490,30 @@ func messageSetSize(msgs ...Message) (size int32) { return } -func recordBatchHeaderSize() int32 { - return 8 + // base offset - 4 + // batch length - 4 + // partition leader epoch - 1 + // magic - 4 + // crc - 2 + // attributes - 4 + // last offset delta - 8 + // first timestamp - 8 + // max timestamp - 8 + // producer id - 2 + // producer epoch - 4 + // base sequence - 4 // msg count -} +const recordBatchHeaderSize int32 = 0 + + 8 + // base offset + 4 + // batch length + 4 + // partition leader epoch + 1 + // magic + 4 + // crc + 2 + // attributes + 4 + // last offset delta + 8 + // first timestamp + 8 + // max timestamp + 8 + // producer id + 2 + // producer epoch + 4 + // base sequence + 4 // msg count func recordBatchSize(msgs ...Message) (size int32) { - size = recordBatchHeaderSize() - + size = recordBatchHeaderSize baseTime := msgs[0].Time for i, msg := range msgs { - sz := recordSize(&msg, msg.Time.Sub(baseTime), int64(i)) - size += int32(sz + varIntLen(int64(sz))) } + return } @@ -602,25 +578,57 @@ func recordSize(msg *Message, timestampDelta time.Duration, offsetDelta int64) ( return } -func compress(codec CompressionCodec, msgs ...Message) ([]Message, error) { +func compressMessageSet(codec CompressionCodec, msgs ...Message) ([]Message, error) { estimatedLen := 0 + for _, msg := range msgs { estimatedLen += int(msgSize(msg.Key, msg.Value)) } - buf := &bytes.Buffer{} - buf.Grow(estimatedLen) - bufWriter := bufio.NewWriter(buf) + + buffer := &bytes.Buffer{} + buffer.Grow(estimatedLen / 2) + compressor := codec.NewWriter(buffer) + compressedWriter := bufio.NewWriterSize(compressor, 512) + for offset, msg := range msgs { - writeMessage(bufWriter, int64(offset), CompressionNoneCode, msg.Time, msg.Key, msg.Value) + writeMessage(compressedWriter, int64(offset), CompressionNoneCode, msg.Time, msg.Key, msg.Value) } - bufWriter.Flush() - compressed, err := codec.Encode(buf.Bytes()) - if err != nil { + if err := compressedWriter.Flush(); err != nil { + compressor.Close() + return nil, err + } + + if err := compressor.Close(); err != nil { return nil, err } - return []Message{{Value: compressed}}, nil + return []Message{{Value: buffer.Bytes()}}, nil +} + +func compressRecordBatch(codec CompressionCodec, msgs ...Message) (compressed []byte, attributes int16, size int32, err error) { + recordBuf := new(bytes.Buffer) + recordBuf.Grow(int(recordBatchSize(msgs...)) / 2) + compressor := codec.NewWriter(recordBuf) + compressedWriter := bufio.NewWriterSize(compressor, 512) + + for i, msg := range msgs { + writeRecord(compressedWriter, 0, msgs[0].Time, int64(i), msg) + } + + if err = compressedWriter.Flush(); err != nil { + compressor.Close() + return + } + + if err = compressor.Close(); err != nil { + return + } + + compressed = recordBuf.Bytes() + attributes = int16(codec.Code()) + size = recordBatchHeaderSize + int32(len(compressed)) + return } const magicByte = 1 // compatible with kafka 0.10.0.0+ diff --git a/zstd/zstd.go b/zstd/zstd.go index 58fa0145a..3a14d41e5 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -3,48 +3,196 @@ package zstd import ( + "io" + "sync" + "github.com/DataDog/zstd" - "github.com/segmentio/kafka-go" + kafka "github.com/segmentio/kafka-go" ) func init() { - kafka.RegisterCompressionCodec(func() kafka.CompressionCodec { - return NewCompressionCodec() - }) -} - -type CompressionCodec struct { - // CompressionLevel is the level of compression to use on messages. - CompressionLevel int + kafka.RegisterCompressionCodec(NewCompressionCodec()) } const ( - Code int8 = 4 - // https://github.com/DataDog/zstd/blob/1e382f59b41eebd6f592c5db4fd1958ec38a0eba/zstd.go#L33 - DefaultCompressionLevel int = 5 + Code = 4 + + DefaultCompressionLevel = zstd.DefaultCompression ) -func NewCompressionCodec() CompressionCodec { +type CompressionCodec struct{ level int } + +func NewCompressionCodec() *CompressionCodec { return NewCompressionCodecWith(DefaultCompressionLevel) } -func NewCompressionCodecWith(level int) CompressionCodec { - return CompressionCodec{ - CompressionLevel: level, - } +func NewCompressionCodecWith(level int) *CompressionCodec { + return &CompressionCodec{level} } // Code implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Code() int8 { - return Code +func (c *CompressionCodec) Code() int8 { return Code } + +// Name implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) Name() string { return "zstd" } + +// NewReader implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser { + return &reader{ + reader: r, + buffer: bufferPool.Get().(*buffer), + } +} + +// NewWriter implements the kafka.CompressionCodec interface. +func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser { + return &writer{ + writer: w, + buffer: bufferPool.Get().(*buffer), + level: c.level, + } +} + +// ============================================================================= +// The DataDog/zstd package exposes io.Writer and io.Reader implementations that +// encode and decode streams, however there are no APIs to reuse the values like +// other compression format have (through a Reset method usually). +// +// I first tried using these abstractions but the amount of state that gets +// recreated and destroyed was so large that it was slower than using the +// zstd.Compress and zstd.Decompress functions directly. Knowing that, I changed +// the implementation to be more of a buffer management on top of these instead. +// ============================================================================= + +type reader struct { + reader io.Reader + buffer *buffer + offset int +} + +func (r *reader) Read(b []byte) (int, error) { + if err := r.decompress(); err != nil { + return 0, err + } + + if r.offset >= len(r.buffer.output) { + return 0, io.EOF + } + + n := copy(b, r.buffer.output[r.offset:]) + r.offset += n + return n, nil +} + +func (r *reader) WriteTo(w io.Writer) (int64, error) { + if err := r.decompress(); err != nil { + return 0, err + } + + if r.offset >= len(r.buffer.output) { + return 0, nil + } + + n, err := w.Write(r.buffer.output[r.offset:]) + r.offset += n + return int64(n), err +} + +func (r *reader) Close() (err error) { + if b := r.buffer; b != nil { + r.buffer = nil + b.reset() + bufferPool.Put(b) + } + return +} + +func (r *reader) decompress() (err error) { + if r.reader == nil { + return + } + + b := r.buffer + + if _, err = b.readFrom(r.reader); err != nil { + return + } + + r.reader = nil + b.output, err = zstd.Decompress(b.output[:cap(b.output)], b.input) + return +} + +type writer struct { + writer io.Writer + buffer *buffer + level int +} + +func (w *writer) Write(b []byte) (int, error) { + return w.buffer.write(b) +} + +func (w *writer) ReadFrom(r io.Reader) (int64, error) { + return w.buffer.readFrom(r) +} + +func (w *writer) Close() (err error) { + if b := w.buffer; b != nil { + w.buffer = nil + + b.output, err = zstd.CompressLevel(b.output[:cap(b.output)], b.input, w.level) + if err == nil { + _, err = w.writer.Write(b.output) + } + + b.reset() + bufferPool.Put(b) + } + return +} + +type buffer struct { + input []byte + output []byte +} + +func (b *buffer) reset() { + b.input = b.input[:0] + b.output = b.output[:0] +} + +func (b *buffer) readFrom(r io.Reader) (int64, error) { + prefix := len(b.input) + + for { + if len(b.input) == cap(b.input) { + tmp := make([]byte, len(b.input), 2*cap(b.input)) + copy(tmp, b.input) + b.input = tmp + } + + n, err := r.Read(b.input[len(b.input):cap(b.input)]) + b.input = b.input[:len(b.input)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return int64(len(b.input) - prefix), err + } + } } -// Encode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Encode(src []byte) ([]byte, error) { - return zstd.CompressLevel(nil, src, c.CompressionLevel) +func (b *buffer) write(data []byte) (int, error) { + b.input = append(b.input, data...) + return len(data), nil } -// Decode implements the kafka.CompressionCodec interface. -func (c CompressionCodec) Decode(src []byte) ([]byte, error) { - return zstd.Decompress(nil, src) +var bufferPool = sync.Pool{ + New: func() interface{} { + return &buffer{ + input: make([]byte, 0, 32*1024), + output: make([]byte, 0, 32*1024), + } + }, }