0 | module Compiler.CaseOpts
4 | import Core.CompileExpr
9 | import Libraries.Data.List.SizeOf
29 | shiftUnder : {args : _} ->
31 | (0 p : IsVar n idx (x :: args ++ vars)) ->
32 | NVar n (args ++ x :: vars)
33 | shiftUnder First = weakenNVar (mkSizeOf args) (MkNVar First)
34 | shiftUnder (Later p) = insertNVar (mkSizeOf args) (MkNVar p)
36 | shiftVar : {outer : Scope} -> {args : List Name} ->
37 | NVar n (outer ++ (x :: args ++ vars)) ->
38 | NVar n (outer ++ (args ++ x :: vars))
40 | = let out = mkSizeOf outer in
41 | case locateNVar out nvar of
42 | Left nvar => embed nvar
43 | Right (MkNVar p) => weakenNs out (shiftUnder p)
46 | shiftBinder : {outer, args : _} ->
48 | CExp (outer ++ old :: (args ++ vars)) ->
49 | CExp (outer ++ (args ++ new :: vars))
50 | shiftBinder new (CLocal fc p)
51 | = case shiftVar (MkNVar p) of
52 | MkNVar p' => CLocal fc (renameVar p')
54 | renameVar : IsVar x i (outer ++ (args ++ (old :: rest))) ->
55 | IsVar x i (outer ++ (args ++ (new :: rest)))
56 | renameVar = believe_me
57 | shiftBinder new (CRef fc n) = CRef fc n
58 | shiftBinder {outer} new (CLam fc n sc)
59 | = CLam fc n $
shiftBinder {outer = n :: outer} new sc
60 | shiftBinder new (CLet fc n inlineOK val sc)
61 | = CLet fc n inlineOK (shiftBinder new val)
62 | $
shiftBinder {outer = n :: outer} new sc
63 | shiftBinder new (CApp fc f args)
64 | = CApp fc (shiftBinder new f) $
map (shiftBinder new) args
65 | shiftBinder new (CCon fc ci c tag args)
66 | = CCon fc ci c tag $
map (shiftBinder new) args
67 | shiftBinder new (COp fc op args) = COp fc op $
map (shiftBinder new) args
68 | shiftBinder new (CExtPrim fc p args)
69 | = CExtPrim fc p $
map (shiftBinder new) args
70 | shiftBinder new (CForce fc r arg) = CForce fc r $
shiftBinder new arg
71 | shiftBinder new (CDelay fc r arg) = CDelay fc r $
shiftBinder new arg
72 | shiftBinder new (CConCase fc sc alts def)
73 | = CConCase fc (shiftBinder new sc)
74 | (map (shiftBinderConAlt new) alts)
75 | (map (shiftBinder new) def)
76 | shiftBinder new (CConstCase fc sc alts def)
77 | = CConstCase fc (shiftBinder new sc)
78 | (map (shiftBinderConstAlt new) alts)
79 | (map (shiftBinder new) def)
80 | shiftBinder new (CPrimVal fc c) = CPrimVal fc c
81 | shiftBinder new (CErased fc) = CErased fc
82 | shiftBinder new (CCrash fc msg) = CCrash fc msg
84 | shiftBinderConAlt : {outer, args : _} ->
86 | CConAlt (outer ++ (x :: args ++ vars)) ->
87 | CConAlt (outer ++ (args ++ new :: vars))
88 | shiftBinderConAlt new (MkConAlt n ci t args' sc)
89 | = let sc' : CExp ((args' ++ outer) ++ (x :: args ++ vars))
90 | = rewrite sym (appendAssociative args' outer (x :: args ++ vars)) in sc in
91 | MkConAlt n ci t args' $
92 | rewrite (appendAssociative args' outer (args ++ new :: vars))
93 | in shiftBinder new {outer = args' ++ outer} sc'
95 | shiftBinderConstAlt : {outer, args : _} ->
97 | CConstAlt (outer ++ (x :: args ++ vars)) ->
98 | CConstAlt (outer ++ (args ++ new :: vars))
99 | shiftBinderConstAlt new (MkConstAlt c sc) = MkConstAlt c $
shiftBinder new sc
103 | liftOutLambda : {args : _} ->
105 | CExp (old :: args ++ vars) ->
106 | CExp (args ++ new :: vars)
107 | liftOutLambda = shiftBinder {outer = Scope.empty}
111 | tryLiftOut : (new : Name) ->
112 | List (CConAlt vars) ->
113 | Maybe (List (CConAlt (new :: vars)))
114 | tryLiftOut new [] = Just []
115 | tryLiftOut new (MkConAlt n ci t args (CLam fc x sc) :: as)
116 | = do as' <- tryLiftOut new as
117 | let sc' = liftOutLambda new sc
118 | pure (MkConAlt n ci t args sc' :: as')
119 | tryLiftOut _ _ = Nothing
121 | tryLiftOutConst : (new : Name) ->
122 | List (CConstAlt vars) ->
123 | Maybe (List (CConstAlt (new :: vars)))
124 | tryLiftOutConst new [] = Just []
125 | tryLiftOutConst new (MkConstAlt c (CLam fc x sc) :: as)
126 | = do as' <- tryLiftOutConst new as
127 | let sc' = liftOutLambda {args = []} new sc
128 | pure (MkConstAlt c sc' :: as')
129 | tryLiftOutConst _ _ = Nothing
131 | tryLiftDef : (new : Name) ->
132 | Maybe (CExp vars) ->
133 | Maybe (Maybe (CExp (new :: vars)))
134 | tryLiftDef new Nothing = Just Nothing
135 | tryLiftDef new (Just (CLam fc x sc))
136 | = let sc' = liftOutLambda {args = []} new sc in
138 | tryLiftDef _ _ = Nothing
140 | allLams : List (CConAlt vars) -> Bool
142 | allLams (MkConAlt n ci t args (CLam {}) :: as)
146 | allLamsConst : List (CConstAlt vars) -> Bool
147 | allLamsConst [] = True
148 | allLamsConst (MkConstAlt c (CLam {}) :: as)
150 | allLamsConst _ = False
155 | data NextName : Type where
157 | getName : {auto n : Ref NextName Int} ->
160 | = do n <- get NextName
161 | put NextName (n + 1)
166 | caseLam : {auto n : Ref NextName Int} ->
167 | CExp vars -> Core (CExp vars)
170 | caseLam (CConCase fc sc alts def)
171 | = if allLams alts && defLam def
172 | then do var <- getName
176 | let Just newAlts = tryLiftOut var alts
177 | | Nothing => throw (InternalError "Can't happen caseLam 1")
178 | let Just newDef = tryLiftDef var def
179 | | Nothing => throw (InternalError "Can't happen caseLam 2")
180 | newAlts' <- traverse caseLamConAlt newAlts
181 | newDef' <- traverseOpt caseLam newDef
183 | pure (CLam fc var (CConCase fc (weaken sc) newAlts' newDef'))
184 | else do sc' <- caseLam sc
185 | alts' <- traverse caseLamConAlt alts
186 | def' <- traverseOpt caseLam def
187 | pure (CConCase fc sc' alts' def')
189 | defLam : Maybe (CExp vars) -> Bool
190 | defLam Nothing = True
191 | defLam (Just (CLam {})) = True
195 | caseLam (CConstCase fc sc alts def)
196 | = if allLamsConst alts && defLam def
197 | then do var <- getName
201 | let Just newAlts = tryLiftOutConst var alts
202 | | Nothing => throw (InternalError "Can't happen caseLam 1")
203 | let Just newDef = tryLiftDef var def
204 | | Nothing => throw (InternalError "Can't happen caseLam 2")
205 | newAlts' <- traverse caseLamConstAlt newAlts
206 | newDef' <- traverseOpt caseLam newDef
207 | pure (CLam fc var (CConstCase fc (weaken sc) newAlts' newDef'))
208 | else do sc' <- caseLam sc
209 | alts' <- traverse caseLamConstAlt alts
210 | def' <- traverseOpt caseLam def
211 | pure (CConstCase fc sc' alts' def')
213 | defLam : Maybe (CExp vars) -> Bool
214 | defLam Nothing = True
215 | defLam (Just (CLam {})) = True
218 | caseLam (CLam fc x sc)
219 | = CLam fc x <$> caseLam sc
220 | caseLam (CLet fc x inl val sc)
221 | = CLet fc x inl <$> caseLam val <*> caseLam sc
222 | caseLam (CApp fc f args)
223 | = CApp fc <$> caseLam f <*> traverse caseLam args
224 | caseLam (CCon fc n ci t args)
225 | = CCon fc n ci t <$> traverse caseLam args
226 | caseLam (COp fc op args)
227 | = COp fc op <$> traverseVect caseLam args
228 | caseLam (CExtPrim fc p args)
229 | = CExtPrim fc p <$> traverse caseLam args
230 | caseLam (CForce fc r x)
231 | = CForce fc r <$> caseLam x
232 | caseLam (CDelay fc r x)
233 | = CDelay fc r <$> caseLam x
237 | caseLamConAlt : {auto n : Ref NextName Int} ->
238 | CConAlt vars -> Core (CConAlt vars)
239 | caseLamConAlt (MkConAlt n ci tag args sc)
240 | = MkConAlt n ci tag args <$> caseLam sc
242 | caseLamConstAlt : {auto n : Ref NextName Int} ->
243 | CConstAlt vars -> Core (CConstAlt vars)
244 | caseLamConstAlt (MkConstAlt c sc) = MkConstAlt c <$> caseLam sc
247 | caseLamDef : {auto c : Ref Ctxt Defs} ->
250 | = do defs <- get Ctxt
251 | Just def <- lookupCtxtExact n (gamma defs) | Nothing => pure ()
252 | let Just cexpr = compexpr def | Nothing => pure ()
253 | setCompiled n !(doCaseLam cexpr)
255 | doCaseLam : CDef -> Core CDef
256 | doCaseLam (MkFun args def)
257 | = do n <- newRef NextName 0
258 | pure $
MkFun args !(caseLam def)
259 | doCaseLam d = pure d
298 | doCaseOfCase : FC ->
300 | (xalts : List (CConAlt vars)) ->
301 | (xdef : Maybe (CExp vars)) ->
302 | (alts : List (CConAlt vars)) ->
303 | (def : Maybe (CExp vars)) ->
305 | doCaseOfCase fc x xalts xdef alts def
306 | = CConCase fc x (map updateAlt xalts) (map updateDef xdef)
308 | updateAlt : CConAlt vars -> CConAlt vars
309 | updateAlt (MkConAlt n ci t args sc)
310 | = MkConAlt n ci t args $
312 | (map (weakenNs (mkSizeOf args)) alts)
313 | (map (weakenNs (mkSizeOf args)) def)
315 | updateDef : CExp vars -> CExp vars
316 | updateDef sc = CConCase fc sc alts def
318 | doCaseOfConstCase : FC ->
320 | (xalts : List (CConstAlt vars)) ->
321 | (xdef : Maybe (CExp vars)) ->
322 | (alts : List (CConstAlt vars)) ->
323 | (def : Maybe (CExp vars)) ->
325 | doCaseOfConstCase fc x xalts xdef alts def
326 | = CConstCase fc x (map updateAlt xalts) (map updateDef xdef)
328 | updateAlt : CConstAlt vars -> CConstAlt vars
329 | updateAlt (MkConstAlt c sc)
331 | CConstCase fc sc alts def
333 | updateDef : CExp vars -> CExp vars
334 | updateDef sc = CConstCase fc sc alts def
336 | tryCaseOfCase : CExp vars -> Maybe (CExp vars)
337 | tryCaseOfCase (CConCase fc (CConCase fc' x xalts xdef) alts def)
338 | = if canCaseOfCase xalts xdef
339 | then Just (doCaseOfCase fc' x xalts xdef alts def)
342 | isCon : CExp vars -> Bool
343 | isCon (CCon {}) = True
346 | conCase : CConAlt vars -> Bool
347 | conCase (MkConAlt _ _ _ _ (CCon {})) = True
350 | canCaseOfCase : List (CConAlt vars) -> Maybe (CExp vars) -> Bool
351 | canCaseOfCase [] _ = True
352 | canCaseOfCase [x] Nothing = True
353 | canCaseOfCase xs mdef = all conCase xs && maybe True isCon mdef
354 | tryCaseOfCase (CConstCase fc (CConstCase fc' x xalts xdef) alts def)
355 | = if canCaseOfCase xalts xdef
356 | then Just (doCaseOfConstCase fc' x xalts xdef alts def)
359 | isConst : CExp vars -> Bool
360 | isConst (CPrimVal {}) = True
361 | isConst def = False
363 | constCase : CConstAlt vars -> Bool
364 | constCase (MkConstAlt _ (CPrimVal {})) = True
365 | constCase _ = False
367 | canCaseOfCase : List (CConstAlt vars) -> Maybe (CExp vars) -> Bool
368 | canCaseOfCase [] _ = True
369 | canCaseOfCase [x] Nothing = True
370 | canCaseOfCase xs mdef = all constCase xs && maybe True isConst mdef
371 | tryCaseOfCase _ = Nothing
374 | caseOfCase : CExp vars -> CExp vars
375 | caseOfCase tm = go 5 tm
377 | go : Nat -> CExp vars -> CExp vars
379 | go (S k) tm = maybe tm (go k) (tryCaseOfCase tm)