-- Copy/constant propagation optimisation
-- Replaces variables with their assigned values in-place, in order to
-- avoid creation of intermediate boxed Int/Float values

module Propagate where

import Language

import Control.Monad.State
import Debug.Trace

copyPropagate :: Program -> Program
copyPropagate = opt' where
    opt' [] = []
    opt' ((FunBind (f,l,nm,ty,opts,Defined exp) com ority):xs) =
	     (FunBind (f,l,nm,ty,opts,Defined (propagate exp)) com ority):(opt' xs)
    opt' (x:xs) = x:(opt' xs)

propagate :: Expr Name -> Expr Name
--propagate = id
propagate e = let (prope, atys) = propLoop 4 e [] in
                  dropUnused atys (locsUsed prope) prope

propLoop :: Int -> Expr Name -> [ArgType] -> (Expr Name, [ArgType])
propLoop 0 e tys = (e, tys)
propLoop i e _ 
    = let (prope, (_, atys, act)) = runState (doProp e) ([], [], False) in
          if act 
             then propLoop (i-1) prope atys
             else (prope, atys)

-- Local variable index and value
type Prop = (Int, Expr Name)

dropUnused :: [ArgType] -> [Int] -> Expr Name -> Expr Name
dropUnused atys used exp = du (mapsubexpr (dropUnused atys used) Metavar exp)
  where -- only drop if ex is propagable, since it might be side-effecting
        -- Don't drop assignments to var arguments!
     du a@(Assign (AName i) ex) | (not (propagable 8 ex)) || i `elem` used = a
                                | (i<length atys) && (atys!!i == Var) = a
                                | otherwise = Noop
     du x = x

doProp :: Expr Name -> State ([Prop], [ArgType], Bool) (Expr Name)
-- Do it!
doProp (Loc i) = do (ps, as, action) <- get
                    case lookup i ps of 
                       Just x -> do put (ps, as, True)
                                    {- trace (show (i,x)) $ -}
                                    return x
                       _ -> return (Loc i)
