From 4ff26d9fffc560570cc5ae4798b3d55c332cb2ae Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Fri, 23 Aug 2024 01:31:30 -0700 Subject: [PATCH] Add pointer receiver directive (#357) Adds `//msgp:pointer` file directive. This will generate all functions with pointer receivers. Tested with base types. Not tested with various other directives, so there may be quirks. Fixes #332 --- _generated/pointer.go | 198 ++++++++++++++++++++++++++++++++++++++++++ gen/elem.go | 14 ++- gen/encode.go | 11 ++- gen/marshal.go | 13 ++- gen/size.go | 10 ++- parse/directives.go | 10 ++- parse/getast.go | 2 + printer/print.go | 5 ++ 8 files changed, 255 insertions(+), 8 deletions(-) create mode 100644 _generated/pointer.go diff --git a/_generated/pointer.go b/_generated/pointer.go new file mode 100644 index 00000000..9860004d --- /dev/null +++ b/_generated/pointer.go @@ -0,0 +1,198 @@ +package _generated + +import ( + "fmt" + "time" + + "github.com/tinylib/msgp/msgp" +) + +//go:generate msgp $GOFILE$ + +// Generate only pointer receivers: + +//msgp:pointer + +var mustNoInterf = []interface{}{ + Pointer0{}, + NamedBoolPointer(true), + NamedIntPointer(0), + NamedFloat64Pointer(0), + NamedStringPointer(""), + NamedMapStructPointer(nil), + NamedMapStructPointer2(nil), + NamedMapStringPointer(nil), + NamedMapStringPointer2(nil), + EmbeddableStructPointer{}, + EmbeddableStruct2Pointer{}, + PointerHalfFull{}, + PointerNoName{}, +} + +var mustHaveInterf = []interface{}{ + &Pointer0{}, + mustPtr(NamedBoolPointer(true)), + mustPtr(NamedIntPointer(0)), + mustPtr(NamedFloat64Pointer(0)), + mustPtr(NamedStringPointer("")), + mustPtr(NamedMapStructPointer(nil)), + mustPtr(NamedMapStructPointer2(nil)), + mustPtr(NamedMapStringPointer(nil)), + mustPtr(NamedMapStringPointer2(nil)), + &EmbeddableStructPointer{}, + &EmbeddableStruct2Pointer{}, + &PointerHalfFull{}, + &PointerNoName{}, +} + +func mustPtr[T any](v T) *T { + return &v +} + +func init() { + for _, v := range mustNoInterf { + if _, ok := v.(msgp.Marshaler); ok { + panic(fmt.Sprintf("type %T supports interface", v)) + } + if _, ok := v.(msgp.Encodable); ok { + panic(fmt.Sprintf("type %T supports interface", v)) + } + } + for _, v := range mustHaveInterf { + if _, ok := v.(msgp.Marshaler); !ok { + panic(fmt.Sprintf("type %T does not support interface", v)) + } + if _, ok := v.(msgp.Encodable); !ok { + panic(fmt.Sprintf("type %T does not support interface", v)) + } + } +} + +type Pointer0 struct { + ABool bool `msg:"abool"` + AInt int `msg:"aint"` + AInt8 int8 `msg:"aint8"` + AInt16 int16 `msg:"aint16"` + AInt32 int32 `msg:"aint32"` + AInt64 int64 `msg:"aint64"` + AUint uint `msg:"auint"` + AUint8 uint8 `msg:"auint8"` + AUint16 uint16 `msg:"auint16"` + AUint32 uint32 `msg:"auint32"` + AUint64 uint64 `msg:"auint64"` + AFloat32 float32 `msg:"afloat32"` + AFloat64 float64 `msg:"afloat64"` + AComplex64 complex64 `msg:"acomplex64"` + AComplex128 complex128 `msg:"acomplex128"` + + ANamedBool bool `msg:"anamedbool"` + ANamedInt int `msg:"anamedint"` + ANamedFloat64 float64 `msg:"anamedfloat64"` + + AMapStrStr map[string]string `msg:"amapstrstr"` + + APtrNamedStr *NamedString `msg:"aptrnamedstr"` + + AString string `msg:"astring"` + ANamedString string `msg:"anamedstring"` + AByteSlice []byte `msg:"abyteslice"` + + ASliceString []string `msg:"aslicestring"` + ASliceNamedString []NamedString `msg:"aslicenamedstring"` + + ANamedStruct NamedStruct `msg:"anamedstruct"` + APtrNamedStruct *NamedStruct `msg:"aptrnamedstruct"` + + AUnnamedStruct struct { + A string `msg:"a"` + } `msg:"aunnamedstruct"` // omitempty not supported on unnamed struct + + EmbeddableStruct `msg:",flatten"` // embed flat + + EmbeddableStruct2 `msg:"embeddablestruct2"` // embed non-flat + + AArrayInt [5]int `msg:"aarrayint"` // not supported + + ATime time.Time `msg:"atime"` +} + +type ( + NamedBoolPointer bool + NamedIntPointer int + NamedFloat64Pointer float64 + NamedStringPointer string + NamedMapStructPointer map[string]Pointer0 + NamedMapStructPointer2 map[string]*Pointer0 + NamedMapStringPointer map[string]NamedStringPointer + NamedMapStringPointer2 map[string]*NamedStringPointer +) + +type EmbeddableStructPointer struct { + SomeEmbed string `msg:"someembed"` +} + +type EmbeddableStruct2Pointer struct { + SomeEmbed2 string `msg:"someembed2"` +} + +type NamedStructPointer struct { + A string `msg:"a"` + B string `msg:"b"` +} + +type PointerHalfFull struct { + Field00 string `msg:"field00"` + Field01 string `msg:"field01"` + Field02 string `msg:"field02"` + Field03 string `msg:"field03"` +} + +type PointerNoName struct { + ABool bool `msg:""` + AInt int `msg:""` + AInt8 int8 `msg:""` + AInt16 int16 `msg:""` + AInt32 int32 `msg:""` + AInt64 int64 `msg:""` + AUint uint `msg:""` + AUint8 uint8 `msg:""` + AUint16 uint16 `msg:""` + AUint32 uint32 `msg:""` + AUint64 uint64 `msg:""` + AFloat32 float32 `msg:""` + AFloat64 float64 `msg:""` + AComplex64 complex64 `msg:""` + AComplex128 complex128 `msg:""` + + ANamedBool bool `msg:""` + ANamedInt int `msg:""` + ANamedFloat64 float64 `msg:""` + + AMapStrF map[string]NamedFloat64Pointer `msg:""` + AMapStrStruct map[string]PointerHalfFull `msg:""` + AMapStrStruct2 map[string]*PointerHalfFull `msg:""` + + APtrNamedStr *NamedStringPointer `msg:""` + + AString string `msg:""` + AByteSlice []byte `msg:""` + + ASliceString []string `msg:""` + ASliceNamedString []NamedStringPointer `msg:""` + + ANamedStruct NamedStructPointer `msg:""` + APtrNamedStruct *NamedStructPointer `msg:""` + + AUnnamedStruct struct { + A string `msg:""` + } `msg:""` // omitempty not supported on unnamed struct + + EmbeddableStructPointer `msg:",flatten"` // embed flat + + EmbeddableStruct2Pointer `msg:""` // embed non-flat + + AArrayInt [5]int `msg:""` // not supported + + ATime time.Time `msg:""` + ADur time.Duration `msg:""` +} diff --git a/gen/elem.go b/gen/elem.go index c2c84b65..d397cbed 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -134,13 +134,22 @@ var builtins = map[string]struct{}{ } // common data/methods for every Elem -type common struct{ vname, alias string } +type common struct { + vname, alias string + ptrRcv bool +} func (c *common) SetVarname(s string) { c.vname = s } func (c *common) Varname() string { return c.vname } func (c *common) Alias(typ string) { c.alias = typ } func (c *common) hidden() {} func (c *common) AllowNil() bool { return false } +func (c *common) AlwaysPtr(set *bool) bool { + if c != nil && set != nil { + c.ptrRcv = *set + } + return c.ptrRcv +} func IsPrintable(e Elem) bool { if be, ok := e.(*BaseElem); ok && !be.Printable() { @@ -191,6 +200,9 @@ type Elem interface { // This is true for slices and maps. AllowNil() bool + // AlwaysPtr will return true if receiver should always be a pointer. + AlwaysPtr(set *bool) bool + // IfZeroExpr returns the expression to compare to an empty value // for this type, per the rules of the `omitempty` feature. // It is meant to be used in an if statement diff --git a/gen/encode.go b/gen/encode.go index 900c847b..af83e456 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -62,9 +62,16 @@ func (e *encodeGen) Execute(p Elem) error { e.ctx = &Context{} e.p.comment("EncodeMsg implements msgp.Encodable") - - e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", ogVar, rcv) next(e, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } e.p.nakedReturn() return e.p.err } diff --git a/gen/marshal.go b/gen/marshal.go index 66f280eb..5b94ff39 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -47,10 +47,18 @@ func (m *marshalGen) Execute(p Elem) error { // calling methodReceiver so // that z.Msgsize() is printed correctly c := p.Varname() - - m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", ogVar, rcv) m.p.printf("\no = msgp.Require(b, %s.Msgsize())", c) next(m, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } + m.p.nakedReturn() return m.p.err } @@ -280,7 +288,6 @@ func (m *marshalGen) gBase(b *BaseElem) { } m.fuseHook() vname := b.Varname() - if b.Convert { if b.ShimMode == Cast { vname = tobaseConvert(b) diff --git a/gen/size.go b/gen/size.go index e96e0319..3798532f 100644 --- a/gen/size.go +++ b/gen/size.go @@ -86,9 +86,17 @@ func (s *sizeGen) Execute(p Elem) error { s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message") - s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", p.Varname(), imutMethodReceiver(p)) + rcv := imutMethodReceiver(p) + ogVar := p.Varname() + if p.AlwaysPtr(nil) { + rcv = methodReceiver(p) + } + s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ogVar, rcv) s.state = assign next(s, p) + if p.AlwaysPtr(nil) { + p.SetVarname(ogVar) + } s.p.nakedReturn() return s.p.err } diff --git a/parse/directives.go b/parse/directives.go index ca565171..620589d6 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -34,7 +34,8 @@ var directives = map[string]directive{ // to add an early directive, define a func([]string, *FileSet) error // and then add it to this list. var earlyDirectives = map[string]directive{ - "tag": tag, + "tag": tag, + "pointer": pointer, } var passDirectives = map[string]passDirective{ @@ -120,6 +121,7 @@ func replace(text []string, f *FileSet) error { return err } e := f.parseExpr(expr) + e.AlwaysPtr(&f.pointerRcv) if be, ok := e.(*gen.BaseElem); ok { be.Convert = true @@ -178,3 +180,9 @@ func tag(text []string, f *FileSet) error { f.tagName = strings.TrimSpace(text[1]) return nil } + +//msgp:pointer +func pointer(text []string, f *FileSet) error { + f.pointerRcv = true + return nil +} diff --git a/parse/getast.go b/parse/getast.go index f767601c..ef9782f0 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -22,6 +22,7 @@ type FileSet struct { Directives []string // raw preprocessor directives Imports []*ast.ImportSpec // imports tagName string // tag to read field names from + pointerRcv bool // generate with pointer receivers. } // File parses a file at the relative path @@ -199,6 +200,7 @@ parse: popstate() continue parse } + el.AlwaysPtr(&f.pointerRcv) // push unresolved identities into // the graph of links and resolve after // we've handled every possible named type. diff --git a/printer/print.go b/printer/print.go index e9f0334d..0f31e15c 100644 --- a/printer/print.go +++ b/printer/print.go @@ -42,6 +42,11 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { } err = <-res if err != nil { + os.WriteFile(file+".broken", out.Bytes(), os.ModePerm) + if Logf != nil { + Logf("Error: %s. Wrote broken output to %s\n", err, file+".broken") + } + return err } return nil