Commit d18d2539 authored by Jan-Hendrik Matthes's avatar Jan-Hendrik Matthes 😄

Improve the subsumption in the type check

parent 4e0d33c1
......@@ -60,7 +60,7 @@ import qualified Data.Map as Map (Map, empty, insert, lookup)
import Data.Maybe (fromJust, isJust)
import qualified Data.Set.Extra as Set (Set, concatMap, deleteMin, empty,
filter, fromList, insert,
isSubsetOf, member, notMember,
isSubsetOf, map, member, notMember,
partition, singleton, toList,
union, unions)
......@@ -672,7 +672,7 @@ lambdaVar v = do
unifyDecl :: HasPosition p => p -> String -> Doc -> PredSet -> Type -> PredSet
-> Type -> TCM PredSet
unifyDecl p what doc psLhs tyLhs psRhs tyRhs = do
ps <- unify p what doc psLhs tyLhs psRhs tyRhs
ps <- unifyHelp True p what doc psRhs tyRhs psLhs tyLhs
fvs <- computeFvEnv
applyDefaultsDecl p what doc fvs ps tyLhs
......@@ -1472,7 +1472,11 @@ tcBinary p what doc ty = tcArrow p what doc ty >>= uncurry binaryArrow
unify :: HasPosition p => p -> String -> Doc -> PredSet -> Type -> PredSet
-> Type -> TCM PredSet
unify p what doc ps1 ty1 ps2 ty2 = do
unify = unifyHelp False
unifyHelp :: HasPosition p => Bool -> p -> String -> Doc -> PredSet -> Type
-> PredSet -> Type -> TCM PredSet
unifyHelp b p what doc ps1 ty1 ps2 ty2 = do
theta <- getTypeSubst
let ty1' = subst theta ty1
ty2' = subst theta ty2
......@@ -1481,8 +1485,46 @@ unify p what doc ps1 ty1 ps2 ty2 = do
case res of
Left reason -> report $ errTypeMismatch p what doc m ty1' ty2' reason
Right sigma -> modifyTypeSubst (compose sigma)
theta' <- getTypeSubst
let ty1'' = subst theta' ty1'
ty2'' = subst theta' ty2'
unlessM (subsumCheck b emptyPredSet emptyPredSet ty1'' ty2'') $ do
report $ errSubsumption p what doc m ty2' ty1'
reducePredSet p what doc (ps1 `Set.union` ps2)
subsumCheck :: Bool -> PredSet -> PredSet -> Type -> Type -> TCM Bool
subsumCheck True ps1 ps2 (TypeArrow ty11@(TypeForall _ _) ty12) (TypeArrow ty21@(TypeForall _ _) ty22)
= subsumCheck' ps1 ps2 ty11 ty21 &&^ subsumCheck' ps1 ps2 ty12 ty22
subsumCheck _ ps1 ps2 ty1 ty2 = subsumCheck' ps1 ps2 ty1 ty2
subsumCheck' :: PredSet -> PredSet -> Type -> Type -> TCM Bool
subsumCheck' _ _ (TypeConstructor _) (TypeConstructor _) = return True
subsumCheck' ps1 ps2 (TypeVariable tv1) (TypeVariable tv2) = do
clsEnv <- getClassEnv
let ps1' = maxPredSet clsEnv ps1
preds = Set.map (\(Pred qid _) -> Pred qid (TypeVariable tv1))
(Set.filter (\(Pred _ (TypeVariable v)) -> v == tv2) ps2)
return $ all (`elem` ps1') preds
subsumCheck' _ _ _ (TypeVariable _) = return True
subsumCheck' _ _ (TypeConstrained _ _) (TypeConstrained _ _) = return True
subsumCheck' ps1 ps2 (TypeApply ty11 ty12) (TypeApply ty21 ty22)
= subsumCheck' ps1 ps2 ty11 ty21 &&^ subsumCheck' ps1 ps2 ty12 ty22
subsumCheck' ps1 ps2 ty@(TypeApply _ _) (TypeArrow ty21 ty22)
= subsumCheck' ps1 ps2 ty (TypeApply (TypeApply (TypeConstructor qArrowId) ty21) ty22)
subsumCheck' ps1 ps2 (TypeArrow ty11 ty12) ty@(TypeApply _ _)
= subsumCheck' ps1 ps2 (TypeApply (TypeApply (TypeConstructor qArrowId) ty11) ty12) ty
subsumCheck' ps1 ps2 (TypeArrow ty11@(TypeForall _ _) ty12) (TypeArrow ty21@(TypeForall _ _) ty22)
= subsumCheck' ps2 ps1 ty21 ty11 &&^ subsumCheck' ps1 ps2 ty12 ty22
subsumCheck' ps1 ps2 (TypeArrow ty11 ty12) (TypeArrow ty21 ty22)
= subsumCheck' ps1 ps2 ty11 ty21 &&^ subsumCheck' ps1 ps2 ty12 ty22
subsumCheck' ps1 ps2 (TypeContext ps ty1) ty2
= subsumCheck' (ps1 `Set.union` ps) ps2 ty1 ty2
subsumCheck' ps1 ps2 ty1 (TypeContext ps ty2)
= subsumCheck' ps1 (ps2 `Set.union` ps) ty1 ty2
subsumCheck' ps1 ps2 (TypeForall _ ty1) ty2 = subsumCheck' ps1 ps2 ty1 ty2
subsumCheck' ps1 ps2 ty1 (TypeForall _ ty2) = subsumCheck' ps1 ps2 ty1 ty2
subsumCheck' _ _ _ _ = return False
unifyTypes :: ModuleIdent -> Type -> Type -> TCM (Either Doc TypeSubst)
unifyTypes _ (TypeVariable tv1) ty@(TypeVariable tv2)
| tv1 == tv2 = return $ Right idSubst
......@@ -1538,16 +1580,13 @@ unifyTypes m ty1@(TypeForall _ _) ty2@(TypeForall _ _)
Left x -> return $ Left x
Right s -> do
let (_, tys) = unzip $ substToList $ restrictSubstTo (vs1 ++ vs2) s
case all isVarType tys of
True -> do
let vars = typeVars ty1 ++ typeVars ty2
let tvs = concatMap typeVars $ snd $ unzip $ substToList
$ restrictSubstTo vars s
let tys' = map (\(TypeVariable tv) -> tv) tys
case filter (`elem` tvs) (vs1 ++ vs2 ++ tys') of
[] -> return $ Right s
ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2
False -> return $ Left (errIncompatibleTypes m ty1 ty2)
vars = typeVars ty1 ++ typeVars ty2
tvs = concatMap typeVars $ snd $ unzip $ substToList
$ restrictSubstTo vars s
tys' = concatMap typeVars tys
case filter (`elem` tvs) (vs1 ++ vs2 ++ tys') of
[] -> return $ Right s
ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2
unifyTypes m ty1@(TypeForall _ _) ty2
= do (vs, _, ty1') <- skolemise ty1
res <- unifyTypes m ty1' ty2
......@@ -1555,15 +1594,12 @@ unifyTypes m ty1@(TypeForall _ _) ty2
Left x -> return $ Left x
Right s -> do
let (_, tys) = unzip $ substToList $ restrictSubstTo vs s
case all isVarType tys of
True -> do
let tvs = concatMap typeVars $ snd $ unzip $ substToList
$ restrictSubstTo (typeVars ty1) s
let tys' = map (\(TypeVariable tv) -> tv) tys
case filter (`elem` tvs) (vs ++ tys') of
[] -> return $ Right s
ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2
False -> return $ Left (errIncompatibleTypes m ty1 ty2)
tvs = concatMap typeVars $ snd $ unzip $ substToList
$ restrictSubstTo (typeVars ty1) s
tys' = concatMap typeVars tys
case filter (`elem` tvs) (vs ++ tys') of
[] -> return $ Right s
ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2
unifyTypes m ty1 ty2@(TypeForall _ _)
= do (vs, _, ty2') <- skolemise ty2
res <- unifyTypes m ty1 ty2'
......@@ -1571,15 +1607,12 @@ unifyTypes m ty1 ty2@(TypeForall _ _)
Left x -> return $ Left x
Right s -> do
let (_, tys) = unzip $ substToList $ restrictSubstTo vs s
case all isVarType tys of
True -> do
let tvs = concatMap typeVars $ snd $ unzip $ substToList
$ restrictSubstTo (typeVars ty2) s
let tys' = map (\(TypeVariable tv) -> tv) tys
case filter (`elem` tvs) (vs ++ tys') of
[] -> return $ Right s
ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2
False -> return $ Left (errIncompatibleTypes m ty1 ty2)
tvs = concatMap typeVars $ snd $ unzip $ substToList
$ restrictSubstTo (typeVars ty2) s
tys' = concatMap typeVars tys
case filter (`elem` tvs) (vs ++ tys') of
[] -> return $ Right s
ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2
unifyTypes m ty1 ty2
= return $ Left (errIncompatibleTypes m ty1 ty2)
......@@ -1912,6 +1945,14 @@ errTypeMismatch p what doc m ety ity reason = posMessage p $ vcat
, text "Expected type:" <+> ppType m ety
, reason ]
errSubsumption :: HasPosition a => a -> String -> Doc -> ModuleIdent -> Type
-> Type -> Message
errSubsumption p what doc m ety ity = posMessage p $ vcat
[ text "Type error in" <+> text what, doc
, text "The type" <+> ppType m ity
, text "is not as polymorphic as"
, text "the expected type" <+> ppType m ety ]
errRecursiveType :: ModuleIdent -> Int -> Type -> Doc
errRecursiveType m tv = errIncompatibleTypes m (TypeVariable tv)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment