0 | module Compiler.Opts.ConstantFold
2 | import Core.CompileExpr
3 | import Core.Context.Log
4 | import Core.Primitives
8 | import Data.List.HasLength
9 | import Libraries.Data.List.SizeOf
12 | findConstAlt : Constant -> List (CConstAlt vars) ->
13 | Maybe (CExp vars) -> Maybe (CExp vars)
14 | findConstAlt c [] def = def
15 | findConstAlt c (MkConstAlt c' exp :: alts) def = if c == c'
17 | else findConstAlt c alts def
19 | foldableOp : PrimFn ar -> Bool
20 | foldableOp BelieveMe = False
21 | foldableOp (Cast IntType _) = False
22 | foldableOp (Cast _ IntType) = False
23 | foldableOp (Cast from to) = isJust (intKind from) && isJust (intKind to)
27 | data Subst : Scope -> Scoped where
28 | Nil : Subst Scope.empty vars
29 | (::) : CExp vars -> Subst ds vars -> Subst (d :: ds) vars
30 | Wk : SizeOf ws -> Subst ds vars -> Subst (ws ++ ds) (ws ++ vars)
34 | empty : Subst Scope.empty vars
37 | initSubst : (vars : Scope) -> Subst vars vars
38 | initSubst [] = Subst.empty
40 | = rewrite sym $
appendNilRightNeutral vars in
41 | Wk (mkSizeOf vars) Subst.empty
43 | wk : SizeOf out -> Subst ds vars -> Subst (out ++ ds) (out ++ vars)
44 | wk sout (Wk {ws, ds, vars} sws rho)
45 | = rewrite appendAssociative out ws ds in
46 | rewrite appendAssociative out ws vars in
48 | wk ws rho = Wk ws rho
50 | record WkCExp (vars : Scope) where
51 | constructor MkWkCExp
52 | {0 outer, supp : Scope}
54 | 0 prf : vars === outer ++ supp
58 | weakenNs s' (MkWkCExp {outer, supp} s Refl e)
59 | = MkWkCExp (s' + s) (appendAssociative ns outer supp) e
61 | lookup : FC -> Var ds -> Subst ds vars -> CExp vars
62 | lookup fc (MkVar p) rho = case go p rho of
63 | Left (MkVar p') => CLocal fc p'
64 | Right (MkWkCExp s Refl e) => weakenNs s e
68 | go : {i : Nat} -> {0 ds, vars : _} -> (0 _ : IsVar n i ds) ->
69 | Subst ds vars -> Either (Var vars) (WkCExp vars)
70 | go First (val :: rho) = Right (MkWkCExp zero Refl val)
71 | go (Later p) (val :: rho) = go p rho
72 | go p (Wk ws rho) = case sizedView ws of
76 | S i' => bimap later weaken (go (dropLater p) (Wk ws' rho))
78 | replace : CExp vars -> Bool
79 | replace (CLocal {}) = True
80 | replace (CPrimVal {}) = True
81 | replace (CErased {}) = True
89 | constFold : {vars' : _} ->
91 | CExp vars -> CExp vars'
92 | constFold rho (CLocal fc p) = lookup fc (MkVar p) rho
93 | constFold rho e@(CRef fc x) = CRef fc x
94 | constFold rho (CLam fc x y)
95 | = CLam fc x $
constFold (wk (mkSizeOf (Scope.single x)) rho) y
100 | constFold rho (CLet fc x inl y z) =
101 | let val := constFold rho y
102 | in case replace val of
103 | True => constFold (val::rho) z
104 | False => case constFold (wk (mkSizeOf (Scope.single x)) rho) z of
105 | CLocal {idx = 0} _ _ => val
106 | body => CLet fc x inl val body
107 | constFold rho (CApp fc (CRef fc2 n) [x]) =
108 | if n == NS typesNS (UN $
Basic "prim__integerToNat")
109 | then case constFold rho x of
110 | CPrimVal fc3 (BI v) =>
111 | if v >= 0 then CPrimVal fc3 (BI v) else CPrimVal fc3 (BI 0)
112 | v => CApp fc (CRef fc2 n) [v]
113 | else CApp fc (CRef fc2 n) [constFold rho x]
114 | constFold rho (CApp fc x xs) = CApp fc (constFold rho x) (constFold rho <$> xs)
117 | constFold rho (CCon fc x UNIT tag []) = CErased fc
118 | constFold rho (CCon fc x y tag xs) = CCon fc x y tag $
constFold rho <$> xs
119 | constFold rho (COp fc BelieveMe [CErased _, CErased _ , x]) = constFold rho x
120 | constFold rho (COp {arity} fc fn xs) =
121 | let xs' = map (constFold rho) xs
122 | e = constRight fc fn xs'
123 | in fromMaybe e $
do
124 | guard (foldableOp fn)
125 | nfs <- traverse toNF xs'
129 | toNF : CExp vars' -> Maybe (NF vars')
131 | toNF (CPrimVal fc (I _)) = Nothing
132 | toNF (CPrimVal fc (Db _)) = Nothing
134 | toNF (CPrimVal fc c) = Just $
NPrimVal fc c
137 | fromNF : NF vars' -> Maybe (CExp vars')
138 | fromNF (NPrimVal fc c) = Just $
CPrimVal fc c
141 | commutative : PrimType -> Bool
142 | commutative DoubleType = False
143 | commutative _ = True
145 | constRight : {ar : _} -> FC -> PrimFn ar ->
146 | Vect ar (CExp vars') -> CExp vars'
147 | constRight fc (Add ty) [x@(CPrimVal {}), y] =
149 | then COp fc (Add ty) [y, x]
150 | else COp fc (Add ty) [x, y]
151 | constRight fc (Mul ty) [x@(CPrimVal {}), y] =
153 | then COp fc (Mul ty) [y, x]
154 | else COp fc (Mul ty) [x, y]
155 | constRight fc fn args = COp fc fn args
157 | constFold rho (CExtPrim fc p xs) = CExtPrim fc p $
constFold rho <$> xs
158 | constFold rho (CForce fc x y) = CForce fc x $
constFold rho y
159 | constFold rho (CDelay fc x y) = CDelay fc x $
constFold rho y
160 | constFold rho (CConCase fc sc xs x)
161 | = CConCase fc (constFold rho sc) (foldAlt <$> xs) (constFold rho <$> x)
163 | foldAlt : CConAlt vars -> CConAlt vars'
164 | foldAlt (MkConAlt n ci t xs e)
165 | = MkConAlt n ci t xs $
constFold (wk (mkSizeOf xs) rho) e
167 | constFold rho (CConstCase fc sc xs x) =
168 | let sc' = constFold rho sc
170 | CPrimVal _ val => case findConstAlt val xs x of
171 | Just exp => constFold rho exp
172 | Nothing => CConstCase fc (constFold rho sc) (foldAlt <$> xs) (constFold rho <$> x)
173 | _ => CConstCase fc (constFold rho sc) (foldAlt <$> xs) (constFold rho <$> x)
175 | foldAlt : CConstAlt vars -> CConstAlt vars'
176 | foldAlt (MkConstAlt c e) = MkConstAlt c $
constFold rho e
177 | constFold rho (CPrimVal fc v) = CPrimVal fc v
178 | constFold rho (CErased fc) = CErased fc
179 | constFold rho (CCrash fc err) = CCrash fc err
181 | constFoldCDef : CDef -> Maybe CDef
182 | constFoldCDef (MkFun args exp)
183 | = Just $
MkFun args $
constFold (initSubst args) exp
184 | constFoldCDef _ = Nothing
187 | constantFold : Ref Ctxt Defs => Name -> Core ()
188 | constantFold fn = do
190 | Just (fnIdx, gdef) <- lookupCtxtExactI fn defs.gamma
191 | | Nothing => pure ()
192 | let Just cdef = gdef.compexpr
193 | | Nothing => pure ()
194 | let Just cdef' = constFoldCDef cdef
195 | | Nothing => pure ()
196 | logC "compiler.const-fold" 50 $
do pure $
"constant folding " ++ show !(getFullName fn)
197 | ++ "\n\told def: " ++ show cdef
198 | ++ "\n\tnew def: " ++ show cdef'
199 | setCompiled (Resolved fnIdx) cdef'