Skip to content

Commit

Permalink
Fix bug in readNewBytes so it doesn't attempt to read beyond sz (segm…
Browse files Browse the repository at this point in the history
  • Loading branch information
stevevls authored Feb 21, 2019
1 parent fcb01ba commit ecca8c3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
10 changes: 10 additions & 0 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,22 @@ func readBytesWith(r *bufio.Reader, sz int, cb func(*bufio.Reader, int, int) (in
func readNewBytes(r *bufio.Reader, sz int, n int) ([]byte, int, error) {
var err error
var b []byte
var shortRead bool

if n > 0 {
if sz < n {
n = sz
shortRead = true
}

b = make([]byte, n)
n, err = io.ReadFull(r, b)
b = b[:n]
sz -= n

if err == nil && shortRead {
err = errShortRead
}
}

return b, sz, err
Expand Down
52 changes: 52 additions & 0 deletions read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,55 @@ func TestReadMapStringInt32(t *testing.T) {
})
}
}

func TestReadNewBytes(t *testing.T) {

t.Run("reads new bytes", func(t *testing.T) {
r := bufio.NewReader(bytes.NewReader([]byte("foobar")))

b, remain, err := readNewBytes(r, 6, 3)
if string(b) != "foo" {
t.Error("should have returned 3 bytes")
}
if remain != 3 {
t.Error("should have calculated remaining correctly")
}
if err != nil {
t.Error("should not have errored")
}

b, remain, err = readNewBytes(r, remain, 3)
if string(b) != "bar" {
t.Error("should have returned 3 bytes")
}
if remain != 0 {
t.Error("should have calculated remaining correctly")
}
if err != nil {
t.Error("should not have errored")
}

b, err = r.Peek(0)
if len(b) > 0 {
t.Error("not all bytes were consumed")
}
})

t.Run("discards bytes when insufficient", func(t *testing.T) {
r := bufio.NewReader(bytes.NewReader([]byte("foo")))
b, remain, err := readNewBytes(bufio.NewReader(r), 3, 4)
if string(b) != "foo" {
t.Error("should have returned available bytes")
}
if remain != 0 {
t.Error("all bytes should have been consumed")
}
if err != errShortRead {
t.Error("should have returned errShortRead")
}
b, err = r.Peek(0)
if len(b) > 0 {
t.Error("not all bytes were consumed")
}
})
}

0 comments on commit ecca8c3

Please sign in to comment.