Skip to content

Commit

Permalink
look at the body to determine if it has content or not
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Porto Carrero <ivan@flanders.co.nz>
  • Loading branch information
casualjim committed Jul 10, 2019
1 parent 58872d9 commit 09f01ee
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 27 deletions.
3 changes: 2 additions & 1 deletion middleware/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/go-openapi/errors"
Expand Down Expand Up @@ -81,7 +82,7 @@ func TestContentTypeValidation(t *testing.T) {
assert.Equal(t, "application/json", recorder.Header().Get("content-type"))

recorder = httptest.NewRecorder()
request, _ = http.NewRequest("POST", "/api/pets", nil)
request, _ = http.NewRequest("POST", "/api/pets", strings.NewReader(`{"name":"dog"}`))
request.Header.Add("Accept", "application/json")
request.Header.Add("content-type", "text/html")
request.TransferEncoding = []string{"chunked"}
Expand Down
64 changes: 63 additions & 1 deletion request.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package runtime

import (
"bufio"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -42,7 +43,68 @@ func AllowsBody(r *http.Request) bool {

// HasBody returns true if this method needs a content-type
func HasBody(r *http.Request) bool {
return len(r.TransferEncoding) > 0 || r.ContentLength > 0
// happy case: we have a content length set
if r.ContentLength > 0 {
return true
}

if r.Header.Get(http.CanonicalHeaderKey("content-length")) != "" {
// in this case, no Transfer-Encoding should be present
// we have a header set but it was explicitly set to 0, so we assume no body
return false
}

rdr := newPeekingReader(r.Body)
r.Body = rdr
return rdr.HasContent()
}

func newPeekingReader(r io.ReadCloser) *peekingReader {
if r == nil {
return nil
}
return &peekingReader{
underlying: bufio.NewReader(r),
orig: r,
}
}

type peekingReader struct {
underlying interface {
Buffered() int
Peek(int) ([]byte, error)
Read([]byte) (int, error)
}
orig io.ReadCloser
}

func (p *peekingReader) HasContent() bool {
if p == nil {
return false
}
if p.underlying.Buffered() > 0 {
return true
}
b, err := p.underlying.Peek(1)
if err != nil {
return false
}
return len(b) > 0
}

func (p *peekingReader) Read(d []byte) (int, error) {
if p == nil {
return 0, io.EOF
}
return p.underlying.Read(d)
}

func (p *peekingReader) Close() error {
p.underlying = nil
if p.orig != nil {
return p.orig.Close()
}
return nil
}

// JSONRequest creates a new http request with json headers set
Expand Down
122 changes: 97 additions & 25 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,130 @@
package runtime

import (
"bufio"
"bytes"
"io"
"io/ioutil"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/require"

"github.com/stretchr/testify/assert"
)

/*
type tstreadcloser struct {
closed bool
type eofReader struct{}

func (e *eofReader) Read(d []byte) (int, error) {
return 0, io.EOF
}

func closeReader(rdr io.Reader) *closeCounting {
return &closeCounting{
rdr: rdr,
}
}

type closeCounting struct {
rdr io.Reader
closed int
}

func (c *closeCounting) Read(d []byte) (int, error) {
return c.rdr.Read(d)
}

func (t *tstreadcloser) Read(p []byte) (int, error) { return 0, nil }
func (t *tstreadcloser) Close() error {
t.closed = true
func (c *closeCounting) Close() error {
c.closed++
if cr, ok := c.rdr.(io.ReadCloser); ok {
return cr.Close()
}
return nil
}

type countingBufioReader struct {
buffereds int
peeks int
reads int

br interface {
Buffered() int
Peek(int) ([]byte, error)
Read([]byte) (int, error)
}
}

func (c *countingBufioReader) Buffered() int {
c.buffereds++
return c.br.Buffered()
}

func (c *countingBufioReader) Peek(v int) ([]byte, error) {
c.peeks++
return c.br.Peek(v)
}

func (c *countingBufioReader) Read(p []byte) (int, error) {
c.reads++
return c.br.Read(p)
}

func TestPeekingReader(t *testing.T) {
// just passes to original reader when nothing called
exp1 := []byte("original")
pr1 := &peekingReader{rdr: ioutil.NopCloser(bytes.NewReader(exp1))}
pr1 := newPeekingReader(closeReader(bytes.NewReader(exp1)))
b1, err := ioutil.ReadAll(pr1)
if assert.NoError(t, err) {
assert.Equal(t, exp1, b1)
}

// uses actual when there was some buffering
exp2 := []byte("actual")
pt1, pt2 := []byte("a"), []byte("ctual")
pr2 := &peekingReader{
rdr: ioutil.NopCloser(bytes.NewReader(exp1)),
actual: io.MultiReader(bytes.NewReader(pt1), bytes.NewReader(pt2)),
peeked: pt1,
}
pr2 := newPeekingReader(closeReader(bytes.NewReader(exp2)))
peeked, err := pr2.underlying.Peek(1)
require.NoError(t, err)
require.Equal(t, "a", string(peeked))
b2, err := ioutil.ReadAll(pr2)
if assert.NoError(t, err) {
assert.Equal(t, exp2, b2)
assert.Equal(t, string(exp2), string(b2))
}

// closes original reader
tr := new(tstreadcloser)
pr3 := &peekingReader{
rdr: tr,
actual: ioutil.NopCloser(bytes.NewBuffer(nil)),
peeked: pt1,
// passes close call through to original reader
cr := closeReader(closeReader(bytes.NewReader(exp2)))
pr3 := newPeekingReader(cr)
require.NoError(t, pr3.Close())
require.Equal(t, 1, cr.closed)

// returns false when the stream is empty
pr4 := newPeekingReader(closeReader(&eofReader{}))
require.False(t, pr4.HasContent())

// returns true when the stream has content
rdr := closeReader(strings.NewReader("hello"))
pr := newPeekingReader(rdr)
cbr := &countingBufioReader{
br: bufio.NewReader(rdr),
}
// returns true when peeked previously with data
// returns true when peeked with data
pr.underlying = cbr

require.True(t, pr.HasContent())
require.Equal(t, 1, cbr.buffereds)
require.Equal(t, 1, cbr.peeks)
require.Equal(t, 0, cbr.reads)
require.True(t, pr.HasContent())
require.Equal(t, 2, cbr.buffereds)
require.Equal(t, 1, cbr.peeks)
require.Equal(t, 0, cbr.reads)

b, err := ioutil.ReadAll(pr)
require.NoError(t, err)
require.Equal(t, "hello", string(b))
require.Equal(t, 2, cbr.buffereds)
require.Equal(t, 1, cbr.peeks)
require.Equal(t, 2, cbr.reads)
require.Equal(t, 0, cbr.br.Buffered())
}
*/

func TestJSONRequest(t *testing.T) {
req, err := JSONRequest("GET", "/swagger.json", nil)
Expand Down

0 comments on commit 09f01ee

Please sign in to comment.