From bc51d7f0b7d819740633ab680c902f7f763fce2f Mon Sep 17 00:00:00 2001 From: Vadim Pisarevsky Date: Tue, 16 Mar 2021 18:34:48 +0800 Subject: [PATCH] successfully match 'VariantLabel _' with the proper 'VarianLabel' regardless of the number of attached values: 0, 1, 2, ... --- compiler/Ast.fx | 7 ++ compiler/AstTypeChecker.fx | 176 ++++++++++++++++++++----------------- compiler/KNormalize.fx | 6 +- src/ast_typecheck.ml | 73 ++++++++------- 4 files changed, 146 insertions(+), 116 deletions(-) diff --git a/compiler/Ast.fx b/compiler/Ast.fx index 3f4a52e..237c849 100644 --- a/compiler/Ast.fx +++ b/compiler/Ast.fx @@ -766,6 +766,13 @@ fun get_qualified_name(name: string, sc: scope_t list) = | sc_top :: r => get_qualified_name(name, r) } +// out of 'A.B.C.D' we leave just 'D'. just 'D' stays 'D' +fun get_bare_name(n: id_t): id_t { + val n_str = pp_id2str(n) + val dot_pos = n_str.rfind('.') + get_id(if dot_pos < 0 {n_str} else {n_str[dot_pos+1:]}) +} + fun get_scope(id_info: id_info_t) = match id_info { | IdNone => ScGlobal :: [] diff --git a/compiler/AstTypeChecker.fx b/compiler/AstTypeChecker.fx index 8b6dbf5..1b04561 100644 --- a/compiler/AstTypeChecker.fx +++ b/compiler/AstTypeChecker.fx @@ -503,10 +503,7 @@ fun find_typ_instance(t: typ_t, loc: loc_t): typ_t? { fun get_record_elems(vn_opt: id_t?, t: typ_t, proto_mode: bool, loc: loc_t): (id_t, (id_t, typ_t, lit_t?) list) { val t = deref_typ(t) - val input_vn = match vn_opt { - | Some(vn) => get_id(pp_id2str(vn).split('.').last()) - | _ => noid - } + val input_vn = match vn_opt { | Some(vn) => get_bare_name(vn) | _ => noid } match t { | TypRecord (ref (relems, true)) => (noid, relems) | TypRecord(_) => throw compile_err(loc, "the records, which elements we request, is not finalized") @@ -2296,54 +2293,54 @@ fun check_typ_and_collect_typ_vars(t: typ_t, env: env_t, r_opt_typ_vars: idset_t | TypApp(ty_args, n) => val ty_args = [: for t <- ty_args { check_typ_(t, callb) } :] val ty_args_are_real = ty_args.all(is_real_typ) - val found_typ_opt = find_first(n, r_env, r_env, sc, loc, fun (entry: env_entry_t): typ_t? { - | EnvTyp(t) => - if ty_args.empty() { Some(t) } - else { throw compile_err(loc, f"a concrete type '{pp_id2str(n)}' cannot be further instantiated") } - | EnvId(IdName(_)) => None - | EnvId(i) => - match id_info(i, loc) { - | IdNone | IdDVal(_) | IdFun(_) | IdExn(_) | IdModule(_) => None - | IdInterface(_) => throw compile_err(loc, "classes & interfaces are not supported yet") - | IdTyp(dt) => - val {dt_name, dt_templ_args, dt_typ, dt_scope, dt_finalized, dt_loc} = *dt - if dt_finalized {} - else { - throw compile_err( loc, - f"later declared non-variant type '{pp_id2str(dt_name)}' is referenced; try to reorder the type declarations") - } - if dt_name == n && ty_args.empty() { Some(t) } - else { - val env1 = match_ty_templ_args(ty_args, dt_templ_args, r_env, dt_loc, loc) - Some(check_typ(dup_typ(dt_typ), env1, sc, loc)) - } - | IdVariant(dvar) => - val {dvar_name, dvar_templ_args, dvar_alias, dvar_templ_inst, dvar_loc} = *dvar - Some( - if dvar_templ_args.empty() { - dvar_alias - } else if ty_args_are_real { - val t1 = TypApp(ty_args, dvar_name) - match dvar_templ_inst->find_opt( - fun (inst) { - match id_info(inst, dvar_loc) { - | IdVariant(dvar_inst) => - val {dvar_alias=dvar_inst_alias} = *dvar_inst - maybe_unify(t1, dvar_inst_alias, dvar_loc, true) - | _ => - throw compile_err(loc, f"invalid type of variant instance {id2str(i)} (must be also a variant)") - } - }) { - | Some _ => t1 - | _ => - val (_, inst_app_typ) = instantiate_variant(ty_args, dvar, r_env, sc, loc); - inst_app_typ - } - } else { - TypApp(ty_args, dvar_name) - }) + val found_typ_opt = find_first(n, r_env, r_env, sc, loc, + fun (entry: env_entry_t): typ_t? { + | EnvTyp(t) => + if ty_args.empty() { Some(t) } + else { throw compile_err(loc, f"a concrete type '{pp_id2str(n)}' cannot be further instantiated") } + | EnvId(IdName(_)) => None + | EnvId(i) => + match id_info(i, loc) { + | IdNone | IdDVal(_) | IdFun(_) | IdExn(_) | IdModule(_) => None + | IdInterface(_) => throw compile_err(loc, "classes & interfaces are not supported yet") + | IdTyp(dt) => + val {dt_name, dt_templ_args, dt_typ, dt_scope, dt_finalized, dt_loc} = *dt + if !dt_finalized { + throw compile_err( loc, + f"later declared non-variant type '{pp_id2str(dt_name)}' is referenced; try to reorder the type declarations") } - }) + if dt_name == n && ty_args.empty() { Some(t) } + else { + val env1 = match_ty_templ_args(ty_args, dt_templ_args, r_env, dt_loc, loc) + Some(check_typ(dup_typ(dt_typ), env1, sc, loc)) + } + | IdVariant(dvar) => + val {dvar_name, dvar_templ_args, dvar_alias, dvar_templ_inst, dvar_loc} = *dvar + Some( + if dvar_templ_args.empty() { + dvar_alias + } else if ty_args_are_real { + val t1 = TypApp(ty_args, dvar_name) + match dvar_templ_inst->find_opt( + fun (inst) { + match id_info(inst, dvar_loc) { + | IdVariant(dvar_inst) => + val {dvar_alias=dvar_inst_alias} = *dvar_inst + maybe_unify(t1, dvar_inst_alias, dvar_loc, true) + | _ => + throw compile_err(loc, f"invalid type of variant instance {id2str(i)} (must be also a variant)") + } + }) { + | Some _ => t1 + | _ => + val (_, inst_app_typ) = instantiate_variant(ty_args, dvar, r_env, sc, loc); + inst_app_typ + } + } else { + TypApp(ty_args, dvar_name) + }) + } + }) match (r_opt_typ_vars, ty_args, found_typ_opt) { | (_, _, Some(new_t)) => new_t | (Some(r_typ_vars), [], _) when id2str(n).startswith("'") => @@ -2399,8 +2396,7 @@ fun instantiate_fun_(templ_df: deffun_t ref, inst_ftyp: typ_t, inst_env0: env_t, inst_sc: scope_t list, inst_loc: loc_t, instantiate: bool): deffun_t ref { val {df_name, df_templ_args, df_args, df_body, df_flags, df_scope, df_loc, df_templ_inst} = *templ_df val is_constr = is_fun_ctor(df_flags) - if !is_constr {} - else { + if is_constr { throw compile_err( inst_loc, f"internal error: attempt to instantiate constructor '{id2str(df_name)}'. it should be instantiated in a different way. try to use explicit type specification somewhere") } @@ -2475,7 +2471,8 @@ fun instantiate_fun_body(inst_name: id_t, inst_ftyp: typ_t, inst_args: pat_t lis match pp_id2str(inst_name) { | "__eq_variants__" => match (ftyp, inst_args) { - | (TypFun(TypApp([], n1) :: TypApp([], n2) :: [], TypBool), PatTyped(PatIdent(a, _), _, _) :: PatTyped(PatIdent(b, _), _, _) :: []) + | (TypFun(TypApp([], n1) :: TypApp([], n2) :: [], TypBool), + PatTyped(PatIdent(a, _), _, _) :: PatTyped(PatIdent(b, _), _, _) :: []) when n1 == n2 && (match id_info(n1, inst_loc) { | IdVariant(_) => true | _ => false @@ -2643,6 +2640,16 @@ fun instantiate_variant(ty_args: typ_t list, dvar: defvariant_t ref, env: env_t, (inst_name, inst_app_typ) } +fun get_variant_cases(t: typ_t, loc: loc_t): ((id_t, typ_t) list, id_t list) = + match deref_typ(t) { + | TypApp(_, n) => + match id_info(n, loc) { + | IdVariant (ref {dvar_cases, dvar_ctors}) => (dvar_cases, dvar_ctors) + | _ => ([], []) + } + | _ => ([], []) + } + fun check_pat(pat: pat_t, typ: typ_t, env: env_t, idset: idset_t, typ_vars: idset_t, sc: scope_t list, proto_mode: bool, simple_pat_mode: bool, is_mutable: bool): (pat_t, env_t, idset_t, idset_t, bool) @@ -2651,6 +2658,7 @@ fun check_pat(pat: pat_t, typ: typ_t, env: env_t, idset: idset_t, typ_vars: idse var r_env = env val r_typ_vars = ref typ_vars val captured_val_flags = default_val_flags().{val_flag_mutable=is_mutable} + fun process_id(i, t, loc: loc_t) { val i0 = get_orig_id(i) if r_idset.mem(i0) { @@ -2668,6 +2676,7 @@ fun check_pat(pat: pat_t, typ: typ_t, env: env_t, idset: idset_t, typ_vars: idse j } } + fun check_pat_(p: pat_t, t: typ_t) = match p { | PatAny(_) => (p, false) @@ -2689,10 +2698,23 @@ fun check_pat(pat: pat_t, typ: typ_t, env: env_t, idset: idset_t, typ_vars: idse } (PatTuple(pl_new.rev(), loc), typed) | PatVariant(v, pl, loc) => + val bare_v = get_bare_name(v) if !proto_mode { - /* [TODO] in the ideal case this branch should work fine in the prototype mode as well, - just need to make lookup_id smart enough (maybe add some extra parameters to - avoid preliminary type instantiation) */ + // check special case: | SomeVariantLabel _, + // which we want to match with the proper variant case + // regardless of how many parameters it has + val pl = match pl { + | PatAny(any_loc) :: [] => + match find_opt(for (n, t) <- get_variant_cases(t, loc).0 { + get_orig_id(n) == bare_v}) { + | Some((n, t)) => + match t { + | TypTuple(tl) => [: for _ <- tl {PatAny(any_loc)} :] + | TypVoid => [] + | _ => pl } + | _ => pl } + | _ => pl + } val tl = [: for p <- pl { make_new_typ() } :] val ctyp = match tl { | [] => t | _ => TypFun(tl, t) } val (v_new, _) = lookup_id(v, ctyp, r_env, sc, loc) @@ -2700,30 +2722,25 @@ fun check_pat(pat: pat_t, typ: typ_t, env: env_t, idset: idset_t, typ_vars: idse as explicitly typed, but we set it typed=false for now for simplicity */ (PatVariant(v_new, [: for p <- pl, t <- tl { check_pat_(p, t).0 } :], loc), false) } else { - match deref_typ(t) { - | TypApp(ty_args, n) => - match id_info(n, loc) { - | IdVariant(dv) => - val {dvar_cases} = *dv - if dvar_cases.length() != 1 { - throw compile_err(loc, "a label of multi-case variant may not be used in a formal function parameter") - } else { - val ni = match dvar_cases.assoc_opt(v) { - | Some(TypTuple(tl)) => tl.length() - | Some(TypVoid) => throw compile_err(loc, - f"a variant label '{pp_id2str(v)}' with no arguments may not be used in a formal function parameter") - | Some(_) => 1 - | _ => throw compile_err(loc, f"the variant constructor '{pp_id2str(v)}' is not found") - } - if ni != pl.length() { - throw compile_err( loc, - f"the number of variant pattern arguments does not match to the description of variant case '{pp_id2str(v)}'") + match get_variant_cases(t, loc).0 { + | [] => throw compile_err(loc, "variant pattern is used with non-variant type") + | dvar_cases => + if dvar_cases.length() != 1 { + throw compile_err(loc, "a label of multi-case variant may not be used in a formal function parameter") + } else { + val ni = match dvar_cases.assoc_opt(bare_v) { + | Some(TypTuple(tl)) => tl.length() + | Some(TypVoid) => throw compile_err(loc, + f"a variant label '{pp_id2str(v)}' with no arguments may not be used in a formal function parameter") + | Some(_) => 1 + | _ => throw compile_err(loc, f"the variant constructor '{pp_id2str(v)}' is not found") } - (PatVariant(v, pl, loc), false) + if ni != pl.length() { + throw compile_err( loc, + f"the number of variant pattern arguments does not match to the description of variant case '{pp_id2str(v)}'") } - | _ => throw compile_err(loc, "variant pattern is used with non-variant type") + (PatVariant(v, pl, loc), false) } - | _ => throw compile_err(loc, "variant pattern is used with non-variant type") } } | PatRecord(rn_opt, relems, loc) => @@ -2787,8 +2804,7 @@ fun check_cases(cases: (pat_t list, exp_t) list, inptyp: typ_t, outtyp: typ_t, e if plist1.length() == 1 { (plist1, env1, capt1) } else { - if capt1.empty() {} - else { + if !capt1.empty() { throw compile_err(loc, "captured variables may not be used in the case of multiple alternatives ('|' pattern)") } val temp_id = gen_temp_id("p") diff --git a/compiler/KNormalize.fx b/compiler/KNormalize.fx index 58d4098..ef8fc4a 100644 --- a/compiler/KNormalize.fx +++ b/compiler/KNormalize.fx @@ -615,11 +615,7 @@ fun pat_have_vars(p: pat_t): bool fun get_record_elems_k(vn_opt: id_t?, t: ktyp_t, loc: loc_t) { val t = deref_ktyp(t, loc) - val input_vn = - match vn_opt { - | Some(vn) => get_id(pp_id2str(get_orig_id(vn)).split('.').last()) - | _ => noid - } + val input_vn = match vn_opt { | Some(vn) => get_bare_name(vn) | _ => noid } match t { | KTypRecord(_, relems) => ((noid, t, false), relems) | KTypName(tn) => diff --git a/src/ast_typecheck.ml b/src/ast_typecheck.ml index 6d1f856..70cf5a5 100644 --- a/src/ast_typecheck.ml +++ b/src/ast_typecheck.ml @@ -2330,6 +2330,14 @@ and instantiate_variant ty_args dvar env sc loc = (*printf "variant after instantiation: {{{ "; pprint_exp_x (DefVariant inst_dvar); printf " }}}\n";*) (inst_name, inst_app_typ) +and get_variant_cases t loc = + match (deref_typ t) with + | TypApp(_, n) -> + (match (id_info n) with + | IdVariant {contents={dvar_cases; dvar_ctors}} -> (dvar_cases, dvar_ctors) + | _ -> ([], [])) + | _ -> ([], []) + and check_pat pat typ env idset typ_vars sc proto_mode simple_pat_mode is_mutable = let r_idset = ref idset in let r_env = ref env in @@ -2366,46 +2374,49 @@ and check_pat pat typ env idset typ_vars sc proto_mode simple_pat_mode is_mutabl (PatTuple(List.rev pl_new, loc), typed) | PatVariant(v, pl, loc) -> if not proto_mode then - (* [TODO] in the ideal case this branch should work fine in the prototype mode as well, - just need to make lookup_id smart enough (maybe add some extra parameters to - avoid preliminary type instantiation) *) + let pl = match pl with + | PatAny(any_loc) :: [] -> + let (dvar_cases, _) = get_variant_cases t loc in + (match (List.find_opt (fun (n, t) -> + (get_orig_id n) = (get_orig_id v)) dvar_cases) with + | Some((n, t)) -> (match t with + | TypTuple(tl) -> + List.init (List.length tl) (fun _ -> PatAny(any_loc)) + | _ -> pl) + | _ -> pl) + | _ -> pl + in let tl = List.map (fun p -> make_new_typ()) pl in let ctyp = match tl with [] -> t | _ -> TypFun(tl, t) in - (*let _ = print_env "env @ report_not_found: " env loc in*) let (v_new, _) = lookup_id v ctyp !r_env sc loc in - (*let _ = printf "checking '%s'~'%s' with %d params at %s\n" (id2str v) (id2str v_new) (List.length pl) (loc2str loc) in*) (* in principle, non-template variant with a single case can be considered as explicitly typed, but we set it typed=false for now for simplicity *) (PatVariant(v_new, (List.map2 (fun p t -> let (p, _) = check_pat_ p t in p) pl tl), loc), false) else - (match (deref_typ t) with - | TypApp(ty_args, n) -> - (match (id_info n) with - | IdVariant dv -> - let { dvar_cases } = !dv in - if (List.length dvar_cases) != 1 then - raise_compile_err loc "a label of multi-case variant may not be used in a formal function parameter" - else - (try - let (vi, ti) = List.find (fun (vi, ti) -> vi = v) dvar_cases in - let ni = (match ti with - | TypTuple(tl) -> List.length tl - | TypVoid -> raise_compile_err loc - (sprintf "a variant label '%s' with no arguments may not be used in a formal function parameter" - (pp_id2str vi)) - | _ -> 1) in - if ni != (List.length pl) then - raise_compile_err loc - (sprintf "the number of variant pattern arguments does not match to the description of variant case '%s'" - (pp_id2str vi)) - else (); - (PatVariant(v, pl, loc), false) - with Not_found -> + (match (get_variant_cases t loc) with + | ([], _) -> raise_compile_err loc "variant pattern is used with non-variant type" + | (dvar_cases, _) -> + if (List.length dvar_cases) != 1 then + raise_compile_err loc "a label of multi-case variant may not be used in a formal function parameter" + else + (try + let (vi, ti) = List.find (fun (vi, ti) -> vi = v) dvar_cases in + let ni = (match ti with + | TypTuple(tl) -> List.length tl + | TypVoid -> raise_compile_err loc + (sprintf "a variant label '%s' with no arguments may not be used in a formal function parameter" + (pp_id2str vi)) + | _ -> 1) in + if ni != (List.length pl) then raise_compile_err loc - (sprintf "the variant constructor '%s' is not found" (pp_id2str v))) - | _ -> raise_compile_err loc "variant pattern is used with non-variant type") - | _ -> raise_compile_err loc "variant pattern is used with non-variant type") + (sprintf "the number of variant pattern arguments does not match to the description of variant case '%s'" + (pp_id2str vi)) + else (); + (PatVariant(v, pl, loc), false) + with Not_found -> + raise_compile_err loc + (sprintf "the variant constructor '%s' is not found" (pp_id2str v)))) | PatRecord(rn_opt, relems, loc) -> let (ctor, relems_found) = get_record_elems rn_opt t proto_mode loc in let new_relems = List.fold_left (fun new_relems (n, p) ->