TypeCheck.hs 47 KB
Newer Older
1 2 3 4 5
{- |
    Module      :  $Header$
    Description :  Type checking Curry programs
    Copyright   :  (c) 1999 - 2004 Wolfgang Lux
                                   Martin Engelke
6
                       2011 - 2015 Björn Peemöller
7
                       2014 - 2015 Jan Tikovsky
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
    License     :  OtherLicense

    Maintainer  :  bjp@informatik.uni-kiel.de
    Stability   :  experimental
    Portability :  portable

   This module implements the type checker of the Curry compiler. The
   type checker is invoked after the syntactic correctness of the program
   has been verified. Local variables have been renamed already. Thus the
   compiler can maintain a flat type environment (which is necessary in
   order to pass the type information to later phases of the compiler).
   The type checker now checks the correct typing of all expressions and
   also verifies that the type signatures given by the user match the
   inferred types. The type checker uses the algorithm by Damas and Milner
   (1982) for inferring the types of unannotated declarations, but allows
   for polymorphic recursion when a type annotation is present.
-}
25
{-# LANGUAGE CPP #-}
26 27
module Checks.TypeCheck (typeCheck) where

28
#if __GLASGOW_HASKELL__ < 710
29 30 31
import           Control.Applicative        ((<$>), (<*>))
#endif
import           Control.Monad              (replicateM, unless)
32
import qualified Control.Monad.State as S   (State, execState, gets, modify)
33
import           Data.List                  (nub, nubBy, partition)
34
import qualified Data.Map            as Map (Map, delete, empty, insert, lookup)
35
import           Data.Maybe                 (fromMaybe)
36 37
import qualified Data.Set            as Set
  (Set, fromList, member, notMember, unions)
38 39 40 41 42 43 44

import Curry.Base.Ident
import Curry.Base.Position
import Curry.Base.Pretty
import Curry.Syntax
import Curry.Syntax.Pretty

45
import Base.CurryTypes (toType, toTypes, ppType, ppTypeScheme)
46 47 48 49 50 51 52 53
import Base.Expr
import Base.Messages (Message, posMessage, internalError)
import Base.SCC
import Base.TopEnv
import Base.Types
import Base.TypeSubst
import Base.Utils (foldr2)

54
import Env.TypeConstructor (TCEnv, TypeInfo (..), bindTypeInfo, qualLookupTC)
55
import Env.Value ( ValueEnv, ValueInfo (..), bindFun, rebindFun
56
  , bindGlobalInfo, lookupValue, qualLookupValue )
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

infixl 5 $-$

($-$) :: Doc -> Doc -> Doc
x $-$ y = x $$ space $$ y

-- Type checking proceeds as follows. First, the type constructor
-- environment is initialized by adding all types defined in the current
-- module. Next, the types of all data constructors and field labels
-- are entered into the type environment and then a type inference
-- for all function and value definitions is performed.
-- The type checker returns the resulting type constructor and type
-- environments.

typeCheck :: ModuleIdent -> TCEnv -> ValueEnv -> [Decl]
          -> (TCEnv, ValueEnv, [Message])
typeCheck m tcEnv tyEnv decls = execTCM check initState
  where
75
  check      = checkTypeSynonyms m tds &&> checkDecls
76 77 78
  checkDecls = do
    bindTypes tds
    bindConstrs
79
    mapM_ checkFieldLabel tds &&> bindLabels
80 81
    tcDecls vds
  (tds, vds) = partition isTypeDecl decls
82
  initState  = TcState m tcEnv tyEnv idSubst emptySigEnv 0 []
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

-- The type checker makes use of a state monad in order to maintain the type
-- environment, the current substitution, and a counter which is used for
-- generating fresh type variables.

data TcState = TcState
  { moduleIdent :: ModuleIdent -- read only
  , tyConsEnv   :: TCEnv
  , valueEnv    :: ValueEnv
  , typeSubst   :: TypeSubst
  , sigEnv      :: SigEnv
  , nextId      :: Int         -- automatic counter
  , errors      :: [Message]
  }

type TCM = S.State TcState

getModuleIdent :: TCM ModuleIdent
getModuleIdent = S.gets moduleIdent

getTyConsEnv :: TCM TCEnv
getTyConsEnv = S.gets tyConsEnv

setTyConsEnv :: TCEnv -> TCM ()
setTyConsEnv tcEnv = S.modify $ \ s -> s { tyConsEnv = tcEnv }

getValueEnv :: TCM ValueEnv
getValueEnv = S.gets valueEnv

modifyValueEnv :: (ValueEnv -> ValueEnv) -> TCM ()
modifyValueEnv f = S.modify $ \ s -> s { valueEnv = f $ valueEnv s }

getTypeSubst :: TCM TypeSubst
getTypeSubst = S.gets typeSubst

modifyTypeSubst :: (TypeSubst -> TypeSubst) -> TCM ()
modifyTypeSubst f = S.modify $ \ s -> s { typeSubst = f $ typeSubst s }

getSigEnv :: TCM SigEnv
getSigEnv = S.gets sigEnv

modifySigEnv :: (SigEnv -> SigEnv) -> TCM ()
modifySigEnv f = S.modify $ \ s -> s { sigEnv = f $ sigEnv s }

getNextId :: TCM Int
getNextId = do
  nid <- S.gets nextId
  S.modify $ \ s -> s { nextId = succ nid }
  return nid

report :: Message -> TCM ()
report err = S.modify $ \ s -> s { errors = err : errors s }

(&&>) :: TCM () -> TCM () -> TCM ()
pre &&> suf = do
  errs <- pre >> S.gets errors
  if null errs then suf else return ()

execTCM :: TCM a -> TcState -> (TCEnv, ValueEnv, [Message])
execTCM tcm s = let s' = S.execState tcm s
143
                in  ( tyConsEnv s'
144 145 146 147 148 149 150
                    , typeSubst s' `subst` valueEnv s'
                    , reverse $ errors s'
                    )

-- Defining Types:
-- Before type checking starts, the types defined in the local module
-- have to be entered into the type constructor environment. All type
151
-- synonyms occurring in the definitions are fully expanded (except for
152 153 154 155 156 157
-- record types) and all type constructors are qualified with the name
-- of the module in which they are defined. This is possible because
-- Curry does not allow (mutually) recursive type synonyms.
-- In order to simplify the expansion of type synonyms, the compiler
-- first performs a dependency analysis on the type definitions.
-- This also makes it easy to identify (mutually) recursive synonyms.
158 159 160

-- Note that 'bindTC' is passed the final type constructor environment in
-- order to handle the expansion of type synonyms. This does not lead to a
161
-- termination problem because 'checkTypeDecls' already has checked that there
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
-- are no recursive type synonyms.

-- We have to be careful with existentially quantified type variables for
-- data constructors. An existentially quantified type variable may
-- shadow a universally quantified variable from the left hand side of
-- the type declaration. In order to avoid wrong indices being assigned
-- to these variables, we replace all shadowed variables in the left hand
-- side by anonId before passing them to 'expandMonoType' and
-- 'expandMonoTypes', respectively.

checkTypeSynonyms :: ModuleIdent -> [Decl] -> TCM ()
checkTypeSynonyms m = mapM_ (checkTypeDecls m) . scc bound free
  where
  bound (DataDecl    _ tc _ _) = [tc]
  bound (NewtypeDecl _ tc _ _) = [tc]
  bound (TypeDecl    _ tc _ _) = [tc]
  bound _                      = []
  free  (DataDecl     _ _ _ _) = []
  free  (NewtypeDecl  _ _ _ _) = []
  free  (TypeDecl    _ _ _ ty) = ft m ty []
  free _                       = []

checkTypeDecls :: ModuleIdent -> [Decl] -> TCM ()
checkTypeDecls _ []                    =
  internalError "TypeCheck.checkTypeDecls: empty list"
checkTypeDecls _ [DataDecl    _ _ _ _] = return ()
checkTypeDecls _ [NewtypeDecl _ _ _ _] = return ()
189
checkTypeDecls m [TypeDecl  _ tc _ ty]
190 191
  | tc `elem` ft m ty [] = report $ errRecursiveTypes [tc]
  | otherwise            = return ()
192
checkTypeDecls _ (TypeDecl _ tc _ _ : ds) =
193
      report $ errRecursiveTypes $ tc : [tc' | TypeDecl _ tc' _ _ <- ds]
194 195 196 197 198 199 200 201 202 203
checkTypeDecls _ _                     =
  internalError "TypeCheck.checkTypeDecls: no type synonym"

ft :: ModuleIdent -> TypeExpr -> [Ident] -> [Ident]
ft m (ConstructorType tc tys) tcs =
  maybe id (:) (localIdent m tc) (foldr (ft m) tcs tys)
ft _ (VariableType         _) tcs = tcs
ft m (TupleType          tys) tcs = foldr (ft m) tcs tys
ft m (ListType            ty) tcs = ft m ty tcs
ft m (ArrowType      ty1 ty2) tcs = ft m ty1 $ ft m ty2 $ tcs
204
ft m (ParenType           ty) tcs = ft m ty tcs
205 206 207 208 209 210 211 212 213 214

-- When a field label occurs in more than one constructor declaration of
-- a data type, the compiler ensures that the label is defined
-- consistently, i.e. both occurrences have the same type. In addition,
-- the compiler ensures that no existentially quantified type variable occurs
-- in the type of a field label because such type variables necessarily escape
-- their scope with the type of the record selection function associated with
-- the field label.

checkFieldLabel :: Decl -> TCM ()
215 216
checkFieldLabel (DataDecl _ _ tvs cs) = do
  ls' <- mapM (tcFieldLabel tvs) labels
217
  mapM_ tcFieldLabels (groupLabels ls')
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
  where labels = [(l, p, ty) | RecordDecl _ _ _ fs <- cs,
                               FieldDecl p ls ty <- fs, l <- ls]
checkFieldLabel (NewtypeDecl _ _ tvs (NewRecordDecl p _ _ (l,ty))) = do
  _ <- tcFieldLabel tvs (l, p, ty)
  return ()
checkFieldLabel _ = return ()

tcFieldLabel :: [Ident] -> (Ident, Position, TypeExpr)
             -> TCM (Ident, Position, Type)
tcFieldLabel tvs (l, p, ty) = do
  m     <- getModuleIdent
  tcEnv <- getTyConsEnv
  let ForAll n ty' = polyType $ expandMonoType' m tcEnv tvs ty
      what         = text "record selection"
  if n <= length tvs then return (l, p, ty')
                     else do report $ errSkolemEscapingScope p m what ty'
                             return (l, p, ty')
235

236
groupLabels :: Eq a => [(a,b,c)] -> [(a,b,[c])]
237
groupLabels []             = []
238 239 240 241 242 243 244 245 246
groupLabels ((x,y,z):xyzs) =
  (x,y,z:map thrd3 xyzs') : groupLabels xyzs''
  where (xyzs',xyzs'') = partition ((x ==) . fst3) xyzs
        fst3  (a,_,_) = a
        thrd3 (_,_,c) = c

tcFieldLabels :: (Ident, Position, [Type]) -> TCM ()
tcFieldLabels (_,_,[])     = return ()
tcFieldLabels (l,p,ty:tys) = do
247 248
  m <- getModuleIdent
  unless (null (filter (ty /=) tys)) $
249
    report $ errIncompatibleLabelTypes p m l ty (head tys)
250

251
-- The type constructor environment 'tcEnv' maintains all types
252
-- in fully expanded form.
253 254
bindTypes :: [Decl] -> TCM ()
bindTypes ds = do
255 256
  m      <- getModuleIdent
  tcEnv  <- getTyConsEnv
257
  let tcEnv'  = foldr (bindTC m tcEnv') tcEnv ds
258 259
  setTyConsEnv tcEnv'

260 261
bindTC :: ModuleIdent -> TCEnv -> Decl -> TCEnv -> TCEnv
bindTC m tcEnv (DataDecl _ tc tvs cs) =
262
  bindTypeInfo DataType m tc tvs (map mkData cs)
263 264 265
  where
  mkData (ConstrDecl _ evs     c  tys) = mkData' evs c  tys
  mkData (ConOpDecl  _ evs ty1 op ty2) = mkData' evs op [ty1, ty2]
266 267 268
  mkData (RecordDecl _ evs       c fs) =
    let (labels, tys) = unzip [(l ,ty) | FieldDecl _ ls ty <- fs, l <- ls]
    in mkRec evs c labels tys
269 270
  mkData' evs c tys = DataConstr c (length evs) $
    expandMonoTypes m tcEnv (cleanTVars tvs evs) tys
271 272
  mkRec evs c ls tys = RecordConstr c (length evs) ls $
    expandMonoTypes m tcEnv (cleanTVars tvs evs) tys
273
bindTC m tcEnv (NewtypeDecl _ tc tvs (NewConstrDecl _ evs c ty)) =
274 275
  bindTypeInfo RenamingType m tc tvs (DataConstr c (length evs) [ty'])
  where ty' = expandMonoType' m tcEnv (cleanTVars tvs evs) ty
276
bindTC m tcEnv (NewtypeDecl _ tc tvs (NewRecordDecl _ evs c (l, ty))) =
277
  bindTypeInfo RenamingType m tc tvs (RecordConstr c (length evs) [l] [ty'])
278
  where ty' = expandMonoType' m tcEnv (cleanTVars tvs evs) ty
279 280 281
bindTC m tcEnv (TypeDecl _ tc tvs ty) =
  bindTypeInfo AliasType m tc tvs (expandMonoType' m tcEnv tvs ty)
bindTC _ _ _ = id
282 283 284 285 286 287 288 289 290 291 292 293

cleanTVars :: [Ident] -> [Ident] -> [Ident]
cleanTVars tvs evs = [if tv `elem` evs then anonId else tv | tv <- tvs]

-- Defining Data Constructors:
-- In the next step, the types of all data constructors are entered into
-- the type environment using the information just entered into the type
-- constructor environment. Thus, we can be sure that all type variables
-- have been properly renamed and all type synonyms are already expanded.

bindConstrs :: TCM ()
bindConstrs = do
294
  m     <- getModuleIdent
295 296 297 298 299 300 301 302
  tcEnv <- getTyConsEnv
  modifyValueEnv $ bindConstrs' m tcEnv

bindConstrs' :: ModuleIdent -> TCEnv -> ValueEnv -> ValueEnv
bindConstrs' m tcEnv tyEnv = foldr (bindData . snd) tyEnv
                           $ localBindings tcEnv
  where
  bindData (DataType tc n cs) tyEnv' =
303
    foldr (bindConstr m n (constrType' tc n)) tyEnv' cs
304 305
  bindData (RenamingType tc n c) tyEnv' =
    bindNewConstr m n (constrType' tc n) c tyEnv'
306 307
  bindData (AliasType _ _ _) tyEnv' = tyEnv'
  bindConstr m' n ty (DataConstr c n' tys) =
308 309 310 311 312 313
    bindGlobalInfo (\qc tyScheme -> DataConstructor qc arity ls tyScheme) m' c
                   (ForAllExist n n' (foldr TypeArrow ty tys))
    where arity = length tys
          ls    = replicate arity anonId
  bindConstr m' n ty (RecordConstr c n' ls tys) =
    bindGlobalInfo (\qc tyScheme -> DataConstructor qc arity ls tyScheme) m' c
314
                   (ForAllExist n n' (foldr TypeArrow ty tys))
315 316 317 318 319 320 321
    where arity = length tys
  bindNewConstr m' n cty (DataConstr c n' [lty]) =
    bindGlobalInfo (\qc tyScheme -> NewtypeConstructor qc anonId tyScheme) m' c
                   (ForAllExist n n' (TypeArrow lty cty))
  bindNewConstr m' n cty (RecordConstr c n' [l] [lty]) =
    bindGlobalInfo (\qc tyScheme -> NewtypeConstructor qc l tyScheme) m' c
                   (ForAllExist n n' (TypeArrow lty cty))
322
  bindNewConstr _ _ _ _ =
323
    internalError "TypeCheck.bindConstrs: newtype with illegal constructors"
324 325 326
  constrType' tc n = TypeConstructor tc $ map TypeVariable [0 .. n - 1]

-- Defining Field Labels:
327 328 329
-- Next the types of all field labels are added to the type environment.
-- Since we use the type constructor environment again,
-- we can be sure that all type variables have been properly renamed
330 331 332 333
-- and all type synonyms are already expanded.

bindLabels :: TCM ()
bindLabels = do
334
  m     <- getModuleIdent
335
  tcEnv <- getTyConsEnv
336 337 338 339 340
  modifyValueEnv $ bindLabels' m tcEnv

bindLabels' :: ModuleIdent -> TCEnv -> ValueEnv -> ValueEnv
bindLabels' m tcEnv tyEnv = foldr (bindData . snd) tyEnv
                          $ localBindings tcEnv
341
  where
342 343 344
  bindData (DataType tc n cs) tyEnv' =
    foldr (bindLabel m n (constrType' tc n)) tyEnv' $ nubBy sameLabel clabels
    where
345
      labels   = zip (concatMap recLabels cs) (concatMap recLabelTypes cs)
346 347
      clabels  = [(l, constr l, ty) | (l, ty) <- labels]
      constr l = map (qualifyLike tc) $
348
        [constrIdent c | c <- cs, l `elem` recLabels c]
349
      sameLabel (l1,_,_) (l2,_,_) = l1 == l2
350
  bindData (RenamingType tc n (RecordConstr c _ [l] [lty])) tyEnv' =
351 352 353
    bindLabel m n (constrType' tc n) (l, [qc], lty) tyEnv'
    where
      qc = qualifyLike tc c
354 355 356
  bindData (RenamingType _ _ (RecordConstr _ _ _ _)) _ = internalError $
    "Checks.TypeCheck.bindLabels: RenamingType with more than one record label"
  bindData (RenamingType _ _ (DataConstr _ _ _)) tyEnv' = tyEnv'
357 358 359
  bindData (AliasType _ _ _) tyEnv' = tyEnv'
  bindLabel m' n ty (l, lcs, lty) =
    bindGlobalInfo (\qc tyScheme -> Label qc lcs tyScheme) m' l
360
                   (ForAll n (TypeArrow ty lty))
361
  constrType' tc n = TypeConstructor tc $ map TypeVariable [0 .. n - 1]
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415

-- Type Signatures:
-- The type checker collects type signatures in a flat environment. All
-- anonymous variables occurring in a signature are replaced by fresh
-- names. However, the type is not expanded so that the signature is
-- available for use in the error message that is printed when the
-- inferred type is less general than the signature.

type SigEnv = Map.Map Ident TypeExpr

emptySigEnv :: SigEnv
emptySigEnv = Map.empty

unbindTypeSig :: Ident -> SigEnv -> SigEnv
unbindTypeSig = Map.delete

bindTypeSig :: Ident -> TypeExpr -> SigEnv -> SigEnv
bindTypeSig = Map.insert

bindTypeSigs :: Decl -> SigEnv -> SigEnv
bindTypeSigs (TypeSig _ vs ty) env =
  foldr (flip bindTypeSig (nameSigType ty)) env vs
bindTypeSigs _                 env = env

lookupTypeSig :: Ident -> SigEnv -> Maybe TypeExpr
lookupTypeSig = Map.lookup

qualLookupTypeSig :: ModuleIdent -> QualIdent -> SigEnv -> Maybe TypeExpr
qualLookupTypeSig m f sigs = localIdent m f >>= flip lookupTypeSig sigs

nameSigType :: TypeExpr -> TypeExpr
nameSigType ty = fst $ nameType ty $ filter (`notElem` fv ty) identSupply

nameTypes :: [TypeExpr] -> [Ident] -> ([TypeExpr], [Ident])
nameTypes []         tvs = ([]        , tvs  )
nameTypes (ty : tys) tvs = (ty' : tys', tvs'')
  where (ty' , tvs' ) = nameType ty tvs
        (tys', tvs'') = nameTypes tys tvs'

nameType :: TypeExpr -> [Ident] -> (TypeExpr, [Ident])
nameType (ConstructorType tc tys) tvs = (ConstructorType tc tys', tvs')
  where (tys', tvs') = nameTypes tys tvs
nameType (VariableType tv) (tv' : tvs)
  | isAnonId tv = (VariableType tv', tvs      )
  | otherwise   = (VariableType tv , tv' : tvs)
nameType (TupleType tys) tvs = (TupleType tys', tvs')
  where (tys', tvs') = nameTypes tys tvs
nameType (ListType ty) tvs = (ListType ty', tvs')
  where (ty', tvs') = nameType ty tvs
nameType (ArrowType ty1 ty2) tvs = (ArrowType ty1' ty2', tvs'')
  where (ty1', tvs' ) = nameType ty1 tvs
        (ty2', tvs'') = nameType ty2 tvs'
nameType (VariableType _) [] = internalError
 "TypeCheck.nameType: empty ident list"
416 417
nameType (ParenType ty) tvs = (ParenType ty', tvs')
  where (ty', tvs') = nameType ty tvs


-- Type Inference:
-- Before type checking a group of declarations, a dependency analysis is
-- performed and the declaration group is eventually transformed into
-- nested declaration groups which are checked separately. Within each
-- declaration group, first the left hand sides of all declarations are
-- typed. Next, the right hand sides of the declarations are typed in the
-- extended type environment. Finally, the types for the left and right
-- hand sides are unified and the types of all defined functions are
-- generalized. The generalization step will also check that the type
-- signatures given by the user match the inferred types.

-- Argument and result types of foreign functions using the ccall calling
-- convention are restricted to the basic types Bool, Char, Int, and Float.
-- In addition, IO t is a legitimate result type when t is either one of the
-- basic types or ().

-- TODO: Extend the set of legitimate types to match the types admitted
-- by the Haskell Foreign Function Interface Addendum.

tcDecls :: [Decl] -> TCM ()
tcDecls ds = do
  m <- getModuleIdent
  oldSig <- getSigEnv
  modifySigEnv $ \ sigs -> foldr bindTypeSigs sigs ods
  mapM_ tcDeclGroup $ scc bv (qfv m) vds
  modifySigEnv (const oldSig)
  where (vds, ods) = partition isValueDecl ds

tcDeclGroup :: [Decl] -> TCM ()
tcDeclGroup [ForeignDecl _ _ _ f ty] = tcForeign f ty
tcDeclGroup [ExternalDecl      _ fs] = mapM_ tcExternal fs
tcDeclGroup [FreeDecl          _ vs] = mapM_ tcFree     vs
tcDeclGroup ds                       = do
  tyEnv0 <- getValueEnv
  tysLhs <- mapM tcDeclLhs ds
  tysRhs <- mapM (tcDeclRhs tyEnv0) ds
  sequence_ (zipWith3 unifyDecl ds tysLhs tysRhs)
  theta <- getTypeSubst
  mapM_ (genDecl (fvEnv (subst theta tyEnv0)) theta) ds
--tcDeclGroup m tcEnv _ [ForeignDecl p cc _ f ty] =
--  tcForeign m tcEnv p cc f ty

--tcForeign :: ModuleIdent -> TCEnv -> Position -> CallConv -> Ident
--               -> TypeExpr -> TCM ()
--tcForeign m tcEnv p cc f ty =
--  S.modify (bindFun m f (checkForeignType cc (expandPolyType tcEnv ty)))
--  where checkForeignType CallConvPrimitive ty = ty
--        checkForeignType CallConvCCall (ForAll n ty) =
--          ForAll n (checkCCallType ty)
--        checkCCallType (TypeArrow ty1 ty2)
--          | isCArgType ty1 = TypeArrow ty1 (checkCCallType ty2)
--          | otherwise = errorAt p (invalidCType "argument" m ty1)
--        checkCCallType ty
--          | isCResultType ty = ty
--          | otherwise = errorAt p (invalidCType "result" m ty)
--        isCArgType (TypeConstructor tc []) = tc `elem` basicTypeId
--        isCArgType _ = False
--        isCResultType (TypeConstructor tc []) = tc `elem` basicTypeId
--        isCResultType (TypeConstructor tc [ty]) =
--          tc == qIOId && (ty == unitType || isCArgType ty)
--        isCResultType _ = False
--        basicTypeId = [qBoolId,qCharId,qIntId,qFloatId]

tcForeign :: Ident -> TypeExpr -> TCM ()
tcForeign f ty = do
  m <- getModuleIdent
  tySc@(ForAll _ ty') <- expandPolyType ty
  modifyValueEnv $ bindFun m f (arrowArity ty') tySc

tcExternal :: Ident -> TCM ()
tcExternal f = do
  sigs <- getSigEnv
  case lookupTypeSig f sigs of
    Nothing -> internalError "TypeCheck.tcExternal"
    Just ty -> tcForeign f ty

tcFree :: Ident -> TCM ()
tcFree v = do
  sigs <- getSigEnv
  ty <- case lookupTypeSig v sigs of
    Nothing -> freshTypeVar
    Just t  -> do
      ForAll n ty' <- expandPolyType t
      if (n == 0) then return ty' else do
        -- because of error aggregation, we have to fix
        -- the corrupt information
        report $ errPolymorphicFreeVar v
        modifySigEnv $ unbindTypeSig v
        freshTypeVar
  m  <- getModuleIdent
  modifyValueEnv $ bindFun m v (arrowArity ty) $ monoType ty

tcDeclLhs :: Decl -> TCM Type
tcDeclLhs (FunctionDecl _ f _) = tcFunDecl f
tcDeclLhs (PatternDecl  p t _) = tcPattern p t
tcDeclLhs _ = internalError "TypeCheck.tcDeclLhs: no pattern match"

tcFunDecl :: Ident -> TCM Type
tcFunDecl v = do
  sigs <- getSigEnv
  m <- getModuleIdent
  ty <- case lookupTypeSig v sigs of
    Nothing -> freshTypeVar
    Just t  -> expandPolyType t >>= inst
  modifyValueEnv $ bindFun m v (arrowArity ty) (monoType ty)
  return ty

tcDeclRhs :: ValueEnv -> Decl -> TCM Type
tcDeclRhs tyEnv0 (FunctionDecl _ f (eq:eqs)) = do
  tcEquation tyEnv0 eq >>= flip tcEqns eqs
  where tcEqns ty [] = return ty
        tcEqns ty (eq1@(Equation p _ _):eqs1) = do
          tcEquation tyEnv0 eq1 >>=
            unify p "equation" (ppDecl (FunctionDecl p f [eq1])) ty >>
            tcEqns ty eqs1
tcDeclRhs tyEnv0 (PatternDecl _ _ rhs) = tcRhs tyEnv0 rhs
tcDeclRhs _ _ = internalError "TypeCheck.tcDeclRhs: no pattern match"

unifyDecl :: Decl -> Type -> Type -> TCM ()
unifyDecl (FunctionDecl p f _) =
  unify p "function binding" (text "Function:" <+> ppIdent f)
unifyDecl (PatternDecl  p t _) =
  unify p "pattern binding" (ppPattern 0 t)
unifyDecl _ = internalError "TypeCheck.unifyDecl: no pattern match"

-- In Curry we cannot generalize the types of let-bound variables because
-- they can refer to logic variables. Without this monomorphism
-- restriction unsound code like
--
-- bug = x =:= 1 & x =:= 'a'
--   where x :: a
--         x = fresh
-- fresh :: a
-- fresh = x where x free
--
-- could be written. Note that fresh has the polymorphic type
-- forall alpha . alpha. This is correct because fresh is a
-- function and therefore returns a different variable at each
-- invocation.

-- The code in 'genVar' below also verifies that the inferred type
-- for a variable or function matches the type declared in a type
-- signature. As the declared type is already used for assigning an initial
-- type to a variable when it is used, the inferred type can only be more
-- specific. Therefore, if the inferred type does not match the type
-- signature the declared type must be too general.

genDecl :: Set.Set Int -> TypeSubst -> Decl -> TCM ()
genDecl lvs theta (FunctionDecl _ f (Equation _ lhs _ : _)) =
  genVar True lvs theta arity f
  where arity = Just $ length $ snd $ flatLhs lhs
genDecl lvs theta (PatternDecl  _ t   _) =
  mapM_ (genVar False lvs theta Nothing) (bv t)
genDecl _ _ _ = internalError "TypeCheck.genDecl: no pattern match"

genVar :: Bool -> Set.Set Int -> TypeSubst -> Maybe Int -> Ident -> TCM ()
genVar poly lvs theta ma v = do
  sigs <- getSigEnv
  m <- getModuleIdent
  tyEnv <- getValueEnv
  let sigma = genType poly $ subst theta $ varType v tyEnv
      arity  = fromMaybe (varArity v tyEnv) ma
  case lookupTypeSig v sigs of
    Nothing    -> modifyValueEnv $ rebindFun m v arity sigma
    Just sigTy -> do
      sigma' <- expandPolyType sigTy
      unless (eqTyScheme sigma sigma') $ report
        $ errTypeSigTooGeneral (idPosition v) m what sigTy sigma
      modifyValueEnv $ rebindFun m v arity sigma
  where
  what = text (if poly then "Function:" else "Variable:") <+> ppIdent v
  genType poly' (ForAll n ty)
591 592 593
    | n > 0     = internalError $ "TypeCheck.genVar: "
                    ++ showLine (idPosition v) ++ show v ++ " :: " ++ show ty
    | poly'     = gen lvs ty
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
    | otherwise = monoType ty
  eqTyScheme (ForAll _ t1) (ForAll _ t2) = equTypes t1 t2

tcEquation :: ValueEnv -> Equation -> TCM Type
tcEquation tyEnv0 (Equation p lhs rhs) = do
  tys <- mapM (tcPattern p) ts
  ty <- tcRhs tyEnv0 rhs
  checkSkolems p (text "Function: " <+> ppIdent f) tyEnv0
                 (foldr TypeArrow ty tys)
  where (f, ts) = flatLhs lhs

tcLiteral :: Literal -> TCM Type
tcLiteral (Char   _ _) = return charType
tcLiteral (Int    v _)  = do --return intType
  m  <- getModuleIdent
  ty <- freshConstrained [intType, floatType]
  modifyValueEnv $ bindFun m v (arrowArity ty) $ monoType ty
  return ty
tcLiteral (Float  _ _) = return floatType
tcLiteral (String _ _) = return stringType

tcPattern :: Position -> Pattern -> TCM Type
tcPattern _ (LiteralPattern    l) = tcLiteral l
tcPattern _ (NegativePattern _ l) = tcLiteral l
tcPattern _ (VariablePattern   v) = do
  sigs <- getSigEnv
  ty <- case lookupTypeSig v sigs of
    Nothing -> freshTypeVar
    Just t  -> expandPolyType t >>= inst
  tyEnv <- getValueEnv
624 625 626
  m <- getModuleIdent
  maybe (modifyValueEnv (bindFun m v (arrowArity ty) (monoType ty))
           >> return ty)
627 628 629 630 631 632 633
        (\ (ForAll _ t) -> return t)
        (sureVarType v tyEnv)
tcPattern p t@(ConstructorPattern c ts) = do
  m     <- getModuleIdent
  tyEnv <- getValueEnv
  ty <- skol $ constrType m c tyEnv
  unifyArgs (ppPattern 0 t) ts ty
634 635 636 637 638 639 640 641 642
  where
  unifyArgs _   []       ty                  = return ty
  unifyArgs doc (t1:ts1) (TypeArrow ty1 ty2) = do
    ty' <- tcPattern p t1
    unify p "pattern" (doc $-$ text "Term:" <+> ppPattern 0 t1) ty1 ty'
    unifyArgs doc ts1 ty2
  unifyArgs _ _ _ = internalError "TypeCheck.tcPattern"
tcPattern p (InfixPattern t1 op t2) = tcPattern p
                                    $ ConstructorPattern op [t1, t2]
643
tcPattern p (ParenPattern        t) = tcPattern p t
644
tcPattern _ r@(RecordPattern  c fs) = do
645 646
  m     <- getModuleIdent
  tyEnv <- getValueEnv
647
  ty    <- arrowBase <$> skol (constrType m c tyEnv)
648 649 650
  mapM_ (tcField tcPattern "pattern" doc ty) fs
  return ty
  where doc t1 = ppPattern 0 r $-$ text "Term:" <+> ppPattern 0 t1
651 652
tcPattern p (TuplePattern _ ts)
 | null ts   = return unitType
653
 | otherwise = tupleType <$> mapM (tcPattern p) ts
654 655
tcPattern p t@(ListPattern _ ts) =
  freshTypeVar >>= flip (tcElems (ppPattern 0 t)) ts
656 657 658 659 660 661
  where
  tcElems _   ty []       = return (listType ty)
  tcElems doc ty (t1:ts1) = do
    ty' <- tcPattern p t1
    unify p "pattern" (doc $-$ text "Term:" <+> ppPattern 0 t1) ty ty'
    tcElems doc ty ts1
662 663 664 665 666
tcPattern p t@(AsPattern v t') = do
  ty1 <- tcPattern p (VariablePattern v)
  ty2 <- tcPattern p t'
  unify p "pattern" (ppPattern 0 t) ty1 ty2
  return ty1
667
tcPattern p (LazyPattern        _ t) = tcPattern p t
668 669 670
tcPattern p t@(FunctionPattern f ts) = do
  m     <- getModuleIdent
  tyEnv <- getValueEnv
671
  ty <- inst (funType m f tyEnv)
672
  unifyArgs (ppPattern 0 t) ts ty
673 674 675 676 677 678 679 680 681 682 683 684 685 686
  where
  unifyArgs _   []       ty                  = return ty
  unifyArgs doc (t1:ts1) ty@(TypeVariable _) = do
    (a, b) <- tcArrow p "function pattern" doc ty
    ty'    <- tcPattern p t1
    unify p "function pattern" (doc $-$ text "Term:" <+> ppPattern 0 t1) ty' a
    unifyArgs doc ts1 b
  unifyArgs doc (t1:ts1) (TypeArrow ty1 ty2) = do
    ty' <- tcPattern p t1
    unify p "function pattern" (doc $-$ text "Term:" <+> ppPattern 0 t1) ty1 ty'
    unifyArgs doc ts1 ty2
  unifyArgs _ _ ty = internalError $ "TypeCheck.tcPattern: " ++ show ty
tcPattern p (InfixFuncPattern t1 op t2) = tcPattern p
                                        $ FunctionPattern op [t1, t2]
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739

tcRhs ::ValueEnv -> Rhs -> TCM Type
tcRhs tyEnv0 (SimpleRhs p e ds) = do
  tcDecls ds
  ty <- tcExpr p e
  checkSkolems p (text "Expression:" <+> ppExpr 0 e) tyEnv0 ty
tcRhs tyEnv0 (GuardedRhs es ds) = do
  tcDecls ds
  tcCondExprs tyEnv0 es

tcCondExprs :: ValueEnv -> [CondExpr] -> TCM Type
tcCondExprs tyEnv0 es = do
  gty <- if length es > 1 then return boolType
                          else freshConstrained [successType, boolType]
  ty <- freshTypeVar
  mapM_ (tcCondExpr gty ty) es
  return ty
  where
  tcCondExpr gty ty (CondExpr p g e) = do
    tcExpr p g >>= unify p "guard" (ppExpr 0 g) gty
    tcExpr p e >>= checkSkolems p (text "Expression:" <+> ppExpr 0 e) tyEnv0
               >>= unify p "guarded expression" (ppExpr 0 e) ty

tcExpr :: Position -> Expression -> TCM Type
tcExpr _ (Literal     l) = tcLiteral l
tcExpr _ (Variable    v)
  | isAnonId v' = do -- anonymous free variable
    m <- getModuleIdent
    ty <- freshTypeVar
    modifyValueEnv $ bindFun m v' (arrowArity ty) $ monoType ty
    return ty
  | otherwise   = do
    sigs <- getSigEnv
    m <- getModuleIdent
    case qualLookupTypeSig m v sigs of
      Just ty -> expandPolyType ty >>= inst
      Nothing -> getValueEnv >>= inst . funType m v
  where v' = unqualify v
tcExpr _ (Constructor c) = do
  m <- getModuleIdent
  getValueEnv >>= instExist . constrType m c
tcExpr p (Typed   e sig) = do
  m <- getModuleIdent
  tyEnv0 <- getValueEnv
  ty <- tcExpr p e
  sigma' <- expandPolyType sig'
  inst sigma' >>= flip (unify p "explicitly typed expression" (ppExpr 0 e)) ty
  theta <- getTypeSubst
  let sigma  = gen (fvEnv (subst theta tyEnv0)) (subst theta ty)
  unless (sigma == sigma') $ report $
    errTypeSigTooGeneral p m (text "Expression:" <+> ppExpr 0 e) sig' sigma
  return ty
  where sig' = nameSigType sig
740
tcExpr p (Paren       e) = tcExpr p e
741
tcExpr _ r@(Record c fs) = do
742 743
  m     <- getModuleIdent
  tyEnv <- getValueEnv
744
  ty    <- arrowBase <$> instExist (constrType m c tyEnv)
745 746 747 748 749 750 751 752
  mapM_ (tcField tcExpr "construction" doc ty) fs
  return ty
  where doc e1 = ppExpr 0 r $-$ text "Term:" <+> ppExpr 0 e1
tcExpr p r@(RecordUpdate e fs) = do
  ty <- tcExpr p e
  mapM_ (tcField tcExpr "update" doc ty) fs
  return ty
  where doc e1 = ppExpr 0 r $-$ text "Term:" <+> ppExpr 0 e1
753 754
tcExpr p (Tuple _ es)
  | null es   = return unitType
755
  | otherwise = tupleType <$> mapM (tcExpr p) es
756 757 758 759 760 761 762 763
tcExpr p e@(List _ es) = freshTypeVar >>= tcElems (ppExpr 0 e) es
  where tcElems _   []       ty = return (listType ty)
        tcElems doc (e1:es1) ty =
          tcExpr p e1 >>=
          unify p "expression" (doc $-$ text "Term:" <+> ppExpr 0 e1)
                ty >>
          tcElems doc es1 ty
tcExpr p (ListCompr _ e qs) = do
764 765 766 767 768 769
  tyEnv0 <- getValueEnv
  mapM_ (tcQual p) qs
  ty <- tcExpr p e
  checkSkolems p (text "Expression:" <+> ppExpr 0 e) tyEnv0 (listType ty)
tcExpr p e@(EnumFrom             e1) = do
  tcEnum p e e1
770
  return (listType intType)
771 772
tcExpr p e@(EnumFromThen      e1 e2) = do
  mapM_ (tcEnum p e) [e1, e2]
773
  return (listType intType)
774 775
tcExpr p e@(EnumFromTo        e1 e2) = do
  mapM_ (tcEnum p e) [e1, e2]
776 777
  return (listType intType)
tcExpr p e@(EnumFromThenTo e1 e2 e3) = do
778
  mapM_ (tcEnum p e) [e1, e2, e3]
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860
  return (listType intType)
tcExpr p e@(UnaryMinus op e1) = do
  opTy <- opType op
  ty1 <- tcExpr p e1
  unify p "unary negation" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
        opTy ty1
  return ty1
  where opType op'
          | op' == minusId  = freshConstrained [intType,floatType]
          | op' == fminusId = return floatType
          | otherwise = internalError $ "TypeCheck.tcExpr unary " ++ idName op'
tcExpr p e@(Apply e1 e2) = do
  ty1 <- tcExpr p e1
  ty2 <- tcExpr p e2
  (alpha,beta) <-
    tcArrow p "application" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
           ty1
  unify p "application" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e2)
        alpha ty2
  return beta
tcExpr p e@(InfixApply e1 op e2) = do
  opTy <- tcExpr p (infixOp op)
  ty1  <- tcExpr p e1
  ty2  <- tcExpr p e2
  (alpha,beta,gamma) <-
    tcBinary p "infix application"
             (ppExpr 0 e $-$ text "Operator:" <+> ppOp op) opTy
  unify p "infix application" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
        alpha ty1
  unify p "infix application" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e2)
        beta ty2
  return gamma
tcExpr p e@(LeftSection e1 op) = do
  opTy <- tcExpr p (infixOp op)
  ty1  <- tcExpr p e1
  (alpha,beta) <-
    tcArrow p "left section" (ppExpr 0 e $-$ text "Operator:" <+> ppOp op)
            opTy
  unify p "left section" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
        alpha ty1
  return beta
tcExpr p e@(RightSection op e1) = do
  opTy <- tcExpr p (infixOp op)
  ty1  <- tcExpr p e1
  (alpha,beta,gamma) <-
    tcBinary p "right section"
             (ppExpr 0 e $-$ text "Operator:" <+> ppOp op) opTy
  unify p "right section" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
        beta ty1
  return (TypeArrow alpha gamma)
tcExpr p expr@(Lambda _ ts e) = do
  tyEnv0 <- getValueEnv
  tys <- mapM (tcPattern p) ts
  ty <- tcExpr p e
  checkSkolems p (text "Expression:" <+> ppExpr 0 expr) tyEnv0
               (foldr TypeArrow ty tys)
tcExpr p (Let ds e) = do
  tyEnv0 <- getValueEnv
  tcDecls ds
  ty <- tcExpr p e
  checkSkolems p (text "Expression:" <+> ppExpr 0 e) tyEnv0 ty
tcExpr p (Do sts e) = do
  tyEnv0 <- getValueEnv
  mapM_ (tcStmt p) sts
  alpha <- freshTypeVar
  ty <- tcExpr p e
  unify p "statement" (ppExpr 0 e) (ioType alpha) ty
  checkSkolems p (text "Expression:" <+> ppExpr 0 e) tyEnv0 ty
tcExpr p e@(IfThenElse _ e1 e2 e3) = do
  ty1 <- tcExpr p e1
  unify p "expression" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
        boolType ty1
  ty2 <- tcExpr p e2
  ty3 <- tcExpr p e3
  unify p "expression" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e3)
        ty2 ty3
  return ty3
tcExpr p (Case _ _ e alts) = do
  tyEnv0 <- getValueEnv
  ty <- tcExpr p e
  alpha <- freshTypeVar
  tcAlts tyEnv0 ty alpha alts
861 862 863 864 865 866 867 868 869
  where
  tcAlts _      _   ty  []           = return ty
  tcAlts tyEnv0 ty1 ty2 (alt1:alts1) = do
    tcAlt (ppAlt alt1) tyEnv0 ty1 ty2 alt1
    tcAlts tyEnv0 ty1 ty2 alts1
  tcAlt doc tyEnv0 ty1 ty2 (Alt p1 t rhs) = do
    ty' <- tcPattern p1 t
    unify p1 "case pattern" (doc $-$ text "Term:" <+> ppPattern 0 t) ty1 ty'
    tcRhs tyEnv0 rhs >>= unify p1 "case branch" doc ty2
870

871 872 873 874 875
tcEnum :: Position -> Expression -> Expression -> TCM ()
tcEnum p e e1 = do
  ty1 <- tcExpr p e1
  unify p "arithmetic sequence" (ppExpr 0 e $-$ text "Term:" <+> ppExpr 0 e1)
    intType ty1
876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894

tcQual :: Position -> Statement -> TCM ()
tcQual p (StmtExpr     _ e) =
  tcExpr p e >>= unify p "guard" (ppExpr 0 e) boolType
tcQual p q@(StmtBind _ t e) = do
  ty1 <- tcPattern p t
  ty2 <- tcExpr p e
  unify p "generator" (ppStmt q $-$ text "Term:" <+> ppExpr 0 e)
        (listType ty1) ty2
tcQual _ (StmtDecl      ds) = tcDecls ds

tcStmt ::Position -> Statement -> TCM ()
tcStmt p (StmtExpr _ e) = do
  alpha <- freshTypeVar
  ty    <- tcExpr p e
  unify p "statement" (ppExpr 0 e) (ioType alpha) ty
tcStmt p st@(StmtBind _ t e) = do
  ty1 <- tcPattern p t
  ty2 <- tcExpr p e
895 896
  unify p "statement" (ppStmt st $-$ text "Term:" <+> ppExpr 0 e)
    (ioType ty1) ty2
897 898
tcStmt _ (StmtDecl ds) = tcDecls ds

899 900 901
tcField :: (Position -> a -> TCM Type) -> String -> (a -> Doc) -> Type
        -> Field a -> TCM Type
tcField tcheck what doc ty (Field p l x) = do
902
  m     <- getModuleIdent
903
  tyEnv <- getValueEnv
904
  TypeArrow ty1 ty2 <- inst (labelType m l tyEnv)
905 906 907 908 909
  unify p "field label" empty ty ty1
  lty <- tcheck p x
  unify p ("record " ++ what) (doc x) ty2 lty
  return lty

910
-- The function 'tcArrow' checks that its argument can be used as
911
-- an arrow type a -> b and returns the pair (a,b).
912 913 914 915 916 917 918 919 920 921 922 923 924 925
tcArrow :: Position -> String -> Doc -> Type -> TCM (Type, Type)
tcArrow p what doc ty = do
  theta <- getTypeSubst
  unaryArrow (subst theta ty)
  where
  unaryArrow (TypeArrow ty1 ty2) = return (ty1, ty2)
  unaryArrow (TypeVariable   tv) = do
    alpha <- freshTypeVar
    beta  <- freshTypeVar
    modifyTypeSubst $ bindVar tv $ TypeArrow alpha beta
    return (alpha, beta)
  unaryArrow ty'                 = do
    m <- getModuleIdent
    report $ errNonFunctionType p what doc m ty'
926
    (,) <$> freshTypeVar <*> freshTypeVar
927

928 929
-- The function 'tcBinary' checks that its argument can be used as an arrow type
-- a -> b -> c and returns the triple (a,b,c).
930 931 932 933 934 935 936 937 938
tcBinary :: Position -> String -> Doc -> Type -> TCM (Type, Type, Type)
tcBinary p what doc ty = tcArrow p what doc ty >>= uncurry binaryArrow
  where
  binaryArrow ty1 (TypeArrow ty2 ty3) = return (ty1, ty2, ty3)
  binaryArrow ty1 (TypeVariable   tv) = do
    beta  <- freshTypeVar
    gamma <- freshTypeVar
    modifyTypeSubst $ bindVar tv $ TypeArrow beta gamma
    return (ty1, beta, gamma)
939
  binaryArrow ty1 ty2                 = do
940 941
    m <- getModuleIdent
    report $ errNonBinaryOp p what doc m (TypeArrow ty1 ty2)
942
    (,,) <$> return ty1 <*> freshTypeVar <*> freshTypeVar
943

944
-- Unification: The unification uses Robinson's algorithm.
945 946 947 948 949
unify :: Position -> String -> Doc -> Type -> Type -> TCM ()
unify p what doc ty1 ty2 = do
  theta <- getTypeSubst
  let ty1' = subst theta ty1
  let ty2' = subst theta ty2
950
  m     <- getModuleIdent
951
  case unifyTypes m ty1' ty2' of
952 953 954
    Left reason -> report $ errTypeMismatch p what doc m ty1' ty2' reason
    Right sigma -> modifyTypeSubst (compose sigma)

955 956
unifyTypes :: ModuleIdent -> Type -> Type -> Either Doc TypeSubst
unifyTypes _ (TypeVariable tv1) (TypeVariable tv2)
957 958
  | tv1 == tv2            = Right idSubst
  | otherwise             = Right (singleSubst tv1 (TypeVariable tv2))
959
unifyTypes m (TypeVariable tv) ty
960 961
  | tv `elem` typeVars ty = Left  (errRecursiveType m tv ty)
  | otherwise             = Right (singleSubst tv ty)
962
unifyTypes m ty (TypeVariable tv)
963 964
  | tv `elem` typeVars ty = Left  (errRecursiveType m tv ty)
  | otherwise             = Right (singleSubst tv ty)
965
unifyTypes _ (TypeConstrained tys1 tv1) (TypeConstrained tys2 tv2)
966 967
  | tv1  == tv2           = Right idSubst
  | tys1 == tys2          = Right (singleSubst tv1 (TypeConstrained tys2 tv2))
968 969
unifyTypes m (TypeConstrained tys tv) ty =
  foldr (choose . unifyTypes m ty) (Left (errIncompatibleTypes m ty (head tys)))
970 971 972
        tys
  where choose (Left _) theta' = theta'
        choose (Right theta) _ = Right (bindSubst tv ty theta)
973 974
unifyTypes m ty (TypeConstrained tys tv) =
  foldr (choose . unifyTypes m ty) (Left (errIncompatibleTypes m ty (head tys)))
975 976 977
        tys
  where choose (Left _) theta' = theta'
        choose (Right theta) _ = Right (bindSubst tv ty theta)
978 979 980 981 982
unifyTypes m (TypeConstructor tc1 tys1) (TypeConstructor tc2 tys2)
  | tc1 == tc2 = unifyTypeLists m tys1 tys2
unifyTypes m (TypeArrow ty11 ty12) (TypeArrow ty21 ty22) =
  unifyTypeLists m [ty11, ty12] [ty21, ty22]
unifyTypes _ (TypeSkolem k1) (TypeSkolem k2)
983
  | k1 == k2 = Right idSubst
984
unifyTypes m ty1 ty2 = Left (errIncompatibleTypes m ty1 ty2)
985

986 987 988 989 990
unifyTypeLists :: ModuleIdent -> [Type] -> [Type] -> Either Doc TypeSubst
unifyTypeLists _ []           _            = Right idSubst
unifyTypeLists _ _            []           = Right idSubst
unifyTypeLists m (ty1 : tys1) (ty2 : tys2) =
  either Left unifyTypesTheta (unifyTypeLists m tys1 tys2)
991 992 993
  where
  unifyTypesTheta theta = either Left (Right . flip compose theta)
                          (unifyTypes m (subst theta ty1) (subst theta ty2))
994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010

-- For each declaration group, the type checker has to ensure that no
-- skolem type escapes its scope.
checkSkolems :: Position -> Doc -> ValueEnv -> Type -> TCM Type
checkSkolems p what tyEnv ty = do
  m     <- getModuleIdent
  theta <- getTypeSubst
  let ty' = subst theta ty
      fs  = fsEnv $ subst theta tyEnv
  unless (all (`Set.member` fs) $ typeSkolems ty') $
           report $ errSkolemEscapingScope p m what ty'
  return ty'

-- Instantiation and Generalization:
-- We use negative offsets for fresh type variables.

fresh :: (Int -> a) -> TCM a
1011
fresh f = f <$> getNextId
1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060

freshVar :: (Int -> a) -> TCM a
freshVar f = fresh $ \ n -> f (- n - 1)

freshTypeVar :: TCM Type
freshTypeVar = freshVar TypeVariable

freshConstrained :: [Type] -> TCM Type
freshConstrained = freshVar . TypeConstrained

freshSkolem :: TCM Type
freshSkolem = fresh TypeSkolem

inst :: TypeScheme -> TCM Type
inst (ForAll n ty) = do
  tys <- replicateM n freshTypeVar
  return $ expandAliasType tys ty

instExist :: ExistTypeScheme -> TCM Type
instExist (ForAllExist n n' ty) = do
  tys <- replicateM (n + n') freshTypeVar
  return $ expandAliasType tys ty

skol :: ExistTypeScheme -> TCM Type
skol (ForAllExist n n' ty) = do
  tys  <- replicateM n  freshTypeVar
  tys' <- replicateM n' freshSkolem
  return $ expandAliasType (tys ++ tys') ty

gen :: Set.Set Int -> Type -> TypeScheme
gen gvs ty = ForAll (length tvs)
                    (subst (foldr2 bindSubst idSubst tvs tvs') ty)
  where tvs = [tv | tv <- nub (typeVars ty), tv `Set.notMember` gvs]
        tvs' = map TypeVariable [0 ..]

-- Auxiliary Functions:
-- The functions 'constrType', 'varType', and 'funType' are used to retrieve
-- the type of constructors, pattern variables, and variables in expressions,
-- respectively, from the type environment. Because the syntactical correctness
-- has already been verified by the syntax checker, none of these functions
-- should fail.

-- Note that 'varType' can handle ambiguous identifiers and returns the first
-- available type. This function is used for looking up the type of an
-- identifier on the left hand side of a rule where it unambiguously refers
-- to the local definition.

constrType :: ModuleIdent -> QualIdent -> ValueEnv -> ExistTypeScheme
constrType m c tyEnv = case qualLookupValue c tyEnv of
1061 1062
  [DataConstructor  _ _ _ sigma] -> sigma
  [NewtypeConstructor _ _ sigma] -> sigma
1063
  _ -> case qualLookupValue (qualQualify m c) tyEnv of
1064 1065
    [DataConstructor  _ _ _ sigma] -> sigma
    [NewtypeConstructor _ _ sigma] -> sigma
1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
    _ -> internalError $ "TypeCheck.constrType " ++ show c

varArity :: Ident -> ValueEnv -> Int
varArity v tyEnv = case lookupValue v tyEnv of
  Value _ a _ : _ -> a
  _ -> internalError $ "TypeCheck.varArity " ++ show v

varType :: Ident -> ValueEnv -> TypeScheme
varType v tyEnv = case lookupValue v tyEnv of
  Value _ _ sigma : _ -> sigma
  _ -> internalError $ "TypeCheck.varType " ++ show v

sureVarType :: Ident -> ValueEnv -> Maybe TypeScheme
sureVarType v tyEnv = case lookupValue v tyEnv of
  Value _ _ sigma : _ -> Just sigma
  _ -> Nothing

funType :: ModuleIdent -> QualIdent -> ValueEnv -> TypeScheme
funType m f tyEnv = case qualLookupValue f tyEnv of
  [Value _ _ sigma] -> sigma
Jan Rasmus Tikovsky 's avatar
Jan Rasmus Tikovsky committed
1086
  [Label _ _ sigma] -> sigma
1087
  _                 -> case qualLookupValue (qualQualify m f) tyEnv of
1088
    [Value _ _ sigma] -> sigma
Jan Rasmus Tikovsky 's avatar
Jan Rasmus Tikovsky committed
1089
    [Label _ _ sigma] -> sigma
1090 1091
    _                 -> internalError $ "TypeCheck.funType " ++ show f
                          ++ ", more precisely " ++ show (unqualify f)
1092

1093 1094 1095
labelType :: ModuleIdent -> QualIdent -> ValueEnv -> TypeScheme
labelType m l tyEnv = case qualLookupValue l tyEnv of
  [Label _ _ sigma] -> sigma
1096
  _ -> case qualLookupValue ql tyEnv of
1097
    [Label _ _ sigma] -> sigma
1098 1099 1100
    _ -> internalError $ "TypeCheck.labelType " ++ show ql
          ++ ", more precisely " ++ show l
  where ql = qualQualify m l
1101 1102 1103 1104 1105 1106

-- The function 'expandType' expands all type synonyms in a type
-- and also qualifies all type constructors with the name of the module
-- in which the type was defined.

expandPolyType :: TypeExpr -> TCM TypeScheme
1107
expandPolyType ty = (polyType . normalize) <$> expandMonoType [] ty
1108 1109 1110

expandMonoType :: [Ident] -> TypeExpr -> TCM Type
expandMonoType tvs ty = do
1111
  m     <- getModuleIdent
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182
  tcEnv <- getTyConsEnv
  return $ expandMonoType' m tcEnv tvs ty

expandMonoType' :: ModuleIdent -> TCEnv -> [Ident] -> TypeExpr -> Type
expandMonoType' m tcEnv tvs ty = expandType m tcEnv (toType tvs ty)

expandMonoTypes :: ModuleIdent -> TCEnv -> [Ident] -> [TypeExpr] -> [Type]
expandMonoTypes m tcEnv tvs tys = map (expandType m tcEnv) (toTypes tvs tys)

expandType :: ModuleIdent -> TCEnv -> Type -> Type
expandType m tcEnv (TypeConstructor tc tys) = case qualLookupTC tc tcEnv of
  [DataType     tc' _  _] -> TypeConstructor tc' tys'
  [RenamingType tc' _  _] -> TypeConstructor tc' tys'
  [AliasType    _   _ ty] -> expandAliasType tys' ty
  _ -> case qualLookupTC (qualQualify m tc) tcEnv of
    [DataType     tc' _ _ ] -> TypeConstructor tc' tys'
    [RenamingType tc' _ _ ] -> TypeConstructor tc' tys'
    [AliasType    _   _ ty] -> expandAliasType tys' ty
    _ -> internalError $ "TypeCheck.expandType " ++ show tc
  where tys' = map (expandType m tcEnv) tys
expandType _ _     tv@(TypeVariable      _) = tv
expandType _ _     tc@(TypeConstrained _ _) = tc
expandType m tcEnv (TypeArrow      ty1 ty2) =
  TypeArrow (expandType m tcEnv ty1) (expandType m tcEnv ty2)
expandType _ _     ts@(TypeSkolem        _) = ts

-- The functions 'fvEnv' and 'fsEnv' compute the set of free type variables
-- and free skolems of a type environment, respectively. We ignore the types
-- of data constructors here because we know that they are closed.

fvEnv :: ValueEnv -> Set.Set Int
fvEnv tyEnv = Set.fromList
  [tv | ty <- localTypes tyEnv, tv <- typeVars ty, tv < 0]

fsEnv :: ValueEnv -> Set.Set Int
fsEnv = Set.unions . map (Set.fromList . typeSkolems) . localTypes

localTypes :: ValueEnv -> [Type]
localTypes tyEnv = [ty | (_, Value _ _ (ForAll _ ty)) <- localBindings tyEnv]

-- ---------------------------------------------------------------------------
-- Error functions
-- ---------------------------------------------------------------------------

errRecursiveTypes :: [Ident] -> Message
errRecursiveTypes []         = internalError
  "TypeCheck.recursiveTypes: empty list"
errRecursiveTypes [tc]       = posMessage tc $ hsep $ map text
  ["Recursive synonym type", idName tc]
errRecursiveTypes (tc : tcs) = posMessage tc $
  text "Recursive synonym types" <+> text (idName tc) <+> types empty tcs
  where
  types _    []         = empty
  types comm [tc1]      = comm <+> text "and" <+> text (idName tc1)
                          <+> parens (text $ showLine $ idPosition tc1)
  types _    (tc1:tcs1) = comma <+> text (idName tc1) <+>
                          parens (text $ showLine $ idPosition tc1)
                          <> types comma tcs1

errPolymorphicFreeVar :: Ident -> Message
errPolymorphicFreeVar v = posMessage v $ hsep $ map text
  ["Free variable", idName v, "has a polymorphic type"]

errTypeSigTooGeneral :: Position -> ModuleIdent -> Doc -> TypeExpr -> TypeScheme
                     -> Message
errTypeSigTooGeneral p m what ty sigma = posMessage p $ vcat
  [ text "Type signature too general", what
  , text "Inferred type:"  <+> ppTypeScheme m sigma
  , text "Type signature:" <+> ppTypeExpr 0 ty
  ]

1183 1184
errNonFunctionType :: Position -> String -> Doc -> ModuleIdent -> Type
                   -> Message
1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197
errNonFunctionType p what doc m ty = posMessage p $ vcat
  [ text "Type error in" <+> text what, doc
  , text "Type:" <+> ppType m ty
  , text "Cannot be applied"
  ]

errNonBinaryOp :: Position -> String -> Doc -> ModuleIdent -> Type -> Message
errNonBinaryOp p what doc m ty = posMessage p $ vcat
  [ text "Type error in" <+> text what, doc
  , text "Type:" <+> ppType m ty
  , text "Cannot be used as binary operator"
  ]

1198 1199
errTypeMismatch :: Position -> String -> Doc -> ModuleIdent -> Type -> Type
                -> Doc -> Message
1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222
errTypeMismatch p what doc m ty1 ty2 reason = posMessage p $ vcat
  [ text "Type error in"  <+> text what, doc
  , text "Inferred type:" <+> ppType m ty2
  , text "Expected type:" <+> ppType m ty1
  , reason
  ]

errSkolemEscapingScope :: Position -> ModuleIdent -> Doc -> Type -> Message
errSkolemEscapingScope p m what ty = posMessage p $ vcat
  [ text "Existential type escapes out of its scope"
  , what, text "Type:" <+> ppType m ty
  ]

errRecursiveType :: ModuleIdent -> Int -> Type -> Doc
errRecursiveType m tv ty = errIncompatibleTypes m (TypeVariable tv) ty

errIncompatibleTypes :: ModuleIdent -> Type -> Type -> Doc
errIncompatibleTypes m ty1 ty2 = sep
  [ text "Types" <+> ppType m ty1
  , nest 2 $ text "and" <+> ppType m ty2
  , text "are incompatible"
  ]

1223 1224
errIncompatibleLabelTypes :: Position -> ModuleIdent -> Ident -> Type -> Type -> Message
errIncompatibleLabelTypes p m l ty1 ty2 = posMessage p $ sep
1225 1226 1227 1228
  [ text "Labeled types" <+> ppIdent l <+> text "::" <+> ppType m ty1
  , nest 10 $ text "and" <+> ppIdent l <+> text "::" <+> ppType m ty2
  , text "are incompatible"
  ]