Skip to content

Commit

Permalink
Allow enum unions (#84)
Browse files Browse the repository at this point in the history
* Allow enum unions

* Little refactor

* Avoid to change the order of the generate elements

* Add nested enums
  • Loading branch information
spinillos authored Mar 30, 2023
1 parent 4950d4e commit 325fc6c
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 53 deletions.
163 changes: 112 additions & 51 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ func (g *generator) genType(name string, v cue.Value) []ts.Decl {
switch op {
case cue.OrOp:
for _, dv := range dvals {
tok, err := tsprintField(dv, true)
tok, err := g.tsprintField(dv, true)
if err != nil {
g.addErr(err)
return nil
}
tokens = append(tokens, tok)
}
case cue.NoOp, cue.RegexMatchOp:
tok, err := tsprintField(v, true)
tok, err := g.tsprintField(v, true)
if err != nil {
g.addErr(err)
return nil
Expand All @@ -246,7 +246,7 @@ func (g *generator) genType(name string, v cue.Value) []ts.Decl {
return ret[:1]
}

val, err := tsprintField(d, false)
val, err := g.tsprintField(d, false)
g.addErr(err)

def := tsast.VarDecl{
Expand Down Expand Up @@ -659,13 +659,13 @@ func (g *generator) genInterfaceField(v cue.Value) (*typeRef, error) {
var err error
// One path for when there's a ref to a cuetsy node, and a separate one otherwise
if !containsCuetsyReference(v) {
tref.T, err = tsprintField(v, true)
tref.T, err = g.tsprintField(v, true)
if err != nil {
g.addErr(valError(v, "could not generate field: %w", err))
return nil, err
}
} else {
expr, err := tsprintField(v, true)
expr, err := g.tsprintField(v, true)
if err != nil {
g.addErr(err)
return nil, nil
Expand Down Expand Up @@ -694,7 +694,7 @@ func (g *generator) genInterfaceField(v cue.Value) (*typeRef, error) {
// g.addErr(valError(v, "invalid value"))
// return nil
// }
// expr, err = tsprintField(disjuncts[0])
// expr, err = g.tsprintField(disjuncts[0])
// if err != nil {
// g.addErr(valError(v, "invalid value"))
// return nil
Expand All @@ -705,7 +705,7 @@ func (g *generator) genInterfaceField(v cue.Value) (*typeRef, error) {
// }
}

exists, defExpr, err := tsPrintDefault(v)
exists, defExpr, err := g.tsPrintDefault(v)
if exists {
tref.D = defExpr
}
Expand Down Expand Up @@ -749,30 +749,15 @@ func hasEnumReference(v cue.Value) bool {
func (g *generator) genEnumReference(v cue.Value) (*typeRef, error) {
var lit *cue.Value

findIdent := func(ev, tv cue.Value) (*tsast.Ident, error) {
if ev.Subsume(tv) != nil {
err := valError(v, "may only apply values to an enum that are members of that enum; %#v is not a member of %#v", tv, ev)
g.addErr(err)
return nil, err
}
pairs, err := enumPairs(ev)
if err != nil {
return nil, err
}
for _, pair := range pairs {
if veq(pair.val, tv) {
return &tsast.Ident{Name: pair.name}, nil
}
}

panic(fmt.Sprintf("unreachable - %#v not equal to any member of %#v, but should have been caught by subsume check", tv, ev))
}

conjuncts := appendSplit(nil, cue.AndOp, v)
var enumUnions map[cue.Value]cue.Value
switch len(conjuncts) {
case 0:
panic("unreachable")
case 1:
// This case is when we have a union of enums which we need to iterate them to get their values or has a default value.
// It retrieves a list of literals with their references.
enumUnions = g.findEnumUnions(v)
case 2:
var err error
conjuncts[1] = getDefaultEnumValue(conjuncts[1])
Expand Down Expand Up @@ -809,15 +794,15 @@ func (g *generator) genEnumReference(v cue.Value) (*typeRef, error) {
// Search the expr tree for the actual enum. This approach is uncomfortable
// without having the assurance that there aren't more than one possible match/a
// guarantee from the CUE API of a stable, deterministic search order, etc.
ev, referrer, has := findRefWithKind(v, TypeEnum)
enumValues, referrer, has := findRefWithKind(v, TypeEnum)
if !has {
ve := valError(v, "does not reference a field with a cuetsy enum attribute")
g.addErr(ve)
return nil, fmt.Errorf("no enum attr in %s", v)
}

var err error
decls := g.genEnum("foo", ev)
decls := g.genEnum("foo", enumValues)
ref := &typeRef{}

// Construct the type component of the reference
Expand All @@ -827,7 +812,7 @@ func (g *generator) genEnumReference(v cue.Value) (*typeRef, error) {
g.addErr(ve)
return nil, ve
case 1, 2:
ref.T, err = referenceValueAs(referrer)
ref.T, err = referenceValueAs(referrer, TypeEnum)
if err != nil {
panic(err)
}
Expand All @@ -838,19 +823,34 @@ func (g *generator) genEnumReference(v cue.Value) (*typeRef, error) {
switch len(conjuncts) {
case 1:
if defv, hasdef := v.Default(); hasdef {
if defaultIdent, err := findIdent(ev, defv); err == nil {
ref.D = tsast.SelectorExpr{Expr: ref.T, Sel: *defaultIdent}
} else {
return nil, err
}
err = g.findIdent(v, enumValues, defv, func(expr tsast.Ident) {
ref.D = tsast.SelectorExpr{Expr: ref.T, Sel: expr}
})
}
if len(enumUnions) == 0 {
break
}
var elements []tsast.Expr
for lit, enumValues := range enumUnions {
err = g.findIdent(v, enumValues, lit, func(ident tsast.Ident) {
elements = append(elements, tsast.SelectorExpr{
Expr: ref.T,
Sel: ident,
})
})
}

// To avoid to change the order of the elements everytime that we generate the code.
sort.Slice(elements, func(i, j int) bool {
return elements[i].String() < elements[j].String()
})

ref.T = ts.Union(elements...)
case 2, 3:
var rr tsast.Expr
if defaultIdent, err := findIdent(ev, *lit); err == nil {
rr = tsast.SelectorExpr{Expr: ref.T, Sel: *defaultIdent}
} else {
return nil, err
}
err = g.findIdent(v, enumValues, *lit, func(ident tsast.Ident) {
rr = tsast.SelectorExpr{Expr: ref.T, Sel: ident}
})

op, args := v.Expr()
hasInnerDefault := false
Expand All @@ -865,7 +865,64 @@ func (g *generator) genEnumReference(v cue.Value) (*typeRef, error) {
}
}

return ref, nil
return ref, err
}

// findEnumUnions find the unions between enums like (#Enum & "a") | (#Enum & "b")
func (g generator) findEnumUnions(v cue.Value) map[cue.Value]cue.Value {
op, values := v.Expr()
if op != cue.OrOp {
return nil
}

enumsWithUnions := make(map[cue.Value]cue.Value, len(values))
for _, val := range values {
conjuncts := appendSplit(nil, cue.AndOp, val)
if len(conjuncts) != 2 {
return nil
}
cr, lit := conjuncts[0], conjuncts[1]
if cr.Subsume(lit) != nil {
return nil
}

switch val.Kind() {
case cue.StringKind, cue.IntKind:
enumValues, _, has := findRefWithKind(v, TypeEnum)
if !has {
return nil
}
enumsWithUnions[lit] = enumValues
default:
_, vals := val.Expr()
if len(vals) > 1 {
panic(fmt.Sprintf("%s.%s isn't a valid enum value", val.Path().String(), vals[1]))
}
panic(fmt.Sprintf("Invalid value in path %s", val.Path().String()))
}
}

return enumsWithUnions
}

func (g generator) findIdent(v, ev, tv cue.Value, fn func(tsast.Ident)) error {
if ev.Subsume(tv) != nil {
err := valError(v, "may only apply values to an enum that are members of that enum; %#v is not a member of %#v", tv, ev)
g.addErr(err)
return err
}
pairs, err := enumPairs(ev)
if err != nil {
return err
}
for _, pair := range pairs {
if veq(pair.val, tv) {
fn(tsast.Ident{Name: pair.name})
return nil
}
}

panic(fmt.Sprintf("unreachable - %#v not equal to any member of %#v, but should have been caught by subsume check", tv, ev))
}

func getEnumLiteral(conjuncts []cue.Value) (*cue.Value, error) {
Expand Down Expand Up @@ -924,7 +981,7 @@ type typeRef struct {
D ts.Expr
}

func tsPrintDefault(v cue.Value) (bool, ts.Expr, error) {
func (g generator) tsPrintDefault(v cue.Value) (bool, ts.Expr, error) {
d, ok := v.Default()
// [...number] results in [], which is a fake default, we need to correct it here.
// if ok && d.Kind() == cue.ListKind {
Expand Down Expand Up @@ -953,7 +1010,7 @@ func tsPrintDefault(v cue.Value) (bool, ts.Expr, error) {
// }

if ok {
expr, err := tsprintField(d, false)
expr, err := g.tsprintField(d, false)
if err != nil {
return false, nil, err
}
Expand All @@ -978,14 +1035,19 @@ func tsPrintDefault(v cue.Value) (bool, ts.Expr, error) {

// Render a string containing a Typescript semantic equivalent to the provided
// Value for placement in a single field, if possible.
func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
func (g generator) tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
// Let the forceText attribute supersede everything.
if ft := getForceText(v); ft != "" {
return ts.Raw(ft), nil
}

if hasEnumReference(v) {
ref, err := g.genEnumReference(v)
return ref.T, err
}

// References are orthogonal to the Kind system. Handle them first.
if containsCuetsyReference(v, TypeAlias, TypeInterface) || hasEnumReference(v) {
if containsCuetsyReference(v, TypeAlias, TypeInterface) {
ref, err := referenceValueAs(v)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1015,7 +1077,7 @@ func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
// It skips structs like {...} (cue.TopKind) to avoid undesired results.
val := v.LookupPath(cue.MakePath(cue.AnyString))
if val.Exists() && val.IncompleteKind() != cue.TopKind {
expr, err := tsprintField(val, isType)
expr, err := g.tsprintField(val, isType)
if err != nil {
return nil, valError(v, err.Error())
}
Expand All @@ -1036,7 +1098,7 @@ func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
size, _ := v.Len().Int64()
kvs := make([]tsast.KeyValueExpr, 0, size)
for iter.Next() {
expr, err := tsprintField(iter.Value(), isType)
expr, err := g.tsprintField(iter.Value(), isType)
if err != nil {
return nil, valError(v, err.Error())
}
Expand Down Expand Up @@ -1066,7 +1128,7 @@ func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
iter, _ := v.List()
var elems []ts.Expr
for iter.Next() {
e, err := tsprintField(iter.Value(), isType)
e, err := g.tsprintField(iter.Value(), isType)
if err != nil {
return nil, err
}
Expand All @@ -1078,12 +1140,11 @@ func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
case cue.BytesKind:
return nil, valError(v, "bytes have no equivalent in Typescript; use double-quotes (string) instead")
}

// Handler for disjunctions
disj := func(dvals []cue.Value) (ts.Expr, error) {
parts := make([]ts.Expr, 0, len(dvals))
for _, dv := range dvals {
p, err := tsprintField(dv, isType)
p, err := g.tsprintField(dv, isType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1132,7 +1193,7 @@ func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {

e := v.LookupPath(cue.MakePath(cue.AnyIndex))
if e.Exists() {
expr, err := tsprintField(e, isType)
expr, err := g.tsprintField(e, isType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1171,7 +1232,7 @@ func tsprintField(v cue.Value, isType bool) (ts.Expr, error) {
switch op {
case cue.OrOp:
if len(dvals) == 2 && dvals[0].Kind() == cue.NullKind {
return tsprintField(dvals[1], isType)
return g.tsprintField(dvals[1], isType)
}
return disj(dvals)
case cue.NoOp, cue.AndOp:
Expand Down
4 changes: 2 additions & 2 deletions testdata/imports/compose_enums.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ export interface Compose {
localstr: LocalEnum;
localstrd: LocalEnumD;
localstrover: LocalEnumD;
union: dep.DepEnumNumericD;
unionStrings: LocalEnum;
union: (dep.DepEnumNumericD.Three | dep.DepEnumNumericD.Two);
unionStrings: (LocalEnum.Bar | LocalEnum.Baz | LocalEnum.Foo);
}

export const defaultCompose: Partial<Compose> = {
Expand Down
47 changes: 47 additions & 0 deletions testdata/union_enum_types.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
-- cue --

#StringEnum: "a" | "b" | "c" @cuetsy(kind="enum")
#StringEnumWithMemberTypes: "a" | "b" | "c" @cuetsy(kind="enum",memberNames="First|Second|Third")
#IntEnum: 1 | 2 | 3 @cuetsy(kind="enum",memberNames="First|Second|Third")

#Expressions: {
sEnum: (#StringEnum & "a") | (#StringEnum & "b")
sEnumMem: (#StringEnumWithMemberTypes & "a") | (#StringEnumWithMemberTypes & "b")
iEnum: (#IntEnum & 1) | (#IntEnum & 2) | (#IntEnum & 3)
normal: #StringEnum & "a"
nested: {
nestedEnum: #StringEnum & "a"
nestedUnionEnum: (#StringEnum & "a") | (#StringEnum & "b")
}
} @cuetsy(kind="interface")

-- ts --

export enum StringEnum {
A = 'a',
B = 'b',
C = 'c',
}

export enum StringEnumWithMemberTypes {
First = 'a',
Second = 'b',
Third = 'c',
}

export enum IntEnum {
First = 1,
Second = 2,
Third = 3,
}

export interface Expressions {
iEnum: (IntEnum.First | IntEnum.Second | IntEnum.Third);
nested: {
nestedEnum: StringEnum.A;
nestedUnionEnum: (StringEnum.A | StringEnum.B);
};
normal: StringEnum.A;
sEnum: (StringEnum.A | StringEnum.B);
sEnumMem: (StringEnumWithMemberTypes.First | StringEnumWithMemberTypes.Second);
}

0 comments on commit 325fc6c

Please sign in to comment.