Skip to content

Commit

Permalink
* added support for nested alternating patterns, not just at the top …
Browse files Browse the repository at this point in the history
…level inside match and catch

* added diagnostics about unhandled variant cases
* added "from Math import *" into preamble; updated math functions' implementations to use recently introduced intrinsics
* cleaned up Hashset & Hashmap implementations
  • Loading branch information
vpisarev committed Mar 30, 2021
1 parent 9038c54 commit cafcc57
Show file tree
Hide file tree
Showing 59 changed files with 13,150 additions and 12,556 deletions.
9 changes: 6 additions & 3 deletions compiler/Ast.fx
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ type exp_t =
| ExpDoWhile: (exp_t, exp_t, loc_t)
| ExpFor: ((pat_t, exp_t) list, pat_t, exp_t, for_flags_t, loc_t)
| ExpMap: (((pat_t, exp_t) list, pat_t) list, exp_t, for_flags_t, ctx_t)
| ExpTryCatch: (exp_t, (pat_t list, exp_t) list, ctx_t)
| ExpMatch: (exp_t, (pat_t list, exp_t) list, ctx_t)
| ExpTryCatch: (exp_t, (pat_t, exp_t) list, ctx_t)
| ExpMatch: (exp_t, (pat_t, exp_t) list, ctx_t)
| ExpCast: (exp_t, typ_t, ctx_t)
| ExpTyped: (exp_t, typ_t, ctx_t)
| ExpCCode: (string, ctx_t)
Expand All @@ -329,6 +329,7 @@ type pat_t =
| PatAs: (pat_t, id_t, loc_t)
| PatTyped: (pat_t, typ_t, loc_t)
| PatWhen: (pat_t, exp_t, loc_t)
| PatAlt: (pat_t list, loc_t)
| PatRef: (pat_t, loc_t)

type env_entry_t =
Expand Down Expand Up @@ -703,6 +704,7 @@ fun get_pat_loc(p: pat_t) {
| PatTyped(_, _, l) => l
| PatRef(_, l) => l
| PatWhen(_, _, l) => l
| PatAlt(_, l) => l
}

