Commit 44ec8933 authored by Björn Peemöller 's avatar Björn Peemöller
Browse files

Simplified the lifting phase

parent c072740b
...@@ -23,16 +23,18 @@ module Transformations.Lift (lift) where ...@@ -23,16 +23,18 @@ module Transformations.Lift (lift) where
#if __GLASGOW_HASKELL__ < 710 #if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>)) import Control.Applicative ((<$>), (<*>))
#endif #endif
import Control.Arrow (first)
import qualified Control.Monad.State as S (State, runState, gets, modify) import qualified Control.Monad.State as S (State, runState, gets, modify)
import Data.List import Data.List
import qualified Data.Map as Map (Map, empty, insert, lookup) import qualified Data.Map as Map (Map, empty, insert, lookup)
import qualified Data.Set as Set (toList, fromList, unions) import qualified Data.Set as Set (toList, fromList, unions)
import Curry.Base.Ident import Curry.Base.Ident
import Curry.Base.Position (Position)
import Curry.Syntax import Curry.Syntax
import Base.Expr import Base.Expr
import Base.Messages (internalError) import Base.Messages (internalError)
import Base.SCC import Base.SCC
import Base.Types import Base.Types
...@@ -103,7 +105,7 @@ absLhs (FunLhs f ts) = FunLhs f <$> mapM absPat ts ...@@ -103,7 +105,7 @@ absLhs (FunLhs f ts) = FunLhs f <$> mapM absPat ts
absLhs _ = error "Lift.absLhs: no simple LHS" absLhs _ = error "Lift.absLhs: no simple LHS"
absRhs :: String -> [Ident] -> Rhs -> LiftM Rhs absRhs :: String -> [Ident] -> Rhs -> LiftM Rhs
absRhs pre lvs (SimpleRhs p e _) = flip (SimpleRhs p) [] <$> absExpr pre lvs e absRhs pre lvs (SimpleRhs p e _) = simpleRhs p <$> absExpr pre lvs e
absRhs _ _ _ = error "Lift.absRhs: no simple RHS" absRhs _ _ _ = error "Lift.absRhs: no simple RHS"
-- Within a declaration group we have to split the list of declarations -- Within a declaration group we have to split the list of declarations
...@@ -118,13 +120,13 @@ absRhs _ _ _ = error "Lift.absRhs: no simple RHS" ...@@ -118,13 +120,13 @@ absRhs _ _ _ = error "Lift.absRhs: no simple RHS"
-- call each other. -- call each other.
-- --
-- f = g True -- f = g True
-- where x = f 1 -- where x = h 1
-- f z = y + z -- h z = y + z
-- y = g False -- y = g False
-- g z = if z then x else 0 -- g z = if z then x else 0
-- --
-- Because of this fact, f and g can be abstracted separately by adding -- Because of this fact, 'g' and 'h' can be abstracted separately by adding
-- only 'y' to 'f' and 'x' to 'g'. On the other hand, in the following example -- only 'y' to 'h' and 'x' to 'g'. On the other hand, in the following example
-- --
-- f x y = g 4 -- f x y = g 4
-- where g p = h p + x -- where g p = h p + x
...@@ -161,7 +163,6 @@ absDeclGroup pre lvs ds e = do ...@@ -161,7 +163,6 @@ absDeclGroup pre lvs ds e = do
absFunDecls pre (lvs ++ bv vds) (scc bv (qfv m) fds) vds e absFunDecls pre (lvs ++ bv vds) (scc bv (qfv m) fds) vds e
where (fds, vds) = partition isFunDecl ds where (fds, vds) = partition isFunDecl ds
-- TODO: too complicated?
absFunDecls :: String -> [Ident] -> [[Decl]] -> [Decl] -> Expression absFunDecls :: String -> [Ident] -> [[Decl]] -> [Decl] -> Expression
-> LiftM Expression -> LiftM Expression
absFunDecls pre lvs [] vds e = do absFunDecls pre lvs [] vds e = do
...@@ -169,32 +170,40 @@ absFunDecls pre lvs [] vds e = do ...@@ -169,32 +170,40 @@ absFunDecls pre lvs [] vds e = do
e' <- absExpr pre lvs e e' <- absExpr pre lvs e
return (Let vds' e') return (Let vds' e')
absFunDecls pre lvs (fds:fdss) vds e = do absFunDecls pre lvs (fds:fdss) vds e = do
m <- getModuleIdent m <- getModuleIdent
env <- getAbstractEnv env <- getAbstractEnv
let fs = bv fds tyEnv <- getValueEnv
fvs = filter (`elem` lvs) (Set.toList fvsRhs) let -- defined functions
env' = foldr (bindF fvs) env fs fs = bv fds
-- free variables on the right-hand sides
fvsRhs = Set.unions fvsRhs = Set.unions
[ Set.fromList (maybe [v] (qfv m . asFunCall) (Map.lookup v env)) [ Set.fromList (maybe [v] (qfv m . asFunCall) (Map.lookup v env))
| v <- qfv m fds] | v <- qfv m fds]
-- free variables that are local
fvs = filter (`elem` lvs) (Set.toList fvsRhs)
-- extended abstraction environment
env' = foldr (bindF fvs) env fs
bindF fvs' f = Map.insert f (qualifyWith m $ liftIdent pre f, fvs') bindF fvs' f = Map.insert f (qualifyWith m $ liftIdent pre f, fvs')
isLifted tyEnv f = null $ lookupValue f tyEnv -- newly abstracted functions
fs' <- (\tyEnv -> filter (not . isLifted tyEnv) fs) <$> getValueEnv fs' = filter (\f -> not $ null $ lookupValue f tyEnv) fs
-- update environment
modifyValueEnv $ absFunTypes m pre fvs fs' modifyValueEnv $ absFunTypes m pre fvs fs'
(fds', e') <- withLocalAbstractEnv env' $ do withLocalAbstractEnv env' $ do
fds'' <- mapM (absFunDecl pre fvs lvs) -- add variables to functions
[d | d <- fds, any (`elem` fs') (bv d)] fds' <- mapM (absFunDecl pre fvs lvs) [d | d <- fds, any (`elem` fs') (bv d)]
e'' <- absFunDecls pre lvs fdss vds e -- abstract remaining declarations
return (fds'', e'') e' <- absFunDecls pre lvs fdss vds e
return (Let fds' e') return (Let fds' e')
-- Add the additional variables to the types of the functions and rebind
-- the functions in the value environment
absFunTypes :: ModuleIdent -> String -> [Ident] -> [Ident] absFunTypes :: ModuleIdent -> String -> [Ident] -> [Ident]
-> ValueEnv -> ValueEnv -> ValueEnv -> ValueEnv
absFunTypes m pre fvs fs tyEnv = foldr abstractFunType tyEnv fs absFunTypes m pre fvs fs tyEnv = foldr abstractFunType tyEnv fs
where tys = map (varType tyEnv) fvs where tys = map (varType tyEnv) fvs
abstractFunType f tyEnv' = abstractFunType f tyEnv' =
qualBindFun m (liftIdent pre f) qualBindFun m (liftIdent pre f)
(length fvs + varArity tyEnv' f) -- (arrowArity ty) (length fvs + varArity tyEnv' f)
(polyType (normType ty)) (polyType (normType ty))
(unbindFun f tyEnv') (unbindFun f tyEnv')
where ty = foldr TypeArrow (varType tyEnv' f) tys where ty = foldr TypeArrow (varType tyEnv' f) tys
...@@ -242,6 +251,8 @@ absExpr _ _ e = internalError $ "Lift.absExpr: " ++ show e ...@@ -242,6 +251,8 @@ absExpr _ _ e = internalError $ "Lift.absExpr: " ++ show e
absAlt :: String -> [Ident] -> Alt -> LiftM Alt absAlt :: String -> [Ident] -> Alt -> LiftM Alt
absAlt pre lvs (Alt p t rhs) = Alt p t <$> absRhs pre (lvs ++ bv t) rhs absAlt pre lvs (Alt p t rhs) = Alt p t <$> absRhs pre (lvs ++ bv t) rhs
-- TODO: Remove since functional patterns should not be abstracted
absPat :: Pattern -> LiftM Pattern absPat :: Pattern -> LiftM Pattern
absPat v@(VariablePattern _) = return v absPat v@(VariablePattern _) = return v
absPat l@(LiteralPattern _) = return l absPat l@(LiteralPattern _) = return l
...@@ -262,14 +273,14 @@ absPat p = error $ "Lift.absPat: " ++ show p ...@@ -262,14 +273,14 @@ absPat p = error $ "Lift.absPat: " ++ show p
-- to the top-level. -- to the top-level.
liftFunDecl :: Decl -> [Decl] liftFunDecl :: Decl -> [Decl]
liftFunDecl (FunctionDecl p f eqs) = (FunctionDecl p f eqs' : concat dss') liftFunDecl (FunctionDecl p f eqs) = FunctionDecl p f eqs' : concat dss'
where (eqs', dss') = unzip $ map liftEquation eqs where (eqs', dss') = unzip $ map liftEquation eqs
liftFunDecl d = [d] liftFunDecl d = [d]
liftVarDecl :: Decl -> (Decl, [Decl]) liftVarDecl :: Decl -> (Decl, [Decl])
liftVarDecl (PatternDecl p t rhs) = (PatternDecl p t rhs', ds') liftVarDecl (PatternDecl p t rhs) = (PatternDecl p t rhs', ds')
where (rhs', ds') = liftRhs rhs where (rhs', ds') = liftRhs rhs
liftVarDecl ex@(FreeDecl _ _) = (ex, []) liftVarDecl ex@(FreeDecl _ _) = (ex, [])
liftVarDecl _ = error "Lift.liftVarDecl: no pattern match" liftVarDecl _ = error "Lift.liftVarDecl: no pattern match"
liftEquation :: Equation -> (Equation, [Decl]) liftEquation :: Equation -> (Equation, [Decl])
...@@ -277,12 +288,11 @@ liftEquation (Equation p lhs rhs) = (Equation p lhs rhs', ds') ...@@ -277,12 +288,11 @@ liftEquation (Equation p lhs rhs) = (Equation p lhs rhs', ds')
where (rhs', ds') = liftRhs rhs where (rhs', ds') = liftRhs rhs
liftRhs :: Rhs -> (Rhs, [Decl]) liftRhs :: Rhs -> (Rhs, [Decl])
liftRhs (SimpleRhs p e _) = (SimpleRhs p e' [], ds') liftRhs (SimpleRhs p e _) = first (simpleRhs p) (liftExpr e)
where (e', ds') = liftExpr e liftRhs _ = error "Lift.liftRhs: no pattern match"
liftRhs _ = error "Lift.liftRhs: no pattern match"
liftDeclGroup :: [Decl] -> ([Decl],[Decl]) liftDeclGroup :: [Decl] -> ([Decl],[Decl])
liftDeclGroup ds = (vds', concat $ map liftFunDecl fds ++ dss') liftDeclGroup ds = (vds', concat (map liftFunDecl fds ++ dss'))
where (fds , vds ) = partition isFunDecl ds where (fds , vds ) = partition isFunDecl ds
(vds', dss') = unzip $ map liftVarDecl vds (vds', dss') = unzip $ map liftVarDecl vds
...@@ -290,13 +300,12 @@ liftExpr :: Expression -> (Expression, [Decl]) ...@@ -290,13 +300,12 @@ liftExpr :: Expression -> (Expression, [Decl])
liftExpr l@(Literal _) = (l, []) liftExpr l@(Literal _) = (l, [])
liftExpr v@(Variable _) = (v, []) liftExpr v@(Variable _) = (v, [])
liftExpr c@(Constructor _) = (c, []) liftExpr c@(Constructor _) = (c, [])
liftExpr (Apply e1 e2) = (Apply e1' e2', ds' ++ ds'') liftExpr (Apply e1 e2) = (Apply e1' e2', ds1 ++ ds2)
where (e1', ds' ) = liftExpr e1 where (e1', ds1) = liftExpr e1
(e2', ds'') = liftExpr e2 (e2', ds2) = liftExpr e2
liftExpr (Let ds e) = (mkLet ds' e', ds'' ++ ds''') liftExpr (Let ds e) = (mkLet ds' e', ds1 ++ ds2)
where (ds', ds'' ) = liftDeclGroup ds where (ds', ds1) = liftDeclGroup ds
(e' , ds''') = liftExpr e (e' , ds2) = liftExpr e
mkLet ds1 e1 = if null ds1 then e1 else Let ds1 e1
liftExpr (Case r ct e alts) = (Case r ct e' alts', concat $ ds' : dss') liftExpr (Case r ct e alts) = (Case r ct e' alts', concat $ ds' : dss')
where (e' ,ds' ) = liftExpr e where (e' ,ds' ) = liftExpr e
(alts',dss') = unzip $ map liftAlt alts (alts',dss') = unzip $ map liftAlt alts
...@@ -321,9 +330,15 @@ asFunCall (f, vs) = apply (Variable f) (map mkVar vs) ...@@ -321,9 +330,15 @@ asFunCall (f, vs) = apply (Variable f) (map mkVar vs)
mkVar :: Ident -> Expression mkVar :: Ident -> Expression
mkVar v = Variable $ qualify v mkVar v = Variable $ qualify v
mkLet :: [Decl] -> Expression -> Expression
mkLet ds e = if null ds then e else Let ds e
apply :: Expression -> [Expression] -> Expression apply :: Expression -> [Expression] -> Expression
apply = foldl Apply apply = foldl Apply
simpleRhs :: Position -> Expression -> Rhs
simpleRhs p e = SimpleRhs p e []
varArity :: ValueEnv -> Ident -> Int varArity :: ValueEnv -> Ident -> Int
varArity tyEnv v = case lookupValue v tyEnv of varArity tyEnv v = case lookupValue v tyEnv of
[Value _ a _] -> a [Value _ a _] -> a
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment