diff --git a/src/InferredLang/TypeChecker.hs b/src/InferredLang/TypeChecker.hs index 3d76efd..5de175c 100644 --- a/src/InferredLang/TypeChecker.hs +++ b/src/InferredLang/TypeChecker.hs @@ -49,7 +49,6 @@ checkType expr expect tenv = do actual <- typeOf expr tenv unifyTypes expect actual expr - unifyTypes :: Type -> Type -> Expression -> TypeStateTry () unifyTypes typ1 typ2 expr = do subst <- getSubst @@ -68,15 +67,17 @@ unifyTypes typ1 typ2 expr = do unifyTypes' t1@(TypeProc params1 res1) t2@(TypeProc params2 res2) = do paramPairs <- safeZip params1 params2 - unifyAll $ mappend paramPairs [(res1, res2)] + unifyAllTypes (mappend paramPairs [(res1, res2)]) expr where safeZip :: [a] -> [b] -> TypeStateTry [(a, b)] safeZip p1 p2 = if length p1 == length p2 then return $ zip p1 p2 else throwError $ TypeUnifyError t1 t2 expr unifyTypes' t1 t2 = throwError $ TypeUnifyError t1 t2 expr - unifyAll :: [(Type, Type)] -> TypeStateTry () - unifyAll = foldl (\acc (t1, t2) -> acc >> unifyTypes' t1 t2) (return ()) + +unifyAllTypes :: [(Type, Type)] -> Expression -> TypeStateTry () +unifyAllTypes lst expr = foldl func (return ()) lst + where func acc (t1, t2) = acc >> unifyTypes t1 t2 expr nextVar :: TypeStateTry TypeVariable nextVar = do @@ -222,26 +223,31 @@ typeOfCallExpr ratorE argEs tenv = do typeOfLetRecExpr :: [(Maybe Type, String, [(String, Maybe Type)], Expression)] -> Expression -> TypeEnvironment -> TypeResult -typeOfLetRecExpr binds body tenv = undefined - --- typeOfLetRecExpr binds body tenv = do - -- checkAllBinds binds - -- typeOf body bodyEnv - -- where - -- getRecBinds :: [(Maybe Type, String, [(String, Maybe Type)] - -- -> TypeStateTry [(Type, String, [(String, Type)] - -- getRecBinds [] = return [] - -- getRecBinds ((mayResT, name, params, expr) : remain) = do - -- resT <- ensureType mayResT - -- let mayTs = fmap snd params - -- paramTs <- ensureAllTypes mayTs - -- let params' = zip (fmap fst params) paramTs - -- ((resT, name, params') :) <$> getRecBinds remain - -- getBodyEnv :: TypeStateTry TypeEnvironment - -- getBodyEnv = do - -- binds' <- getBodyEnv binds - -- return $ extendMany binds' tenv - -- checkAllBinds [] = return () - -- checkAllBinds ((res, name, params, body) : remain) = - -- checkType body res (extendMany params bodyEnv) >> checkAllBinds remain --- +typeOfLetRecExpr mayBinds body tenv = do + binds <- ensureRecBinds mayBinds + let recBinds = allRecBinds binds + unifyAllRecBinds binds (extendMany recBinds tenv) + subst <- getSubst + let recBinds' = fmap (second (applySubst subst)) recBinds + typeOf body (extendMany recBinds' tenv) + where + ensureRecBind (mayResT, name, mayParams, resBody) = do + params <- ensureAllBinds mayParams + resT <- ensureType mayResT + return (resT, name, params, resBody) + ensureRecBinds = foldr func (return []) + func recBind acc = do + bind <- ensureRecBind recBind + binds <- acc + return $ bind : binds + allRecBinds binds = + let func (t, name, ps, _) = (name, TypeProc (fmap snd ps) t) + in fmap func binds + unifyAllRecBinds :: [(Type, String, [(String, Type)], Expression)] + -> TypeEnvironment + -> TypeStateTry () + unifyAllRecBinds [] _ = return () + unifyAllRecBinds ((resT, _, params, body) : remain) tenv = do + resT' <- typeOf body (extendMany params tenv) + unifyTypes resT resT' body + unifyAllRecBinds remain tenv diff --git a/test/InferredLang/TypeCheckerSuite.hs b/test/InferredLang/TypeCheckerSuite.hs index 6164895..cb64b43 100644 --- a/test/InferredLang/TypeCheckerSuite.hs +++ b/test/InferredLang/TypeCheckerSuite.hs @@ -38,20 +38,94 @@ testOp = TestList testLet :: Test testLet = TestList [ testEq "Type of var" TypeInt "let x = 1 in x" - -- , testEq "Type of letrec expression 1" - -- TypeInt - -- $ unlines - -- [ "letrec int f(x: int, y: int) = 3" - -- , " bool g(x: bool) = x in" - -- , "(f 1 2)" - -- ] - -- , testEq "Type of letrec expression 2" - -- (TypeProc [TypeInt, TypeInt] TypeInt) - -- $ unlines - -- [ "letrec int f(x: int, y: int) = 3" - -- , " bool g(x: bool) = x in" - -- , "f" - -- ] + , testEq "Simple type for applying letrec expression" + TypeInt + "letrec int f(x: int, y: int) = 3 in (f 1 2)" + , testEq "Simple infer type for applying letrec expression" + TypeInt + "letrec ? f(x: ?, y: ?) = +(x, y) in (f 1 2)" + , testEq "Polymorphic infer type for applying letrec expression" + TypeInt + "letrec ? f(x: ?, y: ?) = x in (f 1 2)" + , testEq "Type of function in letrec expression" + (TypeProc [TypeInt] TypeInt) + "letrec int f(x: int) = 3 in f" + , testEq "Type of polymorphic function in letrec expression" + (TypeProc [TypeInt] TypeBool) + "letrec ? f(x: ?) = zero?(x) in f" + , testEq "Type of letrec expression (recursive)" + TypeInt + $ unlines + [ "letrec int double(x: int)" + , " = if zero?(x) then 0 else -((double -(x,1)), -2)" + , "in (double 6)" + ] + , testEq "Inferr type of letrec expression (recursive)" + TypeInt + $ unlines + [ "letrec ? double(x: ?)" + , " = if zero?(x) then 0 else -((double -(x,1)), -2)" + , "in (double 6)" + ] + , testEq "Inferr type of function letrec expression (recursive)" + (TypeProc [TypeInt] TypeInt) + $ unlines + [ "letrec ? double(x: ?)" + , " = if zero?(x) then 0 else -((double -(x,1)), -2)" + , "in double" + ] + , testEq "Inferr type of function letrec expression (recursive)" + (TypeProc [TypeInt] TypeInt) + $ unlines + [ "letrec ? double(x: ?)" + , " = if zero?(x) then 0 else -((double -(x,1)), -2)" + , "in double" + ] + , testEq "Type of letrec with multi parameters" + TypeInt + $ unlines + [ "letrec int double(x: int, dummy: int)" + , " = if zero?(x) then 0 else -((double -(x,1) dummy), -2)" + , "in (double 6 10000)" + ] + , testEq "Infer type of letrec with multi parameters" + TypeInt + $ unlines + [ "letrec ? double(x: ?, dummy: ?)" + , " = if zero?(x) then 0 else -((double -(x,1) dummy), -2)" + , "in (double 6 10000)" + ] + , testEq "Infer type of function with multi parameters in letrec expression" + (TypeProc [TypeInt, TypeVar 1] TypeInt) + $ unlines + [ "letrec ? double(x: ?, dummy: ?)" + , " = if zero?(x) then 0 else -((double -(x,1) dummy), -2)" + , "in double" + ] + , testEq "Type of co-recursion body in letrec expression" + TypeInt + $ unlines + [ "letrec" + , " int even(x: int) = if zero?(x) then 1 else (odd -(x,1))" + , " int odd(x: int) = if zero?(x) then 0 else (even -(x,1))" + , "in (odd 13)" + ] + , testEq "Infer type of co-recursion body in letrec expression" + TypeInt + $ unlines + [ "letrec" + , " ? even(x: ?) = if zero?(x) then 1 else (odd -(x,1))" + , " ? odd(x: ?) = if zero?(x) then 0 else (even -(x,1))" + , "in (odd 13)" + ] + , testEq "Infer type of co-recursion function in letrec expression" + (TypeProc [TypeInt] TypeInt) + $ unlines + [ "letrec" + , " ? even(x: ?) = if zero?(x) then 1 else (odd -(x,1))" + , " ? odd(x: ?) = if zero?(x) then 0 else (even -(x,1))" + , "in odd" + ] ] testProc :: Test