Skip to content
Transform.hs 25.3 KiB
Newer Older
{-# LANGUAGE FlexibleContexts #-}
Bennet Bleßmann's avatar
Bennet Bleßmann committed
{-# LANGUAGE TupleSections #-}
module Z3.Transformation.Transform where

import qualified Data.Map as Map
import qualified Data.List as List
Marvin Lira's avatar
Marvin Lira committed
import qualified Data.Set as Set
import Control.Monad.State
import Z3.Monad (AST)
import qualified Z3.Monad as Z3
import Z3.Transformation.Monad
import Z3.Transformation.DataTypes as DataTypes
import Z3.Transformation.TypeTransform as TypeTransform
Marvin Lira's avatar
Marvin Lira committed
import Z3.Transformation.Records as Records
import Verify.AST hiding (removeAliases,args)
import qualified Verify.AST as AST
transformSetup :: [Function] -> [Function]
               -> [Function] -> [Type]-> SMTMonad ()
transformSetup pres posts funs tipes = do
  mapM_ (uncurry addPrecondition ) pres
  mapM_ (uncurry addPostcondition) posts
  setTypes $ Map.fromList tipes
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  -- might be used implicitly and required for reverse lookup
  _ <- defineOrGetType DataTypes.BoolId
  _ <- defineOrGetType DataTypes.DoubleId
  setRawFunctions $ Map.fromList funs
  addAppendFunction
getOrDeclareFunction :: FunctionIdentifier -> DataTypes.SortId -> [DataTypes.SortId]
                     -> SMTMonad (Either SMTFunction Z3.FuncDecl)
getOrDeclareFunction fid@(_,name) sort params = do
  mayFun <- getSMTFunction (sort,fid)
  case mayFun of
    Just a  -> Left <$> return a
    Nothing -> do
      mayRawFun <- getRawFunction fid
      case mayRawFun of
        Just def -> Left <$> declareFunction fid sort def
        Nothing  -> do
          mv <- getVariable name
          case mv of
            Nothing -> error $ "Encountered Undefined Function " ++ show fid
            Just _  -> do
              idt <- freshVarId
              sym <- Z3.mkStringSymbol ("HO#" ++ name ++ "#" ++ show idt)
              paramSorts <- mapM defineOrGetType params
              resultSort  <- defineOrGetType sort
              Right <$> Z3.mkFuncDecl sym paramSorts resultSort

splitArgsResult :: DataTypes.SortId -> ([DataTypes.SortId],DataTypes.SortId)
splitArgsResult sortId = case sortId of
    LambdaId from to -> (from:params, result)
      where (params, result) = splitArgsResult to
    a                -> ([], a)

declareFunction :: FunctionIdentifier -> SortId -> FunctionDefinition -> SMTMonad SMTFunction
declareFunction (m,n) sortId f = do
  let (argIds,resultId) = splitArgsResult sortId
  -- Declare functions
Marvin Lira's avatar
Marvin Lira committed
  whenDebug $ liftIO $ putStrLn $ "declaring function '" ++ n ++ "'"
  sym      <- Z3.mkStringSymbol (m ++ "." ++ n ++ "#" ++ show findex)
  argSorts <- mapM defineOrGetType argIds
  resSort  <- defineOrGetType resultId
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  whenDebug $ liftIO $ putStrLn "arg sorts:"
  whenDebug $ liftIO $ print argSorts
  whenDebug $ liftIO $ putStrLn   "result sort:"
  whenDebug $ liftIO $ print resSort
  fun      <- Z3.mkFuncDecl sym argSorts resSort
  addFunction (sortId,(m,n)) fun f
Tammo Heilemann's avatar
Tammo Heilemann committed
declareFunctionConditions :: FunctionIdentifier -> DataTypes.SortId -> FunctionDefinition -> [AST] -> SMTMonad ()
declareFunctionConditions name sortId (FunDefWithConds sig _ _ pres posts) args = do
  params <- mapM (getSortIdFromType . snd) sig
  fun <- either fst id <$> getOrDeclareFunction name sortId params
  fr <- app fun args
  -- Prepare preconditions
  presResults <- mapM (makeCond args . snd) pres
  -- Prepare postconditions
  postsResults <- mapM (makeCond (args ++ [fr]) . snd) posts
  presConj  <- Z3.mkAnd presResults
  postsConj <- Z3.mkAnd postsResults
  if List.null posts then
    whenDebugIO $ putStrLn $ "Function " ++ show name ++ " has no postconditions (including invariants)"
  else
    whenDebugIO $ putStrLn $
    " has " ++ show (List.length presResults) ++
    " preconditions and " ++ show (List.length postsResults) ++
  -- Assert that conditions hold
  Z3.assert =<< Z3.mkImplies presConj postsConj
makeCond :: [AST] -> FunctionDefinition -> SMTMonad AST
makeCond params (FunDefWithConds ps pb rs _ _) = do
  pvs <- concat <$> zipWithM bindArgToPattern (map fst ps) params
  typeScope $ do
    paramSorts <- mapM astToSortId params
    initNewTypeScope (foldr FunId rs $ fmap snd ps) paramSorts
    varsWithScope2 pvs (transformExpr pb)
transformExpr :: Expr -> SMTMonad AST
transformExpr (Call vtype what typ params) = handleCall vtype what typ params
transformExpr (Constant (VBool   b)) = Z3.mkBool b
transformExpr (Constant (VInt    n)) = Z3.mkInt n =<< defineOrGetType DataTypes.IntId
transformExpr (Constant (VDouble r)) = Z3.mkRealNum r
transformExpr (Constant VUnit      ) = do
  vid <- freshVarId
  sym <- Z3.mkIntSymbol vid
  usort <- defineOrGetType DataTypes.UnitId
  Z3.mkConst sym usort
transformExpr (Var (Local _ ) name _) =
  fromMaybe (error $ "Undefined Variable: " ++ name) <$> getVariable name
transformExpr (Var (Ctor _) name sort) = do
  let resSort = getResultSort sort
  sort' <- getSortFromType resSort -- doesn't handle unknown type
  let transC ctor = do
       symbol <- Z3.getDeclName ctor
       newname <- Z3.getSymbolString symbol
       return (newname,ctor)
  ctors <- mapM transC =<< Z3.getDatatypeSortConstructors sort'
  let ctor = snd $ fromMaybe (error "Failed to find Constructor")
                 $ List.find (List.isPrefixOf (name++"#") . fst) ctors
transformExpr (Var ty name _) = do
  let mname = fromMaybe (error "Var was local") $ getVarTypeModule ty -- cannot be local
  let idt = (mname, name)
  FunDefWithConds sig e _ _ _ <-
    fromMaybe (error $ "Undefined function" ++ show idt) <$> getFunctionDef idt
  if null sig
    then transformExpr e
    else Z3.mkInt (length sig) =<< defineOrGetType DataTypes.IntId -- return arity instead

transformExpr (Tuple a b mc) = do
  ts <- mapM transformExpr $ case mc of
    Nothing -> [a,b]
    Just c  -> [a,b,c]
  let sort = DataTypes.TupleId tsid
  sort' <- defineOrGetType sort
  ctor <- head <$> Z3.getDatatypeSortConstructors sort'
  whenDebug $ liftIO $ putStrLn $ "Creating Tuple " ++ show tsid
Marvin Lira's avatar
Marvin Lira committed
transformExpr (Record fieldMap) = do
Marvin Lira's avatar
Marvin Lira committed
  valsMap <- mapM transformExpr fieldMap
  sort <- defineOrGetType =<< getRecordSortId valsMap
  createRecord sort valsMap
transformExpr (String value) = createString DataTypes.StringId value
transformExpr (Char value) = createString DataTypes.CharId value
transformExpr (List exprs) = do
    exprs' <- mapM transformExpr exprs
    sort <- astToSortId $ head exprs'
    constructList (DataTypes.ListId sort) exprs'
Marvin Lira's avatar
Marvin Lira committed
transformExpr (Access expr name) = do
  record <- transformExpr expr
Marvin Lira's avatar
Marvin Lira committed
  fieldMap <- getRecordFieldMap record
Tammo Heilemann's avatar
Tammo Heilemann committed
  return $ fromMaybe (error $ "Z3.Transformation.Transform.transformExpr: Access of non-existing record field: " ++ show name)
         $ Map.lookup name fieldMap
Marvin Lira's avatar
Marvin Lira committed
transformExpr (Update expr fieldMap) = do
  record <- transformExpr expr
  sort <- Z3.getSort record
Marvin Lira's avatar
Marvin Lira committed
  baseMap <- getRecordFieldMap record
  updateMap <- mapM transformExpr fieldMap
  createRecord sort $ updateRecordFieldMap baseMap updateMap
transformExpr (Case expr (cBranch:cBranches)) = do
David Wolff's avatar
David Wolff committed
    e <- transformExpr expr
Bennet Bleßmann's avatar
Bennet Bleßmann committed
    (predicate,vars) <- evaluateCaseBranch (fst cBranch) e
    let branchExpression = varsWithScope2 vars $ transformExpr $ snd cBranch
    case cBranches of
        [] -> branch predicate branchExpression
        (_:_) -> ite predicate branchExpression $ transformExpr (Case expr cBranches)
Tammo Heilemann's avatar
Tammo Heilemann committed
transformExpr v@(Let lettype expr) =
    case lettype of
        Single f -> transformSingleLet f expr
        Destruct pat patExpr -> do
            vars <- bindArgToPattern pat =<< transformExpr patExpr
            varsWithScope2 vars $ transformExpr expr
        _        -> error $ "transformExpr doesn't handle " ++ show v
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
transformExpr v = error $ "transformExpr doesn't handle " ++ show v
Tammo Heilemann's avatar
Tammo Heilemann committed
transformSingleLet :: Function -> Expr -> SMTMonad AST
transformSingleLet ((_fmdl, fname),FunDefWithConds _fsig fexpr _ftipe _fpres _fposts) expr = do
Tammo Heilemann's avatar
Tammo Heilemann committed
    fast <- transformExpr fexpr
    varsWithScope2 [(fname, fast)] $ transformExpr expr
    --error $ "This is f: " ++ show f

Bennet Bleßmann's avatar
Bennet Bleßmann committed

evaluateCaseBranch :: Pattern -> AST -> SMTMonad (AST,[(String,AST)])
evaluateCaseBranch (PVariable n     ) v = (,[(n,v)]) <$> Z3.mkBool True
evaluateCaseBranch  PDefault          _ = (,[]) <$> Z3.mkBool True
evaluateCaseBranch  PUnit             _ = (,[]) <$> Z3.mkBool True
Bennet Bleßmann's avatar
Bennet Bleßmann committed
evaluateCaseBranch  (PConstant value) v = do
    constant <- transformExpr $ Constant value
    (,[]) <$> mkEq constant v
evaluateCaseBranch (PRecord fields)   v = do
    rFields <- Map.toList . flip Map.restrictKeys (Set.fromList fields) <$> getRecordFieldMap v
    (,rFields) <$> Z3.mkBool True
Bennet Bleßmann's avatar
Bennet Bleßmann committed
evaluateCaseBranch  (PAs p n)         v = do
    (takes,vars) <- evaluateCaseBranch p v
    return (takes,(n,v):vars)
evaluateCaseBranch  (PTuple p1 p2 p3) v = do
    sort <- Z3.getSort v
    accs <- head <$> Z3.getDatatypeSortConstructorAccessors sort
Bennet Bleßmann's avatar
Bennet Bleßmann committed
    parts <- sequence $ case p3 of
        Just p3J ->  [evaluateCaseBranch p1 =<< app (head accs) [v] , evaluateCaseBranch p2 =<< app (accs !! 1) [v] , evaluateCaseBranch p3J =<< app (accs!!2) [v]]
        Nothing  ->  [evaluateCaseBranch p1 =<< app (head accs) [v] , evaluateCaseBranch p2 =<< app (accs !! 1) [v]]
Bennet Bleßmann's avatar
Bennet Bleßmann committed
    let (takes, varss) = unzip parts
    (, concat varss) <$> Z3.mkAnd takes
Bennet Bleßmann's avatar
Bennet Bleßmann committed
evaluateCaseBranch  (PCtor _ cidx subPats) v = do
    sort <- Z3.getSort v
    recog <- flip (!!) cidx <$> Z3.getDatatypeSortRecognizers sort
    acces <- flip (!!) cidx <$> Z3.getDatatypeSortConstructorAccessors sort
    isCtor <- app recog [v]
    (takes,varss) <- unzip <$> mapM (\(subPat,idx,_) -> evaluateCaseBranch subPat =<< app (acces !! idx) [v]) subPats
    (, concat varss) <$> Z3.mkAnd (isCtor:takes)
Bennet Bleßmann's avatar
Bennet Bleßmann committed
evaluateCaseBranch  (PCons   h t) v = do
    sort <- Z3.getSort v
    recog <- flip (!!) 1 <$> Z3.getDatatypeSortRecognizers sort
    isCons <- app recog [v]
    acces <- flip (!!) 1 <$> Z3.getDatatypeSortConstructorAccessors sort
    (hTake,hVars) <- evaluateCaseBranch h =<< app (head acces) [v]
    (tTake,tVars) <- evaluateCaseBranch t =<< app (acces !! 1) [v]
    (, hVars ++ tVars) <$> Z3.mkAnd [isCons, hTake, tTake]
Bennet Bleßmann's avatar
Bennet Bleßmann committed
evaluateCaseBranch  (PList     l) v = case l of
        [] -> do
            sort <- Z3.getSort v
            recogs <- Z3.getDatatypeSortRecognizers sort
            (,[]) <$> app (head recogs) [v]
Bennet Bleßmann's avatar
Bennet Bleßmann committed
        (h:t) -> evaluateCaseBranch (PCons h $ PList t) v
evaluateCaseBranch  (PString   s) v = do
    sv <- createString DataTypes.StringId s
    (,[]) <$> mkEq sv v
evaluateCaseBranch  (PChar     c) v = do
    sv <- createString DataTypes.CharId c
    (,[]) <$> mkEq sv v
David Wolff's avatar
David Wolff committed

astToSortId :: AST -> SMTMonad DataTypes.SortId
 sort <- Z3.getSort ast
 res <- reverseLookupSort sort
 case res of
    Just sortId -> return sortId
    Nothing -> do
        sortStr <- Z3.sortToString sort
        return $ error $ "Failed on reverse lookup on " ++ sortStr
ite :: AST -> SMTMonad AST -> SMTMonad AST -> SMTMonad AST
ite predicate e1 e2 = do
  negPredicate <- Z3.mkNot predicate
  a1 <- branch predicate e1
  a2 <- branch negPredicate e2

  double <- useDouble a1 a2
  (a1',a2') <- if double
    then (,) <$> makeDouble a1 <*> makeDouble a2
    else return (a1,a2)

  Z3.mkIte predicate a1' a2'

handleCall :: VarType -> String -> AST.Sort -> [Expr] -> SMTMonad AST
handleCall _ "ite" _ [p, e1, e2] = do
  predicate <- transformExpr p
  ite predicate (transformExpr e1) $ transformExpr e2
handleCall (Operator "Basics" "or") _ _ [e1, e2] =
  liftL2 Z3.mkOr (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "and") _ _ [e1, e2] =
  liftL2 Z3.mkAnd (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "neq") _ _ [e1, e2] =
  Z3.mkNot =<< lift2 mkEq (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "eq") _ _ [e1, e2] =
  lift2 mkEq (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "ge") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGenGe (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "gt") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGenGt (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "le") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGenLe (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "lt") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGenLt (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "add") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGenadd (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "sub") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGensub (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "mul") _ _ [e1, e2] =
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
  lift2 mkGenmul (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "fdiv") _ _ [e1, e2] =
  lift2 Z3.mkDiv (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "idiv") _ _ [e1, e2] =
  lift2 Z3.mkBvsdiv (transformExpr e1) (transformExpr e2)
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
handleCall (Operator "List" "cons") _ _ [e1, e2] =
  lift2 mkCons (transformExpr e1) (transformExpr e2)
handleCall (Operator "Basics" "append") a b [e1, e2] =
  mkGenappend a b e1 e2
handleCall (Foreign "Basics") "not" _ [e1] =
  Z3.mkNot =<< transformExpr e1
handleCall (TopLevel "Basics") "not" _ [e1] =
  Z3.mkNot =<< transformExpr e1
handleCall (Foreign "Basics") "remainderBy" _ [e1, e2] =
  lift2 Z3.mkBvsrem (transformExpr e1) (transformExpr e2)
handleCall (Foreign "Basics") "modBy" _ [e1, e2] =
  lift2 Z3.mkBvsmod (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "fromInt" _ [e1] =
  flip Z3.mkBv2int True =<< transformExpr e1
handleCall (Foreign "BigInt") "add" _ [e1,e2] =
  liftL2 Z3.mkAdd (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "sub" _ [e1,e2] =
  liftL2 Z3.mkSub (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "mul" _ [e1,e2] =
  liftL2 Z3.mkMul (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "div" _ [e1,e2] =
  liftL2 Z3.mkMul (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "gt" _ [e1,e2] =
  lift2 Z3.mkGt (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "gte" _ [e1,e2] =
  lift2 Z3.mkGe (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "lt" _ [e1,e2] =
  lift2 Z3.mkLt (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "lte" _ [e1,e2] =
  lift2 Z3.mkLe (transformExpr e1) (transformExpr e2)
handleCall (Foreign "BigInt") "eq" _ [e1,e2] =
  lift2 mkEq (transformExpr e1) (transformExpr e2)
handleCall (Ctor _) ctorName sort params = do
  let resSort = getResultSort sort
  sort' <- getSortFromType resSort -- doesn't handle unknown type
  let transC ctor = do
       symbol <- Z3.getDeclName ctor
       name <- Z3.getSymbolString symbol
  ctors <- mapM transC =<< Z3.getDatatypeSortConstructors sort'
  let ctor = snd $ fromMaybe (error "Failed to find Constructor") $ List.find (List.isPrefixOf (ctorName++"#") . fst) ctors
  params' <- mapM transformExpr params
handleCall from name tipe es = do
  let mdl = fromMaybe "" $ getVarTypeModule from
  params <- mapM transformExpr es -- needs to happen in old type scope
  paramSort <- mapM astToSortId params
  typeScope $ do
      initNewTypeScope tipe paramSort
      sortId <- getSortIdFromType tipe
      whenDebugIO $ putStrLn $ "Parameter: "       ++ show es
      whenDebugIO $ putStrLn $ "Parameter sorts: " ++ show paramSort
      whenDebugIO $ putStrLn $ "Function tipe: "   ++ show tipe
      func <- getOrDeclareFunction (mdl,name) sortId paramSort
      case func of
        Right smtdecl ->
          app smtdecl params
        Left (smtdecl, fundecl@(FunDefWithConds sig _ res _ _)) -> do
          whenDebugIO $ putStrLn $ "Function Signature: " ++ show sig
          declareFunctionConditions (mdl,name) sortId fundecl params
          -- TODO check preconditions
          -- this should look somewhat like:
          {- (push)
          - (assert (and (pathconditions)))
          - (assert (not (and (preconditions applied to args))))
          - (check-sat) and do something with the result
          - (pop)
          -}
          isVisited <- isVisitedFunction (mdl, name)
          if isVisited
            then do
              resVal <- app smtdecl params
              -- Make Equations with extensible Records and TVars in the signature to enhance expressiveness
              (resultERs', _)                <- getExtensibleRecords res resVal
              (paramsERs', problematicTVars') <- unzip <$> zipWithM getExtensibleRecords (map snd sig) params
              let -- Remove occurrences where we might not have all related occurences in the params
                  problematicTVars = Set.unions problematicTVars'
                  resultERs = removeProblematicOccurrences problematicTVars resultERs'
                  paramsERs = removeProblematicOccurrences problematicTVars $ foldr combineOccurrences [] paramsERs'
              mapM_ (\(s, os) -> mapM (\o -> assertOneOfOccurrences s o paramsERs) os) resultERs
              return resVal
            else transformFunc sortId (mdl, name) params fundecl

transformFunc :: SortId -> FunctionIdentifier -> [AST] -> FunctionDefinition
              -> SMTMonad AST
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
transformFunc _ name params (FunDefWithConds sig body _ _ _) = do
  pvs <- concat <$> zipWithM bindArgToPattern (map fst sig) params
  cFun <- getCurrentFunction
  setCurrentFunction name
  addVisitedFunction name
  res <- varsWithScope2 pvs $ transformExpr body
  removeVisitedFunction name
  setCurrentFunction cFun
  return res
getResultSort :: AST.Sort -> AST.Sort
getResultSort inp = case inp of
                        FunId _ to  -> getResultSort to
                        a           -> a

initNewTypeScope :: AST.Sort -> [DataTypes.SortId] -> SMTMonad ()
initNewTypeScope _callType             []     = return ()
initNewTypeScope (AST.FunId from to)   (x:xs) = initMerge (AST.removeAliases from) x >> initNewTypeScope to xs
initNewTypeScope callType              param  = initMerge callType $ foldr1 DataTypes.LambdaId param
initMerge :: AST.Sort -> DataTypes.SortId -> SMTMonad ()
initMerge (AST.VarId name)          smth                              = bindTypeVariable name smth
initMerge (AST.FunId fl tl)         (DataTypes.LambdaId fr tr)        = initMerge fl fr >> initMerge tl tr
initMerge (AST.ListId lInner)       (DataTypes.ListId rInner)         = initMerge lInner rInner
initMerge (AST.TupleId al bl mcl)   (DataTypes.TupleId (ar: br: mcr)) = initMerge al ar >> initMerge bl br
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
                                                                                    >> case  (mcl    , mcr  ) of
                                                                                             (Nothing, []   ) -> return ()
                                                                                             (Just cl, cr:_ ) -> initMerge cl cr
                                                                                             _                -> error "Trying to merge Pair with Triple!"
initMerge (AST.RecordId fieldsl mv) r@(DataTypes.RecordId fieldsr)    = do
                                                                            case mv of
                                                                                Nothing -> return ()
                                                                                Just v -> bindTypeVariable v r
                                                                            mapM_ (uncurry initMerge) $ Map.intersectionWith (,) fieldsl fieldsr
initMerge AST.UnionId{AST.args = argsl} (DataTypes.TypeId _ _ argsr)  = zipWithM_ initMerge argsl argsr
initMerge a@AST.AliasId{}           r                                 = initMerge (stripAliases a) r
initMerge _                         _                                 = return()
combineToFunctionType' :: Signature -> AST.Sort -> SMTMonad DataTypes.SortId
combineToFunctionType' sig res = do
    argIds <- mapM (getSortIdFromType.snd) sig
    resId  <- getSortIdFromType res
    return $ foldr LambdaId resId argIds
Bennet Bleßmann's avatar
Bennet Bleßmann committed
pickSort' :: AST.Sort -> Either String AST.Sort -> AST.Sort
pickSort' a (Left  _b) = a
Tammo Heilemann's avatar
Tammo Heilemann committed
pickSort' a (Right b) = pickSort a b
liftL2 :: Monad m => ([AST] -> m a) -> m AST -> m AST -> m a
liftL2 f am bm = do
  a <- am
  b <- bm
  f [a, b]

lift2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c
lift2 f am bm = do
  a <- am
  b <- bm
  f a b

lift3 :: Monad m => (a -> b -> c -> m d) -> m a -> m b -> m c -> m d
lift3 f am bm cm = do
  a <- am
  b <- bm
  c <- cm
  f a b c
Bennet Bleßmann's avatar
Bennet Bleßmann committed
{-
replaceSnds :: [(a, b)] -> [c]-> [(a, c)]
replaceSnds = zipWith (\(a,_) c -> (a, c))
Bennet Bleßmann's avatar
Bennet Bleßmann committed
-}
Bennet Bleßmann's avatar
Bennet Bleßmann committed
{-
rebindVariables :: [AST] -> [String] -> SMTMonad [(String, Z3.Sort)]
rebindVariables []            []             = return []
rebindVariables (oldVar : xs) (newName : ys) = do
  v <- fromMaybe (error "Failed on Variable lookup") <$> getVariable newName
  Z3.assert =<< mkEq oldVar v
  rebindVariables xs ys
rebindVariables _ _ = return []
Bennet Bleßmann's avatar
Bennet Bleßmann committed
-}
-- Creates a tuple with variable name and variable sort for an element in a signature
-- Ignores deconstruction via pattern
getVarFromSig :: (Pattern, SortId) -> Int -> SMTMonad (String, Z3.Sort)
getVarFromSig (PVariable n, ty) _ = (,) n <$> defineOrGetType ty
getVarFromSig (_          , ty) i = (,) n <$> defineOrGetType ty
  where n = "#" ++ show i
-- Creates a list of tuples with variable name and variable sort for an element in a signature
-- Performs deconstruction via pattern
getVarsFromSig :: (Pattern, SortId) -> Int -> SMTMonad [(String, Z3.Sort)]
getVarsFromSig (PVariable n   , ty)                       _ = flip (:) [] . (,) n <$> defineOrGetType ty
getVarsFromSig (PRecord fields, ty) _ =
  forM fields (\n -> (,) n <$> defineOrGetType (fromMaybe (error $ "Record pattern tries to access unavailable field " ++ n) $ Map.lookup n (fieldmap ty)))
  where
Marvin Lira's avatar
Marvin Lira committed
    fieldmap (DataTypes.RecordId fm) = fm
    fieldmap t = error $ "Record pattern has type " ++ show t
getVarsFromSig (PDefault      , ty)                       i = flip (:) [] . (,) n <$> defineOrGetType ty
  where n = "#" ++ show i
getVarsFromSig (PCtor _name _index _subPattern, _ty)     _i = error   "PCtor not yet handled!" -- TO-DO
getVarsFromSig pat                                       _i = error $ "Pattern not implemented " ++ show pat -- TO-DO
-}

-- Creates a list of tuples with variable name and value for an element in a signature and an argument
-- Performs deconstruction via pattern
bindArgToPattern :: Pattern -> AST -> SMTMonad [(String,AST)]
David Wolff's avatar
David Wolff committed
bindArgToPattern (PVariable n     ) v = return [(n,v)]
bindArgToPattern (PRecord fields)   v = Map.toList . flip Map.restrictKeys (Set.fromList fields) <$> getRecordFieldMap v
David Wolff's avatar
David Wolff committed
bindArgToPattern  PDefault          _ = return []
Bennet Bleßmann's avatar
Bennet Bleßmann committed
bindArgToPattern  PUnit             _ = return []
bindArgToPattern  (PAs p n)         v = (:) (n,v) <$> bindArgToPattern p v
bindArgToPattern  (PTuple p1 p2 p3) v = do
  sort <- Z3.getSort v
  accs <- head <$> Z3.getDatatypeSortConstructorAccessors sort
David Wolff's avatar
David Wolff committed
  case p3 of
    Just p3J -> concat <$> sequence [bindArgToPattern p1 =<< app (head accs) [v] , bindArgToPattern p2 =<< app (accs !! 1) [v] , bindArgToPattern p3J =<< app (accs!!2) [v]]
    Nothing  -> concat <$> sequence [bindArgToPattern p1 =<< app (head accs) [v] , bindArgToPattern p2 =<< app (accs !! 1) [v]]
bindArgToPattern  (PCtor _ 0 subPats) v = do
    sort <- Z3.getSort v
    accses <- Z3.getDatatypeSortConstructorAccessors sort
Bennet Bleßmann's avatar
Bennet Bleßmann committed
    case accses of
        [accs] ->  -- type has exactly one constructor
                  concat <$> mapM (\(subPat,idx,_) -> bindArgToPattern subPat =<< app (accs !! idx) [v]) subPats
        _      -> return $ error "Unexpected refutable pattern PCtor"  -- type has no or more than one constructor
bindArgToPattern  (PCtor _ _ _) _ = return $ error "Unexpected refutable pattern PCtor"
bindArgToPattern  (PCons   _ _) _ = return $ error "Unexpected refutable pattern PCons"
bindArgToPattern  (PList     _) _ = return $ error "Unexpected refutable pattern PList"
bindArgToPattern  (PString   _) _ = return $ error "Unexpected refutable pattern PString"
bindArgToPattern  (PChar     _) _ = return $ error "Unexpected refutable pattern PChar"
Bennet Bleßmann's avatar
Bennet Bleßmann committed
-- Bool, Int and Double all have more than one possible value and Unit has it's own pattern
bindArgToPattern  (PConstant _) _ = return $ error "Unexpected refutable pattern PConstant"

getVarFromType :: String -> AST.Sort -> SMTMonad (String, Z3.Sort)
getVarFromType name ty = (,) name <$> getSortFromType ty

getSortFromType :: HasCallStack => AST.Sort -> SMTMonad Z3.Sort
Kai-Oliver Prott's avatar
Kai-Oliver Prott committed
getSortFromType sort = do
  ressort <- getSortIdFromType sort
  whenDebug $ liftIO $ putStrLn "for sort:"
  whenDebug $ liftIO $ print sort
  whenDebug $ liftIO $ putStrLn "recieved sort id:"
  whenDebug $ liftIO $ print ressort
  defineOrGetType ressort
constructList :: SortId -> [AST] -> SMTMonad AST
constructList sort [] = do
    sort' <- defineOrGetType sort
    cons <- head <$> Z3.getDatatypeSortConstructors sort'
    app cons []
constructList sort (x:xs) = do
    sort' <- defineOrGetType sort
    tl <- constructList sort xs
    cons <- (!! 1) <$> Z3.getDatatypeSortConstructors sort'
    app cons [x,tl]

createString :: SortId -> String -> SMTMonad AST
createString sort chars = do
    chars' <- mapM convertChar chars
    constructList sort chars'

convertChar :: Char -> SMTMonad AST
convertChar c = Z3.mkInt (fromEnum c) =<< defineOrGetType DataTypes.IntId
mkGenappend :: String -> AST.Sort -> Expr -> Expr -> SMTMonad AST
mkGenappend _ b x y = do
  sort <- Z3.getSort =<< transformExpr x
  str <- defineOrGetType DataTypes.StringId
  if sort == str
    then handleCall (Foreign "Basics") "appendStr"  b [x, y]
    else handleCall (Foreign "Basics") "appendList" b [x, y]

{-# ANN module ("HLint: ignore Reduce duplication"::String)  #-}
{-# ANN module ("HLint: ignore Use record patterns"::String) #-}