0 | module Compiler.CaseOpts
  1 |
  2 | -- Case block related transformations
  3 |
  4 | import Core.CompileExpr
  5 | import Core.Context
  6 |
  7 | import Data.Vect
  8 |
  9 | import Libraries.Data.List.SizeOf
 10 |
 11 | %default covering
 12 |
 13 | {-
 14 | Lifting out lambdas:
 15 |
 16 | case t of
 17 |      C1 => \x1 => e1
 18 |      ...
 19 |      Cn => \xn = en
 20 |
 21 |   where every branch begins with a lambda, can become:
 22 |
 23 | \x => case t of
 24 |            C1 => e1[x/x1]
 25 |            ,,,
 26 |            Cn => en[x/xn]
 27 | -}
 28 |
 29 | shiftUnder : {args : _} ->
 30 |              {idx : _} ->
 31 |              (0 p : IsVar n idx (x :: args ++ vars)) ->
 32 |              NVar n (args ++ x :: vars)
 33 | shiftUnder First = weakenNVar (mkSizeOf args) (MkNVar First)
 34 | shiftUnder (Later p) = insertNVar (mkSizeOf args) (MkNVar p)
 35 |
 36 | shiftVar : {outer : Scope} -> {args : List Name} ->
 37 |            NVar n (outer ++ (x :: args ++ vars)) ->
 38 |            NVar n (outer ++ (args ++ x :: vars))
 39 | shiftVar nvar
 40 |   = let out = mkSizeOf outer in
 41 |     case locateNVar out nvar of
 42 |       Left nvar => embed nvar
 43 |       Right (MkNVar p) => weakenNs out (shiftUnder p)
 44 |
 45 | mutual
 46 |   shiftBinder : {outer, args : _} ->
 47 |                 (new : Name) ->
 48 |                 CExp (outer ++ old :: (args ++ vars)) ->
 49 |                 CExp (outer ++ (args ++ new :: vars))
 50 |   shiftBinder new (CLocal fc p)
 51 |       = case shiftVar (MkNVar p) of
 52 |              MkNVar p' => CLocal fc (renameVar p')
 53 |     where
 54 |       renameVar : IsVar x i (outer ++ (args ++ (old :: rest))) ->
 55 |                   IsVar x i (outer ++ (args ++ (new :: rest)))
 56 |       renameVar = believe_me -- it's the same index, so just the identity at run time
 57 |   shiftBinder new (CRef fc n) = CRef fc n
 58 |   shiftBinder {outer} new (CLam fc n sc)
 59 |       = CLam fc n $ shiftBinder {outer = n :: outer} new sc
 60 |   shiftBinder new (CLet fc n inlineOK val sc)
 61 |       = CLet fc n inlineOK (shiftBinder new val)
 62 |                            $ shiftBinder {outer = n :: outer} new sc
 63 |   shiftBinder new (CApp fc f args)
 64 |       = CApp fc (shiftBinder new f) $ map (shiftBinder new) args
 65 |   shiftBinder new (CCon fc ci c tag args)
 66 |       = CCon fc ci c tag $ map (shiftBinder new) args
 67 |   shiftBinder new (COp fc op args) = COp fc op $ map (shiftBinder new) args
 68 |   shiftBinder new (CExtPrim fc p args)
 69 |       = CExtPrim fc p $ map (shiftBinder new) args
 70 |   shiftBinder new (CForce fc r arg) = CForce fc r $ shiftBinder new arg
 71 |   shiftBinder new (CDelay fc r arg) = CDelay fc r $ shiftBinder new arg
 72 |   shiftBinder new (CConCase fc sc alts def)
 73 |       = CConCase fc (shiftBinder new sc)
 74 |                     (map (shiftBinderConAlt new) alts)
 75 |                     (map (shiftBinder new) def)
 76 |   shiftBinder new (CConstCase fc sc alts def)
 77 |       = CConstCase fc (shiftBinder new sc)
 78 |                       (map (shiftBinderConstAlt new) alts)
 79 |                       (map (shiftBinder new) def)
 80 |   shiftBinder new (CPrimVal fc c) = CPrimVal fc c
 81 |   shiftBinder new (CErased fc) = CErased fc
 82 |   shiftBinder new (CCrash fc msg) = CCrash fc msg
 83 |
 84 |   shiftBinderConAlt : {outer, args : _} ->
 85 |                 (new : Name) ->
 86 |                 CConAlt (outer ++ (x :: args ++ vars)) ->
 87 |                 CConAlt (outer ++ (args ++ new :: vars))
 88 |   shiftBinderConAlt new (MkConAlt n ci t args' sc)
 89 |       = let sc' : CExp ((args' ++ outer) ++ (x :: args ++ vars))
 90 |                 = rewrite sym (appendAssociative args' outer (x :: args ++ vars)) in sc in
 91 |         MkConAlt n ci t args' $
 92 |            rewrite (appendAssociative args' outer (args ++ new :: vars))
 93 |              in shiftBinder new {outer = args' ++ outer} sc'
 94 |
 95 |   shiftBinderConstAlt : {outer, args : _} ->
 96 |                 (new : Name) ->
 97 |                 CConstAlt (outer ++ (x :: args ++ vars)) ->
 98 |                 CConstAlt (outer ++ (args ++ new :: vars))
 99 |   shiftBinderConstAlt new (MkConstAlt c sc) = MkConstAlt c $ shiftBinder new sc
100 |
101 | -- If there's a lambda inside a case, move the variable so that it's bound
102 | -- outside the case block so that we can bind it just once outside the block
103 | liftOutLambda : {args : _} ->
104 |                 (new : Name) ->
105 |                 CExp (old :: args ++ vars) ->
106 |                 CExp (args ++ new :: vars)
107 | liftOutLambda = shiftBinder {outer = Scope.empty}
108 |
109 | -- If all the alternatives start with a lambda, we can have a single lambda
110 | -- binding outside
111 | tryLiftOut : (new : Name) ->
112 |              List (CConAlt vars) ->
113 |              Maybe (List (CConAlt (new :: vars)))
114 | tryLiftOut new [] = Just []
115 | tryLiftOut new (MkConAlt n ci t args (CLam fc x sc) :: as)
116 |     = do as' <- tryLiftOut new as
117 |          let sc' = liftOutLambda new sc
118 |          pure (MkConAlt n ci t args sc' :: as')
119 | tryLiftOut _ _ = Nothing
120 |
121 | tryLiftOutConst : (new : Name) ->
122 |                   List (CConstAlt vars) ->
123 |                   Maybe (List (CConstAlt (new :: vars)))
124 | tryLiftOutConst new [] = Just []
125 | tryLiftOutConst new (MkConstAlt c (CLam fc x sc) :: as)
126 |     = do as' <- tryLiftOutConst new as
127 |          let sc' = liftOutLambda {args = []} new sc
128 |          pure (MkConstAlt c sc' :: as')
129 | tryLiftOutConst _ _ = Nothing
130 |
131 | tryLiftDef : (new : Name) ->
132 |              Maybe (CExp vars) ->
133 |              Maybe (Maybe (CExp (new :: vars)))
134 | tryLiftDef new Nothing = Just Nothing
135 | tryLiftDef new (Just (CLam fc x sc))
136 |    = let sc' = liftOutLambda {args = []} new sc in
137 |          pure (Just sc')
138 | tryLiftDef _ _ = Nothing
139 |
140 | allLams : List (CConAlt vars) -> Bool
141 | allLams [] = True
142 | allLams (MkConAlt n ci t args (CLam {}) :: as)
143 |    = allLams as
144 | allLams _ = False
145 |
146 | allLamsConst : List (CConstAlt vars) -> Bool
147 | allLamsConst [] = True
148 | allLamsConst (MkConstAlt c (CLam {}) :: as)
149 |    = allLamsConst as
150 | allLamsConst _ = False
151 |
152 | -- label for next name for a lambda. These probably don't need really to be
153 | -- unique, since we've proved things about the de Bruijn index, but it's easier
154 | -- to see what's going on if they are.
155 | data NextName : Type where
156 |
157 | getName : {auto n : Ref NextName Int} ->
158 |           Core Name
159 | getName
160 |     = do n <- get NextName
161 |          put NextName (n + 1)
162 |          pure (MN "clam" n)
163 |
164 | -- The transformation itself
165 | mutual
166 |   caseLam : {auto n : Ref NextName Int} ->
167 |             CExp vars -> Core (CExp vars)
168 |   -- Interesting cases first: look for case blocks where every branch is a
169 |   -- lambda
170 |   caseLam (CConCase fc sc alts def)
171 |       = if allLams alts && defLam def
172 |            then do var <- getName
173 |                    -- These will work if 'allLams' and 'defLam' are consistent.
174 |                    -- We only do that boolean check because it saves us doing
175 |                    -- unnecessary work (say, if the last one we try fails)
176 |                    let Just newAlts = tryLiftOut var alts
177 |                             | Nothing => throw (InternalError "Can't happen caseLam 1")
178 |                    let Just newDef = tryLiftDef var def
179 |                             | Nothing => throw (InternalError "Can't happen caseLam 2")
180 |                    newAlts' <- traverse caseLamConAlt newAlts
181 |                    newDef' <- traverseOpt caseLam newDef
182 |                    -- Q: Should we go around again?
183 |                    pure (CLam fc var (CConCase fc (weaken sc) newAlts' newDef'))
184 |            else do sc' <- caseLam sc
185 |                    alts' <- traverse caseLamConAlt alts
186 |                    def' <- traverseOpt caseLam def
187 |                    pure (CConCase fc sc' alts' def')
188 |     where
189 |       defLam : Maybe (CExp vars) -> Bool
190 |       defLam Nothing = True
191 |       defLam (Just (CLam {})) = True
192 |       defLam _ = False
193 |   -- Next case is pretty much as above. There's a boring amount of repetition
194 |   -- here because ConstCase is just a little bit different.
195 |   caseLam (CConstCase fc sc alts def)
196 |       = if allLamsConst alts && defLam def
197 |            then do var <- getName
198 |                    -- These will work if 'allLams' and 'defLam' are consistent.
199 |                    -- We only do that boolean check because it saves us doing
200 |                    -- unnecessary work (say, if the last one we try fails)
201 |                    let Just newAlts = tryLiftOutConst var alts
202 |                             | Nothing => throw (InternalError "Can't happen caseLam 1")
203 |                    let Just newDef = tryLiftDef var def
204 |                             | Nothing => throw (InternalError "Can't happen caseLam 2")
205 |                    newAlts' <- traverse caseLamConstAlt newAlts
206 |                    newDef' <- traverseOpt caseLam newDef
207 |                    pure (CLam fc var (CConstCase fc (weaken sc) newAlts' newDef'))
208 |            else do sc' <- caseLam sc
209 |                    alts' <- traverse caseLamConstAlt alts
210 |                    def' <- traverseOpt caseLam def
211 |                    pure (CConstCase fc sc' alts' def')
212 |     where
213 |       defLam : Maybe (CExp vars) -> Bool
214 |       defLam Nothing = True
215 |       defLam (Just (CLam {})) = True
216 |       defLam _ = False
217 |   -- Structural recursive cases
218 |   caseLam (CLam fc x sc)
219 |       = CLam fc x <$> caseLam sc
220 |   caseLam (CLet fc x inl val sc)
221 |       = CLet fc x inl <$> caseLam val <*> caseLam sc
222 |   caseLam (CApp fc f args)
223 |       = CApp fc <$> caseLam f <*> traverse caseLam args
224 |   caseLam (CCon fc n ci t args)
225 |       = CCon fc n ci t <$> traverse caseLam args
226 |   caseLam (COp fc op args)
227 |       = COp fc op <$> traverseVect caseLam args
228 |   caseLam (CExtPrim fc p args)
229 |       = CExtPrim fc p <$> traverse caseLam args
230 |   caseLam (CForce fc r x)
231 |       = CForce fc r <$> caseLam x
232 |   caseLam (CDelay fc r x)
233 |       = CDelay fc r <$> caseLam x
234 |   -- All the others, no recursive case so just return the input
235 |   caseLam x = pure x
236 |
237 |   caseLamConAlt : {auto n : Ref NextName Int} ->
238 |                   CConAlt vars -> Core (CConAlt vars)
239 |   caseLamConAlt (MkConAlt n ci tag args sc)
240 |       = MkConAlt n ci tag args <$> caseLam sc
241 |
242 |   caseLamConstAlt : {auto n : Ref NextName Int} ->
243 |                     CConstAlt vars -> Core (CConstAlt vars)
244 |   caseLamConstAlt (MkConstAlt c sc) = MkConstAlt c <$> caseLam sc
245 |
246 | export
247 | caseLamDef : {auto c : Ref Ctxt Defs} ->
248 |              Name -> Core ()
249 | caseLamDef n
250 |     = do defs <- get Ctxt
251 |          Just def <- lookupCtxtExact n (gamma defs) | Nothing => pure ()
252 |          let Just cexpr =  compexpr def             | Nothing => pure ()
253 |          setCompiled n !(doCaseLam cexpr)
254 |   where
255 |     doCaseLam : CDef -> Core CDef
256 |     doCaseLam (MkFun args def)
257 |         = do n <- newRef NextName 0
258 |              pure $ MkFun args !(caseLam def)
259 |     doCaseLam d = pure d
260 |
261 | {-
262 |
263 | Case of case:
264 |
265 | case (case x of C1 => E1
266 |                 C2 => E2
267 |                 _ => Ed
268 |                 ...) of
269 |      D1 => F1
270 |      D2 => F2
271 |      ...
272 |      _ => Fd
273 |
274 | can become
275 |
276 | case x of
277 |      C1 => case E1 of
278 |                 D1 => F1
279 |                 D2 => F2
280 |                 ...
281 |                 _ => Fd
282 |      C2 => case E2 of
283 |                 D1 => F1
284 |                 D2 => F2
285 |                 ...
286 |                 _ => Fd
287 |     _ => case Ed of
288 |               D1 => F1
289 |               D2 => F2
290 |               ...
291 |               _ => Fd
292 |
293 | to minimise risk of duplication, do this only when E1, E2 are all
294 | constructor headed, or there's only one branch (for now)
295 |
296 | -}
297 |
298 | doCaseOfCase : FC ->
299 |                (x : CExp vars) ->
300 |                (xalts : List (CConAlt vars)) ->
301 |                (xdef : Maybe (CExp vars)) ->
302 |                (alts : List (CConAlt vars)) ->
303 |                (def : Maybe (CExp vars)) ->
304 |                CExp vars
305 | doCaseOfCase fc x xalts xdef alts def
306 |     = CConCase fc x (map updateAlt xalts) (map updateDef xdef)
307 |   where
308 |     updateAlt : CConAlt vars -> CConAlt vars
309 |     updateAlt (MkConAlt n ci t args sc)
310 |         = MkConAlt n ci t args $
311 |               CConCase fc sc
312 |                        (map (weakenNs (mkSizeOf args)) alts)
313 |                        (map (weakenNs (mkSizeOf args)) def)
314 |
315 |     updateDef : CExp vars -> CExp vars
316 |     updateDef sc = CConCase fc sc alts def
317 |
318 | doCaseOfConstCase : FC ->
319 |                     (x : CExp vars) ->
320 |                     (xalts : List (CConstAlt vars)) ->
321 |                     (xdef : Maybe (CExp vars)) ->
322 |                     (alts : List (CConstAlt vars)) ->
323 |                     (def : Maybe (CExp vars)) ->
324 |                     CExp vars
325 | doCaseOfConstCase fc x xalts xdef alts def
326 |     = CConstCase fc x (map updateAlt xalts) (map updateDef xdef)
327 |   where
328 |     updateAlt : CConstAlt vars -> CConstAlt vars
329 |     updateAlt (MkConstAlt c sc)
330 |         = MkConstAlt c $
331 |               CConstCase fc sc alts def
332 |
333 |     updateDef : CExp vars -> CExp vars
334 |     updateDef sc = CConstCase fc sc alts def
335 |
336 | tryCaseOfCase : CExp vars -> Maybe (CExp vars)
337 | tryCaseOfCase (CConCase fc (CConCase fc' x xalts xdef) alts def)
338 |     = if canCaseOfCase xalts xdef
339 |          then Just (doCaseOfCase fc' x xalts xdef alts def)
340 |          else Nothing
341 |   where
342 |     isCon : CExp vars -> Bool
343 |     isCon (CCon {}) = True
344 |     isCon _ = False
345 |
346 |     conCase : CConAlt vars -> Bool
347 |     conCase (MkConAlt _ _ _ _ (CCon {})) = True
348 |     conCase _ = False
349 |
350 |     canCaseOfCase : List (CConAlt vars) -> Maybe (CExp vars) -> Bool
351 |     canCaseOfCase [] _ = True
352 |     canCaseOfCase [x] Nothing = True
353 |     canCaseOfCase xs mdef = all conCase xs && maybe True isCon mdef
354 | tryCaseOfCase (CConstCase fc (CConstCase fc' x xalts xdef) alts def)
355 |     = if canCaseOfCase xalts xdef
356 |          then Just (doCaseOfConstCase fc' x xalts xdef alts def)
357 |          else Nothing
358 |   where
359 |     isConst : CExp vars -> Bool
360 |     isConst (CPrimVal {}) = True
361 |     isConst def = False
362 |
363 |     constCase : CConstAlt vars -> Bool
364 |     constCase (MkConstAlt _ (CPrimVal {})) = True
365 |     constCase _ = False
366 |
367 |     canCaseOfCase : List (CConstAlt vars) -> Maybe (CExp vars) -> Bool
368 |     canCaseOfCase [] _ = True
369 |     canCaseOfCase [x] Nothing = True
370 |     canCaseOfCase xs mdef = all constCase xs && maybe True isConst mdef
371 | tryCaseOfCase _ = Nothing
372 |
373 | export
374 | caseOfCase : CExp vars -> CExp vars
375 | caseOfCase tm = go 5 tm
376 |   where
377 |     go : Nat -> CExp vars -> CExp vars
378 |     go Z tm = tm
379 |     go (S k) tm = maybe tm (go k) (tryCaseOfCase tm)
380 |