Commit 5d088454 authored by Kai-Oliver Prott's avatar Kai-Oliver Prott

Merge remote-tracking branch 'origin/data-class' into version3

parents 63f7f3f8 f8de5f36
......@@ -43,6 +43,7 @@ checkDerivable m tcEnv cs cls
| ocls == qEnumId && not (isEnum cs) = [errNotEnum cls]
| ocls == qBoundedId && not (isBounded cs) = [errNotBounded cls]
| ocls `notElem` derivableClasses = [errNotDerivable ocls]
| ocls == qDataId = [errNoDataDerive ocls]
| otherwise = []
where ocls = getOrigName m cls tcEnv
......@@ -53,7 +54,7 @@ derivableClasses = [qEqId, qOrdId, qEnumId, qBoundedId, qReadId, qShowId]
-- where all data constructors are constants.
isEnum :: [ConstrDecl] -> Bool
isEnum cs = all ((0 ==) . constrArity) cs
isEnum = all ((0 ==) . constrArity)
-- Instances of 'Bounded' can be derived only for enumerations and for single
-- constructor types.
......@@ -88,6 +89,12 @@ errNotDerivable :: QualIdent -> Message
errNotDerivable cls = posMessage cls $ hsep $ map text
["Instances of type class", escQualName cls, "cannot be derived"]
errNoDataDerive :: QualIdent -> Message
errNoDataDerive qcls = posMessage qcls $ hsep $ map text
[ "Instances of type class"
, escQualName qcls
, "are automatically derived if possible"]
errNotEnum :: HasPosition a => a -> Message
errNotEnum p = posMessage p $
text "Instances for Enum can be derived only for enumeration types"
......
......@@ -88,7 +88,7 @@ checkImpredDecl (InstanceDecl _ _ _ ty ds) = do
mapM_ checkImpredDecl ds
checkImpredDecl (DefaultDecl _ tys) = mapM_ checkType tys
where
checkType te = unlessM (checkSimpleType te) $ do
checkType te = unlessM (checkSimpleType te) $
report $ errIllegalDefaultType (getPosition te) te
checkImpredDecl _ = ok
......@@ -108,17 +108,17 @@ checkImpredNewConsDecl (NewRecordDecl _ _ (_, ty)) = checkImpredType ty
checkImpredType :: TypeExpr -> ICM ()
checkImpredType (ConstructorType _ _) = ok
checkImpredType te@(ApplyType spi ty1 ty2) = do
unlessM (checkSimpleType ty2) $ do
unlessM (checkSimpleType ty2) $
report $ errIllegalPolymorphicType (getPosition spi) te
checkImpredType ty1
checkImpredType ty2
checkImpredType (VariableType _ _) = ok
checkImpredType te@(TupleType spi tys) = do
unlessM (allM checkSimpleType tys) $ do
unlessM (allM checkSimpleType tys) $
report $ errIllegalPolymorphicType (getPosition spi) te
mapM_ checkImpredType tys
checkImpredType te@(ListType spi ty) = do
unlessM (checkSimpleType ty) $ do
unlessM (checkSimpleType ty) $
report $ errIllegalPolymorphicType (getPosition spi) te
checkImpredType ty
checkImpredType (ArrowType _ ty1 ty2) = do
......
......@@ -19,9 +19,10 @@
-}
module Checks.InstanceCheck (instanceCheck) where
import Control.Monad.Extra (concatMapM, whileM)
import Control.Monad.Extra (concatMapM, whileM, when)
import qualified Control.Monad.State as S (State, execState, gets, modify)
import Data.List (nub, partition, sortBy)
import Data.Maybe (catMaybes)
import qualified Data.Map as Map
import qualified Data.Set.Extra as Set
......@@ -52,7 +53,7 @@ instanceCheck m tcEnv clsEnv inEnv ds =
iss -> (inEnv, map (errMultipleInstances tcEnv) iss)
where
local = map (flip InstSource m) $ concatMap (genInstIdents m tcEnv) ds
imported = map (uncurry InstSource) $ map (fmap fst3) $ Map.toList inEnv
imported = map (uncurry InstSource . fmap fst3) $ Map.toList inEnv
state = INCState m inEnv []
-- In order to provide better error messages, we use the following data type
......@@ -95,6 +96,8 @@ ok = return ()
checkDecls :: TCEnv -> ClassEnv -> [Decl a] -> INCM ()
checkDecls tcEnv clsEnv ds = do
mapM_ (bindInstance tcEnv clsEnv) ids
mapM (declDeriveDataInfo tcEnv clsEnv) (filter isDataDecl tds) >>=
mapM_ (bindDerivedInstances clsEnv) . groupDeriveInfos
mapM (declDeriveInfo tcEnv clsEnv) (filter hasDerivedInstances tds) >>=
mapM_ (bindDerivedInstances clsEnv) . groupDeriveInfos
mapM_ (checkInstance tcEnv clsEnv) ids
......@@ -102,6 +105,9 @@ checkDecls tcEnv clsEnv ds = do
where (tds, ods) = partition isTypeDecl ds
ids = filter isInstanceDecl ods
dds = filter isDefaultDecl ods
isDataDecl (DataDecl _ _ _ _ _) = True
isDataDecl (NewtypeDecl _ _ _ _ _) = True
isDataDecl _ = False
-- First, the compiler adds all explicit instance declarations to the
-- instance environment.
......@@ -151,8 +157,21 @@ declDeriveInfo tcEnv clsEnv (NewtypeDecl p tc tvs nc clss) =
declDeriveInfo _ _ _ =
internalError "InstanceCheck.declDeriveInfo: no data or newtype declaration"
mkDeriveInfo :: TCEnv -> ClassEnv -> SpanInfo -> Ident -> [Ident] -> [TypeExpr]
-> [QualIdent] -> INCM DeriveInfo
declDeriveDataInfo :: TCEnv -> ClassEnv -> Decl a -> INCM DeriveInfo
declDeriveDataInfo tcEnv clsEnv (DataDecl p tc tvs cs _) =
mkDeriveDataInfo tcEnv clsEnv p tc tvs (concat tyss)
where tyss = map constrDeclTypes cs
constrDeclTypes (ConstrDecl _ _ tys) = tys
constrDeclTypes (ConOpDecl _ ty1 _ ty2) = [ty1, ty2]
constrDeclTypes (RecordDecl _ _ fs) = tys
where tys = [ty | FieldDecl _ ls ty <- fs, _ <- ls]
declDeriveDataInfo tcEnv clsEnv (NewtypeDecl p tc tvs nc _) =
mkDeriveDataInfo tcEnv clsEnv p tc tvs [nconstrType nc]
declDeriveDataInfo _ _ _ = internalError
"InstanceCheck.declDeriveDataInfo: no data or newtype declaration"
mkDeriveInfo :: TCEnv -> ClassEnv -> SpanInfo -> Ident -> [Ident]
-> [TypeExpr] -> [QualIdent] -> INCM DeriveInfo
mkDeriveInfo tcEnv clsEnv spi tc tvs tys clss = do
m <- getModuleIdent
let otc = qualifyWith m tc
......@@ -162,34 +181,59 @@ mkDeriveInfo tcEnv clsEnv spi tc tvs tys clss = do
return $ DeriveInfo p otc (TypeContext ps ty') tys' $ sortClasses clsEnv oclss
where p = spanInfo2Pos spi
mkDeriveDataInfo :: TCEnv -> ClassEnv -> SpanInfo -> Ident -> [Ident]
-> [TypeExpr] -> INCM DeriveInfo
mkDeriveDataInfo tcEnv clsEnv spi tc tvs tys = do
m <- getModuleIdent
let otc = qualifyWith m tc
TypeContext ps ty = expandConstrType m tcEnv clsEnv otc tvs tys
(tys', ty') = arrowUnapply ty
return $ DeriveInfo p otc (TypeContext ps ty') tys' [qDataId]
where p = spanInfo2Pos spi
sortClasses :: ClassEnv -> [QualIdent] -> [QualIdent]
sortClasses clsEnv clss = map fst $ sortBy compareDepth $ map adjoinDepth clss
where (_, d1) `compareDepth` (_, d2) = d1 `compare` d2
adjoinDepth cls = (cls, length $ allSuperClasses cls clsEnv)
groupDeriveInfos :: [DeriveInfo] -> [[DeriveInfo]]
groupDeriveInfos ds = scc bound free ds
groupDeriveInfos = scc bound free
where bound (DeriveInfo _ tc _ _ _) = [tc]
free (DeriveInfo _ _ _ tys _) = concatMap typeConstrs tys
bindDerivedInstances :: ClassEnv -> [DeriveInfo] -> INCM ()
bindDerivedInstances clsEnv dis = do
mapM_ (enterInitialPredSet clsEnv) dis
whileM $ concatMapM (inferPredSets clsEnv) dis >>= updatePredSets
-- If any registration of initial pred sets failed, return immediately, as
-- there are no other (Data-)Instances that might succeed.
bs <- mapM (enterInitialPredSet clsEnv) dis
when (any or bs) $
whileM $ concatMapM (inferPredSets clsEnv) dis >>= updatePredSets
enterInitialPredSet :: ClassEnv -> DeriveInfo -> INCM ()
enterInitialPredSet clsEnv (DeriveInfo p tc pty _ clss) =
mapM_ (bindDerivedInstance clsEnv p tc pty []) clss
enterInitialPredSet :: ClassEnv -> DeriveInfo -> INCM [Bool]
enterInitialPredSet clsEnv (DeriveInfo p tc pty tys clss) =
mapM (bindDerivedInstance clsEnv p tc pty tys) clss
-- Note: The methods and arities entered into the instance environment have
-- to match methods and arities of the later generated instance declarations.
bindDerivedInstance :: ClassEnv -> Position -> QualIdent -> Type -> [Type]
-> QualIdent -> INCM ()
-> QualIdent -> INCM Bool
bindDerivedInstance clsEnv p tc pty tys cls = do
m <- getModuleIdent
(i, ps) <- inferPredSet clsEnv p tc pty tys cls
modifyInstEnv $ bindInstInfo i (m, ps, impls)
-- immediately return if asked to derive Data for functional Datatype
if any isFunType tys && cls == qDataId
then return False
else do
-- bindDerivedInstances normally infers the PredSet with empty `tys`
-- in order to always bind the instance in a first step.
-- For DataDeriving, this leads to problems.
let tys' = if cls == qDataId then tys else []
mps <- inferPredSet clsEnv p tc pty tys' cls
case mps of
Just (i, ps) -> modifyInstEnv (bindInstInfo i (m, ps, impls)) >>
return True
-- encountered unsatisfied DataClass constraint -> dont derive it here
Nothing -> return False
where impls | cls == qEqId = [(eqOpId, 2)]
| cls == qOrdId = [(leqOpId, 2)]
| cls == qEnumId = [ (succId, 1), (predId, 1), (toEnumId, 1)
......@@ -199,15 +243,16 @@ bindDerivedInstance clsEnv p tc pty tys cls = do
| cls == qBoundedId = [(maxBoundId, 0), (minBoundId, 0)]
| cls == qReadId = [(readsPrecId, 2)]
| cls == qShowId = [(showsPrecId, 2)]
| cls == qDataId = [(dataEqId, 2), (aValueId, 0)]
| otherwise =
internalError "InstanceCheck.bindDerivedInstance.impls"
inferPredSets :: ClassEnv -> DeriveInfo -> INCM [(InstIdent, PredSet)]
inferPredSets clsEnv (DeriveInfo p tc pty tys clss) =
mapM (inferPredSet clsEnv p tc pty tys) clss
catMaybes <$> mapM (inferPredSet clsEnv p tc pty tys) clss
inferPredSet :: ClassEnv -> Position -> QualIdent -> Type -> [Type]
-> QualIdent -> INCM (InstIdent, PredSet)
-> QualIdent -> INCM (Maybe (InstIdent, PredSet))
inferPredSet clsEnv p tc (TypeContext ps inst) tys cls = do
m <- getModuleIdent
let doc = ppPred m $ Pred cls inst
......@@ -215,13 +260,21 @@ inferPredSet clsEnv p tc (TypeContext ps inst) tys cls = do
ps' = Set.fromList [Pred cls ty | ty <- tys]
ps'' = Set.fromList [Pred scls inst | scls <- sclss]
ps''' = ps `Set.union` ps' `Set.union` ps''
ps'''' <- reducePredSet p "derived instance" doc clsEnv ps'''
mapM_ (reportUndecidable p "derived instance" doc) $ Set.toList ps''''
return ((cls, tc), ps'''')
(ps4, novarps) <-
reducePredSet (cls == qDataId) p "derived instance" doc clsEnv ps'''
let ps5 = filter noPolyPred $ Set.toList ps4
if any (isDataPred m) (Set.toList novarps ++ ps5) && cls == qDataId
then return Nothing
else mapM_ (reportUndecidable p "derived instance" doc) ps5
>> return (Just ((cls, tc), ps4))
where
noPolyPred (Pred _ (TypeVariable _)) = False
noPolyPred (Pred _ _ ) = True
isDataPred _ (Pred qid _) = qid == qDataId
inferPredSet _ _ _ _ _ _ = internalError "InstanceCheck.inferPredSet"
updatePredSets :: [(InstIdent, PredSet)] -> INCM Bool
updatePredSets = (=<<) (return . or) . mapM (uncurry updatePredSet)
updatePredSets = fmap or . mapM (uncurry updatePredSet)
updatePredSet :: InstIdent -> PredSet -> INCM Bool
updatePredSet i ps = do
......@@ -260,9 +313,9 @@ checkInstance tcEnv clsEnv (InstanceDecl spi cx cls inst _) = do
ps' = Set.fromList [ Pred scls ty | scls <- superClasses ocls clsEnv ]
doc = ppPred m $ Pred cls ty
what = "instance declaration"
ps'' <- reducePredSet p what doc clsEnv ps'
(ps'', _) <- reducePredSet False p what doc clsEnv ps'
Set.mapM_ (report . errMissingInstance m p what doc) $
ps'' `Set.difference` (maxPredSet clsEnv ps)
ps'' `Set.difference` maxPredSet clsEnv ps
where p = spanInfo2Pos spi
checkInstance _ _ _ = ok
......@@ -281,7 +334,8 @@ checkDefaultType p tcEnv clsEnv ty = do
m <- getModuleIdent
let TypeContext _ ty' = expandPolyType m tcEnv clsEnv $
ContextType NoSpanInfo [] ty
ps <- reducePredSet p what empty clsEnv (Set.singleton $ Pred qNumId ty')
(ps, _) <- reducePredSet False p what empty clsEnv
(Set.singleton $ Pred qNumId ty')
Set.mapM_ (report . errMissingInstance m p what empty) ps
where what = "default declaration"
......@@ -291,16 +345,20 @@ checkDefaultType p tcEnv clsEnv ty = do
-- a type variable. An error is reported if the predicate set cannot
-- be transformed into this form. In addition, we remove all predicates
-- that are implied by others within the same set.
-- When the flag is set, all missing Data preds are ignored
reducePredSet :: Position -> String -> Doc -> ClassEnv -> PredSet
-> INCM PredSet
reducePredSet p what doc clsEnv ps = do
reducePredSet :: Bool -> Position -> String -> Doc -> ClassEnv -> PredSet
-> INCM (PredSet, PredSet)
reducePredSet b p what doc clsEnv ps = do
m <- getModuleIdent
inEnv <- getInstEnv
let (ps1, ps2) = partitionPredSet $ minPredSet clsEnv $ reducePreds inEnv ps
Set.mapM_ (report . errMissingInstance m p what doc) ps2
return ps1
ps2' = if b then Set.filter (isNotDataPred m) ps2 else ps2
Set.mapM_ (reportMissing m) ps2' >> return (ps1, ps2)
where
isNotDataPred _ (Pred qid _) = qid /= qDataId
reportMissing m pr@(Pred _ _) =
report $ errMissingInstance m p what doc pr
reducePreds inEnv = Set.concatMap $ reducePred inEnv
reducePred inEnv predicate = maybe (Set.singleton predicate)
(reducePreds inEnv)
......@@ -345,6 +403,15 @@ unqualInstIdent tcEnv (qcls, tc) = (unqual qcls, unqual tc)
where
unqual = head . flip reverseLookupByOrigName tcEnv
isFunType :: Type -> Bool
isFunType (TypeArrow _ _) = True
isFunType (TypeApply t1 t2) = isFunType t1 || isFunType t2
isFunType (TypeForall _ ty) = isFunType ty
isFunType (TypeContext _ ty) = isFunType ty
isFunType (TypeConstructor _) = False
isFunType (TypeVariable _) = False
isFunType (TypeConstrained tys _) = any isFunType tys
-- ---------------------------------------------------------------------------
-- Error messages
-- ---------------------------------------------------------------------------
......
......@@ -361,7 +361,7 @@ bindRecordLabels cs =
bindRecordLabel :: (Ident, [Ident]) -> SCM ()
bindRecordLabel (l, cs) = do
m <- getModuleIdent
new <- (null . lookupVar l) <$> getRenameEnv
new <- null . lookupVar l <$> getRenameEnv
unless new $ report $ errDuplicateDefinition l
modifyRenameEnv $ bindGlobal False m l $
RecordLabel (qualifyWith m l) (map (qualifyWith m) cs)
......@@ -378,7 +378,7 @@ bindFuncDecl tcc m (FunctionDecl _ _ f (eq:_)) env
bindFuncDecl tcc m (TypeSig spi fs (ContextType _ _ ty)) env
= bindFuncDecl tcc m (TypeSig spi fs ty) env
bindFuncDecl tcc m (TypeSig _ fs ty) env
= foldr bindTS env $ map (qualifyWith m) fs
= foldr (bindTS . qualifyWith m) env fs
where
bindTS qf env'
| null $ qualLookupVar qf env'
......@@ -408,8 +408,7 @@ bindVarDecl (FunctionDecl _ _ f eqs) env
| otherwise = let arty = length $ snd $ getFlatLhs $ head eqs
in bindLocal (unRenameIdent f) (LocalVar f arty) env
bindVarDecl (PatternDecl _ t _) env = foldr bindVar env (bv t)
bindVarDecl (FreeDecl _ vs) env =
foldr bindVar env (map varIdent vs)
bindVarDecl (FreeDecl _ vs) env = foldr (bindVar . varIdent) env vs
bindVarDecl _ env = env
bindVar :: Ident -> RenameEnv -> RenameEnv
......@@ -470,7 +469,7 @@ checkFuncPatDeps = do
fps <- getFuncPats
deps <- getGlobalDeps
let levels = scc (:[])
(\k -> Set.toList (Map.findWithDefault (Set.empty) k deps))
(\k -> Set.toList (Map.findWithDefault Set.empty k deps))
(Map.keys deps)
levelMap = Map.fromList [ (f, l) | (fs, l) <- zip levels [1 ..], f <- fs ]
level f = Map.findWithDefault (0 :: Int) f levelMap
......@@ -763,7 +762,7 @@ checkPattern _ (LiteralPattern spi a l) =
checkPattern _ (NegativePattern spi a l) =
return $ NegativePattern spi a l
checkPattern p (VariablePattern spi a v)
| isAnonId v = (VariablePattern spi a . renameIdent v) <$> newId
| isAnonId v = VariablePattern spi a . renameIdent v <$> newId
| otherwise = checkConstructorPattern p spi (qualify v) []
checkPattern p (ConstructorPattern spi _ c ts) =
checkConstructorPattern p spi c ts
......@@ -783,9 +782,9 @@ checkPattern p (LazyPattern spi t) = do
t' <- checkPattern p t
banFPTerm "lazy pattern" p t'
return (LazyPattern spi t')
checkPattern _ (FunctionPattern _ _ _ _) = internalError $
checkPattern _ (FunctionPattern _ _ _ _) = internalError
"SyntaxCheck.checkPattern: function pattern not defined"
checkPattern _ (InfixFuncPattern _ _ _ _ _) = internalError $
checkPattern _ (InfixFuncPattern _ _ _ _ _) = internalError
"SyntaxCheck.checkPattern: infix function pattern not defined"
checkConstructorPattern :: SpanInfo -> SpanInfo -> QualIdent -> [Pattern ()]
......@@ -953,7 +952,7 @@ checkVariable spi a v
-- anonymous free variable
| isAnonId (unqualify v) = do
checkAnonFreeVarsExtension $ getPosition v
(\n -> Variable spi a $ updQualIdent id (flip renameIdent n) v) <$> newId
(\n -> Variable spi a $ updQualIdent id (`renameIdent` n) v) <$> newId
-- return $ Variable v
-- normal variable
| otherwise = do
......@@ -1151,7 +1150,7 @@ recLabels _ = []
-- it is necessary to sort the list of declarations.
sortFuncDecls :: [Decl a] -> [Decl a]
sortFuncDecls decls = sortFD Set.empty [] decls
sortFuncDecls = sortFD Set.empty []
where
sortFD _ res [] = reverse res
sortFD env res (decl : decls') = case decl of
......
......@@ -49,8 +49,8 @@ import Prelude hiding ((<>))
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
#endif
import Control.Monad.Extra (allM, eitherM, filterM, foldM, liftM,
notM, replicateM, unless, unlessM, (&&^))
import Control.Monad.Extra (allM, eitherM, filterM, foldM, (&&^),
notM, replicateM, when, unless, unlessM)
import qualified Control.Monad.State as S (State, gets, modify, runState)
import Data.Foldable (foldrM)
import Data.Function (on)
......@@ -148,7 +148,7 @@ data TcState = TcState
(&&>) :: TCM () -> TCM () -> TCM ()
pre &&> suf = do
errs <- pre >> S.gets errors
if null errs then suf else return ()
when (null errs) suf
(>>-) :: TCM (a, b, c) -> (a -> b -> TCM a) -> TCM (a, c)
m >>- f = do
......@@ -361,8 +361,8 @@ bindLabels' m tcEnv vEnv = foldr (bindData . snd) vEnv $ localBindings tcEnv
n = kindArity k
labels = zip (concatMap recLabels cs) (concatMap recLabelTypes cs)
clabels = [(l, constr l, ty) | (l, ty) <- labels]
constr l = map (qualifyLike tc)
[constrIdent c | c <- cs, l `elem` recLabels c]
constr l = [qualifyLike tc (constrIdent c)
| c <- cs, l `elem` recLabels c]
sameLabel (l1, _, _) (l2, _, _) = l1 == l2
bindData (RenamingType tc k (RecordConstr c [l] [lty])) vEnv'
= bindLabel m n (constrType' tc n) (l, [qc], lty) vEnv'
......@@ -437,7 +437,7 @@ bindTypeSig :: Ident -> TypeExpr -> SigEnv -> SigEnv
bindTypeSig = Map.insert
bindTypeSigs :: Decl a -> SigEnv -> SigEnv
bindTypeSigs (TypeSig _ vs ty) env = foldr (flip bindTypeSig ty) env vs
bindTypeSigs (TypeSig _ vs ty) env = foldr (`bindTypeSig` ty) env vs
bindTypeSigs _ env = env
lookupTypeSig :: Ident -> SigEnv -> Maybe TypeExpr
......@@ -474,7 +474,7 @@ toCheckModeList tys = map Check tys ++ repeat Infer
-- that group as well.
tcDecls :: [Decl a] -> TCM (PredSet, [Decl Type])
tcDecls = liftM (fmap fromPDecls) . tcPDecls . toPDecls
tcDecls = fmap (fmap fromPDecls) . tcPDecls . toPDecls
tcPDecls :: [PDecl a] -> TCM (PredSet, [PDecl Type])
tcPDecls pds = withLocalSigEnv $ do
......@@ -492,8 +492,16 @@ tcPDeclGroup ps [(i, ExternalDecl p fs)] = do
tcPDeclGroup ps [(i, FreeDecl p fvs)] = do
vs <- mapM (tcDeclVar False) (bv fvs)
m <- getModuleIdent
modifyValueEnv $ flip (bindVars m) vs
return (ps, [(i, FreeDecl p (map (\(v, _, TypeForall _ ty) -> Var ty v) vs))])
(vs', ps') <- unzip <$> mapM addDataPred vs
modifyValueEnv $ flip (bindVars m) vs'
let d = FreeDecl p (map (\(v, _, TypeForall _ ty) -> Var ty v) vs')
return (ps `Set.union` Set.unions ps', [(i, d)])
where
addDataPred (idt, n, TypeForall ids ty1) = do
(ps2, ty2) <- freshDataType
ps' <- unify idt "free variable" (ppIdent idt) emptyPredSet ty1 ps2 ty2
return ((idt, n, TypeForall ids ty1), ps')
addDataPred _ = internalError "TypeCheck.addDataPred"
tcPDeclGroup ps pds = do
vEnv <- getValueEnv
vss <- mapM (tcDeclVars . snd) pds
......@@ -503,7 +511,7 @@ tcPDeclGroup ps pds = do
let (impPds, expPds) = partitionPDecls sigs pds
(ps', impPds') <- mapAccumM tcPDecl ps impPds
theta <- getTypeSubst
tvs <- liftM (concatMap $ typeVars . subst theta . fst) $
tvs <- concatMap (typeVars . subst theta . fst) <$>
filterM (notM . isNonExpansive . snd . snd) impPds'
let fvs = foldr Set.insert (fvEnv (subst theta vEnv)) tvs
(gps, lps) = splitPredSet fvs ps'
......@@ -624,7 +632,7 @@ bindPatternVars (Check ty) (VariablePattern _ _ v) = do
bindPatternVars _ (ConstructorPattern _ _ c ps) = do
m <- getModuleIdent
vEnv <- getValueEnv
tys <- (fst . arrowUnapply . snd) <$> inst (constrType m c vEnv)
tys <- fst . arrowUnapply . snd <$> inst (constrType m c vEnv)
mapM_ (uncurry bindPatternVars) $ zip (toCheckModeList tys) ps
bindPatternVars cm (InfixPattern spi a p1 op p2)
= bindPatternVars cm (ConstructorPattern spi a op [p1, p2])
......@@ -640,11 +648,11 @@ bindPatternVars cm (LazyPattern _ p) = bindPatternVars cm p
bindPatternVars _ (FunctionPattern _ _ f ps) = do
m <- getModuleIdent
vEnv <- getValueEnv
tys <- (fst . arrowUnapply . snd) <$> inst (funType m f vEnv)
tys <- fst . arrowUnapply . snd <$> inst (funType m f vEnv)
mapM_ (uncurry bindPatternVars) $ zip (toCheckModeList tys) ps
bindPatternVars cm (InfixFuncPattern spi a p1 op p2)
= bindPatternVars cm (FunctionPattern spi a op [p1, p2])
bindPatternVars _ (RecordPattern _ _ _ fs) = do
bindPatternVars _ (RecordPattern _ _ _ fs) =
mapM_ bindFieldVars fs
bindPatternVars _ _ = ok
......@@ -652,7 +660,7 @@ bindFieldVars :: Field (Pattern a) -> TCM ()
bindFieldVars (Field _ l p) = do
m <- getModuleIdent
vEnv <- getValueEnv
ty <- (arrowBase . snd) <$> inst (labelType m l vEnv)
ty <- arrowBase . snd <$> inst (labelType m l vEnv)
bindPatternVars (Check ty) p
lambdaVar :: Ident -> TCM (Ident, Int, Type)
......@@ -867,7 +875,7 @@ isNonExpansive' n (Typed _ e _) = isNonExpansive' n e
isNonExpansive' _ (Record _ _ c fs) = do
m <- getModuleIdent
vEnv <- getValueEnv
liftM ((length (constrLabels m c vEnv) == length fs) &&) (isNonExpansive fs)
fmap ((length (constrLabels m c vEnv) == length fs) &&) (isNonExpansive fs)
isNonExpansive' _ (Tuple _ es) = isNonExpansive es
isNonExpansive' _ (List _ _ es) = isNonExpansive es
isNonExpansive' n (Apply _ f e) = isNonExpansive' (n + 1) f
......@@ -879,8 +887,8 @@ isNonExpansive' n (LeftSection _ e op) = isNonExpansive' (n + 1) (infixOp op)
&&^ isNonExpansive e
isNonExpansive' n (Lambda _ ts e) = withLocalValueEnv $ do
modifyValueEnv $ flip (foldr bindVarArity) (bv ts)
liftM ((n < length ts) ||) (liftM ((all isVariablePattern ts) &&)
(isNonExpansive' (n - length ts) e))
fmap (((n < length ts) ||) . (all isVariablePattern ts &&))
(isNonExpansive' (n - length ts) e)
isNonExpansive' n (Let _ ds e) = withLocalValueEnv $ do
m <- getModuleIdent
tcEnv <- getTyConsEnv
......@@ -1063,15 +1071,15 @@ tcLiteral :: Bool -> Literal -> TCM (PredSet, Type)
tcLiteral _ (Char _) = return (emptyPredSet, charType)
tcLiteral poly (Int _)
| poly = freshNumType
| otherwise = liftM ((,) emptyPredSet) (freshConstrained numTypes)
| otherwise = fmap ((,) emptyPredSet) (freshConstrained numTypes)
tcLiteral poly (Float _)
| poly = freshFractionalType
| otherwise = liftM ((,) emptyPredSet) (freshConstrained fractionalTypes)
| otherwise = fmap ((,) emptyPredSet) (freshConstrained fractionalTypes)
tcLiteral _ (String _) = return (emptyPredSet, stringType)
tcLhs :: HasPosition p => p -> Lhs a -> TCM (PredSet, [Type], Lhs Type)
tcLhs p (FunLhs spi f ts) = do
(pss, tys, ts') <- liftM unzip3 $ mapM (tcPattern p) ts
(pss, tys, ts') <- unzip3 <$> mapM (tcPattern p) ts
return (Set.unions pss, tys, FunLhs spi f ts')
tcLhs p (OpLhs spi t1 op t2) = do
(ps1, ty1, t1') <- tcPattern p t1
......@@ -1079,7 +1087,7 @@ tcLhs p (OpLhs spi t1 op t2) = do
return (ps1 `Set.union` ps2, [ty1, ty2], OpLhs spi t1' op t2')
tcLhs p (ApLhs spi lhs ts) = do
(ps, tys1, lhs') <- tcLhs p lhs
(pss, tys2, ts') <- liftM unzip3 $ mapM (tcPattern p) ts
(pss, tys2, ts') <- unzip3 <$> mapM (tcPattern p) ts
return (Set.unions (ps:pss), tys1 ++ tys2, ApLhs spi lhs' ts')
-- When computing the type of a variable in a pattern, we ignore the
......@@ -1104,7 +1112,7 @@ tcPattern _ (VariablePattern spi _ v) = do
tcPattern p t@(ConstructorPattern spi _ c ts) = do
m <- getModuleIdent
vEnv <- getValueEnv
(ps, (tys, ty')) <- liftM (fmap arrowUnapply) (inst (constrType m c vEnv))
(ps, (tys, ty')) <- fmap (fmap arrowUnapply) (inst (constrType m c vEnv))
(ps', ts') <- mapAccumM (uncurry . tcPatternArg p "pattern" (pPrintPrec 0 t))
ps (zip tys ts)
return (ps', ty', ConstructorPattern spi ty' c ts')
......@@ -1118,12 +1126,12 @@ tcPattern p (ParenPattern spi t) = do
tcPattern _ t@(RecordPattern spi _ c fs) = do
m <- getModuleIdent
vEnv <- getValueEnv
(ps, ty) <- liftM (fmap arrowBase) (inst (constrType m c vEnv))
(ps, ty) <- fmap (fmap arrowBase) (inst (constrType m c vEnv))
(ps', fs') <- mapAccumM (tcField tcPattern "pattern"
(\t' -> pPrintPrec 0 t $-$ text "Term:" <+> pPrintPrec 0 t') ty) ps fs
return (ps', ty, RecordPattern spi ty c fs')
tcPattern p (TuplePattern spi ts) = do
(pss, tys, ts') <- liftM unzip3 $ mapM (tcPattern p) ts
(pss, tys, ts') <- unzip3 <$> mapM (tcPattern p) ts
return (Set.unions pss, tupleType tys, TuplePattern spi ts')
tcPattern p t@(ListPattern spi _ ts) = do
ty <- freshTypeVar
......@@ -1225,7 +1233,7 @@ tcExpr _ p (Typed spi e qty) = do
tcExpr _ _ e@(Record spi _ c fs) = do
m <- getModuleIdent
vEnv <- getValueEnv
(ps, ty) <- liftM (fmap arrowBase) (inst (constrType m c vEnv))
(ps, ty) <- fmap (fmap arrowBase) (inst (constrType m c vEnv))
(ps', fs') <- mapAccumM (tcField (tcExpr Infer) "construction"
(\e' -> pPrintPrec 0 e $-$ text "Term:" <+> pPrintPrec 0 e') ty) ps fs
return (ps', ty, Record spi ty c fs')
......@@ -1235,7 +1243,7 @@ tcExpr _ p e@(RecordUpdate spi e1 fs) = do
(\e' -> pPrintPrec 0 e $-$ text "Term:" <+> pPrintPrec 0 e') ty) ps fs
return (ps', ty, RecordUpdate spi e1' fs')
tcExpr _ p (Tuple spi es) = do
(pss, tys, es') <- liftM unzip3 $ mapM (tcExpr Infer p) es
(pss, tys, es') <- unzip3 <$> mapM (tcExpr Infer p) es
return (Set.unions pss, tupleType tys, Tuple spi es')
tcExpr _ p e@(List spi _ es) = do
ty <- freshTypeVar
......@@ -1302,7 +1310,7 @@ tcExpr cm p (Lambda spi ts e) = do
Infer -> toCheckModeList []
Check ty -> toCheckModeList $ fst $ arrowUnapply ty
mapM_ (uncurry bindPatternVars) $ zip cmList ts
(pss, tys, ts') <- liftM unzip3 $ mapM (tcPattern p) ts
(pss, tys, ts') <- unzip3 <$> mapM (tcPattern p) ts
(ps, ty, e') <- tcExpr Infer p e
return (pss, tys, ts', ps, ty, e')
ps' <- reducePredSet p "expression" (pPrintPrec 0 e') (Set.unions $ ps : pss)
......@@ -1318,7 +1326,7 @@ tcExpr _ p (Do spi sts e) = do
(sts', ty, ps', e') <- withLocalValueEnv $ do
((ps, mTy), sts') <-
mapAccumM (uncurry (tcStmt p)) (emptyPredSet, Nothing) sts
ty <- liftM (maybe id TypeApply mTy) freshTypeVar
ty <- fmap (maybe id TypeApply mTy) freshTypeVar
(ps', e') <- tcExpr Infer p e >>- unify p "statement" (pPrintPrec 0 e) ps ty
return (sts', ty, ps', e')
return (ps', ty, Do spi sts' e')
......@@ -1526,14 +1534,14 @@ unifyTypes _ (TypeConstrained tys1 tv1) ty@(TypeConstrained tys2 tv2)
| tv1 == tv2 = return $ Right idSubst
| tys1 == tys2 = return $ Right (singleSubst tv1 ty)
unifyTypes m (TypeConstrained tys tv) ty
= foldrM (\ty' s -> liftM (`choose` s) (unifyTypes m ty ty'))
= foldrM (\ty' s -> fmap (`choose` s) (unifyTypes m ty ty'))
(Left (errIncompatibleTypes m ty (head tys)))
tys
where
choose (Left _) theta' = theta'
choose (Right theta) _ = Right (bindSubst tv ty theta)
unifyTypes m ty (TypeConstrained tys tv)
= foldrM (\ty' s -> liftM (`choose` s) (unifyTypes m ty ty'))
= foldrM (\ty' s -> fmap (`choose` s) (unifyTypes m ty ty'))
(Left (errIncompatibleTypes m ty (head tys)))
tys
where
......@@ -1557,16 +1565,16 @@ unifyTypes m ty1@(TypeForall _ _) ty2@(TypeForall _ _)
Left x -> return $ Left x
Right s -> do
let (_, tys) = unzip $ substToList $ restrictSubstTo (vs1 ++ vs2) s
case all isVarType tys of