Skip to content

Commit

Permalink
track the number of byts consumed in a response
Browse files Browse the repository at this point in the history
  • Loading branch information
Achille Roussel committed May 31, 2017
1 parent 68dc7e8 commit 1ae0b59
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 72 deletions.
16 changes: 8 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ func (c *Conn) ReadOffsets() (first int64, last int64, err error) {
}},
})
},
func(size int32) error {
func(size int) error {
var res []listOffsetResponse
if err := c.readResponse(size, &res); err != nil {
return err
Expand Down Expand Up @@ -381,7 +381,7 @@ func (c *Conn) ReadPartitions(topics ...string) (partitions []Partition, err err
func(id int32) error {
return c.writeRequest(metadataRequest, v0, id, topicMetadataRequest(topics))
},
func(size int32) error {
func(size int) error {
var res metadataResponse
if err := c.readResponse(size, &res); err != nil {
return err
Expand Down Expand Up @@ -461,7 +461,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}},
})
},
func(size int32) error {
func(size int) error {
var res produceResponse
if err := c.readResponse(size, &res); err != nil {
return err
Expand Down Expand Up @@ -494,7 +494,7 @@ func (c *Conn) writeRequest(apiKey apiKey, apiVersion apiVersion, correlationID
return writeRequest(&c.wbuf, c.requestHeader(apiKey, apiVersion, correlationID), req)
}

func (c *Conn) readResponse(size int32, res interface{}) error {
func (c *Conn) readResponse(size int, res interface{}) error {
return readResponse(&c.rbuf, size, res)
}

Expand Down Expand Up @@ -529,15 +529,15 @@ func (c *Conn) writeDeadline() time.Time {
return t
}

func (c *Conn) readOperation(write func(int32) error, read func(int32) error) error {
func (c *Conn) readOperation(write func(int32) error, read func(int) error) error {
return c.do(c.readDeadline(), write, read)
}

func (c *Conn) writeOperation(write func(int32) error, read func(int32) error) error {
func (c *Conn) writeOperation(write func(int32) error, read func(int) error) error {
return c.do(c.writeDeadline(), write, read)
}

func (c *Conn) do(deadline time.Time, write func(int32) error, read func(int32) error) error {
func (c *Conn) do(deadline time.Time, write func(int32) error, read func(int) error) error {
id := c.generateCorrelationID()

c.wlock.Lock()
Expand Down Expand Up @@ -566,7 +566,7 @@ func (c *Conn) do(deadline time.Time, write func(int32) error, read func(int32)

if id == opid {
c.skipResponseSizeAndID()
err := read(opsize - 4)
err := read(int(opsize) - 4)
switch err.(type) {
case nil, Error:
default:
Expand Down
2 changes: 1 addition & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestConn(t *testing.T) {
topic := fmt.Sprintf("kafka-go-%02d", atomic.AddInt32(&id, 1))

conn, err := (&Dialer{
Resolver: &net.Resolver{PreferGo: true},
Resolver: &net.Resolver{},
}).DialLeader(ctx, "tcp", "localhost:9092", topic, 0)
if err != nil {
t.Fatal("failed to open a new kafka connection:", err)
Expand Down
150 changes: 88 additions & 62 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kafka
import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
Expand Down Expand Up @@ -190,6 +191,10 @@ type partitionOffset struct {
Offsets []int64
}

var (
errShortRead = errors.New("not enough bytes available to load the response")
)

func makeInt8(b []byte) int8 {
return int8(b[0])
}
Expand All @@ -206,89 +211,97 @@ func makeInt64(b []byte) int64 {
return int64(binary.BigEndian.Uint64(b))
}

func peekRead(r *bufio.Reader, n int, f func([]byte)) error {
func peekRead(r *bufio.Reader, sz int, n int, f func([]byte)) (int, error) {
if n > sz {
return sz, errShortRead
}
b, err := r.Peek(n)
if err != nil {
return err
return sz, err
}
f(b)
_, err = r.Discard(n)
return err
r.Discard(n)
return sz - n, nil
}

func readInt8(r *bufio.Reader, v *int8) error {
return peekRead(r, 1, func(b []byte) { *v = makeInt8(b) })
func readInt8(r *bufio.Reader, sz int, v *int8) (int, error) {
return peekRead(r, sz, 1, func(b []byte) { *v = makeInt8(b) })
}

func readInt16(r *bufio.Reader, v *int16) error {
return peekRead(r, 2, func(b []byte) { *v = makeInt16(b) })
func readInt16(r *bufio.Reader, sz int, v *int16) (int, error) {
return peekRead(r, sz, 2, func(b []byte) { *v = makeInt16(b) })
}

func readInt32(r *bufio.Reader, v *int32) error {
return peekRead(r, 4, func(b []byte) { *v = makeInt32(b) })
func readInt32(r *bufio.Reader, sz int, v *int32) (int, error) {
return peekRead(r, sz, 4, func(b []byte) { *v = makeInt32(b) })
}

func readInt64(r *bufio.Reader, v *int64) error {
return peekRead(r, 8, func(b []byte) { *v = makeInt64(b) })
func readInt64(r *bufio.Reader, sz int, v *int64) (int, error) {
return peekRead(r, sz, 8, func(b []byte) { *v = makeInt64(b) })
}

func readString(r *bufio.Reader, v *string) error {
func readString(r *bufio.Reader, sz int, v *string) (int, error) {
var err error
var len int16
var b []byte
var n int16

if err := readInt16(r, &n); err != nil {
return err
if sz, err = readInt16(r, sz, &len); err != nil {
return sz, err
}

if n >= 0 {
b = make([]byte, int(n))
if _, err := io.ReadFull(r, b); err != nil {
return err
if n := int(len); n >= 0 {
if n > sz {
return sz, errShortRead
}
b = make([]byte, n)
if _, err = io.ReadFull(r, b); err != nil {
return sz, err
}
sz -= n
}

*v = string(b)
return nil
return sz, nil
}

func readBytes(r *bufio.Reader, v *[]byte) error {
func readBytes(r *bufio.Reader, sz int, v *[]byte) (int, error) {
var err error
var len int32
var b []byte
var n int32

if err := readInt32(r, &n); err != nil {
return err
if sz, err = readInt32(r, sz, &len); err != nil {
return sz, err
}

if n >= 0 {
b = make([]byte, int(n))
if _, err := io.ReadFull(r, b); err != nil {
return err
if n := int(len); n >= 0 {
if n > sz {
return sz, errShortRead
}
b = make([]byte, n)
if _, err = io.ReadFull(r, b); err != nil {
return sz, err
}
sz -= n
}

*v = b
return nil
}

func readResponse(r *bufio.Reader, size int32, res interface{}) error {
// TODO: check size?
return read(r, res)
return sz, nil
}

func read(r *bufio.Reader, a interface{}) error {
func read(r *bufio.Reader, sz int, a interface{}) (int, error) {
switch v := a.(type) {
case *int8:
return readInt8(r, v)
return readInt8(r, sz, v)
case *int16:
return readInt16(r, v)
return readInt16(r, sz, v)
case *int32:
return readInt32(r, v)
return readInt32(r, sz, v)
case *int64:
return readInt64(r, v)
return readInt64(r, sz, v)
case *string:
return readString(r, v)
return readString(r, sz, v)
case *[]byte:
return readBytes(r, v)
return readBytes(r, sz, v)
}

v := reflect.ValueOf(a)
Expand All @@ -298,39 +311,52 @@ func read(r *bufio.Reader, a interface{}) error {

switch v = v.Elem(); v.Kind() {
case reflect.Struct:
return readStruct(r, v)
return readStruct(r, sz, v)
case reflect.Slice:
return readSlice(r, v)
return readSlice(r, sz, v)
default:
panic(fmt.Sprintf("unsupported type: %T", a))
}
}

func readStruct(r *bufio.Reader, v reflect.Value) error {
func readStruct(r *bufio.Reader, sz int, v reflect.Value) (int, error) {
var err error
for i, n := 0, v.NumField(); i != n; i++ {
if err := read(r, v.Field(i).Addr().Interface()); err != nil {
return err
if sz, err = read(r, sz, v.Field(i).Addr().Interface()); err != nil {
return sz, err
}
}
return nil
return sz, nil
}

func readSlice(r *bufio.Reader, v reflect.Value) error {
size := int32(0)
func readSlice(r *bufio.Reader, sz int, v reflect.Value) (int, error) {
var err error
var len int32

if err := readInt32(r, &size); err != nil {
return err
if sz, err = readInt32(r, sz, &len); err != nil {
return sz, err
}

n := int(size)
n := int(len)
v.Set(reflect.MakeSlice(v.Type(), n, n))

for i := 0; i != n; i++ {
if err := read(r, v.Index(i).Addr().Interface()); err != nil {
return err
if sz, err = read(r, sz, v.Index(i).Addr().Interface()); err != nil {
return sz, err
}
}

return sz, nil
}

func readResponse(r *bufio.Reader, sz int, res interface{}) error {
n, err := read(r, sz, res)
if err != nil {
return err
}
if n != 0 {
return fmt.Errorf("reading a response of size %d left %d unread bytes", sz, n)
}
return nil
}

Expand Down Expand Up @@ -397,13 +423,6 @@ func writeSmallBuffer(w *bufio.Writer, b []byte) error {
return nil
}

func writeRequest(w *bufio.Writer, hdr requestHeader, req interface{}) error {
hdr.Size = (sizeof(hdr) + sizeof(req)) - 4
write(w, hdr)
write(w, req)
return w.Flush()
}

func write(w *bufio.Writer, a interface{}) error {
switch v := a.(type) {
case int8:
Expand Down Expand Up @@ -453,6 +472,13 @@ func writeSlice(w *bufio.Writer, v reflect.Value) error {
return nil
}

func writeRequest(w *bufio.Writer, hdr requestHeader, req interface{}) error {
hdr.Size = (sizeof(hdr) + sizeof(req)) - 4
write(w, hdr)
write(w, req)
return w.Flush()
}

func sizeof(a interface{}) int32 {
switch v := a.(type) {
case int8:
Expand Down
2 changes: 1 addition & 1 deletion protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestProtocol(t *testing.T) {

v := reflect.New(reflect.TypeOf(test))

if err := read(r, v.Interface()); err != nil {
if _, err := read(r, b.Len(), v.Interface()); err != nil {
t.Fatal(err)
}

Expand Down

0 comments on commit 1ae0b59

Please sign in to comment.