diff --git a/conn.go b/conn.go index 1a7d00101..f8d185c0a 100644 --- a/conn.go +++ b/conn.go @@ -9,7 +9,6 @@ import ( "path/filepath" "runtime" "sync" - "sync/atomic" "time" ) @@ -59,7 +58,7 @@ type Conn struct { fetchMaxBytes int32 fetchMinSize int32 - // correlation ID generator (accessed through atomic operations) + // correlation ID generator (synchronized on wlock) correlationID int32 } @@ -342,7 +341,7 @@ func (c *Conn) ReadBatch(minBytes int, maxBytes int) *Batch { return &Batch{err: err} } - throttle, remain, err := readFetchResponseHeader(&c.rbuf, int(size)) + throttle, remain, err := readFetchResponseHeader(&c.rbuf, size) return &Batch{ conn: c, reader: &c.rbuf, @@ -358,7 +357,7 @@ func (c *Conn) ReadBatch(minBytes int, maxBytes int) *Batch { // ReadOffset returns the offset of the first message with a timestamp equal or // greater to t. func (c *Conn) ReadOffset(t time.Time) (int64, error) { - return c.readOffset(timeToTimestamp(t)) + return c.readOffset(timestamp(t)) } // ReadOffsets returns the absolute first and last offsets of the topic used by @@ -498,30 +497,24 @@ func (c *Conn) WriteMessages(msgs ...Message) (int, error) { return 0, nil } - var n int - var set = make(messageSet, len(msgs)) - - for i, msg := range msgs { + n := 0 + for _, msg := range msgs { n += len(msg.Key) + len(msg.Value) - set[i] = msg.item() } err := c.writeOperation( func(deadline time.Time, id int32) error { now := time.Now() deadline = adjustDeadlineForRTT(deadline, now, defaultRTT) - return c.writeRequest(produceRequest, v2, id, produceRequestV2{ - RequiredAcks: -1, - Timeout: milliseconds(deadlineToTimeout(deadline, now)), - Topics: []produceRequestTopicV2{{ - TopicName: c.topic, - Partitions: []produceRequestPartitionV2{{ - Partition: c.partition, - MessageSetSize: set.size(), - MessageSet: set, - }}, - }}, - }) + return writeProduceRequest( + &c.wbuf, + id, + c.clientID, + c.topic, + c.partition, + deadlineToTimeout(deadline, now), + msgs..., + ) }, func(deadline time.Time, size int) error { return expectZeroSize(readArrayWith(&c.rbuf, size, func(r *bufio.Reader, size int) (int, error) { @@ -536,7 +529,7 @@ func (c *Conn) WriteMessages(msgs ...Message) (int, error) { // we've produced a message to a single partition. size, err = readArrayWith(r, size, func(r *bufio.Reader, size int) (int, error) { var p produceResponsePartitionV2 - size, err := read(r, size, &p) + size, err := p.readFrom(r, size) if err == nil && p.ErrorCode != 0 { err = Error(p.ErrorCode) } @@ -560,6 +553,12 @@ func (c *Conn) WriteMessages(msgs ...Message) (int, error) { return n, err } +func (c *Conn) writeRequestHeader(apiKey apiKey, apiVersion apiVersion, correlationID int32, size int32) { + hdr := c.requestHeader(apiKey, apiVersion, correlationID) + hdr.Size = (hdr.size() + size) - 4 + hdr.writeTo(&c.wbuf) +} + func (c *Conn) writeRequest(apiKey apiKey, apiVersion apiVersion, correlationID int32, req request) error { hdr := c.requestHeader(apiKey, apiVersion, correlationID) hdr.Size = (hdr.size() + req.size()) - 4 @@ -594,7 +593,8 @@ func (c *Conn) skipResponseSizeAndID() { } func (c *Conn) generateCorrelationID() int32 { - return atomic.AddInt32(&c.correlationID, 1) + c.correlationID++ + return c.correlationID } func (c *Conn) readDeadline() time.Time { @@ -605,15 +605,15 @@ func (c *Conn) writeDeadline() time.Time { return c.wdeadline.deadline() } -func (c *Conn) readOperation(write writeFunc, read readFunc) error { +func (c *Conn) readOperation(write func(time.Time, int32) error, read func(time.Time, int) error) error { return c.do(&c.rdeadline, write, read) } -func (c *Conn) writeOperation(write writeFunc, read readFunc) error { +func (c *Conn) writeOperation(write func(time.Time, int32) error, read func(time.Time, int) error) error { return c.do(&c.wdeadline, write, read) } -func (c *Conn) do(d *connDeadline, write writeFunc, read readFunc) error { +func (c *Conn) do(d *connDeadline, write func(time.Time, int32) error, read func(time.Time, int) error) error { id, err := c.doRequest(d, write) if err != nil { return err @@ -624,7 +624,7 @@ func (c *Conn) do(d *connDeadline, write writeFunc, read readFunc) error { return err } - if err = read(deadline, int(size)); err != nil { + if err = read(deadline, size); err != nil { switch err.(type) { case Error: default: @@ -637,13 +637,11 @@ func (c *Conn) do(d *connDeadline, write writeFunc, read readFunc) error { return err } -func (c *Conn) doRequest(d *connDeadline, write writeFunc) (id int32, err error) { - id = c.generateCorrelationID() - +func (c *Conn) doRequest(d *connDeadline, write func(time.Time, int32) error) (id int32, err error) { c.wlock.Lock() + id = c.generateCorrelationID() err = write(d.setConnWriteDeadline(c.conn), id) d.unsetConnWriteDeadline() - c.wlock.Unlock() if err != nil { // When an error occurs there's no way to know if the connection is in a @@ -652,10 +650,11 @@ func (c *Conn) doRequest(d *connDeadline, write writeFunc) (id int32, err error) c.conn.Close() } + c.wlock.Unlock() return } -func (c *Conn) waitResponse(d *connDeadline, id int32) (deadline time.Time, size int32, lock *sync.Mutex, err error) { +func (c *Conn) waitResponse(d *connDeadline, id int32) (deadline time.Time, size int, lock *sync.Mutex, err error) { for { var rsz int32 var rid int32 @@ -672,7 +671,7 @@ func (c *Conn) waitResponse(d *connDeadline, id int32) (deadline time.Time, size if id == rid { c.skipResponseSizeAndID() - size, lock = rsz-4, &c.rlock + size, lock = int(rsz-4), &c.rlock return } @@ -752,11 +751,3 @@ func (d *connDeadline) unsetConnWriteDeadline() { d.wconn = nil d.mutex.Unlock() } - -// writeFunc is the type of functions passed to (*Conn).do to write a request -// to the kafka connection. -type writeFunc func(deadline time.Time, id int32) error - -// readFunc is the type of functions passed to (*Conn).do to read a response -// from the kafka connection. -type readFunc func(deadline time.Time, size int) error diff --git a/crc32.go b/crc32.go index a257bd086..f1a617f02 100644 --- a/crc32.go +++ b/crc32.go @@ -7,6 +7,20 @@ import ( "sync" ) +func crc32OfMessage(magicByte int8, attributes int8, timestamp int64, key []byte, value []byte) uint32 { + b := acquireCrc32Buffer() + b.writeInt8(magicByte) + b.writeInt8(attributes) + if magicByte != 0 { + b.writeInt64(timestamp) + } + b.writeBytes(key) + b.writeBytes(value) + sum := b.sum + releaseCrc32Buffer(b) + return sum +} + type crc32Buffer struct { sum uint32 buf bytes.Buffer diff --git a/discard.go b/discard.go index ebd486626..bbb56b620 100644 --- a/discard.go +++ b/discard.go @@ -1,10 +1,6 @@ package kafka -import ( - "bufio" - "fmt" - "reflect" -) +import "bufio" func discardN(r *bufio.Reader, sz int, n int) (int, error) { n, err := r.Discard(n) @@ -44,56 +40,3 @@ func discardBytes(r *bufio.Reader, sz int) (int, error) { return discardN(r, sz, n) }) } - -func discard(r *bufio.Reader, sz int, a interface{}) (int, error) { - switch a.(type) { - case int8: - return discardInt8(r, sz) - case int16: - return discardInt16(r, sz) - case int32: - return discardInt32(r, sz) - case int64: - return discardInt64(r, sz) - case string: - return discardString(r, sz) - case []byte: - return discardBytes(r, sz) - } - switch v := reflect.ValueOf(a); v.Kind() { - case reflect.Struct: - return discardStruct(r, sz, v) - case reflect.Slice: - return discardSlice(r, sz, v) - default: - panic(fmt.Sprintf("unsupported type: %T", a)) - } -} - -func discardStruct(r *bufio.Reader, sz int, v reflect.Value) (int, error) { - var err error - for i, n := 0, v.NumField(); i != n; i++ { - if sz, err = discard(r, sz, v.Field(i)); err != nil { - break - } - } - return sz, err -} - -func discardSlice(r *bufio.Reader, sz int, v reflect.Value) (int, error) { - var zero = reflect.Zero(v.Type().Elem()) - var err error - var len int32 - - if sz, err = readInt32(r, sz, &len); err != nil { - return sz, err - } - - for n := int(len); n > 0; n-- { - if sz, err = discard(r, sz, zero); err != nil { - break - } - } - - return sz, err -} diff --git a/message.go b/message.go index 44308f396..5f17d8567 100644 --- a/message.go +++ b/message.go @@ -27,7 +27,7 @@ func (msg Message) message() message { MagicByte: 1, Key: msg.Key, Value: msg.Value, - Timestamp: timeToTimestamp(msg.Time), + Timestamp: timestamp(msg.Time), } m.CRC = m.crc32() return m @@ -43,17 +43,7 @@ type message struct { } func (m message) crc32() int32 { - b := acquireCrc32Buffer() - b.writeInt8(m.MagicByte) - b.writeInt8(m.Attributes) - if m.MagicByte != 0 { - b.writeInt64(m.Timestamp) - } - b.writeBytes(m.Key) - b.writeBytes(m.Value) - sum := b.sum - releaseCrc32Buffer(b) - return int32(sum) + return int32(crc32OfMessage(m.MagicByte, m.Attributes, m.Timestamp, m.Key, m.Value)) } func (m message) size() int32 { diff --git a/produce.go b/produce.go index 93840e882..377f6ebe7 100644 --- a/produce.go +++ b/produce.go @@ -95,3 +95,17 @@ func (p produceResponsePartitionV2) writeTo(w *bufio.Writer) { writeInt64(w, p.Offset) writeInt64(w, p.Timestamp) } + +func (p *produceResponsePartitionV2) readFrom(r *bufio.Reader, sz int) (remain int, err error) { + if remain, err = readInt32(r, sz, &p.Partition); err != nil { + return + } + if remain, err = readInt16(r, remain, &p.ErrorCode); err != nil { + return + } + if remain, err = readInt64(r, remain, &p.Offset); err != nil { + return + } + remain, err = readInt64(r, remain, &p.Timestamp) + return +} diff --git a/read.go b/read.go index a593c3bea..52de6bb92 100644 --- a/read.go +++ b/read.go @@ -8,9 +8,9 @@ import ( "reflect" ) -type readBytesFunc func(r *bufio.Reader, sz int, n int) (remain int, err error) - -type readArrayFunc func(r *bufio.Reader, sz int) (remain int, err error) +type readable interface { + readFrom(*bufio.Reader, int) (int, error) +} var errShortRead = errors.New("not enough bytes available to load the response") @@ -49,7 +49,7 @@ func readString(r *bufio.Reader, sz int, v *string) (int, error) { }) } -func readStringWith(r *bufio.Reader, sz int, cb readBytesFunc) (int, error) { +func readStringWith(r *bufio.Reader, sz int, cb func(*bufio.Reader, int, int) (int, error)) (int, error) { var err error var len int16 @@ -77,7 +77,7 @@ func readBytes(r *bufio.Reader, sz int, v *[]byte) (int, error) { }) } -func readBytesWith(r *bufio.Reader, sz int, cb readBytesFunc) (int, error) { +func readBytesWith(r *bufio.Reader, sz int, cb func(*bufio.Reader, int, int) (int, error)) (int, error) { var err error var len int32 @@ -107,7 +107,7 @@ func readNewBytes(r *bufio.Reader, sz int, n int) ([]byte, int, error) { return b, sz, err } -func readArrayWith(r *bufio.Reader, sz int, cb readArrayFunc) (int, error) { +func readArrayWith(r *bufio.Reader, sz int, cb func(*bufio.Reader, int) (int, error)) (int, error) { var err error var len int32 @@ -149,6 +149,39 @@ func read(r *bufio.Reader, sz int, a interface{}) (int, error) { } } +func readAll(r *bufio.Reader, sz int, ptrs ...interface{}) (int, error) { + var err error + + for _, ptr := range ptrs { + if sz, err = readPtr(r, sz, ptr); err != nil { + break + } + } + + return sz, err +} + +func readPtr(r *bufio.Reader, sz int, ptr interface{}) (int, error) { + switch v := ptr.(type) { + case *int8: + return readInt8(r, sz, v) + case *int16: + return readInt16(r, sz, v) + case *int32: + return readInt32(r, sz, v) + case *int64: + return readInt64(r, sz, v) + case *string: + return readString(r, sz, v) + case *[]byte: + return readBytes(r, sz, v) + case readable: + return v.readFrom(r, sz) + default: + panic(fmt.Sprintf("unsupported type: %T", v)) + } +} + func readStruct(r *bufio.Reader, sz int, v reflect.Value) (int, error) { var err error for i, n := 0, v.NumField(); i != n; i++ { @@ -283,7 +316,10 @@ func readMessageHeader(r *bufio.Reader, sz int) (offset int64, attributes int8, return } -func readMessage(r *bufio.Reader, sz int, min int64, key readBytesFunc, val readBytesFunc) (offset int64, timestamp int64, remain int, err error) { +func readMessage(r *bufio.Reader, sz int, min int64, + key func(*bufio.Reader, int, int) (int, error), + val func(*bufio.Reader, int, int) (int, error), +) (offset int64, timestamp int64, remain int, err error) { for { // TODO: read attributes and decompress the message if offset, _, timestamp, remain, err = readMessageHeader(r, sz); err != nil { diff --git a/time.go b/time.go index 0bbaa32a9..26f33afd0 100644 --- a/time.go +++ b/time.go @@ -11,11 +11,7 @@ const ( defaultRTT = 1 * time.Second ) -func timestamp() int64 { - return timeToTimestamp(time.Now()) -} - -func timeToTimestamp(t time.Time) int64 { +func timestamp(t time.Time) int64 { if t.IsZero() { return 0 } diff --git a/write.go b/write.go index 271e3d873..d852e790a 100644 --- a/write.go +++ b/write.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/binary" "fmt" + "time" ) type writable interface { @@ -57,8 +58,12 @@ func writeBytes(w *bufio.Writer, b []byte) { w.Write(b) } -func writeArray(w *bufio.Writer, n int, f func(int)) { +func writeArrayLen(w *bufio.Writer, n int) { writeInt32(w, int32(n)) +} + +func writeArray(w *bufio.Writer, n int, f func(int)) { + writeArrayLen(w, n) for i := 0; i != n; i++ { f(i) } @@ -92,3 +97,75 @@ func write(w *bufio.Writer, a interface{}) { panic(fmt.Sprintf("unsupported type: %T", a)) } } + +// This function is used as an optimization to avoid dynamic memory allocations +// in the common case of sending a batch messages to a kafka server for a single +// topic and partition. +func writeProduceRequest(w *bufio.Writer, corrleationID int32, clientID string, topic string, partition int32, timeout time.Duration, msgs ...Message) error { + var size int32 + + for _, msg := range msgs { + size += 8 + // offset + 4 + // message size + 4 + // crc + 1 + // magic byte + 1 + // attributes + 8 + // timestamp + sizeofBytes(msg.Key) + + sizeofBytes(msg.Value) + } + + h := requestHeader{ + ApiKey: int16(produceRequest), + ApiVersion: int16(v2), + CorrelationID: corrleationID, + ClientID: clientID, + } + h.Size = (h.size() - 4) + + 2 + // required acks + 4 + // timeout + 4 + // topic array length + sizeofString(topic) + // topic + 4 + // partition array length + 4 + // partition + 4 + // message set size + size + + h.writeTo(w) + writeInt16(w, -1) // required acks + writeInt32(w, milliseconds(timeout)) + + // topic array + writeArrayLen(w, 1) + writeString(w, topic) + + // partition array + writeArrayLen(w, 1) + writeInt32(w, partition) + writeInt32(w, size) + + const magicByte = 1 + const attributes = 0 + + for _, msg := range msgs { + timestamp := timestamp(msg.Time) + crc32 := crc32OfMessage(magicByte, attributes, timestamp, msg.Key, msg.Value) + size := 4 + // crc + 1 + // magic byte + 1 + // attributes + 8 + // timestamp + sizeofBytes(msg.Key) + + sizeofBytes(msg.Value) + + writeInt64(w, msg.Offset) + writeInt32(w, int32(size)) + writeInt32(w, int32(crc32)) + writeInt8(w, magicByte) + writeInt8(w, attributes) + writeInt64(w, timestamp) + writeBytes(w, msg.Key) + writeBytes(w, msg.Value) + } + + return w.Flush() +}