0 | module Core.Case.CaseBuilder
2 | import Core.Case.CaseTree
3 | import Core.Case.Util
4 | import Core.Context.Log
6 | import Core.Normalise
10 | import Idris.Pretty.Annotations
13 | import Data.List.Quantifiers
14 | import Data.SortedSet
17 | import Libraries.Data.IMaybe
18 | import Libraries.Data.List.SizeOf
19 | import Libraries.Data.List.LengthMatch
20 | import Libraries.Data.List01
21 | import Libraries.Data.List01.Quantifiers
23 | import Decidable.Equality
25 | import Libraries.Text.PrettyPrint.Prettyprinter
29 | %hide Symbols.equals
32 | data Phase = CompileTime RigCount | RunTime
35 | CompileTime r == CompileTime r' = r == r'
36 | RunTime == RunTime = True
39 | data ArgType : Scoped where
40 | Known : RigCount -> (ty : Term vars) -> ArgType vars
41 | Stuck : (fty : Term vars) -> ArgType vars
44 | Unknown : ArgType vars
47 | HasNames (ArgType vars) where
48 | full gam (Known c ty) = Known c <$> full gam ty
49 | full gam (Stuck ty) = Stuck <$> full gam ty
50 | full gam Unknown = pure Unknown
52 | resolved gam (Known c ty) = Known c <$> resolved gam ty
53 | resolved gam (Stuck ty) = Stuck <$> resolved gam ty
54 | resolved gam Unknown = pure Unknown
57 | {ns : _} -> Show (ArgType ns) where
58 | show (Known c t) = "Known " ++ show c ++ " " ++ show t
59 | show (Stuck t) = "Stuck " ++ show t
60 | show Unknown = "Unknown"
62 | record PatInfo (pvar : Name) (vars : Scope) where
67 | 0 loc : IsVar name idx vars
68 | argType : ArgType vars
72 | {vars : _} -> Show (PatInfo n vars) where
73 | show pi = show (pat pi) ++ " : " ++ show (argType pi)
75 | HasNames (PatInfo n vars) where
76 | full gam (MkInfo pat loc argType)
77 | = do pat <- full gam pat
78 | argType <- full gam argType
79 | pure $
MkInfo pat loc argType
81 | resolved gam (MkInfo pat loc argType)
82 | = do pat <- resolved gam pat
83 | argType <- resolved gam argType
84 | pure $
MkInfo pat loc argType
97 | data NamedPats : List Name ->
99 | Nil : NamedPats [] vars
100 | (::) : PatInfo pvar vars ->
104 | NamedPats ns vars -> NamedPats (pvar :: ns) vars
106 | getPatInfo : NamedPats todo vars -> List Pat
108 | getPatInfo (x :: xs) = pat x :: getPatInfo xs
110 | updatePats : {vars, todo : _} ->
111 | {auto c : Ref Ctxt Defs} ->
113 | NF vars -> NamedPats todo vars -> Core (NamedPats todo vars)
114 | updatePats env nf [] = pure []
115 | updatePats {todo = pvar :: ns} env (NBind fc _ (Pi _ c _ farg) fsc) (p :: ps)
116 | = case argType p of
118 | do defs <- get Ctxt
119 | empty <- clearDefs defs
120 | pure ({ argType := Known c !(quote empty env farg) } p
121 | :: !(updatePats env !(fsc defs (toClosure defaultOpts env (Ref fc Bound pvar))) ps))
122 | _ => pure (p :: ps)
123 | updatePats env nf (p :: ps)
124 | = case argType p of
126 | do defs <- get Ctxt
127 | empty <- clearDefs defs
128 | pure ({ argType := Stuck !(quote empty env nf) } p :: ps)
129 | _ => pure (p :: ps)
131 | substInPatInfo : {pvar, vars, todo : _} ->
132 | {auto c : Ref Ctxt Defs} ->
133 | FC -> Name -> Term vars -> PatInfo pvar vars ->
134 | NamedPats todo vars ->
135 | Core (PatInfo pvar vars, NamedPats todo vars)
136 | substInPatInfo fc n tm p ps
137 | = case argType p of
139 | do defs <- get Ctxt
140 | tynf <- nf defs (mkEnv fc vars) ty
143 | pure ({ argType := Known c (substName n tm ty) } p, ps)
147 | do defs <- get Ctxt
148 | empty <- clearDefs defs
149 | let env = mkEnv fc vars
150 | case !(nf defs env (substName n tm fty)) of
151 | NBind pfc _ (Pi _ c _ farg) fsc =>
152 | pure ({ argType := Known c !(quote empty env farg) } p,
154 | !(fsc defs (toClosure defaultOpts env
155 | (Ref pfc Bound pvar))) ps))
157 | Unknown => pure (p, ps)
161 | substInPats : {vars, todo : _} ->
162 | {auto c : Ref Ctxt Defs} ->
163 | FC -> Name -> Term vars -> NamedPats todo vars ->
164 | Core (NamedPats todo vars)
165 | substInPats fc n tm [] = pure []
166 | substInPats fc n tm (p :: ps)
167 | = do (p', ps') <- substInPatInfo fc n tm p ps
168 | pure (p' :: !(substInPats fc n tm ps'))
170 | getPat : {idx : Nat} ->
171 | (0 el : IsVar nm idx ps) -> NamedPats ps ns -> PatInfo nm ns
172 | getPat First (x :: xs) = x
173 | getPat (Later p) (x :: xs) = getPat p xs
175 | dropPat : {idx : Nat} ->
176 | (0 el : IsVar nm idx ps) ->
177 | NamedPats ps ns -> NamedPats (dropIsVar ps el) ns
178 | dropPat First (x :: xs) = xs
179 | dropPat (Later p) (x :: xs) = x :: dropPat p xs
181 | HasNames (NamedPats todo vars) where
182 | full gam [] = pure []
183 | full gam (x::xs) = [| (::) (full gam x) (full gam xs) |]
185 | resolved gam [] = pure []
186 | resolved gam (x::xs) = [| (::) (resolved gam x) (resolved gam xs) |]
189 | {vars : _} -> {todo : _} -> Show (NamedPats todo vars) where
190 | show xs = "[" ++ showAll xs ++ "]"
192 | showAll : {vs, ts : _} -> NamedPats ts vs -> String
194 | showAll {ts = t :: _} [x]
195 | = show t ++ " " ++ show (pat x) ++ " [" ++ show (argType x) ++ "]"
196 | showAll {ts = t :: _} (x :: xs)
197 | = show t ++ " " ++ show (pat x) ++ " [" ++ show (argType x) ++ "]"
198 | ++ ", " ++ showAll xs
200 | {vars : _} -> {todo : _} -> Pretty IdrisSyntax (NamedPats todo vars) where
201 | pretty xs = hsep $
prettyAll xs
203 | prettyAll : {vs, ts : _} -> NamedPats ts vs -> List (Doc IdrisSyntax)
205 | prettyAll {ts = t :: _} (x :: xs)
206 | = parens (pretty0 t <++> equals <++> pretty (pat x))
209 | Weaken ArgType where
210 | weaken (Known c ty) = Known c (weaken ty)
211 | weaken (Stuck fty) = Stuck (weaken fty)
212 | weaken Unknown = Unknown
214 | weakenNs s (Known c ty) = Known c (weakenNs s ty)
215 | weakenNs s (Stuck fty) = Stuck (weakenNs s fty)
216 | weakenNs s Unknown = Unknown
218 | Weaken (PatInfo p) where
219 | weakenNs s (MkInfo p el fty) = MkInfo p (weakenIsVar s el) (weakenNs s fty)
221 | Weaken (NamedPats todo) where
223 | weaken (p :: ps) = weaken p :: weaken ps
225 | weakenNs ns [] = []
226 | weakenNs ns (p :: ps) = weakenNs ns p :: weakenNs ns ps
228 | (++) : NamedPats ms vars -> NamedPats ns vars -> NamedPats (ms ++ ns) vars
230 | (++) (x :: xs) ys = x :: xs ++ ys
232 | tail : NamedPats (p :: ps) vars -> NamedPats ps vars
233 | tail (x :: xs) = xs
235 | data PatClause : (todo : List Name) -> Scoped where
236 | MkPatClause : List Name ->
237 | NamedPats todo vars ->
238 | Int -> (rhs : Term vars) -> PatClause todo vars
240 | getNPs : PatClause todo vars -> NamedPats todo vars
241 | getNPs (MkPatClause _ lhs pid rhs) = lhs
244 | {vars : _} -> {todo : _} -> Show (PatClause todo vars) where
245 | show (MkPatClause _ ps pid rhs)
246 | = show ps ++ " => " ++ show rhs
248 | {vars : _} -> {todo : _} -> Pretty IdrisSyntax (PatClause todo vars) where
249 | pretty (MkPatClause _ ps _ rhs)
250 | = pretty ps <++> fatArrow <++> byShow rhs
252 | HasNames (PatClause todo vars) where
253 | full gam (MkPatClause ns nps i rhs)
254 | = [| MkPatClause (traverse (full gam) ns) (full gam nps) (pure i) (full gam rhs) |]
256 | resolved gam (MkPatClause ns nps i rhs)
257 | = [| MkPatClause (traverse (resolved gam) ns) (resolved gam nps) (pure i) (resolved gam rhs) |]
259 | 0 IsConClause : PatClause (a :: todo) vars -> Type
260 | IsConClause (MkPatClause _ (MkInfo pat _ _ :: _) _ _) = IsConPat pat
262 | substInClause : {a, vars, todo : _} ->
263 | {auto c : Ref Ctxt Defs} ->
264 | FC -> Subset (PatClause (a :: todo) vars) IsConClause ->
265 | Core (Subset (PatClause (a :: todo) vars) IsConClause)
266 | substInClause fc (Element (MkPatClause pvars (MkInfo pat pprf fty :: pats) pid rhs) isCons)
267 | = do pats' <- substInPats fc a (mkTerm vars pat) pats
268 | pure $
Element (MkPatClause pvars (MkInfo pat pprf fty :: pats') pid rhs) isCons
270 | data Partitions : List01 ne (PatClause (a :: todo) vars) -> Type where
271 | ConClauses : {a, todo, vars : _} ->
272 | {ps : List01 ne (PatClause (a :: todo) vars)} ->
273 | (cs : List01 True (PatClause (a :: todo) vars )) ->
274 | (0 isCons : All IsConClause cs) =>
275 | Partitions ps -> Partitions (cs ++ ps)
276 | VarClauses : {a, todo, vars : _} ->
277 | {ps : List01 ne (PatClause (a :: todo) vars)} ->
278 | (vs : List01 True (PatClause (a :: todo) vars)) ->
279 | Partitions ps -> Partitions (vs ++ ps)
280 | NoClauses : Partitions []
283 | {ps : _} -> Show (Partitions ps) where
284 | show (ConClauses cs rest)
285 | = unlines ("CON" :: map ((" " ++) . show) (forget cs))
286 | ++ "\n, " ++ show rest
287 | show (VarClauses vs rest)
288 | = unlines ("VAR" :: map ((" " ++) . show) (forget vs))
289 | ++ "\n, " ++ show rest
290 | show NoClauses = "NONE"
292 | data ClauseType : PatClause (a :: todo) vars -> Type where
293 | ConClause : (0 isCon : IsConClause p) => ClauseType p
294 | VarClause : ClauseType p
296 | namesIn : List Name -> Pat -> Bool
297 | namesIn pvars (PAs _ n p) = (n `elem` pvars) && namesIn pvars p
298 | namesIn pvars (PCon _ _ _ _ ps) = all (namesIn pvars) ps
299 | namesIn pvars (PTyCon _ _ _ ps) = all (namesIn pvars) ps
300 | namesIn pvars (PArrow _ _ s t) = namesIn pvars s && namesIn pvars t
301 | namesIn pvars (PDelay _ _ t p) = namesIn pvars t && namesIn pvars p
302 | namesIn pvars (PLoc _ n) = n `elem` pvars
303 | namesIn pvars _ = True
305 | namesFrom : Pat -> List Name
306 | namesFrom (PAs _ n p) = n :: namesFrom p
307 | namesFrom (PCon _ _ _ _ ps) = concatMap namesFrom ps
308 | namesFrom (PTyCon _ _ _ ps) = concatMap namesFrom ps
309 | namesFrom (PArrow _ _ s t) = namesFrom s ++ namesFrom t
310 | namesFrom (PDelay _ _ t p) = namesFrom t ++ namesFrom p
311 | namesFrom (PLoc _ n) = [n]
314 | clauseType : Phase -> (p : PatClause (a :: as) vars) -> ClauseType p
320 | clauseType phase (MkPatClause pvars (MkInfo arg _ ty :: rest) pid rhs)
321 | = maybe VarClause (\isCon => ConClause @{isCon}) $
getClauseType phase arg ty
325 | splitCon : Nat -> List Pat -> Maybe (So True)
326 | splitCon arity xs = toMaybe (arity == length xs) Oh
329 | clauseType' : (p : Pat) -> Maybe (IsConPat p)
330 | clauseType' (PCon _ _ _ a xs) = splitCon a xs
331 | clauseType' (PTyCon _ _ a xs) = splitCon a xs
332 | clauseType' (PConst _ x) = Just Oh
333 | clauseType' (PArrow _ _ s t) = Just Oh
334 | clauseType' (PDelay {}) = Just Oh
335 | clauseType' _ = Nothing
337 | getClauseType : Phase -> (p : Pat) -> ArgType vars -> Maybe (IsConPat p)
338 | getClauseType (CompileTime cr) (PCon _ _ _ a xs) (Known r t)
339 | = if (isErased r && not (isErased cr) &&
340 | all (namesIn (pvars ++ concatMap namesFrom (getPatInfo rest))) xs)
343 | getClauseType phase (PAs _ _ p) t = getClauseType phase p t
344 | getClauseType phase l (Known r t) = if isErased r
347 | getClauseType phase l _ = clauseType' l
349 | partition : {a, as, vars : _} ->
350 | Phase -> (ps : List01 ne (PatClause (a :: as) vars)) -> Partitions ps
351 | partition phase [] = NoClauses
352 | partition phase (x :: xs) with (partition phase xs)
353 | partition phase (x :: .(cs ++ ps)) | (ConClauses cs rest)
354 | = case clauseType phase x of
355 | ConClause => ConClauses (x :: cs) rest
356 | VarClause => VarClauses [x] (ConClauses cs rest)
357 | partition phase (x :: .(vs ++ ps)) | (VarClauses vs rest)
358 | = case clauseType phase x of
359 | ConClause => ConClauses [x] (VarClauses vs rest)
360 | VarClause => VarClauses (x :: vs) rest
361 | partition phase [x] | NoClauses
362 | = case clauseType phase x of
363 | ConClause => ConClauses [x] NoClauses
364 | VarClause => VarClauses [x] NoClauses
366 | data ConType : Type where
367 | CName : Name -> (tag : Int) -> ConType
369 | CConst : Constant -> ConType
371 | data Group : List Name ->
373 | ConGroup : {newargs : _} ->
374 | Name -> (tag : Int) ->
375 | List01 True (PatClause (newargs ++ todo) (newargs ++ vars)) ->
377 | DelayGroup : {tyarg, valarg : _} ->
378 | List01 True (PatClause (tyarg :: valarg :: todo)
379 | (tyarg :: valarg :: vars)) ->
381 | ConstGroup : Constant -> List01 True (PatClause todo vars) ->
385 | {vars : _} -> {todo : _} -> Show (Group todo vars) where
386 | show (ConGroup c t cs) = "Con " ++ show c ++ ": " ++ show cs
387 | show (DelayGroup cs) = "Delay: " ++ show cs
388 | show (ConstGroup c cs) = "Const " ++ show c ++ ": " ++ show cs
390 | data GroupMatch : ConType -> List Pat -> Group todo vars -> Type where
391 | ConMatch : {tag : Int} -> LengthMatch ps newargs ->
392 | GroupMatch (CName n tag) ps
393 | (ConGroup {newargs} n tag (MkPatClause pvs pats pid rhs :: rest))
394 | DelayMatch : GroupMatch CDelay []
395 | (DelayGroup {tyarg} {valarg} (MkPatClause pvs pats pid rhs :: rest))
396 | ConstMatch : GroupMatch (CConst c) []
397 | (ConstGroup c (MkPatClause pvs pats pid rhs :: rest))
398 | NoMatch : GroupMatch ct ps g
400 | checkGroupMatch : (c : ConType) -> (ps : List Pat) -> (g : Group todo vars) ->
402 | checkGroupMatch (CName x tag) ps (ConGroup {newargs} x' tag' (MkPatClause pvs pats pid rhs :: rest))
403 | = case checkLengthMatch ps newargs of
405 | Just prf => case (nameEq x x', decEq tag tag') of
406 | (Just Refl, Yes Refl) => ConMatch prf
408 | checkGroupMatch (CName x tag) ps _ = NoMatch
409 | checkGroupMatch CDelay [] (DelayGroup (MkPatClause pvs pats pid rhs :: rest))
411 | checkGroupMatch (CConst c) [] (ConstGroup c' (MkPatClause pvs pats pid rhs :: rest))
412 | = case constantEq c c' of
414 | Just Refl => ConstMatch
415 | checkGroupMatch _ _ _ = NoMatch
417 | data PName : Type where
419 | nextName : {auto i : Ref PName Int} ->
420 | String -> Core Name
422 | = do x <- get PName
426 | nextNames : {vars : _} ->
427 | {auto i : Ref PName Int} ->
428 | {auto c : Ref Ctxt Defs} ->
429 | FC -> String -> List Pat -> Maybe (NF vars) ->
430 | Core (
args ** (SizeOf args, NamedPats args (args ++ vars)))
431 | nextNames fc root [] fty = pure (
[] ** (zero, []))
432 | nextNames fc root (p :: pats) fty
433 | = do defs <- get Ctxt
434 | empty <- clearDefs defs
436 | let env = mkEnv fc vars
437 | fa_tys <- the (Core (Maybe (NF vars), ArgType vars)) $
439 | Nothing => pure (Nothing, Unknown)
440 | Just (NBind pfc _ (Pi _ c _ fargc) fsc) =>
441 | do farg <- evalClosure defs fargc
444 | pure (Just !(fsc defs (toClosure defaultOpts env (Ref pfc Bound n))),
446 | _ => pure (Just !(fsc defs (toClosure defaultOpts env (Ref pfc Bound n))),
447 | Known c !(quote empty env farg))
449 | pure (Nothing, Stuck !(quote empty env t))
450 | (
args ** (l, ps))
<- nextNames fc root pats (fst fa_tys)
451 | let argTy = case snd fa_tys of
453 | Known rig t => Known rig (weakenNs (suc l) t)
454 | Stuck t => Stuck (weakenNs (suc l) t)
455 | pure (
n :: args ** (suc l, MkInfo p First argTy :: weaken ps))
458 | newPats : (pargs : List Pat) -> LengthMatch pargs ns ->
459 | NamedPats (ns ++ todo) vars ->
461 | newPats [] NilMatch rest = []
462 | newPats (newpat :: xs) (ConsMatch w) (pi :: rest)
463 | = { pat := newpat } pi :: newPats xs w rest
465 | updateNames : List (Name, Pat) -> List (Name, Name)
466 | updateNames = mapMaybe update
468 | update : (Name, Pat) -> Maybe (Name, Name)
469 | update (n, PLoc fc p) = Just (p, n)
472 | updatePatNames : List (Name, Name) -> NamedPats todo vars -> NamedPats todo vars
473 | updatePatNames _ [] = []
474 | updatePatNames ns (pi :: ps)
475 | = { pat $= update } pi :: updatePatNames ns ps
477 | update : Pat -> Pat
478 | update (PAs fc n p)
479 | = case lookup n ns of
480 | Nothing => PAs fc n (update p)
481 | Just n' => PAs fc n' (update p)
482 | update (PCon fc n i a ps) = PCon fc n i a (map update ps)
483 | update (PTyCon fc n a ps) = PTyCon fc n a (map update ps)
484 | update (PArrow fc x s t) = PArrow fc x (update s) (update t)
485 | update (PDelay fc r t p) = PDelay fc r (update t) (update p)
487 | = case lookup n ns of
488 | Nothing => PLoc fc n
489 | Just n' => PLoc fc n'
492 | groupCons : {a, vars, todo : _} ->
493 | {auto i : Ref PName Int} ->
494 | {auto ct : Ref Ctxt Defs} ->
497 | (cs : List01 True (PatClause (a :: todo) vars)) ->
498 | (0 isCons : All IsConClause cs) =>
499 | Core (List01 True (Group todo vars))
500 | groupCons fc fn pvars (x :: xs) {isCons = p :: ps}
501 | = foldlC (uncurry . gc) !(gc [] x p) $
pushIn xs ps
503 | addConG : {vars', todo' : _} ->
504 | Name -> (tag : Int) ->
505 | List Pat -> NamedPats todo' vars' ->
506 | Int -> (rhs : Term vars') ->
507 | (acc : List01 ne (Group todo' vars')) ->
508 | Core (List01 True (Group todo' vars'))
513 | addConG n tag pargs pats pid rhs []
514 | = do cty <- if n == UN (Basic "->")
515 | then pure $
NBind fc (MN "_" 0) (Pi fc top Explicit (MkNFClosure defaultOpts (mkEnv fc vars') (NType fc (MN "top" 0)))) $
516 | (\d, a => pure $
NBind fc (MN "_" 1) (Pi fc top Explicit (MkNFClosure defaultOpts (mkEnv fc vars') (NErased fc Placeholder)))
517 | (\d, a => pure $
NType fc (MN "top" 0)))
518 | else do defs <- get Ctxt
519 | Just t <- lookupTyExact n (gamma defs)
520 | | Nothing => pure (NErased fc Placeholder)
521 | nf defs (mkEnv fc vars') (embed t)
522 | (
patnames ** (l, newargs))
<- nextNames fc "e" pargs (Just cty)
525 | let pats' = updatePatNames (updateNames (zip patnames pargs))
527 | let clause = MkPatClause pvars (newargs ++ pats') pid (weakenNs l rhs)
528 | pure [ConGroup n tag [clause]]
529 | addConG n tag pargs pats pid rhs (g :: gs) with (checkGroupMatch (CName n tag) pargs g)
530 | addConG n tag pargs pats pid rhs
531 | (ConGroup n tag (MkPatClause pvars ps tid tm :: rest) :: gs) | ConMatch {newargs} lprf
532 | = do let newps = newPats pargs lprf ps
533 | let l = mkSizeOf newargs
534 | let pats' = updatePatNames (updateNames (zip newargs pargs))
536 | let newclause = MkPatClause pvars (newps ++ pats') pid (weakenNs l rhs)
539 | pure $
ConGroup n tag (MkPatClause pvars ps tid tm :: rest ++ [newclause]) :: gs
540 | addConG n tag pargs pats pid rhs (g :: gs) | NoMatch
541 | = (g ::) <$> addConG n tag pargs pats pid rhs gs
547 | addDelayG : {vars', todo' : _} ->
548 | Pat -> Pat -> NamedPats todo' vars' ->
549 | Int -> (rhs : Term vars') ->
550 | (acc : List01 ne (Group todo' vars')) ->
551 | Core (List01 True (Group todo' vars'))
552 | addDelayG pty parg pats pid rhs []
553 | = do let dty = NBind fc (MN "a" 0) (Pi fc erased Explicit (MkNFClosure defaultOpts (mkEnv fc vars') (NType fc (MN "top" 0)))) $
555 | do a' <- evalClosure d a
556 | pure (NBind fc (MN "x" 0) (Pi fc top Explicit a)
557 | (\dv, av => pure (NDelayed fc LUnknown a'))))
558 | (
[tyname, argname] ** (l, newargs))
<- nextNames fc "e" [pty, parg]
560 | | _ => throw (InternalError "Error compiling Delay pattern match")
561 | let pats' = updatePatNames (updateNames [(tyname, pty),
564 | let clause = MkPatClause pvars (newargs ++ pats') pid (weakenNs l rhs)
565 | pure [DelayGroup [clause]]
566 | addDelayG pty parg pats pid rhs (g :: gs) with (checkGroupMatch CDelay [] g)
567 | addDelayG pty parg pats pid rhs
568 | (DelayGroup (MkPatClause pvars ps tid tm :: rest) :: gs) | DelayMatch {tyarg} {valarg}
569 | = do let l = mkSizeOf [tyarg, valarg]
570 | let newps = newPats [pty, parg] (ConsMatch (ConsMatch NilMatch)) ps
571 | let pats' = updatePatNames (updateNames [(tyarg, pty),
574 | let newclause = MkPatClause pvars (newps ++ pats') pid (weakenNs l rhs)
575 | pure $
DelayGroup (MkPatClause pvars ps tid tm :: rest ++ [newclause]) :: gs
576 | addDelayG pty parg pats pid rhs (g :: gs) | NoMatch
577 | = (g ::) <$> addDelayG pty parg pats pid rhs gs
579 | addConstG : {vars', todo' : _} ->
580 | Constant -> NamedPats todo' vars' ->
581 | Int -> (rhs : Term vars') ->
582 | (acc : List01 ne (Group todo' vars')) ->
583 | Core (List01 True (Group todo' vars'))
584 | addConstG c pats pid rhs []
585 | = pure [ConstGroup c [MkPatClause pvars pats pid rhs]]
586 | addConstG c pats pid rhs (g :: gs) with (checkGroupMatch (CConst c) [] g)
587 | addConstG c pats pid rhs
588 | (ConstGroup c (MkPatClause pvars ps tid tm :: rest) :: gs) | ConstMatch
589 | = do let newclause = MkPatClause pvars pats pid rhs
590 | pure $
ConstGroup c (MkPatClause pvars ps tid tm :: rest ++ [newclause]) :: gs
591 | addConstG c pats pid rhs (g :: gs) | NoMatch
592 | = (g ::) <$> addConstG c pats pid rhs gs
594 | addGroup : {vars, todo, idx : _} ->
595 | (pat : Pat) -> (0 _ : IsConPat pat) =>
596 | (0 p : IsVar nm idx vars) ->
597 | NamedPats todo vars -> Int -> Term vars ->
598 | List01 ne (Group todo vars) ->
599 | Core (List01 True (Group todo vars))
602 | addGroup (PAs fc n p) pprf pats pid rhs acc
603 | = addGroup p pprf pats pid (substName n (Local fc (Just True) idx pprf) rhs) acc
604 | addGroup (PCon cfc n t a pargs) pprf pats pid rhs acc
605 | = if a == length pargs
606 | then addConG n t pargs pats pid rhs acc
607 | else throw (CaseCompile cfc fn (NotFullyApplied n))
608 | addGroup (PTyCon cfc n a pargs) pprf pats pid rhs acc
609 | = if a == length pargs
610 | then addConG n 0 pargs pats pid rhs acc
611 | else throw (CaseCompile cfc fn (NotFullyApplied n))
612 | addGroup (PArrow _ _ s t) pprf pats pid rhs acc
613 | = addConG (UN $
Basic "->") 0 [s, t] pats pid rhs acc
616 | addGroup (PDelay _ _ pty parg) pprf pats pid rhs acc
617 | = addDelayG pty parg pats pid rhs acc
618 | addGroup (PConst _ c) pprf pats pid rhs acc
619 | = addConstG c pats pid rhs acc
621 | gc : {a, vars, todo : _} ->
622 | List01 ne (Group todo vars) ->
623 | (p : PatClause (a :: todo) vars) ->
624 | (0 _ : IsConClause p) ->
625 | Core (List01 True (Group todo vars))
626 | gc acc (MkPatClause _ (MkInfo pat pprf _ :: pats) pid rhs) isCon
627 | = addGroup pat pprf pats pid rhs acc
629 | getFirstPat : NamedPats (p :: ps) ns -> Pat
630 | getFirstPat (p :: _) = pat p
632 | getFirstArgType : NamedPats (p :: ps) ns -> ArgType ns
633 | getFirstArgType (p :: _) = argType p
637 | data ScoredPats : List Name -> Scoped where
638 | Scored : List01 True (NamedPats (p :: ps) ns) -> Vect (length (p :: ps)) Int -> ScoredPats (p :: ps) ns
640 | {ps : _} -> Show (ScoredPats ps ns) where
641 | show (Scored xs ys) = (show ps) ++ "//" ++ (show ys)
643 | zeroedScore : {ps : _} -> List01 True (NamedPats (p :: ps) ns) -> ScoredPats (p :: ps) ns
644 | zeroedScore nps = Scored nps (replicate (S $
length ps) 0)
648 | highScore : {prev : List Name} ->
650 | (scores : Vect (length names) Int) ->
652 | (highIdx : (
n ** NVar n (prev ++ names))
) ->
654 | Maybe (
n ** NVar n (prev ++ names))
655 | highScore [] [] high idx True = Nothing
656 | highScore [] [] high idx False = Just idx
657 | highScore (x :: xs) (y :: ys) high idx duped =
658 | let next = highScore {prev = prev `snoc` x} xs ys
659 | prf = appendAssociative prev [x] xs
661 | case compare y high of
662 | LT => next high (rewrite sym $
prf in idx) duped
663 | EQ => next high (rewrite sym $
prf in idx) True
664 | GT => next y (
x ** rewrite sym $
prf in weakenNVar (mkSizeOf prev) (MkNVar First))
False
671 | highScoreIdx : {p : _} -> {ps : _} -> ScoredPats (p :: ps) ns -> Maybe (
n ** NVar n (p :: ps))
672 | highScoreIdx (Scored xs (y :: ys)) = highScore {prev = []} (p :: ps) (y :: ys) (y - 1) (
p ** MkNVar First)
False
676 | headConsPenalty : (penality : Nat -> Int) -> Pat -> Int
677 | headConsPenalty p (PAs _ _ w) = headConsPenalty p w
678 | headConsPenalty p (PCon _ n _ arity pats) = p arity
679 | headConsPenalty p (PTyCon _ _ arity _) = p arity
680 | headConsPenalty _ (PConst {}) = 0
681 | headConsPenalty _ (PArrow {}) = 0
682 | headConsPenalty p (PDelay _ _ _ w) = headConsPenalty p w
683 | headConsPenalty _ (PLoc {}) = 0
684 | headConsPenalty _ (PUnmatchable {}) = 0
686 | splitColumn : (nps : List01 True (NamedPats (p :: ps) ns)) -> (Vect (length nps) (PatInfo p ns), List01 True (NamedPats ps ns))
687 | splitColumn [(w :: ws)] = ([w], [ws])
688 | splitColumn ((w :: ws) :: nps@(_ :: _)) = bimap (w ::) (ws ::) $
splitColumn nps
692 | consScoreHeuristic : {ps : _} -> (scorePat : Pat -> Int) -> ScoredPats ps ns -> ScoredPats ps ns
693 | consScoreHeuristic scorePat (Scored xs ys) =
694 | let columnScores = scoreColumns xs
695 | ys' = zipWith (+) ys columnScores
698 | scoreColumns : {ps' : _} -> (nps : List01 True (NamedPats ps' ns)) -> Vect (length ps') Int
699 | scoreColumns {ps' = []} nps = []
700 | scoreColumns {ps' = w :: ws} nps =
701 | let (col, nps') = splitColumn nps
702 | in sum (scorePat . pat <$> col) :: scoreColumns nps'
706 | heuristicF : {ps : _} -> ScoredPats (p :: ps) ns -> ScoredPats (p :: ps) ns
707 | heuristicF (Scored (x :: xs) ys) =
708 | let columnScores = scores x
709 | ys' = zipWith (+) ys columnScores
710 | in Scored (x :: xs) ys'
712 | isBlank : Pat -> Bool
713 | isBlank (PLoc {}) = True
716 | scores : NamedPats ps' ns' -> Vect (length ps') Int
718 | scores (y :: ys) = let score : Int = if isBlank (pat y) then 0 else 1
719 | in score :: scores ys
723 | heuristicB : {ps : _} -> ScoredPats ps ns -> ScoredPats ps ns
724 | heuristicB = consScoreHeuristic (headConsPenalty (\arity => if arity == 0 then 0 else -
1))
727 | heuristicA : {ps : _} -> ScoredPats ps ns -> ScoredPats ps ns
728 | heuristicA = consScoreHeuristic (headConsPenalty (negate . cast))
730 | applyHeuristics : {p : _} ->
732 | ScoredPats (p :: ps) ns ->
733 | List (ScoredPats (p :: ps) ns -> ScoredPats (p :: ps) ns) ->
734 | Maybe (
n ** NVar n (p :: ps))
735 | applyHeuristics x [] = highScoreIdx x
736 | applyHeuristics x (f :: fs) = highScoreIdx x <|> applyHeuristics (f x) fs
743 | nextIdxByScore : {p : _} ->
745 | (useHeuristics : Bool) ->
747 | List01 True (NamedPats (p :: ps) ns) ->
748 | (
n ** NVar n (p :: ps))
749 | nextIdxByScore False _ _ = (
_ ** (MkNVar First))
750 | nextIdxByScore _ (CompileTime _) _ = (
_ ** (MkNVar First))
751 | nextIdxByScore True RunTime xs =
752 | fromMaybe (
_ ** (MkNVar First)) $
753 | applyHeuristics (zeroedScore xs) [heuristicF, heuristicB, heuristicA]
758 | sameType : {ns : _} ->
759 | {auto c : Ref Ctxt Defs} ->
760 | FC -> Phase -> Name ->
761 | Env Term ns -> List01 ne (NamedPats (p :: ps) ns) ->
763 | sameType fc phase fn env [] = pure ()
764 | sameType {ns} fc phase fn env (p :: xs)
765 | = do defs <- get Ctxt
766 | case getFirstArgType p of
767 | Known _ t => sameTypeAs phase
769 | (map getFirstArgType xs)
770 | ty => throw (CaseCompile fc fn DifferingTypes)
772 | firstPat : NamedPats (np :: nps) ns -> Pat
773 | firstPat (pinf :: _) = pat pinf
775 | headEq : NF ns -> NF ns -> Phase -> Bool
776 | headEq (NBind _ _ (Pi {}) _) (NBind _ _ (Pi {}) _) _ = True
777 | headEq (NTCon _ n _ _) (NTCon _ n' _ _) _ = n == n'
778 | headEq (NPrimVal _ c) (NPrimVal _ c') _ = c == c'
779 | headEq (NType {}) (NType {}) _ = True
780 | headEq (NApp _ (NRef _ n) _) (NApp _ (NRef _ n') _) RunTime = n == n'
781 | headEq (NErased _ (Dotted x)) y ph = headEq x y ph
782 | headEq x (NErased _ (Dotted y)) ph = headEq x y ph
783 | headEq (NErased {}) _ RunTime = True
784 | headEq _ (NErased {}) RunTime = True
785 | headEq _ _ _ = False
787 | sameTypeAs : forall ne. Phase -> NF ns -> List01 ne (ArgType ns) -> Core ()
788 | sameTypeAs _ ty [] = pure ()
789 | sameTypeAs ph ty (Known r t :: xs) =
790 | do defs <- get Ctxt
791 | if headEq ty !(nf defs env t) phase
792 | then sameTypeAs ph ty xs
793 | else throw (CaseCompile fc fn DifferingTypes)
794 | sameTypeAs p ty _ = throw (CaseCompile fc fn DifferingTypes)
798 | samePat : List01 True (NamedPats (p :: ps) ns) -> Bool
800 | = samePatAs (dropAs (getFirstPat pi))
801 | (map (dropAs . getFirstPat) xs)
803 | dropAs : Pat -> Pat
804 | dropAs (PAs _ _ p) = p
807 | samePatAs : Pat -> List01 ne Pat -> Bool
808 | samePatAs p [] = True
809 | samePatAs (PTyCon fc n a args) (PTyCon _ n' _ _ :: ps)
810 | = n == n' && samePatAs (PTyCon fc n a args) ps
811 | samePatAs (PCon fc n t a args) (PCon _ n' t' _ _ :: ps)
812 | = n == n' && t == t' && samePatAs (PCon fc n t a args) ps
813 | samePatAs (PConst fc c) (PConst _ c' :: ps)
814 | = c == c' && samePatAs (PConst fc c) ps
815 | samePatAs (PArrow fc x s t) (PArrow _ _ s' t' :: ps)
816 | = samePatAs (PArrow fc x s t) ps
817 | samePatAs (PDelay fc r t p) (PDelay _ _ _ _ :: ps)
818 | = samePatAs (PDelay fc r t p) ps
819 | samePatAs (PLoc fc n) (PLoc _ _ :: ps) = samePatAs (PLoc fc n) ps
820 | samePatAs x y = False
822 | getScore : {ns : _} ->
823 | {auto c : Ref Ctxt Defs} ->
824 | FC -> Phase -> Name ->
825 | List01 True (NamedPats (p :: ps) ns) ->
826 | Core (Either CaseError ())
827 | getScore fc phase name npss
828 | = catch (Right () <$ sameType fc phase name (mkEnv fc ns) npss)
830 | CaseCompile _ _ err => pure $
Left err
835 | pickNextViable : {p, ns, ps : _} ->
836 | {auto c : Ref Ctxt Defs} ->
837 | FC -> Phase -> Name -> List01 True (NamedPats (p :: ps) ns) ->
838 | Core (
n ** NVar n (p :: ps))
840 | pickNextViable {ps = []} fc phase fn npss
842 | then pure (
_ ** MkNVar First)
843 | else do Right () <- getScore fc phase fn npss
844 | | Left err => throw (CaseCompile fc fn err)
845 | pure (
_ ** MkNVar First)
846 | pickNextViable {ps = q :: qs} fc phase fn npss
848 | then pure (
_ ** MkNVar First)
849 | else case !(getScore fc phase fn npss) of
850 | Right () => pure (
_ ** MkNVar First)
851 | _ => do (
_ ** MkNVar var)
<- pickNextViable fc phase fn (map tail npss)
852 | pure (
_ ** MkNVar (Later var))
854 | moveFirst : {idx : Nat} -> (0 el : IsVar nm idx ps) -> NamedPats ps ns ->
855 | NamedPats (nm :: dropIsVar ps el) ns
856 | moveFirst el nps = getPat el nps :: dropPat el nps
858 | shuffleVars : {idx : Nat} -> (0 el : IsVar nm idx todo) -> PatClause todo vars ->
859 | PatClause (nm :: dropIsVar todo el) vars
860 | shuffleVars First orig@(MkPatClause pvars lhs pid rhs) = orig
861 | shuffleVars el (MkPatClause pvars lhs pid rhs)
862 | = MkPatClause pvars (moveFirst el lhs) pid rhs
871 | match : {vars, todo : _} ->
872 | {auto i : Ref PName Int} ->
873 | {auto c : Ref Ctxt Defs} ->
874 | FC -> Name -> Phase ->
875 | List01 True (PatClause todo vars) ->
876 | IMaybe ne (CaseTree vars) ->
877 | Core (CaseTree vars)
881 | match {todo = _ :: _} fc fn phase clauses err
882 | = do let nps = getNPs <$> clauses
883 | let (
_ ** (MkNVar next))
= nextIdxByScore (caseTreeHeuristics !getSession) phase nps
884 | let prioritizedClauses = shuffleVars next <$> clauses
885 | (
n ** MkNVar next')
<- pickNextViable fc phase fn (getNPs <$> prioritizedClauses)
886 | log "compile.casetree.pick" 25 $
"Picked " ++ show n ++ " as the next split"
887 | let clauses' = shuffleVars next' <$> prioritizedClauses
888 | log "compile.casetree.clauses" 25 $
889 | unlines ("Using clauses:" :: map ((" " ++) . show) (forget clauses'))
890 | let ps = partition phase clauses'
891 | log "compile.casetree.partition" 25 $
"Got Partition:\n" ++ show ps
892 | Just mix <- mixture fc fn phase ps err
893 | log "compile.casetree.intermediate" 25 $
"match: new case tree " ++ show mix
895 | match {todo = []} fc fn phase (MkPatClause pvars [] pid (Erased _ Impossible) :: _) err
897 | match {todo = []} fc fn phase (MkPatClause pvars [] pid rhs :: _) err
898 | = pure $
STerm pid rhs
900 | caseGroups : {pvar, vars, todo : _} ->
901 | {auto i : Ref PName Int} ->
902 | {auto c : Ref Ctxt Defs} ->
903 | FC -> Name -> Phase ->
904 | {idx : Nat} -> (0 p : IsVar pvar idx vars) -> Term vars ->
905 | List01 True (Group todo vars) -> IMaybe ne (CaseTree vars) ->
906 | Core (CaseTree vars)
907 | caseGroups fc fn phase el ty gs errorCase
908 | = Case idx el (resolveNames vars ty) <$> altGroups gs
910 | altGroups : forall ne. List01 ne (Group todo vars) -> Core (List (CaseAlt vars))
911 | altGroups [] = pure $
toList $
DefaultCase <$> errorCase
912 | altGroups (ConGroup {newargs} cn tag rest :: cs)
913 | = do crest <- match fc fn phase rest (map (weakenNs (mkSizeOf newargs)) errorCase)
914 | cs' <- altGroups cs
915 | pure (ConCase cn tag newargs crest :: cs')
916 | altGroups (DelayGroup {tyarg} {valarg} rest :: cs)
917 | = do crest <- match fc fn phase rest (map (weakenNs (mkSizeOf [tyarg, valarg])) errorCase)
918 | cs' <- altGroups cs
919 | pure (DelayCase tyarg valarg crest :: cs')
920 | altGroups (ConstGroup c rest :: cs)
921 | = do crest <- match fc fn phase rest errorCase
922 | cs' <- altGroups cs
923 | pure (ConstCase c crest :: cs')
925 | conRule : {a, vars, todo : _} ->
926 | {auto i : Ref PName Int} ->
927 | {auto c : Ref Ctxt Defs} ->
928 | FC -> Name -> Phase ->
929 | (cs : List01 True (PatClause (a :: todo) vars)) ->
930 | (0 isCons : All IsConClause cs) =>
931 | IMaybe ne (CaseTree vars) ->
932 | Core (CaseTree vars)
937 | conRule {a} fc fn phase cs@(MkPatClause pvars (MkInfo pat pprf fty :: pats) pid rhs :: rest) err
938 | = do Element refinedcs _ <- pullOut <$> traverseList01 (substInClause fc) (pushIn cs isCons)
939 | groups <- groupCons fc fn pvars refinedcs
941 | Known _ t => pure t
942 | _ => throw (CaseCompile fc fn UnknownType)
943 | caseGroups fc fn phase pprf ty groups err
945 | varRule : {a, vars, todo : _} ->
946 | {auto i : Ref PName Int} ->
947 | {auto c : Ref Ctxt Defs} ->
948 | FC -> Name -> Phase ->
949 | List01 True (PatClause (a :: todo) vars) ->
950 | IMaybe ne (CaseTree vars) ->
951 | Core (CaseTree vars)
952 | varRule fc fn phase cs err
953 | = do alts' <- traverseList01 updateVar cs
954 | match fc fn phase alts' err
956 | updateVar : PatClause (a :: todo) vars -> Core (PatClause todo vars)
958 | updateVar (MkPatClause pvars (MkInfo (PLoc pfc n) prf fty :: pats) pid rhs)
959 | = pure $
MkPatClause (n :: pvars)
960 | !(substInPats fc a (Local pfc (Just False) _ prf) pats)
961 | pid (substName n (Local pfc (Just False) _ prf) rhs)
964 | updateVar (MkPatClause pvars (MkInfo (PAs pfc n pat) prf fty :: pats) pid rhs)
965 | = do pats' <- substInPats fc a (mkTerm _ pat) pats
966 | let rhs' = substName n (Local pfc (Just True) _ prf) rhs
967 | updateVar (MkPatClause pvars (MkInfo pat prf fty :: pats') pid rhs')
970 | updateVar (MkPatClause pvars (MkInfo pat prf fty :: pats) pid rhs)
971 | = pure $
MkPatClause pvars
972 | !(substInPats fc a (mkTerm vars pat) pats) pid rhs
974 | mixture : {a, vars, todo : _} ->
975 | {auto i : Ref PName Int} ->
976 | {auto c : Ref Ctxt Defs} ->
977 | {ps : List01 ne (PatClause (a :: todo) vars)} ->
978 | FC -> Name -> Phase ->
980 | IMaybe neErr (CaseTree vars) ->
981 | Core (IMaybe (ne || neErr) (CaseTree vars))
982 | mixture fc fn phase (ConClauses cs rest) err
983 | = do fallthrough <- mixture fc fn phase rest err
984 | Just <$> conRule fc fn phase cs fallthrough
985 | mixture fc fn phase (VarClauses vs rest) err
986 | = do fallthrough <- mixture fc fn phase rest err
987 | Just <$> varRule fc fn phase vs fallthrough
988 | mixture fc fn phase NoClauses err
992 | mkPat : {auto c : Ref Ctxt Defs} ->
993 | (matchable : Bool) -> List Pat -> ClosedTerm -> ClosedTerm -> Core Pat
994 | mkPat _ [] orig (Ref fc Bound n) = pure $
PLoc fc n
995 | mkPat True args orig (Ref fc (DataCon t a) n) = pure $
PCon fc n t a args
996 | mkPat True args orig (Ref fc (TyCon a) n) = pure $
PTyCon fc n a args
997 | mkPat True args orig (Ref fc Func n)
998 | = do prims <- getPrimitiveNames
999 | mtm <- normalisePrims (const True) isPConst True prims n args orig Env.empty
1001 | Just tm => if tm /= orig
1004 | then mkPat True [] tm tm
1006 | pure $
PUnmatchable (getLoc orig) orig
1008 | do log "compile.casetree" 10 $
1009 | "Unmatchable function: " ++ show n
1010 | pure $
PUnmatchable (getLoc orig) orig
1011 | mkPat True args orig (Bind fc x (Pi _ _ _ s) t)
1013 | = case subst (Erased fc Placeholder) t of
1014 | App _ t'@(Ref fc Bound n) (Erased {}) => pure $
PArrow fc x !(mkPat True [] s s) !(mkPat False [] t' t')
1015 | t' => pure $
PArrow fc x !(mkPat True [] s s) !(mkPat False [] t' t')
1016 | mkPat True args orig (App fc fn arg)
1017 | = do parg <- mkPat True [] arg arg
1018 | mkPat True (parg :: args) orig fn
1019 | mkPat True args orig (As fc _ (Ref _ Bound n) ptm)
1020 | = pure $
PAs fc n !(mkPat True [] ptm ptm)
1021 | mkPat True args orig (As fc _ _ ptm)
1022 | = mkPat True [] orig ptm
1023 | mkPat True args orig (TDelay fc r ty p)
1024 | = pure $
PDelay fc r !(mkPat True [] orig ty) !(mkPat True [] orig p)
1025 | mkPat True args orig (PrimVal fc $
PrT c)
1026 | = pure $
PTyCon fc (UN (Basic $
show c)) 0 []
1027 | mkPat True args orig (PrimVal fc c) = pure $
PConst fc c
1028 | mkPat True args orig (TType fc _) = pure $
PTyCon fc (UN $
Basic "Type") 0 []
1030 | = do log "compile.casetree" 10 $
1031 | "Catchall: marking " ++ show tm ++ " as unmatchable"
1032 | pure $
PUnmatchable (getLoc orig) orig
1035 | argToPat : {auto c : Ref Ctxt Defs} -> ClosedTerm -> Core Pat
1036 | argToPat tm = mkPat True [] tm tm
1038 | mkPatClause : {auto c : Ref Ctxt Defs} ->
1040 | (args : Scope) -> ClosedTerm ->
1041 | Int -> (List Pat, ClosedTerm) ->
1042 | Core (PatClause args args)
1043 | mkPatClause fc fn args ty pid (ps, rhs)
1044 | = maybe (throw (CaseCompile fc fn DifferingArgNumbers))
1047 | nty <- nf defs Env.empty ty
1048 | ns <- mkNames args ps eq (Just nty)
1049 | log "compile.casetree" 20 $
1050 | "Make pat clause for names " ++ show ns
1051 | ++ " in LHS " ++ show ps
1052 | pure (MkPatClause [] ns pid
1053 | (rewrite sym (appendNilRightNeutral args) in
1054 | (weakenNs (mkSizeOf args) rhs))))
1055 | (checkLengthMatch args ps)
1057 | mkNames : (vars : Scope) -> (ps : List Pat) ->
1058 | LengthMatch vars ps -> Maybe (NF []) ->
1059 | Core (NamedPats vars vars)
1060 | mkNames [] [] NilMatch fty = pure []
1061 | mkNames (arg :: args) (p :: ps) (ConsMatch eq) fty
1063 | empty <- clearDefs defs
1064 | fa_tys <- the (Core (Maybe _, ArgType _)) $
1066 | Nothing => pure (Nothing, CaseBuilder.Unknown)
1067 | Just (NBind pfc _ (Pi _ c _ farg) fsc) =>
1068 | pure (Just !(fsc defs (toClosure defaultOpts [] (Ref pfc Bound arg))),
1069 | Known c (embed !(quote empty [] farg)))
1071 | pure (Nothing, Stuck (embed !(quote empty [] t)))
1072 | pure (MkInfo p First (Builtin.snd fa_tys)
1073 | :: weaken !(mkNames args ps eq (Builtin.fst fa_tys)))
1076 | patCompile : {auto c : Ref Ctxt Defs} ->
1078 | ClosedTerm -> List01 True (List Pat, ClosedTerm) ->
1079 | Core (
args ** CaseTree args)
1080 | patCompile fc fn phase ty (p :: ps)
1081 | = do let (
ns ** n)
= getNames 0 (fst p)
1082 | pats <- mkPatClausesFrom 0 ns (p :: ps)
1084 | logC "compile.casetree" 5 $
do
1085 | pats <- traverse toFullNames $
forget pats
1086 | pure $
"Pattern clauses:\n"
1087 | ++ show (indent 2 $
vcat $
pretty <$> pats)
1089 | log "compile.casetree" 10 $
show pats
1090 | i <- newRef PName (the Int 0)
1091 | cases <- match fc fn phase pats Nothing
1094 | mkPatClausesFrom : Int -> (args : Scope) ->
1095 | List01 ne (List Pat, ClosedTerm) ->
1096 | Core (List01 ne (PatClause args args))
1097 | mkPatClausesFrom i ns [] = pure []
1098 | mkPatClausesFrom i ns (p :: ps)
1099 | = do p' <- mkPatClause fc fn ns ty i p
1100 | ps' <- mkPatClausesFrom (i + 1) ns ps
1103 | getNames : Int -> List Pat -> (ns : Scope ** SizeOf ns)
1104 | getNames i [] = (
[] ** zero)
1106 | let (
ns ** n)
= getNames (i + 1) xs
1107 | in (
MN "arg" i :: ns ** suc n)
1109 | toPatClause : {auto c : Ref Ctxt Defs} ->
1110 | FC -> Name -> (ClosedTerm, ClosedTerm) ->
1111 | Core (List Pat, ClosedTerm)
1112 | toPatClause fc n (lhs, rhs)
1113 | = case getFnArgs lhs of
1114 | (Ref ffc Func fn, args)
1116 | (np, _) <- getPosition n (gamma defs)
1117 | (fnp, _) <- getPosition fn (gamma defs)
1119 | then pure (!(traverse argToPat args), rhs)
1120 | else throw (GenericMsg ffc ("Wrong function name in pattern LHS " ++ show (n, fn)))
1121 | (f, args) => throw (GenericMsg fc "Not a function name in pattern LHS")
1127 | simpleCase : {auto c : Ref Ctxt Defs} ->
1128 | FC -> Phase -> Name -> ClosedTerm ->
1129 | (clauses : List01 True (ClosedTerm, ClosedTerm)) ->
1130 | Core (
args ** CaseTree args)
1131 | simpleCase fc phase fn ty clauses
1132 | = do logC "compile.casetree" 5 $
1133 | do cs <- traverse (\ (c,d) => [| MkPair (toFullNames c) (toFullNames d) |]) (forget clauses)
1134 | pure $
"simpleCase: Clauses:\n" ++ show (
1135 | indent 2 $
vcat $
flip map cs $
\ lrhs =>
1136 | byShow (fst lrhs) <++> pretty "=" <++> byShow (snd lrhs))
1137 | ps <- traverseList01 (toPatClause fc fn) clauses
1139 | patCompile fc fn phase ty ps
1142 | findReachedAlts : CaseAlt ns' -> List Int
1143 | findReachedAlts (ConCase _ _ _ t) = findReached t
1144 | findReachedAlts (DelayCase _ _ t) = findReached t
1145 | findReachedAlts (ConstCase _ t) = findReached t
1146 | findReachedAlts (DefaultCase t) = findReached t
1148 | findReached : CaseTree ns -> List Int
1149 | findReached (Case _ _ _ alts) = concatMap findReachedAlts alts
1150 | findReached (STerm i _) = [i]
1158 | identifyUnreachableDefaults : {auto c : Ref Ctxt Defs} ->
1160 | FC -> Defs -> NF vars -> List (CaseAlt vars) ->
1164 | identifyUnreachableDefaults fc defs (NPrimVal {}) cs = pure empty
1165 | identifyUnreachableDefaults fc defs (NType {}) cs = pure empty
1166 | identifyUnreachableDefaults fc defs nfty cs
1167 | = do cs' <- traverse rep cs
1168 | let (cs'', extraClauseIdxs) = dropRep (concat cs') empty
1170 | if (length cs == (length cs'' + 1))
1175 | when (not $
null extraClauseIdxs') $
1176 | log "compile.casetree.clauses" 25 $
1177 | "Marking the following clause indices as unreachable under the current branch of the tree: " ++ (show extraClauseIdxs')
1180 | rep : CaseAlt vars -> Core (List (CaseAlt vars))
1182 | = do allCons <- getCons defs nfty
1183 | pure (map (mkAlt fc sc) allCons)
1186 | dropRep : List (CaseAlt vars) -> SortedSet Int -> (List (CaseAlt vars), SortedSet Int)
1187 | dropRep [] extra = ([], extra)
1188 | dropRep (c@(ConCase n t args sc) :: rest) extra
1191 | = let (filteredClauses, extraCases) = partition (not . tagIs t) rest
1192 | extraClauses = extraCases >>= findReachedAlts
1193 | (rest', extra') = dropRep filteredClauses $
fromList extraClauses
1194 | in (c :: rest', extra `union` extra')
1195 | dropRep (c :: rest) extra
1196 | = let (rest', extra') = dropRep rest extra
1197 | in (c :: rest', extra')
1206 | findExtraDefaults : {auto c : Ref Ctxt Defs} ->
1208 | FC -> Defs -> CaseTree vars ->
1210 | findExtraDefaults fc defs (Case idx el ty altsIn)
1211 | = do let fenv = mkEnv fc vars
1212 | nfty <- nf defs fenv ty
1213 | extraCases <- identifyUnreachableDefaults fc defs nfty altsIn
1214 | extraCases' <- concat <$> traverse findExtraAlts altsIn
1215 | pure (Prelude.toList extraCases ++ extraCases')
1217 | findExtraAlts : CaseAlt vars -> Core (List Int)
1218 | findExtraAlts (ConCase x tag args ctree) = findExtraDefaults fc defs ctree
1219 | findExtraAlts (DelayCase x arg ctree) = findExtraDefaults fc defs ctree
1220 | findExtraAlts (ConstCase x ctree) = findExtraDefaults fc defs ctree
1222 | findExtraAlts (DefaultCase ctree) = pure []
1224 | findExtraDefaults fc defs ctree = pure []
1228 | getPMDef : {auto c : Ref Ctxt Defs} ->
1229 | FC -> Phase -> Name -> ClosedTerm -> List Clause ->
1230 | Core (
args ** (CaseTree args, List Clause))
1234 | getPMDef fc phase fn ty []
1235 | = do log "compile.casetree.getpmdef" 20 "getPMDef: No clauses!"
1237 | pure (
!(getArgs 0 !(nf defs Env.empty ty)) ** (Unmatched "No clauses in \{show fn}", []))
1239 | getArgs : Int -> ClosedNF -> Core (List Name)
1240 | getArgs i (NBind fc x (Pi {}) sc)
1242 | sc' <- sc defs (toClosure defaultOpts Env.empty (Erased fc Placeholder))
1243 | pure (MN "arg" i :: !(getArgs i sc'))
1245 | getPMDef fc phase fn ty clauses@(_ :: _)
1247 | let cs = map (toClosed defs) (labelPat 0 $
fromList clauses)
1248 | (
_ ** t)
<- simpleCase fc phase fn ty cs
1249 | logC "compile.casetree.getpmdef" 20 $
1250 | pure $
"Compiled to: " ++ show !(toFullNames t)
1251 | let reached = findReached t
1252 | log "compile.casetree.clauses" 25 $
1253 | "Reached clauses: " ++ (show reached)
1254 | extraDefaults <- findExtraDefaults fc defs t
1255 | let unreachable = getUnreachable 0 (reached \\ extraDefaults) clauses
1256 | pure (
_ ** (t, unreachable))
1258 | getUnreachable : Int -> List Int -> List Clause -> List Clause
1259 | getUnreachable i is [] = []
1260 | getUnreachable i is (c :: cs)
1262 | then getUnreachable (i + 1) is cs
1263 | else c :: getUnreachable (i + 1) is cs
1265 | labelPat : Int -> List01 ne a -> List01 ne (String, a)
1267 | labelPat i (x :: xs) = ("pat" ++ show i ++ ":", x) :: labelPat (i + 1) xs
1269 | toClosed : Defs -> (String, Clause) -> (ClosedTerm, ClosedTerm)
1270 | toClosed defs (pname, MkClause env lhs rhs)
1271 | = (close fc pname env lhs, close fc pname env rhs)