Skip to content

Commit

Permalink
fix: generate __is-aware code for is on unions
Browse files Browse the repository at this point in the history
This expands an `is` applied to a union (even in behind an alias)
into a chain of `or` tests with individual `is` checks on the
types of the union. If the individual entries have __is metamethods,
expands their code as well.

This implementation is not recursive (i.e. it does not handle
unions of unions), but it is a major improvement over the previous
behavior.

Fixes teal-language#742.
  • Loading branch information
hishamhm committed Oct 16, 2024
1 parent 908ce16 commit e240869
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 8 deletions.
118 changes: 118 additions & 0 deletions spec/lang/operator/is_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,124 @@ end]]))
end
end
]]))

it("generates type checks expanding unions (#742)", util.gen([[
global record Foo
bar: string
end
global function repro(x:Foo | string | nil): integer
local y = x
if y is string | Foo then
return 1
elseif y is nil then
return 2
end
return 3
end
]], [[
Foo = {}
function repro(x)
local y = x
if type(y) == "string" or type(y) == "table" then
return 1
elseif y == nil then
return 2
end
return 3
end
]]))

it("generates type checks applying __is to discriminated records in unions", util.gen([[
local interface Type
typename: string
end
local record FooType is Type where self.typename == "foo"
end
local record BarType is Type where self.typename == "bar"
end
global function repro(x:Type | string | nil): integer
local y = x
if y is FooType | BarType then
return 1
elseif y is nil then
return 2
end
return 3
end
]], [[
function repro(x)
local y = x
if y.typename == "foo" or y.typename == "bar" then
return 1
elseif y == nil then
return 2
end
return 3
end
]]))

it("generates type checks applying __is to discriminated records in unions expanding alias", util.gen([[
local interface Type
typename: string
end
local record FooType is Type where self.typename == "foo"
end
local record BarType is Type where self.typename == "bar"
end
local type FooBar = FooType | BarType
global function repro(x:Type | string | nil): integer
local y = x
if y is FooBar then
return 1
elseif y is nil then
return 2
end
return 3
end
]], [[
function repro(x)
local y = x
if y.typename == "foo" or y.typename == "bar" then
return 1
elseif y == nil then
return 2
end
return 3
end
]]))
end)

end)
44 changes: 40 additions & 4 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8063,7 +8063,7 @@ do

local immediate, found = find_nominal_type_decl(self, nom)

if type(immediate) == "table" then
if immediate and (immediate.typename == "invalid" or immediate.typename == "typedecl") then
return immediate
end

Expand Down Expand Up @@ -9670,6 +9670,36 @@ a.types[i], b.types[i]), }
end
end

local function make_is_node(self, var, v, t)
local node = node_at(var, { kind = "op", op = { op = "is", arity = 2, prec = 3 } })
node.e1 = var
node.e2 = node_at(var, { kind = "cast", casttype = self:infer_at(var, t) })
self:check_metamethod(node, "__is", self:to_structural(v), self:to_structural(t), v, t)
if node.expanded then
apply_macroexp(node)
end
node.known = IsFact({ var = var.tk, typ = t, w = node })
return node
end

local function convert_is_of_union_to_or_of_is(self, node, v, u)
local var = node.e1
node.op.op = "or"
node.op.arity = 2
node.op.prec = 1
node.e1 = make_is_node(self, var, v, u.types[1])
local at = node
local n = #u.types
for i = 2, n - 1 do
at.e2 = node_at(var, { kind = "op", op = { op = "or", arity = 2, prec = 1 } })
at.e2.e1 = make_is_node(self, var, v, u.types[i])
node.known = OrFact({ f1 = at.e1.known, f2 = at.e2.known, w = node })
at = at.e2
end
at.e2 = make_is_node(self, var, v, u.types[n])
node.known = OrFact({ f1 = at.e1.known, f2 = at.e2.known, w = node })
end

function TypeChecker:match_record_key(tbl, rec, key)
assert(type(tbl) == "table")
assert(type(rec) == "table")
Expand Down Expand Up @@ -12320,9 +12350,15 @@ self:expand_type(node, values, elements) })
if rb.typename == "integer" then
self.all_needs_compat["math"] = true
end
if node.e1.kind == "variable" then
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact({ var = node.e1.tk, typ = ub, w = node })
if ra.typename == "typedecl" then
self.errs:add(node, "can only use 'is' on variables, not types")
elseif node.e1.kind == "variable" then
if rb.typename == "union" then
convert_is_of_union_to_or_of_is(self, node, ra, rb)
else
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact({ var = node.e1.tk, typ = ub, w = node })
end
else
self.errs:add(node, "can only use 'is' on variables")
end
Expand Down
44 changes: 40 additions & 4 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -8063,7 +8063,7 @@ do

local immediate, found = find_nominal_type_decl(self, nom)
-- if it was previously resolved (or a circular require, or an error), return that;
if immediate is InvalidOrTypeDeclType then
if immediate and immediate is InvalidOrTypeDeclType then
return immediate
end

Expand Down Expand Up @@ -9670,6 +9670,36 @@ do
end
end

local function make_is_node(self: TypeChecker, var: Node, v: Type, t: Type): Node
local node = node_at(var, { kind = "op", op = { op = "is", arity = 2, prec = 3 } })
node.e1 = var
node.e2 = node_at(var, { kind = "cast", casttype = self:infer_at(var, t) })
self:check_metamethod(node, "__is", self:to_structural(v), self:to_structural(t), v, t)
if node.expanded then
apply_macroexp(node)
end
node.known = IsFact { var = var.tk, typ = t, w = node }
return node
end

local function convert_is_of_union_to_or_of_is(self: TypeChecker, node: Node, v: Type, u: UnionType)
local var = node.e1
node.op.op = "or"
node.op.arity = 2
node.op.prec = 1
node.e1 = make_is_node(self, var, v, u.types[1])
local at = node
local n = #u.types
for i = 2, n - 1 do
at.e2 = node_at(var, { kind = "op", op = { op = "or", arity = 2, prec = 1 } })
at.e2.e1 = make_is_node(self, var, v, u.types[i])
node.known = OrFact { f1 = at.e1.known, f2 = at.e2.known, w = node }
at = at.e2
end
at.e2 = make_is_node(self, var, v, u.types[n])
node.known = OrFact { f1 = at.e1.known, f2 = at.e2.known, w = node }
end

function TypeChecker:match_record_key(tbl: Type, rec: Node, key: string): Type, string
assert(type(tbl) == "table")
assert(type(rec) == "table")
Expand Down Expand Up @@ -12320,9 +12350,15 @@ do
if rb.typename == "integer" then
self.all_needs_compat["math"] = true
end
if node.e1.kind == "variable" then
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact { var = node.e1.tk, typ = ub, w = node }
if ra is TypeDeclType then
self.errs:add(node, "can only use 'is' on variables, not types")
elseif node.e1.kind == "variable" then
if rb is UnionType then
convert_is_of_union_to_or_of_is(self, node, ra, rb)
else
self:check_metamethod(node, "__is", ra, resolve_typedecl(rb), ua, ub)
node.known = IsFact { var = node.e1.tk, typ = ub, w = node }
end
else
self.errs:add(node, "can only use 'is' on variables")
end
Expand Down

0 comments on commit e240869

Please sign in to comment.