diff --git a/src/Checks.hs b/src/Checks.hs index 4671ff592eed4af0eed0e1e9fe329ae56988c01a..477d24a8dad9253cec3b8b41521e2ecbad6ea0fe 100644 --- a/src/Checks.hs +++ b/src/Checks.hs @@ -148,7 +148,8 @@ typeCheck :: Monad m => Options -> CompEnv (Module a) typeCheck _ (env, Module spi ps m es is ds) | null msgs = ok (env { valueEnv = vEnv' }, Module spi ps m es is ds') | otherwise = failMessages msgs - where (ds', vEnv', msgs) = TC.typeCheck (moduleIdent env) (tyConsEnv env) + where (ds', vEnv', msgs) = TC.typeCheck (extensions env) + (moduleIdent env) (tyConsEnv env) (valueEnv env) (classEnv env) (instEnv env) ds diff --git a/src/Checks/TypeCheck.hs b/src/Checks/TypeCheck.hs index 460c071bcaeff119f6d53658d3db494aaca300cf..de66c519e8c32e59143266890a5318947d70f0fd 100644 --- a/src/Checks/TypeCheck.hs +++ b/src/Checks/TypeCheck.hs @@ -78,8 +78,8 @@ import Base.CurryTypes import Base.Expr import Base.Kinds import Base.Messages (Message, internalError, posMessage) +import Base.NestEnv import Base.SCC -import Base.TopEnv import Base.TypeExpansion import Base.Types import Base.TypeSubst @@ -95,12 +95,13 @@ import Env.Value -- constructors, field labels and class methods are entered into the value -- environment and then a type inference for all function and value definitions -- is performed. -typeCheck :: ModuleIdent -> TCEnv -> ValueEnv -> ClassEnv -> InstEnv -> [Decl a] - -> ([Decl Type], ValueEnv, [Message]) -typeCheck m tcEnv vEnv clsEnv inEnv ds = runTCM (checkDecls ds) initState +typeCheck :: [KnownExtension] -> ModuleIdent -> TCEnv -> ValueEnv -> ClassEnv + -> InstEnv -> [Decl a] -> ([Decl Type], ValueEnv, [Message]) +typeCheck exts m tcEnv vEnv clsEnv inEnv ds = runTCM (checkDecls ds) initState where - initState = TcState m tcEnv vEnv clsEnv (inEnv, Map.empty) - [intType, floatType] idSubst emptySigEnv 1 [] [] + initState = TcState m exts tcEnv vEnv clsEnv (inEnv, Map.empty) + [intType, floatType] idSubst emptySigEnv + emptyScopedTyVarsEnv 1 [] [] checkDecls :: [Decl a] -> TCM [Decl Type] checkDecls ds = do @@ -136,18 +137,20 @@ type TCM = S.State TcState type InstEnv' = (InstEnv, Map.Map QualIdent [Type]) data TcState = TcState - { moduleIdent :: ModuleIdent -- read-only - , tyConsEnv :: TCEnv - , valueEnv :: ValueEnv - , classEnv :: ClassEnv - , instEnv :: InstEnv' -- instances (static and dynamic) - , defaultTypes :: [Type] - , typeSubst :: TypeSubst - , sigEnv :: SigEnv - , nextId :: Int -- automatic counter - , errors :: [Message] - , impVars :: [Int] -- type variables that can be instantiated with - -- higher-rank types when necessary + { moduleIdent :: ModuleIdent -- read-only + , extensions :: [KnownExtension] + , tyConsEnv :: TCEnv + , valueEnv :: ValueEnv + , classEnv :: ClassEnv + , instEnv :: InstEnv' -- instances (static and dynamic) + , defaultTypes :: [Type] + , typeSubst :: TypeSubst + , sigEnv :: SigEnv + , scopedTyVarsEnv :: ScopedTyVarsEnv + , nextId :: Int -- automatic counter + , errors :: [Message] + , impVars :: [Int] -- type variables that can be instantiated + -- with higher-rank types when necessary } (&&>) :: TCM () -> TCM () -> TCM () @@ -174,6 +177,9 @@ runTCM tcm s = let (a, s') = S.runState tcm s getModuleIdent :: TCM ModuleIdent getModuleIdent = S.gets moduleIdent +hasExtension :: KnownExtension -> TCM Bool +hasExtension ext = S.gets (elem ext . extensions) + getTyConsEnv :: TCM TCEnv getTyConsEnv = S.gets tyConsEnv @@ -221,6 +227,40 @@ withLocalSigEnv act = do setSigEnv oldSigs return res +-- | This environment contains the scoped type variables of a type signature +-- with its corresponding freshly instantiated type variable. +type ScopedTyVarsEnv = NestEnv Int + +emptyScopedTyVarsEnv :: ScopedTyVarsEnv +emptyScopedTyVarsEnv = globalEnv emptyTopEnv + +-- | Retrieve the 'ScopedTyVarsEnv'. +getScopedTyVarsEnv :: TCM ScopedTyVarsEnv +getScopedTyVarsEnv = S.gets scopedTyVarsEnv + +-- | Modify the 'ScopedTyVarsEnv'. +modifyScopedTyVarsEnv :: (ScopedTyVarsEnv -> ScopedTyVarsEnv) -> TCM () +modifyScopedTyVarsEnv f + = S.modify $ \s -> s { scopedTyVarsEnv = f $ scopedTyVarsEnv s } + +-- | Increase the nesting of the 'ScopedTyVarsEnv' to introduce a new local +-- scope. +incScopedTyVarsNesting :: TCM () +incScopedTyVarsNesting = modifyScopedTyVarsEnv nestEnv + +withLocalScopedTyVarsEnv :: TCM a -> TCM a +withLocalScopedTyVarsEnv act = do + oldEnv <- getScopedTyVarsEnv + res <- act + modifyScopedTyVarsEnv $ const oldEnv + return res + +-- | Performs an action in a nested scope (by creating a nested +-- 'ScopedTyVarsEnv') and discard the nested 'ScopedTyVarsEnv' afterwards. +inNestedScopedTyVarsScope :: TCM a -> TCM a +inNestedScopedTyVarsScope act + = withLocalScopedTyVarsEnv (incScopedTyVarsNesting >> act) + getNextId :: TCM Int getNextId = do n <- S.gets nextId @@ -551,6 +591,36 @@ bindVars m = foldr $ uncurry3 $ flip (bindFun m) False rebindVars :: ModuleIdent -> ValueEnv -> [(Ident, Int, Type)] -> ValueEnv rebindVars m = foldr $ uncurry3 $ flip (rebindFun m) False +scopeType :: TypeExpr -> Type -> TCM Type +scopeType tyExpr ty = do + scoped <- hasExtension ScopedTypeVariables + if scoped + then do + env <- getScopedTyVarsEnv + return $ scope env (zip (typeVars ty) (typeExprVars tyExpr)) ty + else return ty + where + scope env vids t@(TypeVariable v) = case lookup v vids of + Nothing -> t + Just i -> case lookupNestEnv i env of + [] -> t + x:_ -> TypeVariable x + scope env vids (TypeApply t1 t2) = TypeApply (scope env vids t1) (scope env vids t2) + scope env vids (TypeArrow t1 t2) = TypeArrow (scope env vids t1) (scope env vids t2) + scope env vids (TypeContext ps t) = TypeContext ps (scope env vids t) + scope env vids (TypeForall tvs t) = TypeForall tvs $ scope env vids t + scope _ _ t = t + + typeExprVars (ConstructorType _ _) = [] + typeExprVars (ApplyType _ ty1 ty2) = typeExprVars ty1 ++ typeExprVars ty2 + typeExprVars (VariableType _ tv) = [mkIdent $ idName tv] + typeExprVars (TupleType _ tys) = concatMap typeExprVars tys + typeExprVars (ListType _ ty1) = typeExprVars ty1 + typeExprVars (ArrowType _ ty1 ty2) = typeExprVars ty1 ++ typeExprVars ty2 + typeExprVars (ParenType _ ty1) = typeExprVars ty1 + typeExprVars (ContextType _ _ ty1) = typeExprVars ty1 + typeExprVars (ForallType _ vs ty1) = filter (not . (`elem` (map (mkIdent . idName) vs))) $ typeExprVars ty1 + tcDeclVars :: Decl a -> TCM [(Ident, Int, Type)] tcDeclVars (FunctionDecl _ _ f eqs) = do sigs <- getSigEnv @@ -558,7 +628,8 @@ tcDeclVars (FunctionDecl _ _ f eqs) = do case lookupTypeSig f sigs of Just ty -> do ty' <- expandTypeExpr ty - return [(f, n, polyType ty')] + ty'' <- scopeType ty ty' + return [(f, n, polyType ty'')] Nothing -> do tys <- replicateM (n + 1) freshTypeVar return [(f, n, monoType $ foldr1 TypeArrow tys)] @@ -572,7 +643,8 @@ tcDeclVar poly v = do sigs <- getSigEnv case lookupTypeSig v sigs of Just ty | poly || null (fv ty) -> do ty' <- expandTypeExpr ty - return (v, 0, polyType ty') + ty'' <- scopeType ty ty' + return (v, 0, polyType ty'') | otherwise -> do report $ errPolymorphicVar v lambdaVar v Nothing -> lambdaVar v @@ -595,9 +667,17 @@ tcPDecl _ _ = internalError "TypeCheck.tcPDecl" tcFunctionPDecl :: Int -> PredSet -> Type -> SpanInfo -> Ident -> [Equation a] -> TCM (PredSet, (Type, PDecl Type)) tcFunctionPDecl i ps tySc p f eqs = do - ty <- snd <$> inst tySc - (ps', eqs') <- mapAccumM (tcEquation ty) ps eqs - return (ps', (ty, (i, FunctionDecl p (rawPredType tySc) f eqs'))) + sigs <- getSigEnv + scoped <- hasExtension ScopedTypeVariables + let idVars = case lookupTypeSig f sigs of + Just (ForallType _ ids _) | scoped -> map (mkIdent . idName) ids + _ -> [] + (vs, _, ty) <- skolemise tySc + let varMap = zipWith (\ident (_, b) -> (ident, b)) idVars vs + inNestedScopedTyVarsScope $ do + modifyScopedTyVarsEnv $ \env -> foldr (uncurry bindNestEnv) env varMap + (ps', eqs') <- mapAccumM (tcEquation ty) ps eqs + return (ps', (ty, (i, FunctionDecl p (rawPredType tySc) f eqs'))) tcEquation :: Type -> PredSet -> Equation a -> TCM (PredSet, Equation Type) tcEquation ty ps eqn@(Equation p lhs rhs) = @@ -755,14 +835,16 @@ tcCheckPDecl ps tySc pd = do checkPDeclType :: TypeExpr -> PredSet -> Type -> PDecl Type -> TCM (PredSet, PDecl Type) checkPDeclType tySc ps ty (i, FunctionDecl p _ f eqs) = do - tySc' <- expandTypeExpr tySc + tySc'' <- expandTypeExpr tySc + tySc' <- scopeType tySc tySc'' unlessM (checkTypeSig tySc' ty) $ do m <- getModuleIdent report $ errTypeSigTooGeneral p m (text "Function:" <+> ppIdent f) tySc (rawPredType ty) return (ps, (i, FunctionDecl p tySc' f eqs)) checkPDeclType tySc ps ty (i, PatternDecl p (VariablePattern spi _ v) rhs) = do - tySc' <- expandTypeExpr tySc + tySc'' <- expandTypeExpr tySc + tySc' <- scopeType tySc tySc'' unlessM (checkTypeSig tySc' ty) $ do m <- getModuleIdent report $ errTypeSigTooGeneral p m (text "Variable:" <+> ppIdent v) tySc @@ -797,9 +879,9 @@ eqTypes fvs = eq idSubst eq sub (TypeConstructor tc1) (TypeConstructor tc2) = return (tc1 == tc2, sub, emptyPredSet, emptyPredSet) eq sub (TypeVariable tv1) (TypeVariable tv2) - | tv1 `elem` fvs = return (False, sub, emptyPredSet, emptyPredSet) - | otherwise = do let (eqb, sub') = eqVar sub tv1 tv2 - return (eqb, sub', emptyPredSet, emptyPredSet) + | tv1 `elem` fvs && tv2 >= 0 = return (False, sub, emptyPredSet, emptyPredSet) + | otherwise = do let (eqb, sub') = eqVar sub tv1 tv2 + return (eqb, sub', emptyPredSet, emptyPredSet) eq sub (TypeConstrained ts1 tv1) (TypeConstrained ts2 tv2) = do (eqb1, sub1, ps1, ps2) <- eqs sub ts1 ts2 let (eqb2, sub2) = eqVar sub1 tv1 tv2 @@ -1257,20 +1339,28 @@ tcExpr cm p (Paren spi e) = do (ps, ty, e') <- tcExpr cm p e return (ps, ty, Paren spi e') tcExpr _ p (Typed spi e qty) = do - pty <- expandTypeExpr qty - (ps, ty) <- inst (polyType pty) - (ps', e') <- tcExpr (Check ty) p e >>- - unifyDecl p "explicitly typed expression" (pPrintPrec 0 e) emptyPredSet ty - fvs <- computeFvEnv - theta <- getTypeSubst - let (gps, lps) = splitPredSet fvs ps' - tySc = gen fvs (TypeContext lps (subst theta ty)) - unlessM (checkTypeSig pty tySc) $ do - m <- getModuleIdent - report $ - errTypeSigTooGeneral p m (text "Expression:" <+> pPrintPrec 0 e) qty - (rawPredType tySc) - return (ps `Set.union` gps, ty, Typed spi e' qty) + scoped <- hasExtension ScopedTypeVariables + pty' <- expandTypeExpr qty + let idVars = case qty of + ForallType _ ids _ | scoped -> map (mkIdent . idName) ids + _ -> [] + pty <- scopeType qty pty' + (vs, ps, ty) <- skolemise (polyType pty) + let varMap = zipWith (\i (_, b) -> (i, b)) idVars vs + inNestedScopedTyVarsScope $ do + modifyScopedTyVarsEnv $ \env -> foldr (uncurry bindNestEnv) env varMap + (ps', e') <- tcExpr (Check ty) p e >>- + unifyDecl p "explicitly typed expression" (pPrintPrec 0 e) emptyPredSet ty + fvs <- computeFvEnv + theta <- getTypeSubst + let (gps, lps) = splitPredSet fvs ps' + tySc = gen fvs (TypeContext lps (subst theta ty)) + unlessM (checkTypeSig pty tySc) $ do + m <- getModuleIdent + report $ + errTypeSigTooGeneral p m (text "Expression:" <+> pPrintPrec 0 e) qty + (rawPredType tySc) + return (ps `Set.union` gps, ty, Typed spi e' qty) tcExpr _ p e@(Record spi _ c fs) = do m <- getModuleIdent vEnv <- getValueEnv @@ -1666,8 +1756,10 @@ unifyTypes m (TypeArrow ty11 ty12) ty@(TypeApply _ _) unifyTypes m (TypeArrow ty11 ty12) (TypeArrow ty21 ty22) = unifyTypeLists m [ty11, ty12] [ty21, ty22] unifyTypes m ty1@(TypeForall _ _) ty2@(TypeForall _ _) - = do (vs1, _, ty1') <- skolemise ty1 - (vs2, _, ty2') <- skolemise ty2 + = do (vs1p, _, ty1') <- skolemise ty1 + (vs2p, _, ty2') <- skolemise ty2 + let vs1 = map snd vs1p + let vs2 = map snd vs2p res <- unifyTypes m ty1' ty2' case res of Left x -> return $ Left x @@ -1681,7 +1773,8 @@ unifyTypes m ty1@(TypeForall _ _) ty2@(TypeForall _ _) [] -> return $ Right s ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2 unifyTypes m ty1@(TypeForall _ _) ty2 - = do (vs, _, ty1') <- skolemise ty1 + = do (vsp, _, ty1') <- skolemise ty1 + let vs = map snd vsp res <- unifyTypes m ty1' ty2 case res of Left x -> return $ Left x @@ -1694,7 +1787,8 @@ unifyTypes m ty1@(TypeForall _ _) ty2 [] -> return $ Right s ev:_ -> return $ Left $ errEscapingTypeVariable m ev ty1 ty2 unifyTypes m ty1 ty2@(TypeForall _ _) - = do (vs, _, ty2') <- skolemise ty2 + = do (vsp, _, ty2') <- skolemise ty2 + let vs = map snd vsp res <- unifyTypes m ty1 ty2' case res of Left x -> return $ Left x @@ -1900,10 +1994,10 @@ inst ty = skolemise ty >>= \(_, ps, ty') -> return (ps, ty') -- | Instantiates the given type with fresh type variables. The first argument -- of the triple is the list of fresh type variables. -skolemise :: Type -> TCM ([Int], PredSet, Type) +skolemise :: Type -> TCM ([(Int, Int)], PredSet, Type) skolemise (TypeForall tvs ty) = do tys <- replicateM (length tvs) freshTypeVar - let tvs' = map (\(TypeVariable tv) -> tv) tys + let tvs' = zip tvs $ map (\(TypeVariable tv) -> tv) tys (tvs'', ps, ty') <- skolemise $ subst (foldr2 bindSubst idSubst tvs tys) ty return (tvs' ++ tvs'', ps, ty') skolemise (TypeContext ps ty) = do diff --git a/src/CompilerOpts.hs b/src/CompilerOpts.hs index 7a5c2cb2f9c63d8d5363b7be2bad3a381a21d87f..27e7f41bc366d9db3af4ea8ef573f0ef05f62cbb 100644 --- a/src/CompilerOpts.hs +++ b/src/CompilerOpts.hs @@ -327,6 +327,8 @@ extensions = , "enable arbitrary-rank types" ) , ( ExplicitForAll , "ExplicitForAll" , "enable explicit foralls" ) + , ( ScopedTypeVariables , "ScopedTypeVariables" + , "enable scoped type variables" ) ] -- ----------------------------------------------------------------------------- diff --git a/test/TestFrontend.hs b/test/TestFrontend.hs index db9ec42b2e1a5c566f29d61b4f608ff0058c96d7..5ba49cebacbadcbd6ea6535a71cd9bdb7038c9ac 100644 --- a/test/TestFrontend.hs +++ b/test/TestFrontend.hs @@ -273,6 +273,13 @@ failInfos = map (uncurry mkFailTest) , ("RankNTypesFuncPats", ["Missing instance for Prelude.Data (forall c. Prelude.Int ->"]) , ("RecordLabelIDs", ["Multiple declarations of `RecordLabelIDs.id'"]) , ("RecursiveTypeSyn", ["Mutually recursive synonym and/or renaming types A and B (line 12.6)"]) + , ("ScopedTypeVariables", + [ "Type signature too general" + , "Function: fun1" + , "Expression: x" + , "in (g x, g y))" + ] + ) , ("Subsumption", [ "Type error in application" , "applyFun idFun" @@ -360,6 +367,7 @@ passInfos = map mkPassTest , "RecordTest2" , "RecordTest3" , "ReexportTest" + , "ScopedTypeVariables" , "ScottEncoding" , "SelfExport" , "SpaceLeak" diff --git a/test/fail/ScopedTypeVariables.curry b/test/fail/ScopedTypeVariables.curry new file mode 100644 index 0000000000000000000000000000000000000000..744c766b2fb7a8dd42d0f48aaf364aa25e89934b --- /dev/null +++ b/test/fail/ScopedTypeVariables.curry @@ -0,0 +1,23 @@ +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + +fun1 :: forall a b. a -> b -> (a, b) +fun1 a b = (idFun a, b) + where + idFun :: b -> b + idFun = id + +fun2 :: forall a. forall b. a -> [b] -> [b] +fun2 _ (x:xs) = xs ++ [x :: b] + +fun3 :: forall a. a -> forall b. [b] -> [b] +fun3 _ (x:xs) = xs ++ [x :: b] + +type A = forall b. [b] -> [b] + +fun4 :: A +fun4 (x:xs) = xs ++ [x :: b] + +fun5 = (\x y -> let g :: a -> a + g = id + in (g x, g y)) :: forall a b. a -> b -> (a, b) diff --git a/test/pass/ScopedTypeVariables.curry b/test/pass/ScopedTypeVariables.curry new file mode 100644 index 0000000000000000000000000000000000000000..c80f6f07d2819f56cda8eeba59dd01962bc7d099 --- /dev/null +++ b/test/pass/ScopedTypeVariables.curry @@ -0,0 +1,41 @@ +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE ScopedTypeVariables #-} + +fun1 :: forall a. [a] -> [a] +fun1 xs = ys ++ ys + where + ys :: [a] + ys = reverse xs + +fun2 :: forall a. [a] -> [a] +fun2 (x:xs) = xs ++ [x :: a] + +fun3 :: forall a. [a] -> [a] +fun3 = \(x:xs) -> xs ++ [x :: a] + +fun4 :: a -> b -> (a, b) +fun4 x y = (idFun x, idFun y) + where + idFun :: a -> a + idFun = id + +fun5 = (\x y -> let g :: a -> a + g = id + in (g x, y)) :: forall a b. a -> b -> (a, b) + +class A a where + funA :: [a] -> a + funA xs = let ys :: [a] + ys = reverse xs + in head ys + +-- instance A b => A [b] where +-- funA xs = reverse (head (xs :: [[b]])) + +class B a where + funB :: a -> a + +instance B [b] where + funB xs = let ys :: [b] + ys = reverse xs + in ys ++ ys