diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 973b3ff2..2a26d588 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,14 +21,14 @@ jobs: test: strategy: matrix: - go-version: [1.20.x, 1.21.x, 1.22.x] + go-version: [1.21.x, 1.22.x, 1.23.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} timeout-minutes: 10 steps: - - uses: actions/setup-go@v3 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: test run: make ci diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index 8758876e..e5f9b277 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -21,19 +21,19 @@ jobs: linters: strategy: matrix: - go-version: [1.22.x] + go-version: [1.23.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} timeout-minutes: 10 steps: - - uses: actions/setup-go@v3 + - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: prepare generated code run: make prepare - name: lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v6 with: - version: v1.57.2 + version: v1.60.3 args: --print-resources-usage --timeout=10m --verbose diff --git a/.golangci.yml b/.golangci.yml index 4c44c5fa..f038fe43 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,3 +1,9 @@ linters: disable: - errcheck + +linters-settings: + staticcheck: + checks: + - all + - '-SA1019' # We use the ast package. diff --git a/Makefile b/Makefile index 00e528a6..a4eda24b 100644 --- a/Makefile +++ b/Makefile @@ -55,8 +55,8 @@ ci: prepare if [ `arch` == 'x86_64' ]; then \ sudo apt-get -y -q update; \ sudo apt-get -y -q install build-essential; \ - wget -q https://github.com/tinygo-org/tinygo/releases/download/v0.32.0/tinygo_0.32.0_amd64.deb; \ - sudo dpkg -i tinygo_0.32.0_amd64.deb; \ + wget -q https://github.com/tinygo-org/tinygo/releases/download/v0.33.0/tinygo_0.33.0_amd64.deb; \ + sudo dpkg -i tinygo_0.33.0_amd64.deb; \ export PATH=$$PATH:/usr/local/tinygo/bin; \ fi go test -v ./... ./_generated diff --git a/_generated/pointer.go b/_generated/pointer.go new file mode 100644 index 00000000..9860004d --- /dev/null +++ b/_generated/pointer.go @@ -0,0 +1,198 @@ +package _generated + +import ( + "fmt" + "time" + + "github.com/tinylib/msgp/msgp" +) + +//go:generate msgp $GOFILE$ + +// Generate only pointer receivers: + +//msgp:pointer + +var mustNoInterf = []interface{}{ + Pointer0{}, + NamedBoolPointer(true), + NamedIntPointer(0), + NamedFloat64Pointer(0), + NamedStringPointer(""), + NamedMapStructPointer(nil), + NamedMapStructPointer2(nil), + NamedMapStringPointer(nil), + NamedMapStringPointer2(nil), + EmbeddableStructPointer{}, + EmbeddableStruct2Pointer{}, + PointerHalfFull{}, + PointerNoName{}, +} + +var mustHaveInterf = []interface{}{ + &Pointer0{}, + mustPtr(NamedBoolPointer(true)), + mustPtr(NamedIntPointer(0)), + mustPtr(NamedFloat64Pointer(0)), + mustPtr(NamedStringPointer("")), + mustPtr(NamedMapStructPointer(nil)), + mustPtr(NamedMapStructPointer2(nil)), + mustPtr(NamedMapStringPointer(nil)), + mustPtr(NamedMapStringPointer2(nil)), + &EmbeddableStructPointer{}, + &EmbeddableStruct2Pointer{}, + &PointerHalfFull{}, + &PointerNoName{}, +} + +func mustPtr[T any](v T) *T { + return &v +} + +func init() { + for _, v := range mustNoInterf { + if _, ok := v.(msgp.Marshaler); ok { + panic(fmt.Sprintf("type %T supports interface", v)) + } + if _, ok := v.(msgp.Encodable); ok { + panic(fmt.Sprintf("type %T supports interface", v)) + } + } + for _, v := range mustHaveInterf { + if _, ok := v.(msgp.Marshaler); !ok { + panic(fmt.Sprintf("type %T does not support interface", v)) + } + if _, ok := v.(msgp.Encodable); !ok { + panic(fmt.Sprintf("type %T does not support interface", v)) + } + } +} + +type Pointer0 struct { + ABool bool `msg:"abool"` + AInt int `msg:"aint"` + AInt8 int8 `msg:"aint8"` + AInt16 int16 `msg:"aint16"` + AInt32 int32 `msg:"aint32"` + AInt64 int64 `msg:"aint64"` + AUint uint `msg:"auint"` + AUint8 uint8 `msg:"auint8"` + AUint16 uint16 `msg:"auint16"` + AUint32 uint32 `msg:"auint32"` + AUint64 uint64 `msg:"auint64"` + AFloat32 float32 `msg:"afloat32"` + AFloat64 float64 `msg:"afloat64"` + AComplex64 complex64 `msg:"acomplex64"` + AComplex128 complex128 `msg:"acomplex128"` + + ANamedBool bool `msg:"anamedbool"` + ANamedInt int `msg:"anamedint"` + ANamedFloat64 float64 `msg:"anamedfloat64"` + + AMapStrStr map[string]string `msg:"amapstrstr"` + + APtrNamedStr *NamedString `msg:"aptrnamedstr"` + + AString string `msg:"astring"` + ANamedString string `msg:"anamedstring"` + AByteSlice []byte `msg:"abyteslice"` + + ASliceString []string `msg:"aslicestring"` + ASliceNamedString []NamedString `msg:"aslicenamedstring"` + + ANamedStruct NamedStruct `msg:"anamedstruct"` + APtrNamedStruct *NamedStruct `msg:"aptrnamedstruct"` + + AUnnamedStruct struct { + A string `msg:"a"` + } `msg:"aunnamedstruct"` // omitempty not supported on unnamed struct + + EmbeddableStruct `msg:",flatten"` // embed flat + + EmbeddableStruct2 `msg:"embeddablestruct2"` // embed non-flat + + AArrayInt [5]int `msg:"aarrayint"` // not supported + + ATime time.Time `msg:"atime"` +} + +type ( + NamedBoolPointer bool + NamedIntPointer int + NamedFloat64Pointer float64 + NamedStringPointer string + NamedMapStructPointer map[string]Pointer0 + NamedMapStructPointer2 map[string]*Pointer0 + NamedMapStringPointer map[string]NamedStringPointer + NamedMapStringPointer2 map[string]*NamedStringPointer +) + +type EmbeddableStructPointer struct { + SomeEmbed string `msg:"someembed"` +} + +type EmbeddableStruct2Pointer struct { + SomeEmbed2 string `msg:"someembed2"` +} + +type NamedStructPointer struct { + A string `msg:"a"` + B string `msg:"b"` +} + +type PointerHalfFull struct { + Field00 string `msg:"field00"` + Field01 string `msg:"field01"` + Field02 string `msg:"field02"` + Field03 string `msg:"field03"` +} + +type PointerNoName struct { + ABool bool `msg:""` + AInt int `msg:""` + AInt8 int8 `msg:""` + AInt16 int16 `msg:""` + AInt32 int32 `msg:""` + AInt64 int64 `msg:""` + AUint uint `msg:""` + AUint8 uint8 `msg:""` + AUint16 uint16 `msg:""` + AUint32 uint32 `msg:""` + AUint64 uint64 `msg:""` + AFloat32 float32 `msg:""` + AFloat64 float64 `msg:""` + AComplex64 complex64 `msg:""` + AComplex128 complex128 `msg:""` + + ANamedBool bool `msg:""` + ANamedInt int `msg:""` + ANamedFloat64 float64 `msg:""` + + AMapStrF map[string]NamedFloat64Pointer `msg:""` + AMapStrStruct map[string]PointerHalfFull `msg:""` + AMapStrStruct2 map[string]*PointerHalfFull `msg:""` + + APtrNamedStr *NamedStringPointer `msg:""` + + AString string `msg:""` + AByteSlice []byte `msg:""` + + ASliceString []string `msg:""` + ASliceNamedString []NamedStringPointer `msg:""` + + ANamedStruct NamedStructPointer `msg:""` + APtrNamedStruct *NamedStructPointer `msg:""` + + AUnnamedStruct struct { + A string `msg:""` + } `msg:""` // omitempty not supported on unnamed struct + + EmbeddableStructPointer `msg:",flatten"` // embed flat + + EmbeddableStruct2Pointer `msg:""` // embed non-flat + + AArrayInt [5]int `msg:""` // not supported + + ATime time.Time `msg:""` + ADur time.Duration `msg:""` +} diff --git a/gen/elem.go b/gen/elem.go index c2c84b65..d397cbed 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -134,13 +134,22 @@ var builtins = map[string]struct{}{ } // common data/methods for every Elem -type common struct{ vname, alias string } +type common struct { + vname, alias string + ptrRcv bool +} func (c *common) SetVarname(s string) { c.vname = s } func (c *common) Varname() string { return c.vname } func (c *common) Alias(typ string) { c.alias = typ } func (c *common) hidden() {} func (c *common) AllowNil() bool { return false } +func (c *common) AlwaysPtr(set *bool) bool { + if c != nil && set != nil { + c.ptrRcv = *set + } + return c.ptrRcv +} func IsPrintable(e Elem) bool { if be, ok := e.(*BaseElem); ok && !be.Printable() { @@ -191,6 +200,9 @@ type Elem interface { // This is true for slices and maps. AllowNil() bool + // AlwaysPtr will return true if receiver should always be a pointer. + AlwaysPtr(set *bool) bool + // IfZeroExpr returns the expression to compare to an empty value // for this type, per the rules of the `omitempty` feature. // It is meant to be used in an if statement diff --git a/gen/encode.go b/gen/encode.go index 900c847b..af83e456 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -62,9 +62,16 @@ func (e *encodeGen) Execute(p Elem) error { e.ctx = &Context{} e.p.comment("EncodeMsg implements msgp.Encodable") - - e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", ogVar, rcv) next(e, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } e.p.nakedReturn() return e.p.err } diff --git a/gen/marshal.go b/gen/marshal.go index 66f280eb..5b94ff39 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -47,10 +47,18 @@ func (m *marshalGen) Execute(p Elem) error { // calling methodReceiver so // that z.Msgsize() is printed correctly c := p.Varname() - - m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", ogVar, rcv) m.p.printf("\no = msgp.Require(b, %s.Msgsize())", c) next(m, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } + m.p.nakedReturn() return m.p.err } @@ -280,7 +288,6 @@ func (m *marshalGen) gBase(b *BaseElem) { } m.fuseHook() vname := b.Varname() - if b.Convert { if b.ShimMode == Cast { vname = tobaseConvert(b) diff --git a/gen/size.go b/gen/size.go index e96e0319..3798532f 100644 --- a/gen/size.go +++ b/gen/size.go @@ -86,9 +86,17 @@ func (s *sizeGen) Execute(p Elem) error { s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message") - s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ogVar, rcv) s.state = assign next(s, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } s.p.nakedReturn() return s.p.err } diff --git a/msgp/defs.go b/msgp/defs.go index e265aa4f..47a8c183 100644 --- a/msgp/defs.go +++ b/msgp/defs.go @@ -32,6 +32,10 @@ const ( last5 = 0x1f first3 = 0xe0 last7 = 0x7f + + // recursionLimit is the limit of recursive calls. + // This limits the call depth of dynamic code, like Skip and interface conversions. + recursionLimit = 100000 ) func isfixint(b byte) bool { diff --git a/msgp/errors.go b/msgp/errors.go index 4f19359a..984cca32 100644 --- a/msgp/errors.go +++ b/msgp/errors.go @@ -13,6 +13,10 @@ var ( // contain the contents of the message ErrShortBytes error = errShort{} + // ErrRecursion is returned when the maximum recursion limit is reached for an operation. + // This should only realistically be seen on adversarial data trying to exhaust the stack. + ErrRecursion error = errRecursion{} + // this error is only returned // if we reach code that should // be unreachable @@ -134,6 +138,11 @@ func (f errFatal) Resumable() bool { return false } func (f errFatal) withContext(ctx string) error { f.ctx = addCtx(f.ctx, ctx); return f } +type errRecursion struct{} + +func (e errRecursion) Error() string { return "msgp: recursion limit reached" } +func (e errRecursion) Resumable() bool { return false } + // ArrayError is an error returned // when decoding a fix-sized array // of the wrong size diff --git a/msgp/json.go b/msgp/json.go index 0e11e603..fe570373 100644 --- a/msgp/json.go +++ b/msgp/json.go @@ -109,6 +109,13 @@ func rwMap(dst jsWriter, src *Reader) (n int, err error) { return dst.WriteString("{}") } + // This is potentially a recursive call. + if done, err := src.recursiveCall(); err != nil { + return 0, err + } else { + defer done() + } + err = dst.WriteByte('{') if err != nil { return @@ -162,6 +169,13 @@ func rwArray(dst jsWriter, src *Reader) (n int, err error) { if err != nil { return } + // This is potentially a recursive call. + if done, err := src.recursiveCall(); err != nil { + return 0, err + } else { + defer done() + } + var sz uint32 var nn int sz, err = src.ReadArrayHeader() diff --git a/msgp/json_bytes.go b/msgp/json_bytes.go index e6162d0a..88ec6045 100644 --- a/msgp/json_bytes.go +++ b/msgp/json_bytes.go @@ -9,12 +9,12 @@ import ( "time" ) -var unfuns [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error) +var unfuns [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error) func init() { // NOTE(pmh): this is best expressed as a jump table, // but gc doesn't do that yet. revisit post-go1.5. - unfuns = [_maxtype]func(jsWriter, []byte, []byte) ([]byte, []byte, error){ + unfuns = [_maxtype]func(jsWriter, []byte, []byte, int) ([]byte, []byte, error){ StrType: rwStringBytes, BinType: rwBytesBytes, MapType: rwMapBytes, @@ -51,7 +51,7 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) { dst = bufio.NewWriterSize(w, 512) } for len(msg) > 0 && err == nil { - msg, scratch, err = writeNext(dst, msg, scratch) + msg, scratch, err = writeNext(dst, msg, scratch, 0) } if !cast && err == nil { err = dst.(*bufio.Writer).Flush() @@ -59,7 +59,7 @@ func UnmarshalAsJSON(w io.Writer, msg []byte) ([]byte, error) { return msg, err } -func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func writeNext(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { if len(msg) < 1 { return msg, scratch, ErrShortBytes } @@ -76,10 +76,13 @@ func writeNext(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { t = TimeType } } - return unfuns[t](w, msg, scratch) + return unfuns[t](w, msg, scratch, depth) } -func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwArrayBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { + if depth >= recursionLimit { + return msg, scratch, ErrRecursion + } sz, msg, err := ReadArrayHeaderBytes(msg) if err != nil { return msg, scratch, err @@ -95,7 +98,7 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error return msg, scratch, err } } - msg, scratch, err = writeNext(w, msg, scratch) + msg, scratch, err = writeNext(w, msg, scratch, depth+1) if err != nil { return msg, scratch, err } @@ -104,7 +107,10 @@ func rwArrayBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error return msg, scratch, err } -func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwMapBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { + if depth >= recursionLimit { + return msg, scratch, ErrRecursion + } sz, msg, err := ReadMapHeaderBytes(msg) if err != nil { return msg, scratch, err @@ -120,7 +126,7 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } } - msg, scratch, err = rwMapKeyBytes(w, msg, scratch) + msg, scratch, err = rwMapKeyBytes(w, msg, scratch, depth) if err != nil { return msg, scratch, err } @@ -128,7 +134,7 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) if err != nil { return msg, scratch, err } - msg, scratch, err = writeNext(w, msg, scratch) + msg, scratch, err = writeNext(w, msg, scratch, depth+1) if err != nil { return msg, scratch, err } @@ -137,17 +143,17 @@ func rwMapBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { - msg, scratch, err := rwStringBytes(w, msg, scratch) +func rwMapKeyBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { + msg, scratch, err := rwStringBytes(w, msg, scratch, depth) if err != nil { if tperr, ok := err.(TypeError); ok && tperr.Encoded == BinType { - return rwBytesBytes(w, msg, scratch) + return rwBytesBytes(w, msg, scratch, depth) } } return msg, scratch, err } -func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwStringBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { str, msg, err := ReadStringZC(msg) if err != nil { return msg, scratch, err @@ -156,7 +162,7 @@ func rwStringBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, erro return msg, scratch, err } -func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwBytesBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { bts, msg, err := ReadBytesZC(msg) if err != nil { return msg, scratch, err @@ -180,7 +186,7 @@ func rwBytesBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error return msg, scratch, err } -func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwNullBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { msg, err := ReadNilBytes(msg) if err != nil { return msg, scratch, err @@ -189,7 +195,7 @@ func rwNullBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwBoolBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { b, msg, err := ReadBoolBytes(msg) if err != nil { return msg, scratch, err @@ -202,7 +208,7 @@ func rwBoolBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwIntBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { i, msg, err := ReadInt64Bytes(msg) if err != nil { return msg, scratch, err @@ -212,7 +218,7 @@ func rwIntBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwUintBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { u, msg, err := ReadUint64Bytes(msg) if err != nil { return msg, scratch, err @@ -222,7 +228,7 @@ func rwUintBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var f float32 var err error f, msg, err = ReadFloat32Bytes(msg) @@ -234,7 +240,7 @@ func rwFloat32Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err return msg, scratch, err } -func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var f float64 var err error f, msg, err = ReadFloat64Bytes(msg) @@ -246,7 +252,7 @@ func rwFloat64Bytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, err return msg, scratch, err } -func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwTimeBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var t time.Time var err error t, msg, err = ReadTimeBytes(msg) @@ -261,7 +267,7 @@ func rwTimeBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) return msg, scratch, err } -func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte) ([]byte, []byte, error) { +func rwExtensionBytes(w jsWriter, msg []byte, scratch []byte, depth int) ([]byte, []byte, error) { var err error var et int8 et, err = peekExtension(msg) diff --git a/msgp/purego.go b/msgp/purego.go index 2cd35c3e..fe872341 100644 --- a/msgp/purego.go +++ b/msgp/purego.go @@ -1,5 +1,5 @@ -//go:build purego || appengine -// +build purego appengine +//go:build (purego && !unsafe) || appengine +// +build purego,!unsafe appengine package msgp diff --git a/msgp/read.go b/msgp/read.go index 82501278..0215a5b9 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -143,8 +143,9 @@ type Reader struct { // is stateless; all the // buffering is done // within R. - R *fwd.Reader - scratch []byte + R *fwd.Reader + scratch []byte + recursionDepth int } // Read implements `io.Reader` @@ -190,6 +191,11 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) { return n, io.ErrShortWrite } + if done, err := m.recursiveCall(); err != nil { + return n, err + } else { + defer done() + } // for maps and slices, read elements for x := uintptr(0); x < o; x++ { var n2 int64 @@ -202,6 +208,18 @@ func (m *Reader) CopyNext(w io.Writer) (int64, error) { return n, nil } +// recursiveCall will increment the recursion depth and return an error if it is exceeded. +// If a nil error is returned, done must be called to decrement the counter. +func (m *Reader) recursiveCall() (done func(), err error) { + if m.recursionDepth >= recursionLimit { + return func() {}, ErrRecursion + } + m.recursionDepth++ + return func() { + m.recursionDepth-- + }, nil +} + // ReadFull implements `io.ReadFull` func (m *Reader) ReadFull(p []byte) (int, error) { return m.R.ReadFull(p) @@ -332,7 +350,12 @@ func (m *Reader) Skip() error { return err } - // for maps and slices, skip elements + // for maps and slices, skip elements with recursive call + if done, err := m.recursiveCall(); err != nil { + return err + } else { + defer done() + } for x := uintptr(0); x < o; x++ { err = m.Skip() if err != nil { @@ -1333,6 +1356,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) { return case MapType: + // This can call back here, so treat as recursive call. + if done, err := m.recursiveCall(); err != nil { + return nil, err + } else { + defer done() + } + mp := make(map[string]interface{}) err = m.ReadMapStrIntf(mp) i = mp @@ -1358,6 +1388,13 @@ func (m *Reader) ReadIntf() (i interface{}, err error) { if err != nil { return } + + if done, err := m.recursiveCall(); err != nil { + return nil, err + } else { + defer done() + } + out := make([]interface{}, int(sz)) for j := range out { out[j], err = m.ReadIntf() diff --git a/msgp/read_bytes.go b/msgp/read_bytes.go index a204ac4b..948faf1d 100644 --- a/msgp/read_bytes.go +++ b/msgp/read_bytes.go @@ -1095,6 +1095,15 @@ func ReadTimeBytes(b []byte) (t time.Time, o []byte, err error) { // out of 'b' and returns the map and remaining bytes. // If 'old' is non-nil, the values will be read into that map. func ReadMapStrIntfBytes(b []byte, old map[string]interface{}) (v map[string]interface{}, o []byte, err error) { + return readMapStrIntfBytesDepth(b, old, 0) +} + +func readMapStrIntfBytesDepth(b []byte, old map[string]interface{}, depth int) (v map[string]interface{}, o []byte, err error) { + if depth >= recursionLimit { + err = ErrRecursion + return + } + var sz uint32 o = b sz, o, err = ReadMapHeaderBytes(o) @@ -1123,7 +1132,7 @@ func ReadMapStrIntfBytes(b []byte, old map[string]interface{}) (v map[string]int return } var val interface{} - val, o, err = ReadIntfBytes(o) + val, o, err = readIntfBytesDepth(o, depth) if err != nil { return } @@ -1136,6 +1145,14 @@ func ReadMapStrIntfBytes(b []byte, old map[string]interface{}) (v map[string]int // the next object out of 'b' as a raw interface{} and // return the remaining bytes. func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { + return readIntfBytesDepth(b, 0) +} + +func readIntfBytesDepth(b []byte, depth int) (i interface{}, o []byte, err error) { + if depth >= recursionLimit { + err = ErrRecursion + return + } if len(b) < 1 { err = ErrShortBytes return @@ -1145,7 +1162,7 @@ func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { switch k { case MapType: - i, o, err = ReadMapStrIntfBytes(b, nil) + i, o, err = readMapStrIntfBytesDepth(b, nil, depth+1) return case ArrayType: @@ -1157,7 +1174,7 @@ func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { j := make([]interface{}, int(sz)) i = j for d := range j { - j[d], o, err = ReadIntfBytes(o) + j[d], o, err = readIntfBytesDepth(o, depth+1) if err != nil { return } @@ -1245,7 +1262,15 @@ func ReadIntfBytes(b []byte) (i interface{}, o []byte, err error) { // // - [ErrShortBytes] (not enough bytes in b) // - [InvalidPrefixError] (bad encoding) +// - [ErrRecursion] (too deeply nested data) func Skip(b []byte) ([]byte, error) { + return skipDepth(b, 0) +} + +func skipDepth(b []byte, depth int) ([]byte, error) { + if depth >= recursionLimit { + return b, ErrRecursion + } sz, asz, err := getSize(b) if err != nil { return b, err @@ -1255,7 +1280,7 @@ func Skip(b []byte) ([]byte, error) { } b = b[sz:] for asz > 0 { - b, err = Skip(b) + b, err = skipDepth(b, depth+1) if err != nil { return b, err } diff --git a/msgp/read_test.go b/msgp/read_test.go index 86099c85..6c988408 100644 --- a/msgp/read_test.go +++ b/msgp/read_test.go @@ -2,6 +2,7 @@ package msgp import ( "bytes" + "errors" "fmt" "io" "math" @@ -79,6 +80,130 @@ func TestReadIntf(t *testing.T) { } } +func TestReadIntfRecursion(t *testing.T) { + var buf bytes.Buffer + dec := NewReader(&buf) + enc := NewWriter(&buf) + // Test array recursion... + for i := 0; i < recursionLimit*2; i++ { + enc.WriteArrayHeader(1) + } + enc.Flush() + b := buf.Bytes() + _, err := dec.ReadIntf() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, _, err = ReadIntfBytes(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + // Test JSON + dec.Reset(bytes.NewReader(b)) + _, err = dec.WriteToJSON(io.Discard) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = UnmarshalAsJSON(io.Discard, b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + _, err = CopyToJSON(io.Discard, bytes.NewReader(b)) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + + // Test map recursion... + buf.Reset() + for i := 0; i < recursionLimit*2; i++ { + enc.WriteMapHeader(1) + // Write a key... + enc.WriteString("a") + } + enc.Flush() + b = buf.Bytes() + dec.Reset(bytes.NewReader(b)) + _, err = dec.ReadIntf() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, _, err = ReadIntfBytes(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + + // Test ReadMapStrInt using same input + dec.Reset(bytes.NewReader(b)) + err = dec.ReadMapStrIntf(map[string]interface{}{}) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, _, err = ReadMapStrIntfBytes(b, map[string]interface{}{}) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + + // Test CopyNext + dec.Reset(bytes.NewReader(b)) + _, err = dec.CopyNext(io.Discard) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + + // Test JSON + dec.Reset(bytes.NewReader(b)) + _, err = dec.WriteToJSON(io.Discard) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = UnmarshalAsJSON(io.Discard, b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + _, err = CopyToJSON(io.Discard, bytes.NewReader(b)) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } +} + +func TestSkipRecursion(t *testing.T) { + var buf bytes.Buffer + dec := NewReader(&buf) + enc := NewWriter(&buf) + // Test array recursion... + for i := 0; i < recursionLimit*2; i++ { + enc.WriteArrayHeader(1) + } + enc.Flush() + b := buf.Bytes() + err := dec.Skip() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = Skip(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } + buf.Reset() + + // Test map recursion... + for i := 0; i < recursionLimit*2; i++ { + enc.WriteMapHeader(1) + // Write a key... + enc.WriteString("a") + } + enc.Flush() + b = buf.Bytes() + err = dec.Skip() + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Reader error: %v", err) + } + _, err = Skip(b) + if !errors.Is(err, ErrRecursion) { + t.Errorf("unexpected Bytes error: %v", err) + } +} + func TestReadMapHeader(t *testing.T) { tests := []struct { Sz uint32 diff --git a/msgp/unsafe.go b/msgp/unsafe.go index 06e8d843..7d36bfb1 100644 --- a/msgp/unsafe.go +++ b/msgp/unsafe.go @@ -1,5 +1,5 @@ -//go:build !purego && !appengine -// +build !purego,!appengine +//go:build (!purego && !appengine) || (!appengine && purego && unsafe) +// +build !purego,!appengine !appengine,purego,unsafe package msgp diff --git a/parse/directives.go b/parse/directives.go index ca565171..620589d6 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -34,7 +34,8 @@ var directives = map[string]directive{ // to add an early directive, define a func([]string, *FileSet) error // and then add it to this list. var earlyDirectives = map[string]directive{ - "tag": tag, + "tag": tag, + "pointer": pointer, } var passDirectives = map[string]passDirective{ @@ -120,6 +121,7 @@ func replace(text []string, f *FileSet) error { return err } e := f.parseExpr(expr) + e.AlwaysPtr(&f.pointerRcv) if be, ok := e.(*gen.BaseElem); ok { be.Convert = true @@ -178,3 +180,9 @@ func tag(text []string, f *FileSet) error { f.tagName = strings.TrimSpace(text[1]) return nil } + +//msgp:pointer +func pointer(text []string, f *FileSet) error { + f.pointerRcv = true + return nil +} diff --git a/parse/getast.go b/parse/getast.go index f767601c..ef9782f0 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -22,6 +22,7 @@ type FileSet struct { Directives []string // raw preprocessor directives Imports []*ast.ImportSpec // imports tagName string // tag to read field names from + pointerRcv bool // generate with pointer receivers. } // File parses a file at the relative path @@ -199,6 +200,7 @@ parse: popstate() continue parse } + el.AlwaysPtr(&f.pointerRcv) // push unresolved identities into // the graph of links and resolve after // we've handled every possible named type. diff --git a/printer/print.go b/printer/print.go index e9f0334d..0f31e15c 100644 --- a/printer/print.go +++ b/printer/print.go @@ -42,6 +42,11 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { } err = <-res if err != nil { + os.WriteFile(file+".broken", out.Bytes(), os.ModePerm) + if Logf != nil { + Logf("Error: %s. Wrote broken output to %s\n", err, file+".broken") + } + return err } return nil