0 | module Compiler.Opts.ConstantFold
  1 |
  2 | import Core.CompileExpr
  3 | import Core.Context.Log
  4 | import Core.Primitives
  5 | import Core.Value
  6 | import Data.Vect
  7 |
  8 | import Data.List.HasLength
  9 | import Libraries.Data.List.SizeOf
 10 |
 11 |
 12 | findConstAlt : Constant -> List (CConstAlt vars) ->
 13 |                Maybe (CExp vars) -> Maybe (CExp vars)
 14 | findConstAlt c [] def = def
 15 | findConstAlt c (MkConstAlt c' exp :: alts) def = if c == c'
 16 |     then Just exp
 17 |     else findConstAlt c alts def
 18 |
 19 | foldableOp : PrimFn ar -> Bool
 20 | foldableOp BelieveMe = False
 21 | foldableOp (Cast IntType _) = False
 22 | foldableOp (Cast _ IntType) = False
 23 | foldableOp (Cast from to)   = isJust (intKind from) && isJust (intKind to)
 24 | foldableOp _                = True
 25 |
 26 |
 27 | data Subst : Scope -> Scoped where
 28 |   Nil  : Subst Scope.empty vars
 29 |   (::) : CExp vars -> Subst ds vars -> Subst (d :: ds) vars
 30 |   Wk   : SizeOf ws -> Subst ds vars -> Subst (ws ++ ds) (ws ++ vars)
 31 |
 32 | namespace Subst
 33 |   public export
 34 |   empty : Subst Scope.empty vars
 35 |   empty = []
 36 |
 37 | initSubst : (vars : Scope) -> Subst vars vars
 38 | initSubst [] = Subst.empty
 39 | initSubst vars
 40 |   = rewrite sym $ appendNilRightNeutral vars in
 41 |     Wk (mkSizeOf vars) Subst.empty
 42 |
 43 | wk : SizeOf out -> Subst ds vars -> Subst (out ++ ds) (out ++ vars)
 44 | wk sout (Wk {ws, ds, vars} sws rho)
 45 |   = rewrite appendAssociative out ws ds in
 46 |     rewrite appendAssociative out ws vars in
 47 |     Wk (sout + sws) rho
 48 | wk ws rho = Wk ws rho
 49 |
 50 | record WkCExp (vars : Scope) where
 51 |   constructor MkWkCExp
 52 |   {0 outer, supp : Scope}
 53 |   size : SizeOf outer
 54 |   0 prf : vars === outer ++ supp
 55 |   expr : CExp supp
 56 |
 57 | Weaken WkCExp where
 58 |   weakenNs s' (MkWkCExp {outer, supp} s Refl e)
 59 |     = MkWkCExp (s' + s) (appendAssociative ns outer supp)  e
 60 |
 61 | lookup : FC -> Var ds -> Subst ds vars -> CExp vars
 62 | lookup fc (MkVar p) rho = case go p rho of
 63 |     Left (MkVar p') => CLocal fc p'
 64 |     Right (MkWkCExp s Refl e) => weakenNs s e
 65 |
 66 |   where
 67 |
 68 |   go : {i : Nat} -> {0 ds, vars : _} -> (0 _ : IsVar n i ds) ->
 69 |        Subst ds vars -> Either (Var vars) (WkCExp vars)
 70 |   go First     (val :: rho) = Right (MkWkCExp zero Refl val)
 71 |   go (Later p) (val :: rho) = go p rho
 72 |   go p         (Wk ws  rho) = case sizedView ws of
 73 |     Z => go p rho
 74 |     S ws' => case i of
 75 |       Z => Left first
 76 |       S i' => bimap later weaken (go (dropLater p) (Wk ws' rho))
 77 |
 78 | replace : CExp vars -> Bool
 79 | replace (CLocal {})   = True
 80 | replace (CPrimVal {}) = True
 81 | replace (CErased {})  = True
 82 | replace _             = False
 83 |
 84 | -- constant folding of primitive functions
 85 | -- if a primitive function is applied to only constant
 86 | -- then replace with the result
 87 | -- if there's only 1 constant argument to a commutative function
 88 | -- move the constant to the right. This simplifies Compiler.Opts.Identity
 89 | constFold : {vars' : _} ->
 90 |             Subst vars vars' ->
 91 |             CExp vars -> CExp vars'
 92 | constFold rho (CLocal fc p) = lookup fc (MkVar p) rho
 93 | constFold rho e@(CRef fc x) = CRef fc x
 94 | constFold rho (CLam fc x y)
 95 |   = CLam fc x $ constFold (wk (mkSizeOf (Scope.single x)) rho) y
 96 |
 97 | -- Expressions of the type `let x := y in x` can be introduced
 98 | -- by the compiler when inlining monadic code (for instance, `io_bind`).
 99 | -- They can be replaced by `y`.
100 | constFold rho (CLet fc x inl y z) =
101 |     let val := constFold rho y
102 |      in case replace val of
103 |           True  => constFold (val::rho) z
104 |           False => case constFold (wk (mkSizeOf (Scope.single x)) rho) z of
105 |             CLocal {idx = 0} _ _ => val
106 |             body                 => CLet fc x inl val body
107 | constFold rho (CApp fc (CRef fc2 n) [x]) =
108 |   if n == NS typesNS (UN $ Basic "prim__integerToNat")
109 |      then case constFold rho x of
110 |             CPrimVal fc3 (BI v) =>
111 |               if v >= 0 then CPrimVal fc3 (BI v) else CPrimVal fc3 (BI 0)
112 |             v                   => CApp fc (CRef fc2 n) [v]
113 |      else CApp fc (CRef fc2 n) [constFold rho x]
114 | constFold rho (CApp fc x xs) = CApp fc (constFold rho x) (constFold rho <$> xs)
115 | -- erase `UNIT` constructors, so they get constant-folded
116 | -- in `let` bindings (for instance, when optimizing `(>>)` for `IO`
117 | constFold rho (CCon fc x UNIT tag []) = CErased fc
118 | constFold rho (CCon fc x y tag xs) = CCon fc x y tag $ constFold rho <$> xs
119 | constFold rho (COp fc BelieveMe [CErased _, CErased _ , x]) = constFold rho x
120 | constFold rho (COp {arity} fc fn xs) =
121 |     let xs' = map (constFold rho) xs
122 |         e = constRight fc fn xs'
123 |      in fromMaybe e $ do
124 |           guard (foldableOp fn)
125 |           nfs <- traverse toNF xs'
126 |           nf <- getOp fn nfs
127 |           fromNF nf
128 |   where
129 |     toNF : CExp vars' -> Maybe (NF vars')
130 |     -- Don't fold `Int` and `Double` because they have varying widths
131 |     toNF (CPrimVal fc (I _)) = Nothing
132 |     toNF (CPrimVal fc (Db _)) = Nothing
133 |     -- Fold the rest
134 |     toNF (CPrimVal fc c) = Just $ NPrimVal fc c
135 |     toNF _ = Nothing
136 |
137 |     fromNF : NF vars' -> Maybe (CExp vars')
138 |     fromNF (NPrimVal fc c) = Just $ CPrimVal fc c
139 |     fromNF _ = Nothing
140 |
141 |     commutative : PrimType -> Bool
142 |     commutative DoubleType = False
143 |     commutative _ = True
144 |
145 |     constRight : {ar : _} -> FC -> PrimFn ar ->
146 |                  Vect ar (CExp vars') -> CExp vars'
147 |     constRight fc (Add ty) [x@(CPrimVal {}), y] =
148 |         if commutative ty
149 |             then COp fc (Add ty) [y, x]
150 |             else COp fc (Add ty) [x, y]
151 |     constRight fc (Mul ty) [x@(CPrimVal {}), y] =
152 |         if commutative ty
153 |             then COp fc (Mul ty) [y, x]
154 |             else COp fc (Mul ty) [x, y]
155 |     constRight fc fn args = COp fc fn args
156 |
157 | constFold rho (CExtPrim fc p xs) = CExtPrim fc p $ constFold rho <$> xs
158 | constFold rho (CForce fc x y) = CForce fc x $ constFold rho y
159 | constFold rho (CDelay fc x y) = CDelay fc x $ constFold rho y
160 | constFold rho (CConCase fc sc xs x)
161 |   = CConCase fc (constFold rho sc) (foldAlt <$> xs) (constFold rho <$> x)
162 |   where
163 |     foldAlt : CConAlt vars -> CConAlt vars'
164 |     foldAlt (MkConAlt n ci t xs e)
165 |       = MkConAlt n ci t xs $ constFold (wk (mkSizeOf xs) rho) e
166 |
167 | constFold rho (CConstCase fc sc xs x) =
168 |     let sc' = constFold rho sc
169 |      in case sc' of
170 |         CPrimVal _ val => case findConstAlt val xs x of
171 |             Just exp => constFold rho exp
172 |             Nothing => CConstCase fc (constFold rho sc) (foldAlt <$> xs) (constFold rho <$> x)
173 |         _ => CConstCase fc (constFold rho sc) (foldAlt <$> xs) (constFold rho <$> x)
174 |   where
175 |     foldAlt : CConstAlt vars -> CConstAlt vars'
176 |     foldAlt (MkConstAlt c e) = MkConstAlt c $ constFold rho e
177 | constFold rho (CPrimVal fc v) = CPrimVal fc v
178 | constFold rho (CErased fc) = CErased fc
179 | constFold rho (CCrash fc err) = CCrash fc err
180 |
181 | constFoldCDef : CDef -> Maybe CDef
182 | constFoldCDef (MkFun args exp)
183 |   = Just $ MkFun args $ constFold (initSubst args) exp
184 | constFoldCDef _ = Nothing
185 |
186 | export
187 | constantFold : Ref Ctxt Defs => Name -> Core ()
188 | constantFold fn = do
189 |     defs <- get Ctxt
190 |     Just (fnIdx, gdef) <- lookupCtxtExactI fn defs.gamma
191 |         | Nothing => pure ()
192 |     let Just cdef = gdef.compexpr
193 |         | Nothing => pure ()
194 |     let Just cdef' = constFoldCDef cdef
195 |         | Nothing => pure ()
196 |     logC "compiler.const-fold" 50 $ do pure $ "constant folding " ++ show !(getFullName fn)
197 |                                            ++ "\n\told def: " ++ show cdef
198 |                                            ++ "\n\tnew def: " ++ show cdef'
199 |     setCompiled (Resolved fnIdx) cdef'
200 |