Skip to content

Commit

Permalink
Add pointer receiver directive (#357)
Browse files Browse the repository at this point in the history
Adds `//msgp:pointer` file directive. This will generate all functions with pointer receivers.

Tested with base types. Not tested with various other directives, so there may be quirks.

Fixes #332
  • Loading branch information
klauspost authored Aug 23, 2024
1 parent f80292a commit 4ff26d9
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 8 deletions.
198 changes: 198 additions & 0 deletions _generated/pointer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package _generated

import (
"fmt"
"time"

"github.com/tinylib/msgp/msgp"
)

//go:generate msgp $GOFILE$

// Generate only pointer receivers:

//msgp:pointer

var mustNoInterf = []interface{}{
Pointer0{},
NamedBoolPointer(true),
NamedIntPointer(0),
NamedFloat64Pointer(0),
NamedStringPointer(""),
NamedMapStructPointer(nil),
NamedMapStructPointer2(nil),
NamedMapStringPointer(nil),
NamedMapStringPointer2(nil),
EmbeddableStructPointer{},
EmbeddableStruct2Pointer{},
PointerHalfFull{},
PointerNoName{},
}

var mustHaveInterf = []interface{}{
&Pointer0{},
mustPtr(NamedBoolPointer(true)),
mustPtr(NamedIntPointer(0)),
mustPtr(NamedFloat64Pointer(0)),
mustPtr(NamedStringPointer("")),
mustPtr(NamedMapStructPointer(nil)),
mustPtr(NamedMapStructPointer2(nil)),
mustPtr(NamedMapStringPointer(nil)),
mustPtr(NamedMapStringPointer2(nil)),
&EmbeddableStructPointer{},
&EmbeddableStruct2Pointer{},
&PointerHalfFull{},
&PointerNoName{},
}

func mustPtr[T any](v T) *T {
return &v
}

func init() {
for _, v := range mustNoInterf {
if _, ok := v.(msgp.Marshaler); ok {
panic(fmt.Sprintf("type %T supports interface", v))
}
if _, ok := v.(msgp.Encodable); ok {
panic(fmt.Sprintf("type %T supports interface", v))
}
}
for _, v := range mustHaveInterf {
if _, ok := v.(msgp.Marshaler); !ok {
panic(fmt.Sprintf("type %T does not support interface", v))
}
if _, ok := v.(msgp.Encodable); !ok {
panic(fmt.Sprintf("type %T does not support interface", v))
}
}
}

type Pointer0 struct {
ABool bool `msg:"abool"`
AInt int `msg:"aint"`
AInt8 int8 `msg:"aint8"`
AInt16 int16 `msg:"aint16"`
AInt32 int32 `msg:"aint32"`
AInt64 int64 `msg:"aint64"`
AUint uint `msg:"auint"`
AUint8 uint8 `msg:"auint8"`
AUint16 uint16 `msg:"auint16"`
AUint32 uint32 `msg:"auint32"`
AUint64 uint64 `msg:"auint64"`
AFloat32 float32 `msg:"afloat32"`
AFloat64 float64 `msg:"afloat64"`
AComplex64 complex64 `msg:"acomplex64"`
AComplex128 complex128 `msg:"acomplex128"`

ANamedBool bool `msg:"anamedbool"`
ANamedInt int `msg:"anamedint"`
ANamedFloat64 float64 `msg:"anamedfloat64"`

AMapStrStr map[string]string `msg:"amapstrstr"`

APtrNamedStr *NamedString `msg:"aptrnamedstr"`

AString string `msg:"astring"`
ANamedString string `msg:"anamedstring"`
AByteSlice []byte `msg:"abyteslice"`

ASliceString []string `msg:"aslicestring"`
ASliceNamedString []NamedString `msg:"aslicenamedstring"`

ANamedStruct NamedStruct `msg:"anamedstruct"`
APtrNamedStruct *NamedStruct `msg:"aptrnamedstruct"`

AUnnamedStruct struct {
A string `msg:"a"`
} `msg:"aunnamedstruct"` // omitempty not supported on unnamed struct

EmbeddableStruct `msg:",flatten"` // embed flat

EmbeddableStruct2 `msg:"embeddablestruct2"` // embed non-flat

AArrayInt [5]int `msg:"aarrayint"` // not supported

ATime time.Time `msg:"atime"`
}

type (
NamedBoolPointer bool
NamedIntPointer int
NamedFloat64Pointer float64
NamedStringPointer string
NamedMapStructPointer map[string]Pointer0
NamedMapStructPointer2 map[string]*Pointer0
NamedMapStringPointer map[string]NamedStringPointer
NamedMapStringPointer2 map[string]*NamedStringPointer
)

type EmbeddableStructPointer struct {
SomeEmbed string `msg:"someembed"`
}

type EmbeddableStruct2Pointer struct {
SomeEmbed2 string `msg:"someembed2"`
}

type NamedStructPointer struct {
A string `msg:"a"`
B string `msg:"b"`
}

type PointerHalfFull struct {
Field00 string `msg:"field00"`
Field01 string `msg:"field01"`
Field02 string `msg:"field02"`
Field03 string `msg:"field03"`
}

type PointerNoName struct {
ABool bool `msg:""`
AInt int `msg:""`
AInt8 int8 `msg:""`
AInt16 int16 `msg:""`
AInt32 int32 `msg:""`
AInt64 int64 `msg:""`
AUint uint `msg:""`
AUint8 uint8 `msg:""`
AUint16 uint16 `msg:""`
AUint32 uint32 `msg:""`
AUint64 uint64 `msg:""`
AFloat32 float32 `msg:""`
AFloat64 float64 `msg:""`
AComplex64 complex64 `msg:""`
AComplex128 complex128 `msg:""`

ANamedBool bool `msg:""`
ANamedInt int `msg:""`
ANamedFloat64 float64 `msg:""`

AMapStrF map[string]NamedFloat64Pointer `msg:""`
AMapStrStruct map[string]PointerHalfFull `msg:""`
AMapStrStruct2 map[string]*PointerHalfFull `msg:""`

APtrNamedStr *NamedStringPointer `msg:""`

AString string `msg:""`
AByteSlice []byte `msg:""`

ASliceString []string `msg:""`
ASliceNamedString []NamedStringPointer `msg:""`

ANamedStruct NamedStructPointer `msg:""`
APtrNamedStruct *NamedStructPointer `msg:""`

AUnnamedStruct struct {
A string `msg:""`
} `msg:""` // omitempty not supported on unnamed struct

EmbeddableStructPointer `msg:",flatten"` // embed flat

EmbeddableStruct2Pointer `msg:""` // embed non-flat

AArrayInt [5]int `msg:""` // not supported

ATime time.Time `msg:""`
ADur time.Duration `msg:""`
}
14 changes: 13 additions & 1 deletion gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,22 @@ var builtins = map[string]struct{}{
}

// common data/methods for every Elem
type common struct{ vname, alias string }
type common struct {
vname, alias string
ptrRcv bool
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) hidden() {}
func (c *common) AllowNil() bool { return false }
func (c *common) AlwaysPtr(set *bool) bool {
if c != nil && set != nil {
c.ptrRcv = *set
}
return c.ptrRcv
}

func IsPrintable(e Elem) bool {
if be, ok := e.(*BaseElem); ok && !be.Printable() {
Expand Down Expand Up @@ -191,6 +200,9 @@ type Elem interface {
// This is true for slices and maps.
AllowNil() bool

// AlwaysPtr will return true if receiver should always be a pointer.
AlwaysPtr(set *bool) bool

// IfZeroExpr returns the expression to compare to an empty value
// for this type, per the rules of the `omitempty` feature.
// It is meant to be used in an if statement
Expand Down
11 changes: 9 additions & 2 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,16 @@ func (e *encodeGen) Execute(p Elem) error {
e.ctx = &Context{}

e.p.comment("EncodeMsg implements msgp.Encodable")

e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", p.Varname(), imutMethodReceiver(p))
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
if p.AlwaysPtr(nil) {
rcv = methodReceiver(p)
}
e.p.printf("\nfunc (%s %s) EncodeMsg(en *msgp.Writer) (err error) {", ogVar, rcv)
next(e, p)
if p.AlwaysPtr(nil) {
p.SetVarname(ogVar)
}
e.p.nakedReturn()
return e.p.err
}
Expand Down
13 changes: 10 additions & 3 deletions gen/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,18 @@ func (m *marshalGen) Execute(p Elem) error {
// calling methodReceiver so
// that z.Msgsize() is printed correctly
c := p.Varname()

m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", p.Varname(), imutMethodReceiver(p))
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
if p.AlwaysPtr(nil) {
rcv = methodReceiver(p)
}
m.p.printf("\nfunc (%s %s) MarshalMsg(b []byte) (o []byte, err error) {", ogVar, rcv)
m.p.printf("\no = msgp.Require(b, %s.Msgsize())", c)
next(m, p)
if p.AlwaysPtr(nil) {
p.SetVarname(ogVar)
}

m.p.nakedReturn()
return m.p.err
}
Expand Down Expand Up @@ -280,7 +288,6 @@ func (m *marshalGen) gBase(b *BaseElem) {
}
m.fuseHook()
vname := b.Varname()

if b.Convert {
if b.ShimMode == Cast {
vname = tobaseConvert(b)
Expand Down
10 changes: 9 additions & 1 deletion gen/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,17 @@ func (s *sizeGen) Execute(p Elem) error {

s.p.comment("Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message")

s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", p.Varname(), imutMethodReceiver(p))
rcv := imutMethodReceiver(p)
ogVar := p.Varname()
if p.AlwaysPtr(nil) {
rcv = methodReceiver(p)
}
s.p.printf("\nfunc (%s %s) Msgsize() (s int) {", ogVar, rcv)
s.state = assign
next(s, p)
if p.AlwaysPtr(nil) {
p.SetVarname(ogVar)
}
s.p.nakedReturn()
return s.p.err
}
Expand Down
10 changes: 9 additions & 1 deletion parse/directives.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ var directives = map[string]directive{
// to add an early directive, define a func([]string, *FileSet) error
// and then add it to this list.
var earlyDirectives = map[string]directive{
"tag": tag,
"tag": tag,
"pointer": pointer,
}

var passDirectives = map[string]passDirective{
Expand Down Expand Up @@ -120,6 +121,7 @@ func replace(text []string, f *FileSet) error {
return err
}
e := f.parseExpr(expr)
e.AlwaysPtr(&f.pointerRcv)

if be, ok := e.(*gen.BaseElem); ok {
be.Convert = true
Expand Down Expand Up @@ -178,3 +180,9 @@ func tag(text []string, f *FileSet) error {
f.tagName = strings.TrimSpace(text[1])
return nil
}

//msgp:pointer
func pointer(text []string, f *FileSet) error {
f.pointerRcv = true
return nil
}
2 changes: 2 additions & 0 deletions parse/getast.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type FileSet struct {
Directives []string // raw preprocessor directives
Imports []*ast.ImportSpec // imports
tagName string // tag to read field names from
pointerRcv bool // generate with pointer receivers.
}

// File parses a file at the relative path
Expand Down Expand Up @@ -199,6 +200,7 @@ parse:
popstate()
continue parse
}
el.AlwaysPtr(&f.pointerRcv)
// push unresolved identities into
// the graph of links and resolve after
// we've handled every possible named type.
Expand Down
5 changes: 5 additions & 0 deletions printer/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ func PrintFile(file string, f *parse.FileSet, mode gen.Method) error {
}
err = <-res
if err != nil {
os.WriteFile(file+".broken", out.Bytes(), os.ModePerm)
if Logf != nil {
Logf("Error: %s. Wrote broken output to %s\n", err, file+".broken")
}

return err
}
return nil
Expand Down

0 comments on commit 4ff26d9

Please sign in to comment.