Skip to content

Commit

Permalink
Improving some internal error-handling (segmentio#846)
Browse files Browse the repository at this point in the history
* refactor: use more errors.Is to allow errors to be wrapped more safely

* refactor(dialer): wrap returned errors with context

* refactor: decorate more errors

* revert some error wrap changes to reduce PR noise

* revert conn changes
  • Loading branch information
dominicbarnes authored Apr 1, 2022
1 parent ec59669 commit ae86f55
Show file tree
Hide file tree
Showing 15 changed files with 66 additions and 53 deletions.
14 changes: 8 additions & 6 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kafka

import (
"bufio"
"errors"
"io"
"sync"
"time"
Expand Down Expand Up @@ -82,7 +83,7 @@ func (batch *Batch) close() (err error) {
batch.msgs.discard()
}

if err = batch.err; err == io.EOF {
if err = batch.err; errors.Is(batch.err, io.EOF) {
err = nil
}

Expand All @@ -93,7 +94,8 @@ func (batch *Batch) close() (err error) {
conn.mutex.Unlock()

if err != nil {
if _, ok := err.(Error); !ok && err != io.ErrShortBuffer {
var kafkaError Error
if !errors.As(err, &kafkaError) && !errors.Is(err, io.ErrShortBuffer) {
conn.Close()
}
}
Expand Down Expand Up @@ -238,11 +240,11 @@ func (batch *Batch) readMessage(

var lastOffset int64
offset, lastOffset, timestamp, headers, err = batch.msgs.readMessage(batch.offset, key, val)
switch err {
case nil:
switch {
case err == nil:
batch.offset = offset + 1
batch.lastOffset = lastOffset
case errShortRead:
case errors.Is(err, errShortRead):
// As an "optimization" kafka truncates the returned response after
// producing MaxBytes, which could then cause the code to return
// errShortRead.
Expand Down Expand Up @@ -272,7 +274,7 @@ func (batch *Batch) readMessage(
// to MaxBytes truncation
// - `batch.lastOffset` to ensure that the message format contains
// `lastOffset`
if batch.err == io.EOF && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 {
if errors.Is(batch.err, io.EOF) && batch.msgs.lengthRemain == 0 && batch.lastOffset != -1 {
// Log compaction can create batches that end with compacted
// records so the normal strategy that increments the "next"
// offset as records are read doesn't work as the compacted
Expand Down
5 changes: 3 additions & 2 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kafka

import (
"context"
"errors"
"io"
"net"
"strconv"
Expand Down Expand Up @@ -30,11 +31,11 @@ func TestBatchDontExpectEOF(t *testing.T) {

batch := conn.ReadBatch(1024, 8192)

if _, err := batch.ReadMessage(); err != io.ErrUnexpectedEOF {
if _, err := batch.ReadMessage(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("bad error when reading message:", err)
}

if err := batch.Close(); err != io.ErrUnexpectedEOF {
if err := batch.Close(); !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("bad error when closing the batch:", err)
}
}
5 changes: 3 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kafka
import (
"context"
"errors"
"fmt"
"net"
"time"

Expand Down Expand Up @@ -67,7 +68,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int
})

if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get topic metadata :%w", err)
}

topic := metadata.Topics[0]
Expand All @@ -85,7 +86,7 @@ func (c *Client) ConsumerOffsets(ctx context.Context, tg TopicAndGroup) (map[int
})

if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get offsets: %w", err)
}

topicOffsets := offsets.Topics[topic.Name]
Expand Down
7 changes: 4 additions & 3 deletions compress/snappy/xerial.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package snappy
import (
"bytes"
"encoding/binary"
"errors"
"io"

"github.com/klauspost/compress/snappy"
Expand Down Expand Up @@ -64,7 +65,7 @@ func (x *xerialReader) WriteTo(w io.Writer) (int64, error) {
}

if _, err := x.readChunk(nil); err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return wn, err
Expand Down Expand Up @@ -128,7 +129,7 @@ func (x *xerialReader) readChunk(dst []byte) (int, error) {
n, err := x.read(x.input[len(x.input):cap(x.input)])
x.input = x.input[:len(x.input)+n]
if err != nil {
if err == io.EOF && len(x.input) > 0 {
if errors.Is(err, io.EOF) && len(x.input) > 0 {
break
}
return 0, err
Expand Down Expand Up @@ -212,7 +213,7 @@ func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) {
}

if err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
err = nil
}
return wn, err
Expand Down
5 changes: 3 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
default:
throttle, highWaterMark, remain, err = readFetchResponseHeaderV2(&c.rbuf, size)
}
if err == errShortRead {
if errors.Is(err, errShortRead) {
err = checkTimeoutErr(adjustedDeadline)
}

Expand All @@ -865,9 +865,10 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch {
msgs, err = newMessageSetReader(&c.rbuf, remain)
}
}
if err == errShortRead {
if errors.Is(err, errShortRead) {
err = checkTimeoutErr(adjustedDeadline)
}

return &Batch{
conn: c,
msgs: msgs,
Expand Down
26 changes: 15 additions & 11 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package kafka
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -640,10 +641,13 @@ func testConnReadBatchWithMaxWait(t *testing.T, conn *Conn) {
conn.Seek(0, SeekAbsolute)
conn.SetDeadline(time.Now().Add(50 * time.Millisecond))
batch = conn.ReadBatchWith(cfg)
var netErr net.Error
if err := batch.Err(); err == nil {
t.Fatal("should have timed out, but got no error")
} else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
t.Fatalf("should have timed out, but got: %v", err)
} else if errors.As(err, &netErr) {
if !netErr.Timeout() {
t.Fatalf("should have timed out, but got: %v", err)
}
}
}

Expand Down Expand Up @@ -761,7 +765,7 @@ func testConnFindCoordinator(t *testing.T, conn *Conn) {

func testConnJoinGroupInvalidGroupID(t *testing.T, conn *Conn) {
_, err := conn.joinGroup(joinGroupRequestV1{})
if err != InvalidGroupId && err != NotCoordinatorForGroup {
if !errors.Is(err, InvalidGroupId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", InvalidGroupId, NotCoordinatorForGroup, err)
}
}
Expand All @@ -773,7 +777,7 @@ func testConnJoinGroupInvalidSessionTimeout(t *testing.T, conn *Conn) {
_, err := conn.joinGroup(joinGroupRequestV1{
GroupID: groupID,
})
if err != InvalidSessionTimeout && err != NotCoordinatorForGroup {
if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err)
}
}
Expand All @@ -786,7 +790,7 @@ func testConnJoinGroupInvalidRefreshTimeout(t *testing.T, conn *Conn) {
GroupID: groupID,
SessionTimeout: int32(3 * time.Second / time.Millisecond),
})
if err != InvalidSessionTimeout && err != NotCoordinatorForGroup {
if !errors.Is(err, InvalidSessionTimeout) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", InvalidSessionTimeout, NotCoordinatorForGroup, err)
}
}
Expand All @@ -798,7 +802,7 @@ func testConnHeartbeatErr(t *testing.T, conn *Conn) {
_, err := conn.syncGroup(syncGroupRequestV0{
GroupID: groupID,
})
if err != UnknownMemberId && err != NotCoordinatorForGroup {
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
}
}
Expand All @@ -810,7 +814,7 @@ func testConnLeaveGroupErr(t *testing.T, conn *Conn) {
_, err := conn.leaveGroup(leaveGroupRequestV0{
GroupID: groupID,
})
if err != UnknownMemberId && err != NotCoordinatorForGroup {
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
}
}
Expand All @@ -822,7 +826,7 @@ func testConnSyncGroupErr(t *testing.T, conn *Conn) {
_, err := conn.syncGroup(syncGroupRequestV0{
GroupID: groupID,
})
if err != UnknownMemberId && err != NotCoordinatorForGroup {
if !errors.Is(err, UnknownMemberId) && !errors.Is(err, NotCoordinatorForGroup) {
t.Fatalf("expected %v or %v; got %v", UnknownMemberId, NotCoordinatorForGroup, err)
}
}
Expand Down Expand Up @@ -985,7 +989,7 @@ func testConnReadShortBuffer(t *testing.T, conn *Conn) {
b[3] = 0

n, err := conn.Read(b)
if err != io.ErrShortBuffer {
if !errors.Is(err, io.ErrShortBuffer) {
t.Error("bad error:", i, err)
}
if n != 4 {
Expand Down Expand Up @@ -1061,7 +1065,7 @@ func testDeleteTopicsInvalidTopic(t *testing.T, conn *Conn) {
}
conn.SetDeadline(time.Now().Add(5 * time.Second))
err = conn.DeleteTopics("invalid-topic", topic)
if err != UnknownTopicOrPartition {
if !errors.Is(err, UnknownTopicOrPartition) {
t.Fatalf("expected UnknownTopicOrPartition error, but got %v", err)
}
partitions, err := conn.ReadPartitions(topic)
Expand Down Expand Up @@ -1154,7 +1158,7 @@ func TestUnsupportedSASLMechanism(t *testing.T) {
}
defer conn.Close()

if err := conn.saslHandshake("FOO"); err != UnsupportedSASLMechanism {
if err := conn.saslHandshake("FOO"); !errors.Is(err, UnsupportedSASLMechanism) {
t.Errorf("Expected UnsupportedSASLMechanism but got %v", err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion consumergroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,7 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup
// assignments for the topic. this matches the behavior of the official
// clients: java, python, and librdkafka.
// a topic watcher can trigger a rebalance when the topic comes into being.
if err != nil && err != UnknownTopicOrPartition {
if err != nil && !errors.Is(err, UnknownTopicOrPartition) {
return nil, err
}

Expand Down
12 changes: 6 additions & 6 deletions consumergroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func TestConsumerGroup(t *testing.T) {
if gen != nil {
t.Errorf("expected generation to be nil")
}
if err != context.Canceled {
if !errors.Is(err, context.Canceled) {
t.Errorf("expected context.Canceled, but got %+v", err)
}
},
Expand All @@ -301,7 +301,7 @@ func TestConsumerGroup(t *testing.T) {
if gen != nil {
t.Errorf("expected generation to be nil")
}
if err != ErrGroupClosed {
if !errors.Is(err, ErrGroupClosed) {
t.Errorf("expected ErrGroupClosed, but got %+v", err)
}
},
Expand Down Expand Up @@ -398,7 +398,7 @@ func TestConsumerGroupErrors(t *testing.T) {
gen, err := group.Next(ctx)
if err == nil {
t.Errorf("expected an error")
} else if err != NotCoordinatorForGroup {
} else if !errors.Is(err, NotCoordinatorForGroup) {
t.Errorf("got wrong error: %+v", err)
}
if gen != nil {
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestConsumerGroupErrors(t *testing.T) {
gen, err := group.Next(ctx)
if err == nil {
t.Errorf("expected an error")
} else if err != InvalidTopic {
} else if !errors.Is(err, InvalidTopic) {
t.Errorf("got wrong error: %+v", err)
}
if gen != nil {
Expand Down Expand Up @@ -540,7 +540,7 @@ func TestConsumerGroupErrors(t *testing.T) {
gen, err := group.Next(ctx)
if err == nil {
t.Errorf("expected an error")
} else if err != InvalidTopic {
} else if !errors.Is(err, InvalidTopic) {
t.Errorf("got wrong error: %+v", err)
}
if gen != nil {
Expand Down Expand Up @@ -672,7 +672,7 @@ func TestGenerationStartsFunctionAfterClosed(t *testing.T) {
case <-time.After(time.Second):
t.Fatal("timed out waiting for func to run")
case err := <-ch:
if err != ErrGenerationEnded {
if !errors.Is(err, ErrGenerationEnded) {
t.Fatalf("expected %v but got %v", ErrGenerationEnded, err)
}
}
Expand Down
Loading

0 comments on commit ae86f55

Please sign in to comment.