Skip to content

Commit

Permalink
Replace Zstandard wrapper with native Go implementation (segmentio#396)
Browse files Browse the repository at this point in the history
* Replace Zstandard wrapper with native Go implementation.

* Use literal for compression levels. The EncoderLevel uses its own scale.

* Upgrade compress (Zstandard) library.

* Apply Agille's code review.
  • Loading branch information
pascaldekloe authored and Achille committed Jan 26, 2020
1 parent 9a956db commit 16d85b1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 141 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ module github.com/segmentio/kafka-go
go 1.11

require (
github.com/DataDog/zstd v1.4.0
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21
github.com/golang/snappy v0.0.1
github.com/klauspost/compress v1.9.8
github.com/pierrec/lz4 v2.0.5+incompatible
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c
github.com/xdg/stringprep v1.0.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
github.com/DataDog/zstd v1.4.0 h1:vhoV+DUHnRZdKW1i5UMjAk2G4JY8wN4ayRfYDNdEhwo=
github.com/DataDog/zstd v1.4.0/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA=
github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/pierrec/lz4 v2.0.5+incompatible h1:2xWsjqPFWcplujydGg4WmhC/6fZqK42wMM8aXeqhl0I=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk=
Expand Down
205 changes: 67 additions & 138 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,30 @@
// +build cgo

// Package zstd implements Zstandard compression.
package zstd

import (
"io"
"sync"

"github.com/DataDog/zstd"
zstdlib "github.com/klauspost/compress/zstd"
kafka "github.com/segmentio/kafka-go"
)

func init() {
kafka.RegisterCompressionCodec(NewCompressionCodec())
}

const (
Code = 4
const Code = 4

DefaultCompressionLevel = zstd.DefaultCompression
)
const DefaultCompressionLevel = 3

type CompressionCodec struct{ level int }
type CompressionCodec struct{ level zstdlib.EncoderLevel }

func NewCompressionCodec() *CompressionCodec {
return NewCompressionCodecWith(DefaultCompressionLevel)
}

func NewCompressionCodecWith(level int) *CompressionCodec {
return &CompressionCodec{level}
return &CompressionCodec{zstdlib.EncoderLevelFromZstd(level)}
}

// Code implements the kafka.CompressionCodec interface.
Expand All @@ -38,161 +35,93 @@ func (c *CompressionCodec) Name() string { return "zstd" }

// NewReader implements the kafka.CompressionCodec interface.
func (c *CompressionCodec) NewReader(r io.Reader) io.ReadCloser {
return &reader{
reader: r,
buffer: bufferPool.Get().(*buffer),
p := new(reader)
if cached := decPool.Get(); cached == nil {
p.dec, p.err = zstdlib.NewReader(r)
} else {
p.dec = cached.(*zstdlib.Decoder)
p.err = p.dec.Reset(r)
}
return p
}

// NewWriter implements the kafka.CompressionCodec interface.
func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser {
return &writer{
writer: w,
buffer: bufferPool.Get().(*buffer),
level: c.level,
}
}

// =============================================================================
// The DataDog/zstd package exposes io.Writer and io.Reader implementations that
// encode and decode streams, however there are no APIs to reuse the values like
// other compression format have (through a Reset method usually).
//
// I first tried using these abstractions but the amount of state that gets
// recreated and destroyed was so large that it was slower than using the
// zstd.Compress and zstd.Decompress functions directly. Knowing that, I changed
// the implementation to be more of a buffer management on top of these instead.
// =============================================================================
var decPool sync.Pool

type reader struct {
reader io.Reader
buffer *buffer
offset int
dec *zstdlib.Decoder
err error
}

func (r *reader) Read(b []byte) (int, error) {
if err := r.decompress(); err != nil {
return 0, err
}

if r.offset >= len(r.buffer.output) {
return 0, io.EOF
// Close implements the io.Closer interface.
func (r *reader) Close() error {
if r.dec != nil {
decPool.Put(r.dec)
r.dec = nil
r.err = io.ErrClosedPipe
}

n := copy(b, r.buffer.output[r.offset:])
r.offset += n
return n, nil
return nil
}

func (r *reader) WriteTo(w io.Writer) (int64, error) {
if err := r.decompress(); err != nil {
return 0, err
}

if r.offset >= len(r.buffer.output) {
return 0, nil
// Read implements the io.Reader interface.
func (r *reader) Read(p []byte) (n int, err error) {
if r.err != nil {
return 0, r.err
}

n, err := w.Write(r.buffer.output[r.offset:])
r.offset += n
return int64(n), err
return r.dec.Read(p)
}

func (r *reader) Close() (err error) {
if b := r.buffer; b != nil {
r.buffer = nil
b.reset()
bufferPool.Put(b)
// WriteTo implements the io.WriterTo interface.
func (r *reader) WriteTo(w io.Writer) (n int64, err error) {
if r.err != nil {
return 0, r.err
}
return
return r.dec.WriteTo(w)
}

func (r *reader) decompress() (err error) {
if r.reader == nil {
return
}

b := r.buffer

if _, err = b.readFrom(r.reader); err != nil {
return
// NewWriter implements the kafka.CompressionCodec interface.
func (c *CompressionCodec) NewWriter(w io.Writer) io.WriteCloser {
p := new(writer)
if cached := encPool.Get(); cached == nil {
p.enc, p.err = zstdlib.NewWriter(w,
zstdlib.WithEncoderLevel(c.level))
} else {
p.enc = cached.(*zstdlib.Encoder)
p.enc.Reset(w)
}

r.reader = nil
b.output, err = zstd.Decompress(b.output[:cap(b.output)], b.input)
return
}

type writer struct {
writer io.Writer
buffer *buffer
level int
return p
}

func (w *writer) Write(b []byte) (int, error) {
return w.buffer.write(b)
}
var encPool sync.Pool

func (w *writer) ReadFrom(r io.Reader) (int64, error) {
return w.buffer.readFrom(r)
type writer struct {
enc *zstdlib.Encoder
err error
}

func (w *writer) Close() (err error) {
if b := w.buffer; b != nil {
w.buffer = nil

b.output, err = zstd.CompressLevel(b.output[:cap(b.output)], b.input, w.level)
if err == nil {
_, err = w.writer.Write(b.output)
}

b.reset()
bufferPool.Put(b)
// Close implements the io.Closer interface.
func (w *writer) Close() error {
if w.enc == nil {
return nil // already closed
}
return
}

type buffer struct {
input []byte
output []byte
}

func (b *buffer) reset() {
b.input = b.input[:0]
b.output = b.output[:0]
err := w.enc.Close()
encPool.Put(w.enc)
w.enc = nil
w.err = io.ErrClosedPipe
return err
}

func (b *buffer) readFrom(r io.Reader) (int64, error) {
prefix := len(b.input)

for {
if len(b.input) == cap(b.input) {
tmp := make([]byte, len(b.input), 2*cap(b.input))
copy(tmp, b.input)
b.input = tmp
}

n, err := r.Read(b.input[len(b.input):cap(b.input)])
b.input = b.input[:len(b.input)+n]
if err != nil {
if err == io.EOF {
err = nil
}
return int64(len(b.input) - prefix), err
}
// WriteTo implements the io.WriterTo interface.
func (w *writer) Write(p []byte) (n int, err error) {
if w.err != nil {
return 0, w.err
}
return w.enc.Write(p)
}

func (b *buffer) write(data []byte) (int, error) {
b.input = append(b.input, data...)
return len(data), nil
}

var bufferPool = sync.Pool{
New: func() interface{} {
return &buffer{
input: make([]byte, 0, 32*1024),
output: make([]byte, 0, 32*1024),
}
},
// ReadFrom implements the io.ReaderFrom interface.
func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
if w.err != nil {
return 0, w.err
}
return w.enc.ReadFrom(r)
}

0 comments on commit 16d85b1

Please sign in to comment.