Skip to content

Commit

Permalink
add letrec to type inferred lang
Browse files Browse the repository at this point in the history
  • Loading branch information
Kurisu.Ti.Na committed Sep 30, 2016
1 parent e0027c1 commit 05dfa7f
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 41 deletions.
60 changes: 33 additions & 27 deletions src/InferredLang/TypeChecker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ checkType expr expect tenv = do
actual <- typeOf expr tenv
unifyTypes expect actual expr


unifyTypes :: Type -> Type -> Expression -> TypeStateTry ()
unifyTypes typ1 typ2 expr = do
subst <- getSubst
Expand All @@ -68,15 +67,17 @@ unifyTypes typ1 typ2 expr = do
unifyTypes' t1@(TypeProc params1 res1)
t2@(TypeProc params2 res2) = do
paramPairs <- safeZip params1 params2
unifyAll $ mappend paramPairs [(res1, res2)]
unifyAllTypes (mappend paramPairs [(res1, res2)]) expr
where
safeZip :: [a] -> [b] -> TypeStateTry [(a, b)]
safeZip p1 p2 = if length p1 == length p2
then return $ zip p1 p2
else throwError $ TypeUnifyError t1 t2 expr
unifyTypes' t1 t2 = throwError $ TypeUnifyError t1 t2 expr
unifyAll :: [(Type, Type)] -> TypeStateTry ()
unifyAll = foldl (\acc (t1, t2) -> acc >> unifyTypes' t1 t2) (return ())

unifyAllTypes :: [(Type, Type)] -> Expression -> TypeStateTry ()
unifyAllTypes lst expr = foldl func (return ()) lst
where func acc (t1, t2) = acc >> unifyTypes t1 t2 expr

nextVar :: TypeStateTry TypeVariable
nextVar = do
Expand Down Expand Up @@ -222,26 +223,31 @@ typeOfCallExpr ratorE argEs tenv = do
typeOfLetRecExpr :: [(Maybe Type, String, [(String, Maybe Type)], Expression)]
-> Expression -> TypeEnvironment
-> TypeResult
typeOfLetRecExpr binds body tenv = undefined

-- typeOfLetRecExpr binds body tenv = do
-- checkAllBinds binds
-- typeOf body bodyEnv
-- where
-- getRecBinds :: [(Maybe Type, String, [(String, Maybe Type)]
-- -> TypeStateTry [(Type, String, [(String, Type)]
-- getRecBinds [] = return []
-- getRecBinds ((mayResT, name, params, expr) : remain) = do
-- resT <- ensureType mayResT
-- let mayTs = fmap snd params
-- paramTs <- ensureAllTypes mayTs
-- let params' = zip (fmap fst params) paramTs
-- ((resT, name, params') :) <$> getRecBinds remain
-- getBodyEnv :: TypeStateTry TypeEnvironment
-- getBodyEnv = do
-- binds' <- getBodyEnv binds
-- return $ extendMany binds' tenv
-- checkAllBinds [] = return ()
-- checkAllBinds ((res, name, params, body) : remain) =
-- checkType body res (extendMany params bodyEnv) >> checkAllBinds remain
--
typeOfLetRecExpr mayBinds body tenv = do
binds <- ensureRecBinds mayBinds
let recBinds = allRecBinds binds
unifyAllRecBinds binds (extendMany recBinds tenv)
subst <- getSubst
let recBinds' = fmap (second (applySubst subst)) recBinds
typeOf body (extendMany recBinds' tenv)
where
ensureRecBind (mayResT, name, mayParams, resBody) = do
params <- ensureAllBinds mayParams
resT <- ensureType mayResT
return (resT, name, params, resBody)
ensureRecBinds = foldr func (return [])
func recBind acc = do
bind <- ensureRecBind recBind
binds <- acc
return $ bind : binds
allRecBinds binds =
let func (t, name, ps, _) = (name, TypeProc (fmap snd ps) t)
in fmap func binds
unifyAllRecBinds :: [(Type, String, [(String, Type)], Expression)]
-> TypeEnvironment
-> TypeStateTry ()
unifyAllRecBinds [] _ = return ()
unifyAllRecBinds ((resT, _, params, body) : remain) tenv = do
resT' <- typeOf body (extendMany params tenv)
unifyTypes resT resT' body
unifyAllRecBinds remain tenv
102 changes: 88 additions & 14 deletions test/InferredLang/TypeCheckerSuite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,94 @@ testOp = TestList
testLet :: Test
testLet = TestList
[ testEq "Type of var" TypeInt "let x = 1 in x"
-- , testEq "Type of letrec expression 1"
-- TypeInt
-- $ unlines
-- [ "letrec int f(x: int, y: int) = 3"
-- , " bool g(x: bool) = x in"
-- , "(f 1 2)"
-- ]
-- , testEq "Type of letrec expression 2"
-- (TypeProc [TypeInt, TypeInt] TypeInt)
-- $ unlines
-- [ "letrec int f(x: int, y: int) = 3"
-- , " bool g(x: bool) = x in"
-- , "f"
-- ]
, testEq "Simple type for applying letrec expression"
TypeInt
"letrec int f(x: int, y: int) = 3 in (f 1 2)"
, testEq "Simple infer type for applying letrec expression"
TypeInt
"letrec ? f(x: ?, y: ?) = +(x, y) in (f 1 2)"
, testEq "Polymorphic infer type for applying letrec expression"
TypeInt
"letrec ? f(x: ?, y: ?) = x in (f 1 2)"
, testEq "Type of function in letrec expression"
(TypeProc [TypeInt] TypeInt)
"letrec int f(x: int) = 3 in f"
, testEq "Type of polymorphic function in letrec expression"
(TypeProc [TypeInt] TypeBool)
"letrec ? f(x: ?) = zero?(x) in f"
, testEq "Type of letrec expression (recursive)"
TypeInt
$ unlines
[ "letrec int double(x: int)"
, " = if zero?(x) then 0 else -((double -(x,1)), -2)"
, "in (double 6)"
]
, testEq "Inferr type of letrec expression (recursive)"
TypeInt
$ unlines
[ "letrec ? double(x: ?)"
, " = if zero?(x) then 0 else -((double -(x,1)), -2)"
, "in (double 6)"
]
, testEq "Inferr type of function letrec expression (recursive)"
(TypeProc [TypeInt] TypeInt)
$ unlines
[ "letrec ? double(x: ?)"
, " = if zero?(x) then 0 else -((double -(x,1)), -2)"
, "in double"
]
, testEq "Inferr type of function letrec expression (recursive)"
(TypeProc [TypeInt] TypeInt)
$ unlines
[ "letrec ? double(x: ?)"
, " = if zero?(x) then 0 else -((double -(x,1)), -2)"
, "in double"
]
, testEq "Type of letrec with multi parameters"
TypeInt
$ unlines
[ "letrec int double(x: int, dummy: int)"
, " = if zero?(x) then 0 else -((double -(x,1) dummy), -2)"
, "in (double 6 10000)"
]
, testEq "Infer type of letrec with multi parameters"
TypeInt
$ unlines
[ "letrec ? double(x: ?, dummy: ?)"
, " = if zero?(x) then 0 else -((double -(x,1) dummy), -2)"
, "in (double 6 10000)"
]
, testEq "Infer type of function with multi parameters in letrec expression"
(TypeProc [TypeInt, TypeVar 1] TypeInt)
$ unlines
[ "letrec ? double(x: ?, dummy: ?)"
, " = if zero?(x) then 0 else -((double -(x,1) dummy), -2)"
, "in double"
]
, testEq "Type of co-recursion body in letrec expression"
TypeInt
$ unlines
[ "letrec"
, " int even(x: int) = if zero?(x) then 1 else (odd -(x,1))"
, " int odd(x: int) = if zero?(x) then 0 else (even -(x,1))"
, "in (odd 13)"
]
, testEq "Infer type of co-recursion body in letrec expression"
TypeInt
$ unlines
[ "letrec"
, " ? even(x: ?) = if zero?(x) then 1 else (odd -(x,1))"
, " ? odd(x: ?) = if zero?(x) then 0 else (even -(x,1))"
, "in (odd 13)"
]
, testEq "Infer type of co-recursion function in letrec expression"
(TypeProc [TypeInt] TypeInt)
$ unlines
[ "letrec"
, " ? even(x: ?) = if zero?(x) then 1 else (odd -(x,1))"
, " ? odd(x: ?) = if zero?(x) then 0 else (even -(x,1))"
, "in odd"
]
]

testProc :: Test
Expand Down

0 comments on commit 05dfa7f

Please sign in to comment.