doProp (Lambda as ts body) = do (ps, _, act) <- get
                                put (ps, as, act)
                                body' <- doProp body
                                return (Lambda as ts body')
doProp (Closure as ts body) = do body' <- doProp body
                                 return (Closure as ts body')
doProp (Declare f l n t ex) = do ex' <- doProp ex
                                 return (Declare f l n t ex')
doProp (Bind n t v b) = do v' <- doProp v
                           b' <- doProp b
                           return (Bind n t v' b')
doProp (Return x) = do x' <- doProp x
                       return (Return x')
doProp (Throw x) = do x' <- doProp x
                      return (Throw x')
doProp (Assign (AName i) ex) = 
    do ex' <- doProp ex
       (ps, as, act) <- get
       -- Remove anything from props where i is on the rhs, and only
       -- add when 'ex' is propagatable and doesn't contain i
       let ps' = dropRHS i ps
       let ps'' = if (not (i `elem` locsUsed ex) && propagable 8 ex) 
                     then addProp ps' i ex'
                     else ps'
       put (ps'', as, act)
       return (Assign (AName i) ex')
doProp (Assign a ex) = 
    do ex' <- doProp ex
       return (Assign a ex')
doProp (AssignOp op a ex) = 
    do modify a
       ex' <- doProp ex
       return (AssignOp op a ex')
   where modify (AName i) = do (ps, as, act) <- get
                               put (dropRHS i ps, as, act)
         modify _ = return ()
doProp (AssignApp a ex) = 
    do modify a
       ex' <- doProp ex
       return (AssignApp a ex')
   where modify (AName i) = do (ps, as, act) <- get
                               put (dropRHS i ps, as, act)
         modify _ = return ()
doProp (Seq x y) = do x' <- doProp x
                      y' <- doProp y
                      return $ Seq x' y'
-- Can't do propagation on anything which gets modified in the loop,
-- so check what does
doProp (While t block) = do let mods = modified block
                            (ps, as, act) <- get
                            put (dropMod mods ps, as, act)
                            t' <- doProp t
                            block' <- doProp block
                            -- Block may not get executed, so drop again
                            put (dropMod mods ps, as, act)
                            return (While t' block')
doProp (DoWhile block t) = do let mods = modified block
                              (ps, as, act) <- get
                              put (dropMod mods ps, as, act)
                              t' <- doProp t
                              block' <- doProp block
                              return (DoWhile block' t')
doProp (If x t e) = do x' <- doProp x
                       t' <- doProp t
                       -- Anything modified in then branch should be forgotten
                       -- for the else branch
                       (ps, as, act) <- get
                       put (dropMod (modified t) ps, as, act)
                       e' <- doProp e
                       -- Anything modified in the branches has not necessarily
                       -- been executed, so forget those assignments
                       (ps, as, act) <- get
                       put (dropMod (modified e) ps, as, act)
                       return (If x' t' e')
doProp c@(Case e alts) = do let mods = modified c
                            e' <- doProp e
                            alts' <- mapM doPropCase alts
                       -- Anything modified in the branches has not necessarily
                       -- been executed, so forget those assignments
                         -- (ps, as, act) <- get
                         -- put (dropMod mods ps, as, act)
                            return (Case e' alts')
    where doPropCase (Alt tag tot es e) = do
               e' <- doProp e
               -- Anything modified in the branch has not necessarily
               -- been executed, so forget those assignments
               (ps, as, act) <- get
               -- es contains assignments, so remove them from ps
               put (dropBind es (dropMod (modified e) ps), as, act)
               return (Alt tag tot es e')
          doPropCase (ArrayAlt es e) = do
               e' <- doProp e
               -- Anything modified in the branch has not necessarily
               -- been executed, so forget those assignments
               (ps, as, act) <- get
               -- es contains assignments, so remove them from ps
               put (dropBind es (dropMod (modified e) ps), as, act)
               return (ArrayAlt es e')
          doPropCase (ConstAlt t c e) = do
               e' <- doProp e
               -- Anything modified in the branch has not necessarily
               -- been executed, so forget those assignments
               (ps, as, act) <- get
               -- es contains assignments, so remove them from ps
               put (dropMod (modified e) ps, as, act)
               return (ConstAlt t c e')
          doPropCase (Default e) = do
               e' <- doProp e
               -- Anything modified in the branch has not necessarily
               -- been executed, so forget those assignments
               (ps, as, act) <- get
               -- es contains assignments, so remove them from ps
               put (dropMod (modified e) ps, as, act)
               return (Default e')
          dropBind [] ps = ps
          dropBind ((Loc i):xs) ps = (dropRHS i (dropBind xs ps))
          dropBind (_:xs) ps = dropBind xs ps

doProp (Infix op x y) = do x' <- doProp x
                           y' <- doProp y
                           return (Infix op x' y')
doProp (RealInfix op x y) = do x' <- doProp x
                               y' <- doProp y
                               return (RealInfix op x' y')
doProp (Unary op x) = do x' <- doProp x
                         return (Unary op x')
doProp (RealUnary op x) = do x' <- doProp x
                             return (RealUnary op x')
doProp (Coerce t1 t2 x) = do x' <- doProp x
                             return (Coerce t1 t2 x')
doProp (CmpExcept op x y) = do x' <- doProp x
                               y' <- doProp y
                               return (CmpExcept op x' y')
doProp (CmpStr op x y) = do x' <- doProp x
                            y' <- doProp y
                            return (CmpStr op x' y')
doProp (Append x y) = do x' <- doProp x
                         y' <- doProp y
                         return (Append x' y')
doProp (AppendChain es) = do es' <- mapM doProp es
                             return (AppendChain es')
doProp (ArrayInit es) = do es' <- mapM doProp es
                           return (ArrayInit es')
doProp (For i m j loopvar arr block) =
    do modify loopvar
       (ps, as, act) <- get
       let mods = modified block ++ inLval loopvar
       put (dropMod mods ps, as, act)
       block' <- doProp block
       -- Block may not get executed, so drop again
       put (dropMod mods ps, as, act)
       return (For i m j loopvar arr block')
   where modify (AName i) = do (ps, as, act) <- get
                               put (dropRHS i ps, as, act)
         modify _ = return ()
doProp (Apply f as) = do f' <- doProp f
   -- On its own as a function argument, it gets boxed anyway, and it might
   -- be a var arg, so don't propagate it.
                         as' <- mapM doPropArg as
                         return (Apply f' as')
doProp (Partial f as i) = do f' <- doProp f
   -- On its own as a function argument, it gets boxed anyway, and it might
   -- be a var arg, so don't propagate it.
                             as' <- mapM doPropArg as
                             return (Partial f' as' i)
doProp (Foreign t n es) = do es' <- mapM doPropArg (map fst es)
                             return (Foreign t n (zip es' (map snd es)))
doProp (NewTryCatch e cs) = do e' <- doProp e
                               cs' <- mapM doPropC cs
                               return (NewTryCatch e' cs')
   where doPropC (Catch (Left (n,es)) e) = do es' <- mapM doPropArg es
                                              e' <- doProp e
                                              return (Catch (Left (n,es')) e')
         doPropC (Catch (Right r) e) = do r' <- doProp r
                                          e' <- doProp e
                                          return (Catch (Right r') e')
doProp (NewExcept es) = do es' <- mapM doPropArg es
                           return (NewExcept es')
doProp (Index x y) = do x' <- doProp x
                        y' <- doProp y
                        return (Index x' y')
doProp (Field e n a t) = do e' <- doProp e
                            return (Field e' n a t)
doProp (Annotation a ex) = do ex' <- doProp ex
                              {- trace (show a) $ -}
                              return (Annotation a ex')
doProp x = return x

doPropArg (Loc i) = do (ps, as, act) <- get
                       put (dropRHS i ps, as, act) -- might have changed
                       return $ Loc i
doPropArg x = doProp x

-- Add a value to propagate; if it's already in the list, update it (since
-- it means the value to propagate has changed)
addProp :: [Prop] -> Int -> Expr Name -> [Prop]
addProp [] n ex = [(n,ex)]
addProp ((p,x):ps) n ex | n == p = (n,ex):ps
addProp (p:ps) n ex = (p:(addProp ps n ex))

-- Drop a variable which has been updated from the list to propagate
dropRHS :: Int -> [Prop] -> [Prop]
dropRHS i [] = []
dropRHS i ((n,ex):ps) | i `elem` locsUsed ex = dropRHS i ps
                      | i == n = dropRHS i ps
dropRHS i (p:ps) = p:(dropRHS i ps)

dropMod :: [Int] -> [Prop] -> [Prop]
dropMod _ [] = []
dropMod mods ((i,x):ps) | i `elem` mods = dropMod mods ps
                        | or (map (\j -> j `elem` locsUsed x) mods) 
                            = dropMod mods ps
dropMod mods (p:ps) = p:(dropMod mods ps)


-- Say whether an expression is small enough to propagate - just
-- arithmetic and numeric (or numeric represented) constants really...
-- Give a threshold for maximum size (somewhat arbitrary...)

propagable :: Int -> Expr n -> Bool
propagable 0 _ = False -- reached threshold, too big
propagable th (GConst (Num _)) = True
propagable th (GConst (Ch _)) = True
propagable th (GConst (Bo _)) = True
propagable th (GConst (Re _)) = True
propagable th (Loc i) = True
propagable th (Infix _ x y) = propagable (th-1) x && propagable (th-1) y
propagable th (RealInfix _ x y) = propagable (th-1) x && propagable (th-1) y
propagable th (Unary _ x) = propagable (th-1) x
propagable th (RealUnary _ x) = propagable (th-1) x
propagable th (Annotation a e) = propagable th e
propagable _ _ = False

