Skip to content

Commit

Permalink
refactor(detect): create readUntilSafeBoundary + add tests (#1676)
Browse files Browse the repository at this point in the history
  • Loading branch information
rgmz authored Dec 28, 2024
1 parent dbe3746 commit 4c3da6e
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 62 deletions.
147 changes: 93 additions & 54 deletions detect/directory.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package detect

import (
"bufio"
"bytes"
"io"
"os"
Expand Down Expand Up @@ -49,64 +50,32 @@ func (d *Detector) DetectFiles(paths <-chan sources.ScanTarget) ([]report.Findin
}
}

// Buffer to hold file chunks
buf := make([]byte, chunkSize)
totalLines := 0
var (
// Buffer to hold file chunks
reader = bufio.NewReaderSize(f, chunkSize)
buf = make([]byte, chunkSize)
totalLines = 0
)
for {
n, err := f.Read(buf)
if n > 0 {
// TODO: optimization could be introduced here
if mimetype, err := filetype.Match(buf[:n]); err != nil {
return nil
} else if mimetype.MIME.Type == "application" {
return nil // skip binary files
}

// If the chunk doesn't end in a newline, peek |maxPeekSize| until we find one.
// This hopefully avoids splitting
// See: https://github.com/gitleaks/gitleaks/issues/1651
var (
peekBuf = bytes.NewBuffer(buf[:n])
tempBuf = make([]byte, 1)
newlineCount = 0 // Tracks consecutive newlines
)
for {
data := peekBuf.Bytes()
if len(data) == 0 {
break
}

// Check if the last character is a newline.
lastChar := data[len(data)-1]
if lastChar == '\n' || lastChar == '\r' {
newlineCount++

// Stop if two consecutive newlines are found
if newlineCount >= 2 {
break
}
} else {
newlineCount = 0 // Reset if a non-newline character is found
}

// Stop growing the buffer if it reaches maxSize
if (peekBuf.Len() - n) >= maxPeekSize {
break
}
n, err := reader.Read(buf)

// Read additional data into a temporary buffer
m, readErr := f.Read(tempBuf)
if m > 0 {
peekBuf.Write(tempBuf[:m])
// "Callers should always process the n > 0 bytes returned before considering the error err."
// https://pkg.go.dev/io#Reader
if n > 0 {
// Only check the filetype at the start of file.
if totalLines == 0 {
// TODO: could other optimizations be introduced here?
if mimetype, err := filetype.Match(buf[:n]); err != nil {
return nil
} else if mimetype.MIME.Type == "application" {
return nil // skip binary files
}
}

// Stop if EOF is reached
if readErr != nil {
if readErr == io.EOF {
break
}
return readErr
}
// Try to split chunks across large areas of whitespace, if possible.
peekBuf := bytes.NewBuffer(buf[:n])
if readErr := readUntilSafeBoundary(reader, n, maxPeekSize, peekBuf); readErr != nil {
return readErr
}

// Count the number of newlines in this chunk
Expand Down Expand Up @@ -145,3 +114,73 @@ func (d *Detector) DetectFiles(paths <-chan sources.ScanTarget) ([]report.Findin

return d.findings, nil
}

// readUntilSafeBoundary consumes |f| until it finds two consecutive `\n` characters, up to |maxPeekSize|.
// This hopefully avoids splitting. (https://github.com/gitleaks/gitleaks/issues/1651)
func readUntilSafeBoundary(r *bufio.Reader, n int, maxPeekSize int, peekBuf *bytes.Buffer) error {
if peekBuf.Len() == 0 {
return nil
}

// Does the buffer end in consecutive newlines?
var (
data = peekBuf.Bytes()
lastChar = data[len(data)-1]
newlineCount = 0 // Tracks consecutive newlines
)
if isWhitespace(lastChar) {
for i := len(data) - 1; i >= 0; i-- {
lastChar = data[i]
if lastChar == '\n' {
newlineCount++

// Stop if two consecutive newlines are found
if newlineCount >= 2 {
return nil
}
} else if lastChar == '\r' || lastChar == ' ' || lastChar == '\t' {
// The presence of other whitespace characters (`\r`, ` `, `\t`) shouldn't reset the count.
// (Intentionally do nothing.)
} else {
break
}
}
}

// If not, read ahead until we (hopefully) find some.
newlineCount = 0
for {
data = peekBuf.Bytes()
// Check if the last character is a newline.
lastChar = data[len(data)-1]
if lastChar == '\n' {
newlineCount++

// Stop if two consecutive newlines are found
if newlineCount >= 2 {
break
}
} else if lastChar == '\r' || lastChar == ' ' || lastChar == '\t' {
// The presence of other whitespace characters (`\r`, ` `, `\t`) shouldn't reset the count.
// (Intentionally do nothing.)
} else {
newlineCount = 0 // Reset if a non-newline character is found
}

// Stop growing the buffer if it reaches maxSize
if (peekBuf.Len() - n) >= maxPeekSize {
break
}

// Read additional data into a temporary buffer
b, err := r.ReadByte()
if err != nil {
if err == io.EOF {
break
}
return err
}
peekBuf.WriteByte(b)
}
return nil
}
72 changes: 72 additions & 0 deletions detect/directory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package detect

import (
"bufio"
"bytes"
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func Test_readUntilSafeBoundary(t *testing.T) {
// Arrange
cases := []struct {
name string
r io.Reader
expected string
}{
// Current split is fine, exit early.
{
name: "safe original split - LF",
r: strings.NewReader("abc\n\ndefghijklmnop\n\nqrstuvwxyz"),
expected: "abc\n\n",
},
{
name: "safe original split - CRLF",
r: strings.NewReader("a\r\n\r\nbcdefghijklmnop\n"),
expected: "a\r\n\r\n",
},
// Current split is bad, look for a better one.
{
name: "safe split - LF",
r: strings.NewReader("abcdefg\nhijklmnop\n\nqrstuvwxyz"),
expected: "abcdefg\nhijklmnop\n\n",
},
{
name: "safe split - CRLF",
r: strings.NewReader("abcdefg\r\nhijklmnop\r\n\r\nqrstuvwxyz"),
expected: "abcdefg\r\nhijklmnop\r\n\r\n",
},
{
name: "safe split - blank line",
r: strings.NewReader("abcdefg\nhijklmnop\n\t \t\nqrstuvwxyz"),
expected: "abcdefg\nhijklmnop\n\t \t\n",
},
// Current split is bad, exhaust options.
{
name: "no safe split",
r: strings.NewReader("abcdefg\nhijklmnopqrstuvwxyz"),
expected: "abcdefg\nhijklmnopqrstuvwx",
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
buf := make([]byte, 5)
n, err := c.r.Read(buf)
require.NoError(t, err)

// Act
reader := bufio.NewReader(c.r)
peekBuf := bytes.NewBuffer(buf[:n])
err = readUntilSafeBoundary(reader, n, 20, peekBuf)
require.NoError(t, err)

// Assert
t.Logf(peekBuf.String())
require.Equal(t, c.expected, string(peekBuf.Bytes()))
})
}
}
20 changes: 13 additions & 7 deletions detect/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package detect

import (
"bufio"
"bytes"
"io"

"github.com/zricethezav/gitleaks/v8/report"
Expand All @@ -10,18 +11,23 @@ import (
// DetectReader accepts an io.Reader and a buffer size for the reader in KB
func (d *Detector) DetectReader(r io.Reader, bufSize int) ([]report.Finding, error) {
reader := bufio.NewReader(r)
buf := make([]byte, 0, 1000*bufSize)
buf := make([]byte, 1000*bufSize)
findings := []report.Finding{}

for {
n, err := reader.Read(buf[:cap(buf)])
n, err := reader.Read(buf)

// "Callers should always process the n > 0 bytes returned before considering the error err."
// https://pkg.go.dev/io#Reader
if n > 0 {
buf = buf[:n]
// Try to split chunks across large areas of whitespace, if possible.
peekBuf := bytes.NewBuffer(buf[:n])
if readErr := readUntilSafeBoundary(reader, n, maxPeekSize, peekBuf); readErr != nil {
return findings, readErr
}

fragment := Fragment{
Raw: string(buf),
Raw: peekBuf.String(),
}
for _, finding := range d.Detect(fragment) {
findings = append(findings, finding)
Expand All @@ -32,10 +38,10 @@ func (d *Detector) DetectReader(r io.Reader, bufSize int) ([]report.Finding, err
}

if err != nil {
if err != io.EOF {
return findings, err
if err == io.EOF {
break
}
break
return findings, err
}
}

Expand Down
9 changes: 8 additions & 1 deletion detect/reader_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
package detect

import (
"github.com/stretchr/testify/require"
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"
)

const secret = "AKIAIRYLJVKMPEGZMPJS"

type mockReader struct {
data []byte
read bool
}

func (r *mockReader) Read(p []byte) (n int, err error) {
if r.read {
return 0, io.EOF
}

// Copy data to the provided buffer.
n = copy(p, r.data)
r.read = true

// Return io.EOF along with the bytes.
return n, io.EOF
Expand Down
4 changes: 4 additions & 0 deletions detect/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,7 @@ func containsDigit(s string) bool {
}
return false
}

func isWhitespace(ch byte) bool {
return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'
}

0 comments on commit 4c3da6e

Please sign in to comment.