fun pat_skip_typed(p: pat_t) {
Expand Down Expand Up @@ -1372,7 +1374,7 @@ fun walk_exp(e: exp_t, callb: ast_callb_t) {
fun walk_plist_(pl: pat_t list) = check_n_walk_plist(pl, callb)
fun walk_pe_l_(pe_l: (pat_t, exp_t) list) = [: for (p, e) <- pe_l { (walk_pat_(p), walk_exp_(e)) } :]
fun walk_ne_l_(ne_l: (id_t, exp_t) list) = [: for (n, e) <- ne_l { (n, walk_exp_(e)) } :]
fun walk_cases_(ple_l: (pat_t list, exp_t) list) = [: for (pl, e) <- ple_l { (walk_plist_(pl), walk_exp_(e)) } :]
fun walk_cases_(pe_l: (pat_t, exp_t) list) = [: for (p, e) <- pe_l { (walk_pat_(p), walk_exp_(e)) } :]
fun walk_exp_opt_(e_opt: exp_t?) {
| Some(e) => Some(walk_exp_(e))
| _ => None
Expand Down Expand Up @@ -1463,6 +1465,7 @@ fun walk_pat(p: pat_t, callb: ast_callb_t) {
| PatAs(p, n, loc) => PatAs(walk_pat_(p), n, loc)
| PatTyped(p, t, loc) => PatTyped(walk_pat_(p), walk_typ_(t), loc)
| PatWhen(p, e, loc) => PatWhen(walk_pat_(p), check_n_walk_exp(e, callb), loc)
| PatAlt(pl, loc) => PatAlt(walk_pl_(pl), loc)
| PatRef(p, loc) => PatRef(walk_pat_(p), loc)
}
}
Expand Down
23 changes: 16 additions & 7 deletions compiler/Ast_pp.fx
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,10 @@ fun pprint_for_flags(pp: PP.t, flags: for_flags_t)

fun pprint_exp(pp: PP.t, e: exp_t): void
{
fun ppcases(pe_l: (pat_t list, exp_t) list) {
fun ppcases(pe_l: (pat_t, exp_t) list) {
pp.str("{"); pp.cut(); pp.begin()
for (pl, e) <- pe_l {
for p <- pl {
pp.space(); pp.str("|"); pp.space(); pprint_pat(pp, p)
}
for (p, e) <- pe_l {
pp.space(); pp.str("| "); pprint_pat(pp, p);
pp.space(); pp.str("=>"); pp.space(); pprint_exp_as_seq(pp, e)
}
pp.cut(); pp.end(); pp.str("}")
Expand Down Expand Up @@ -552,11 +550,22 @@ fun pprint_pat(pp: PP.t, p: pat_t)
pp.end()
| PatTyped(p, t, loc) =>
pp.begin(); pppat(p); pp.str(":"); pp.space(); pprint_typ(pp, t, loc); pp.end()
| PatRef(p, loc) =>
| PatRef(p, _) =>
pp.begin(); pp.str("ref ("); pp.cut(); pppat(p); pp.cut(); pp.str(")"); pp.end()
| PatWhen(p, e, loc) =>
| PatWhen(p, e, _) =>
pp.begin(); pppat(p); pp.space(); pp.str("when")
pp.space(); pprint_exp(pp, e); pp.end()
| PatAlt(pl, _) =>
match pl {
| p :: [] => pppat(p)
| _ =>
pp.beginv(0); pp.str("(");
for p@i <- pl {
if i > 0 {pp.space()}
pp.str("| "); pppat(p)
}
pp.str(")"); pp.end()
}
}
pppat(p)
}
Expand Down
66 changes: 21 additions & 45 deletions compiler/Ast_typecheck.fx
Original file line number Diff line number Diff line change
Expand Up @@ -1337,13 +1337,16 @@ fun check_exp(e: exp_t, env: env_t, sc: scope_t list) {
val argtyps = [: for a <- args {get_exp_typ(a)} :]
val fstr = pp(f)
val t = match (fstr, args, argtyps) {
| ("__intrin_atan2__", y :: x :: [], yt :: xt :: []) =>
unify(yt, xt, eloc, "arguments of atan2() must have the same type")
yt
| ("atan2", y :: x :: [], yt :: xt :: []) =>
unify(xt, yt, eloc, "arguments of atan2() must have the same type")
xt
| ("pow", x :: y :: [], xt :: yt :: []) =>
unify(xt, yt, eloc, "arguments of pow() must have the same type")
xt
| (_, x :: [], xt :: []) => xt
| _ =>
val nargs_expected = if fstr == "__intrin_atan2__" {2} else {1}
throw compile_err(eloc, f"incorrect number of arguments in {fstr}, expect {nargs_expected}")
val nargs_expected = if fstr == "atan2" || fstr == "pow" {2} else {1}
throw compile_err(eloc, f"incorrect number of arguments in __intrin_{fstr}__, {nargs_expected} expected")
}
unify(etyp, t, eloc, f"the input and output of {fstr} should have the same types")
val args = [: for a <- args {check_exp(a, env, sc)} :]
Expand Down Expand Up @@ -2644,7 +2647,7 @@ fun instantiate_fun_body(inst_name: id_t, inst_ftyp: typ_t, inst_args: pat_t lis
}
| _ => throw compile_err(inst_loc, "variant is expected here")
}
val fold complex_cases = ([]: (pat_t list, exp_t) list)
val fold complex_cases = ([]: (pat_t, exp_t) list)
for n <- var_ctors, (n_orig, t_orig) <- proto_cases {
val t = deref_typ_rec(t_orig)
match t {
Expand All @@ -2668,7 +2671,7 @@ fun instantiate_fun_body(inst_name: id_t, inst_ftyp: typ_t, inst_args: pat_t lis
val a_case_pat = PatRecord(Some(n), al.rev(), body_loc)
val b_case_pat = PatRecord(Some(n), bl.rev(), body_loc)
val ab_case_pat = PatTuple([: a_case_pat, b_case_pat :], body_loc)
(ab_case_pat :: [], cmp_code) :: complex_cases
(ab_case_pat, cmp_code) :: complex_cases
| _ =>
val args = match t { | TypTuple(tl) => tl | _ => t :: [] }
val fold (al, bl, cmp_code) = ([], [], ExpNop(body_loc))
Expand All @@ -2687,7 +2690,7 @@ fun instantiate_fun_body(inst_name: id_t, inst_ftyp: typ_t, inst_args: pat_t lis
val a_case_pat = PatVariant(n, al.rev(), body_loc)
val b_case_pat = PatVariant(n, bl.rev(), body_loc)
val ab_case_pat = PatTuple([: a_case_pat, b_case_pat :], body_loc)
(ab_case_pat :: [], cmp_code) :: complex_cases
(ab_case_pat, cmp_code) :: complex_cases
}
}
val a = ExpIdent(get_id(astr), (argtyp, body_loc))
Expand All @@ -2700,7 +2703,7 @@ fun instantiate_fun_body(inst_name: id_t, inst_ftyp: typ_t, inst_args: pat_t lis
match complex_cases {
| [] => cmp_tags
| _ =>
val default_case = (PatAny(body_loc) :: [], cmp_tags)
val default_case = (PatAny(body_loc), cmp_tags)
val ab = ExpMkTuple([: a, b :], (TypTuple([: argtyp, argtyp :]), body_loc))
ExpMatch(ab, (default_case :: complex_cases).rev(), (TypBool, body_loc))
}
Expand Down Expand Up @@ -2943,50 +2946,23 @@ fun check_pat(pat: pat_t, typ: typ_t, env: env_t, idset: idset_t, typ_vars: idse
unify(etyp, TypBool, loc, "'when' clause should have boolean type")
val e1 = check_exp(e1, r_env, sc)
(PatWhen(p1, e1, loc), false)
| PatAlt(pl, loc) =>
if simple_pat_mode { throw compile_err(loc, "'|' pattern is not allowed here") }
val pl1 = [: for p <- pl { check_pat_(p, t).0 } :]
(PatAlt(pl1, loc), false)
}
val (pat_new, typed) = check_pat_(pat, typ)
(pat_new, r_env, r_idset, *r_typ_vars, typed)
}

fun check_cases(cases: (pat_t list, exp_t) list, inptyp: typ_t, outtyp: typ_t, env: env_t, sc: scope_t list, loc: loc_t) =
[: for (plist, e) <- cases {
fun check_cases(cases: (pat_t, exp_t) list, inptyp: typ_t, outtyp: typ_t, env: env_t, sc: scope_t list, loc: loc_t) =
[: for (p, e) <- cases {
val case_sc = new_block_scope() :: sc
val fold (plist1, env1, capt1) = ([], env, empty_idset) for p <- plist {
val (p2, env2, capt2, _, _) = check_pat(p, inptyp, env1, capt1, empty_idset,
case_sc, false, false, false)
(p2 :: plist1, env2, capt2)
}

val (plist1, env1) =
if plist1.length() == 1 {
(plist1, env1)
} else {
if !capt1.empty() {
throw compile_err(loc, "captured variables may not be used in the case of multiple alternatives ('|' pattern)")
}
val temp_id = gen_temp_id("p")
val temp_dv = defval_t {dv_name=temp_id, dv_typ=inptyp, dv_flags=default_tempval_flags(), dv_scope=sc, dv_loc=loc}
set_id_entry(temp_id, IdDVal(temp_dv))
//val capt1 = capt1.add(temp_id)
val env1 = add_id_to_env(temp_id, temp_id, env1)
val temp_pat = PatIdent(temp_id, loc)
val temp_id_exp = ExpIdent(temp_id, (inptyp, loc))
val bool_ctx = (TypBool, loc)
val when_cases =
[: for p <- plist1 {
match p {
| PatAny _ | PatIdent(_, _) =>
throw compile_err(loc, "in the case of multiple alternatives ('|' pattern) '_' or indent cannot be used")
| _ => (p :: [], ExpLit(LitBool(true), bool_ctx))
}
} :]
val when_cases = when_cases + [: (PatAny(loc) :: [], ExpLit(LitBool(false), bool_ctx)) :]
val when_pat = PatWhen(temp_pat, ExpMatch(temp_id_exp, when_cases, bool_ctx), loc)
(when_pat :: [], env1)
}
val (p1, env1, _, _, _) = check_pat(p, inptyp, env, empty_idset, empty_idset,
case_sc, false, false, false)
val (e1_typ, e1_loc) = get_exp_ctx(e)
unify(e1_typ, outtyp, e1_loc, "the case expression type does not match the whole expression type (or the type of previous case(s))")
(plist1.rev(), check_exp(e, env1, case_sc))
(p1, check_exp(e, env1, case_sc))
} :]

fun check_mod(m: id_t) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/C_gen_code.fx
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fun find_single_use_vals(topcode: kcode_t)
| AtomId i =>
if decl_const_vals.mem(i) {
val idx = count_map.find_idx_or_insert(i)
count_map._state->table[idx].2 += 1
count_map.r->table[idx].data += 1
}
| _ => {}
}
Expand Down Expand Up @@ -1378,7 +1378,7 @@ fun gen_ccode(cmods: cmodule_t list, kmod: kmodule_t, c_fdecls: ccode_t, mod_ini
| OpBitwiseXor => COpBitwiseXor
| OpCmp(cmpop) => COpCmp(cmpop)
| OpCons | OpPow | OpMod | OpLogicAnd | OpLogicOr | OpSpaceship | OpDotSpaceship
| OpDotMul | OpDotDiv | OpDotMod | OpDotPow | OpDotCmp _ =>
| OpDotMul | OpDotDiv | OpDotMod | OpDotPow | OpDotCmp _ | OpSame =>
throw compile_err(kloc, f"cgen: unsupported op '{bop}' at this stage")
}
match (c_bop, get_cexp_typ(ce1)) {
Expand Down
4 changes: 2 additions & 2 deletions compiler/C_post_rename_locals.fx
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ fun rename_locals(cmods: cmodule_t list)
| IdTemp (i, j) => i
}
val idx = prefix_hash.find_idx_or_insert(prefix)
val j1 = prefix_hash._state->table[idx].2 + 1
prefix_hash._state->table[idx].2 = j1
val j1 = prefix_hash.r->table[idx].data + 1
prefix_hash.r->table[idx].data = j1
val prefix = dynvec_get(all_strings, prefix)
f"{prefix}_{j1}"
}
Expand Down
11 changes: 2 additions & 9 deletions compiler/Compiler.fx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ fun get_preamble(mfname: string): Lexer.token_t list {
val bare_name = Filename.remove_extension(Filename.basename(mfname))
val (preamble, _) = fold (preamble, found) = ([], false)
for (mname, from_import) <- [: ("Builtins", true), ("List", false),
("Char", false), ("String", false) :] {
("Char", false), ("String", false),
("Math", true) :] {
if found {
(preamble, found)
} else if bare_name == mname {
Expand All @@ -37,14 +38,6 @@ fun get_preamble(mfname: string): Lexer.token_t list {
(preamble + [: Lexer.IMPORT(true), Lexer.IDENT(true, mname), Lexer.SEMICOLON :], false)
}
}
val preamble =
if bare_name != "Builtins" { preamble }
else {
// [TODO] insert proper git hash
[: //Lexer.IMPORT(true), Lexer.IDENT(true, "Config"), Lexer.SEMICOLON,
Lexer.VAL, Lexer.IDENT(true, "__ficus_git_commit__"), Lexer.EQUAL,
Lexer.LITERAL(Ast.LitString("123456789")), Lexer.SEMICOLON :] + preamble
}
preamble
} else { [] }
}
Expand Down
8 changes: 4 additions & 4 deletions compiler/K_inline.fx
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ fun inline_some(kmods: kmodule_t list)
match kinfo_(f, loc) {
| KFun _ =>
val idx = all_funcs_info.find_idx_or_insert(f)
var r_fi = all_funcs_info._state->table[idx].2
var r_fi = all_funcs_info.r->table[idx].data
if r_fi->fi_nrefs == -1 {
r_fi = gen_default_func_info(0)
all_funcs_info._state->table[idx].2 = r_fi
all_funcs_info.r->table[idx].data = r_fi
}
r_fi->fi_nrefs += 1
| _ => {}
Expand All @@ -267,10 +267,10 @@ fun inline_some(kmods: kmodule_t list)
val idx = all_funcs_info.find_idx_or_insert(kf_name)
val can_inline = !kf_flags.fun_flag_recursive && !kf_flags.fun_flag_ccode && kf_flags.fun_flag_ctor == CtorNone
val fsize = calc_exp_size(kf_body)
var r_fi = all_funcs_info._state->table[idx].2
var r_fi = all_funcs_info.r->table[idx].data
if r_fi->fi_nrefs == -1 {
r_fi = gen_default_func_info(0)
all_funcs_info._state->table[idx].2 = r_fi
all_funcs_info.r->table[idx].data = r_fi
}
*r_fi = r_fi->{fi_name=kf_name, fi_can_inline=can_inline, fi_size=fsize, fi_flags=kf_flags}
curr_fi = r_fi
Expand Down
38 changes: 24 additions & 14 deletions compiler/K_normalize.fx
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ fun exp2kexp(e: exp_t, code: kcode_t, tref: bool, sc: scope_t list)
val (e1, body_code) = exp2kexp(e1, [], false, try_sc)
val try_body = rcode2kexp(e1 :: body_code, e1loc)
val exn_loc = match cases {
| (p :: _, _) :: _ => get_pat_loc(p)
| (p, _) :: _ => get_pat_loc(p)
| _ => eloc
}
val exn_n = gen_temp_idk("exn")
Expand Down Expand Up @@ -605,6 +605,7 @@ fun pat_have_vars(p: pat_t): bool
| PatRecord(_, ip_l, _) => ip_l.exists(fun ((_, pi)) {pat_have_vars(pi)})
| PatRef(p, _) => pat_have_vars(p)
| PatWhen(p, _, _) => pat_have_vars(p)
| PatAlt(pl, _) => pl.exists(pat_have_vars)
}

/* version of Ast_typecheck.get_record_elems, but for already transformed types */
Expand Down Expand Up @@ -732,6 +733,8 @@ fun pat_need_checks(p: pat_t, ptyp: ktyp_t)
}
pat_need_checks(p, t)
| PatWhen(_, _, _) => true
| PatAlt(pl, _) =>
exists(for pi <- pl {pat_need_checks(pi, ptyp)})
}
}

Expand Down Expand Up @@ -874,7 +877,7 @@ type pat_info_t = {pinfo_p: pat_t; pinfo_typ: ktyp_t; pinfo_e: kexp_t; pinfo_tag

We dynamically maintain 3 lists of the sub-patterns to consider next.
Each new sub-pattern occuring during recursive processing of the top-level pattern
is classified and is then either discarded or added to one of the 3 lists:
is classified and then is either discarded or added into one of the 3 lists:
* pl_c - the patterns that needs some checks to verify, but have no captured variables
* pl_uc - need checks and have variables to capture
* pl_u - need no checks, but have variables to capture.
Expand All @@ -890,9 +893,9 @@ type pat_info_t = {pinfo_p: pat_t; pinfo_typ: ktyp_t; pinfo_e: kexp_t; pinfo_tag
We do such dispatching in order to minimize the number of read operations from a complex structure.
That is, why capture a variable until all the checks are complete and we know we have a match.
The algorithm does not always produce the most optimal sequence of operations
(e.g. some checks are easier to do than the others etc., but it's probably good enough approximation)
(e.g. some checks are easier to do than the others etc.), but probably it's a good-enough approximation
*/
fun transform_pat_matching(a: atom_t, cases: (pat_t list, exp_t) list,
fun transform_pat_matching(a: atom_t, cases: (pat_t, exp_t) list,
code: kcode_t, sc: scope_t list, loc: loc_t, catch_mode: bool)
{
var match_var_cases = empty_idset
Expand Down Expand Up @@ -1121,6 +1124,18 @@ fun transform_pat_matching(a: atom_t, cases: (pat_t list, exp_t) list,
val (ke, code) = exp2kexp(e, code, true, sc)
val c_exp = rcode2kexp(ke :: code, loc)
(([], [], []), c_exp :: checks, [])
| PatAlt(pl, _) =>
if pat_have_vars(p) { throw compile_err(loc, "alt-pattern cannot contain captured values") }
// build alt_checks, a list of expression lists, which are supposed to be combined by || operator.
val fold alt_cases = ([], KExpAtom(AtomLit(KLitBool(false)), (KTypBool, loc)))::[] for p <- pl.rev() {
val pinfo = pat_info_t {pinfo_p=p, pinfo_typ=ptyp, pinfo_e=KExpAtom(AtomId(n), (ptyp, loc)), pinfo_tag=var_tag0}
val plists_ = dispatch_pat(pinfo, ([], [], []))
val (checks_, code_) = process_next_subpat(plists_, ([], []), case_sc)
val e = rcode2kexp(KExpAtom(AtomLit(KLitBool(true)), (KTypBool, loc)) :: code, loc)
(checks_.rev(), e) :: alt_cases
}
val alt_check = rcode2kexp(KExpMatch(alt_cases, (KTypBool, loc)) :: code, loc)
(plists, alt_check :: checks, [])
| _ => throw compile_err(loc, "this type of pattern is not supported yet")
}
process_next_subpat(plists, (checks, code), case_sc)
Expand Down Expand Up @@ -1155,14 +1170,9 @@ fun transform_pat_matching(a: atom_t, cases: (pat_t list, exp_t) list,
}
var have_else = false
val k_cases =
[: for (pl, e) <- cases {
val ncases = pl.length()
val p0 = pl.hd()
val ploc = get_pat_loc(p0)
if ncases != 1 {
throw compile_err(ploc, "multiple alternative patterns are not supported yet")
}
val pinfo = pat_info_t {pinfo_p=p0, pinfo_typ=atyp, pinfo_e=KExpAtom(a, (atyp, loc)), pinfo_tag=var_tag0}
[: for (p, e) <- cases {
val ploc = get_pat_loc(p)
val pinfo = pat_info_t {pinfo_p=p, pinfo_typ=atyp, pinfo_e=KExpAtom(a, (atyp, loc)), pinfo_tag=var_tag0}
if have_else {
throw compile_err(ploc, "unreacheable pattern matching case")
}
Expand All @@ -1175,10 +1185,10 @@ fun transform_pat_matching(a: atom_t, cases: (pat_t list, exp_t) list,
if checks == [] { have_else = true }
(checks.rev(), ke)
} :]
/*if is_variant && !have_else && !match_var_cases.empty() {
if is_variant && !have_else && !match_var_cases.empty() {
val idlist = ", ".join(match_var_cases.map(fun (n) {f"'{n}'"}))
throw compile_err(loc, f"the case(s) {idlist} are not covered; add '| _ => ...' clause to suppress this error")
}*/
}
val k_cases =
if have_else {
k_cases
Expand Down
Loading

0 comments on commit cafcc57

Please sign in to comment.