Skip to content

Commit

Permalink
Merge pull request #61 from nikhita/fix-recursive-types
Browse files Browse the repository at this point in the history
Fix defaulter-gen for recursive types
  • Loading branch information
lavalamp authored Jul 6, 2017
2 parents cc8100b + 96fe5d0 commit d354881
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions examples/defaulter-gen/generators/defaulter.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
if d.object != nil {
continue
}
if buildCallTreeForType(t, true, existingDefaulters, newDefaulters) != nil {
if newCallTreeForType(existingDefaulters, newDefaulters).build(t, true) != nil {
args := defaultingArgsFromType(t)
sw.Do("$.inType|objectdefaultfn$", args)
newDefaulters[t] = defaults{
Expand Down Expand Up @@ -387,7 +387,22 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
return packages
}

// buildCallTreeForType creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
// callTreeForType contains fields necessary to build a tree for types.
type callTreeForType struct {
existingDefaulters defaulterFuncMap
newDefaulters defaulterFuncMap
currentlyBuildingTypes map[*types.Type]bool
}

func newCallTreeForType(existingDefaulters, newDefaulters defaulterFuncMap) *callTreeForType {
return &callTreeForType{
existingDefaulters: existingDefaulters,
newDefaulters: newDefaulters,
currentlyBuildingTypes: make(map[*types.Type]bool),
}
}

// build creates a tree of paths to fields (based on how they would be accessed in Go - pointer, elem,
// slice, or key) and the functions that should be invoked on each field. An in-order traversal of the resulting tree
// can be used to generate a Go function that invokes each nested function on the appropriate type. The return
// value may be nil if there are no functions to call on type or the type is a primitive (Defaulters can only be
Expand All @@ -396,16 +411,16 @@ func Packages(context *generator.Context, arguments *args.GeneratorArgs) generat
// that could be or will be generated. If newDefaulters has an entry for a type, but the 'object' field is nil,
// this function skips adding that defaulter - this allows us to avoid generating object defaulter functions for
// list types that call empty defaulters.
func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefaulters defaulterFuncMap) *callNode {
func (c *callTreeForType) build(t *types.Type, root bool) *callNode {
parent := &callNode{}

if root {
// the root node is always a pointer
parent.elem = true
}

defaults, _ := existingDefaulters[t]
newDefaults, generated := newDefaulters[t]
defaults, _ := c.existingDefaulters[t]
newDefaults, generated := c.newDefaulters[t]
switch {
case !root && generated && newDefaults.object != nil:
parent.call = append(parent.call, newDefaults.object)
Expand All @@ -432,19 +447,33 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
// base has been added already, now add any additional defaulters defined for this object
parent.call = append(parent.call, defaults.additional...)

// if the type already exists, don't build the tree for it and don't generate anything.
// This is used to avoid recursion for nested recursive types.
if c.currentlyBuildingTypes[t] {
return nil
}
// if type doesn't exist, mark it as existing
c.currentlyBuildingTypes[t] = true

defer func() {
// The type will now acts as a parent, not a nested recursive type.
// We can now build the tree for it safely.
c.currentlyBuildingTypes[t] = false
}()

switch t.Kind {
case types.Pointer:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Elem, false); child != nil {
child.elem = true
parent.children = append(parent.children, *child)
}
case types.Slice, types.Array:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Elem, false); child != nil {
child.index = true
parent.children = append(parent.children, *child)
}
case types.Map:
if child := buildCallTreeForType(t.Elem, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Elem, false); child != nil {
child.key = true
parent.children = append(parent.children, *child)
}
Expand All @@ -458,13 +487,13 @@ func buildCallTreeForType(t *types.Type, root bool, existingDefaulters, newDefau
name = field.Type.Name.Name
}
}
if child := buildCallTreeForType(field.Type, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(field.Type, false); child != nil {
child.field = name
parent.children = append(parent.children, *child)
}
}
case types.Alias:
if child := buildCallTreeForType(t.Underlying, false, existingDefaulters, newDefaulters); child != nil {
if child := c.build(t.Underlying, false); child != nil {
parent.children = append(parent.children, *child)
}
}
Expand Down Expand Up @@ -571,7 +600,7 @@ func (g *genDefaulter) GenerateType(c *generator.Context, t *types.Type, w io.Wr

glog.V(5).Infof("generating for type %v", t)

callTree := buildCallTreeForType(t, true, g.existingDefaulters, g.newDefaulters)
callTree := newCallTreeForType(g.existingDefaulters, g.newDefaulters).build(t, true)
if callTree == nil {
glog.V(5).Infof(" no defaulters defined")
return nil
Expand Down

0 comments on commit d354881

Please sign in to comment.