0 | module Core.Case.CaseTree
7 | import Idris.Pretty.Annotations
9 | import Libraries.Data.NameMap
10 | import Libraries.Text.PrettyPrint.Prettyprinter
11 | import Libraries.Data.List.SizeOf
19 | data CaseTree : Scoped where
21 | Case : {name : _} ->
23 | (0 p : IsVar name idx vars) ->
24 | (scTy : Term vars) -> List (CaseAlt vars) ->
29 | STerm : Int -> Term vars -> CaseTree vars
31 | Unmatched : (msg : String) -> CaseTree vars
33 | Impossible : CaseTree vars
38 | data CaseAlt : Scoped where
40 | ConCase : Name -> (tag : Int) -> (args : List Name) ->
41 | CaseTree (Scope.addInner vars args) -> CaseAlt vars
43 | DelayCase : (ty : Name) -> (arg : Name) ->
44 | CaseTree (Scope.addInner vars [ty, arg]) -> CaseAlt vars
47 | ConstCase : Constant -> CaseTree vars -> CaseAlt vars
49 | DefaultCase : CaseTree vars -> CaseAlt vars
52 | FreelyEmbeddable CaseTree where
56 | measure : CaseTree vars -> Nat
57 | measure (Case idx p scTy xs) = sum $
measureAlts <$> xs
58 | measure (STerm x y) = 0
59 | measure (Unmatched msg) = 0
60 | measure Impossible = 0
62 | measureAlts : CaseAlt vars -> Nat
63 | measureAlts (ConCase x tag args y) = 1 + (measure y)
64 | measureAlts (DelayCase ty arg x) = 1 + (measure x)
65 | measureAlts (ConstCase x y) = 1 + (measure y)
66 | measureAlts (DefaultCase x) = 1 + (measure x)
69 | isDefault : CaseAlt vars -> Bool
70 | isDefault (DefaultCase _) = True
75 | StripNamespace (CaseTree vars) where
76 | trimNS ns (Case idx p scTy xs)
77 | = Case idx p (trimNS ns scTy) (map (trimNS ns) xs)
78 | trimNS ns (STerm x t) = STerm x (trimNS ns t)
81 | restoreNS ns (Case idx p scTy xs)
82 | = Case idx p (restoreNS ns scTy) (map (restoreNS ns) xs)
83 | restoreNS ns (STerm x t) = STerm x (restoreNS ns t)
87 | StripNamespace (CaseAlt vars) where
88 | trimNS ns (ConCase x tag args t) = ConCase x tag args (trimNS ns t)
89 | trimNS ns (DelayCase ty arg t) = DelayCase ty arg (trimNS ns t)
90 | trimNS ns (ConstCase x t) = ConstCase x (trimNS ns t)
91 | trimNS ns (DefaultCase t) = DefaultCase (trimNS ns t)
93 | restoreNS ns (ConCase x tag args t) = ConCase x tag args (restoreNS ns t)
94 | restoreNS ns (DelayCase ty arg t) = DelayCase ty arg (restoreNS ns t)
95 | restoreNS ns (ConstCase x t) = ConstCase x (restoreNS ns t)
96 | restoreNS ns (DefaultCase t) = DefaultCase (restoreNS ns t)
100 | data Pat : Type where
101 | PAs : FC -> Name -> Pat -> Pat
102 | PCon : FC -> Name -> (tag : Int) -> (arity : Nat) ->
104 | PTyCon : FC -> Name -> (arity : Nat) -> List Pat -> Pat
105 | PConst : FC -> (c : Constant) -> Pat
106 | PArrow : FC -> (x : Name) -> Pat -> Pat -> Pat
107 | PDelay : FC -> LazyReason -> Pat -> Pat -> Pat
109 | PLoc : FC -> Name -> Pat
110 | PUnmatchable : FC -> ClosedTerm -> Pat
113 | isPConst : Pat -> Maybe Constant
114 | isPConst (PConst _ c) = Just c
115 | isPConst _ = Nothing
118 | 0 isConPat : Pat -> Bool
119 | isConPat (PAs _ _ p) = isConPat p
120 | isConPat (PCon {}) = True
121 | isConPat (PTyCon {}) = True
122 | isConPat (PConst {}) = True
123 | isConPat (PArrow {}) = True
124 | isConPat (PDelay {}) = True
128 | 0 IsConPat : Pat -> Type
129 | IsConPat = So . isConPat
131 | showCT : {vars : _} -> (indent : String) -> CaseTree vars -> String
132 | showCA : {vars : _} -> (indent : String) -> CaseAlt vars -> String
134 | showCT indent (Case {name} idx prf ty alts)
135 | = "case " ++ show name ++ "[" ++ show idx ++ "] : " ++ show ty ++ " of"
136 | ++ "\n" ++ indent ++ " { "
137 | ++ showSep ("\n" ++ indent ++ " | ")
138 | (assert_total (map (showCA (" " ++ indent)) alts))
139 | ++ "\n" ++ indent ++ " }"
140 | showCT indent (STerm i tm) = "[" ++ show i ++ "] " ++ show tm
141 | showCT indent (Unmatched msg) = "Error: " ++ show msg
142 | showCT indent Impossible = "Impossible"
144 | showCA indent (ConCase n tag args sc)
145 | = showSep " " (map show (n :: args)) ++ " => " ++
147 | showCA indent (DelayCase _ arg sc)
148 | = "Delay " ++ show arg ++ " => " ++ showCT indent sc
149 | showCA indent (ConstCase c sc)
150 | = "Constant " ++ show c ++ " => " ++ showCT indent sc
151 | showCA indent (DefaultCase sc)
152 | = "_ => " ++ showCT indent sc
156 | {vars : _} -> Show (CaseTree vars) where
161 | {vars : _} -> Show (CaseAlt vars) where
166 | eqTree : CaseTree vs -> CaseTree vs' -> Bool
167 | eqTree (Case i _ _ alts) (Case i' _ _ alts')
169 | && length alts == length alts'
170 | && all (uncurry eqAlt) (zip alts alts')
171 | eqTree (STerm _ t) (STerm _ t') = eqTerm t t'
172 | eqTree (Unmatched _) (Unmatched _) = True
173 | eqTree Impossible Impossible = True
176 | eqAlt : CaseAlt vs -> CaseAlt vs' -> Bool
177 | eqAlt (ConCase n t args tree) (ConCase n' t' args' tree')
178 | = n == n' && eqTree tree tree'
179 | eqAlt (DelayCase _ _ tree) (DelayCase _ _ tree')
180 | = eqTree tree tree'
181 | eqAlt (ConstCase c tree) (ConstCase c' tree')
182 | = c == c' && eqTree tree tree'
183 | eqAlt (DefaultCase tree) (DefaultCase tree')
184 | = eqTree tree tree'
190 | show (PAs _ n p) = show n ++ "@(" ++ show p ++ ")"
191 | show (PCon _ n i _ args) = show n ++ " " ++ show i ++ " " ++ assert_total (show args)
192 | show (PTyCon _ n _ args) = "<TyCon>" ++ show n ++ " " ++ assert_total (show args)
193 | show (PConst _ c) = show c
194 | show (PArrow _ x s t) = "(" ++ show s ++ " -> " ++ show t ++ ")"
195 | show (PDelay _ _ _ p) = "(Delay " ++ show p ++ ")"
196 | show (PLoc _ n) = show n
197 | show (PUnmatchable _ tm) = ".(" ++ show tm ++ ")"
200 | Pretty IdrisSyntax Pat where
201 | prettyPrec d (PAs _ n p) = pretty0 n <++> keyword "@" <+> parens (pretty p)
202 | prettyPrec d (PCon _ n _ _ args) =
203 | parenthesise (d > Open) $
hsep (pretty0 n :: map (prettyPrec App) args)
204 | prettyPrec d (PTyCon _ n _ args) =
205 | parenthesise (d > Open) $
hsep (pretty0 n :: map (prettyPrec App) args)
206 | prettyPrec d (PConst _ c) = pretty c
207 | prettyPrec d (PArrow _ _ p q) =
208 | parenthesise (d > Open) $
pretty p <++> arrow <++> pretty q
209 | prettyPrec d (PDelay _ _ _ p) = parens ("Delay" <++> pretty p)
210 | prettyPrec d (PLoc _ n) = pretty0 n
211 | prettyPrec d (PUnmatchable _ tm) = keyword "." <+> parens (byShow tm)
214 | insertCaseNames : SizeOf outer ->
216 | CaseTree (outer ++ inner) ->
217 | CaseTree (outer ++ (ns ++ inner))
218 | insertCaseNames outer ns (Case idx prf scTy alts)
219 | = let MkNVar prf' = insertNVarNames outer ns (MkNVar prf) in
220 | Case _ prf' (insertNames outer ns scTy)
221 | (map (insertCaseAltNames outer ns) alts)
222 | insertCaseNames outer ns (STerm i x) = STerm i (insertNames outer ns x)
223 | insertCaseNames _ _ (Unmatched msg) = Unmatched msg
224 | insertCaseNames _ _ Impossible = Impossible
226 | insertCaseAltNames : SizeOf outer ->
228 | CaseAlt (outer ++ inner) ->
229 | CaseAlt (outer ++ (ns ++ inner))
230 | insertCaseAltNames p q (ConCase x tag args ct)
231 | = ConCase x tag args
232 | (rewrite appendAssociative args outer (ns ++ inner) in
233 | insertCaseNames (mkSizeOf args + p) q {inner}
234 | (rewrite sym (appendAssociative args outer inner) in
236 | insertCaseAltNames outer ns (DelayCase tyn valn ct)
237 | = DelayCase tyn valn
238 | (insertCaseNames (suc (suc outer)) ns ct)
239 | insertCaseAltNames outer ns (ConstCase x ct)
240 | = ConstCase x (insertCaseNames outer ns ct)
241 | insertCaseAltNames outer ns (DefaultCase ct)
242 | = DefaultCase (insertCaseNames outer ns ct)
245 | Weaken CaseTree where
246 | weakenNs ns t = insertCaseNames zero ns t
249 | getNames : (forall vs . NameMap Bool -> Term vs -> NameMap Bool) ->
250 | NameMap Bool -> CaseTree vars -> NameMap Bool
251 | getNames add ns sc = getSet ns sc
254 | getAltSet : NameMap Bool -> CaseAlt vs -> NameMap Bool
255 | getAltSet ns (ConCase n t args sc) = getSet ns sc
256 | getAltSet ns (DelayCase t a sc) = getSet ns sc
257 | getAltSet ns (ConstCase i sc) = getSet ns sc
258 | getAltSet ns (DefaultCase sc) = getSet ns sc
260 | getAltSets : NameMap Bool -> List (CaseAlt vs) -> NameMap Bool
261 | getAltSets ns [] = ns
262 | getAltSets ns (a :: as) = getAltSets (getAltSet ns a) as
264 | getSet : NameMap Bool -> CaseTree vs -> NameMap Bool
265 | getSet ns (Case _ x ty xs) = getAltSets ns xs
266 | getSet ns (STerm i tm) = add ns tm
267 | getSet ns (Unmatched msg) = ns
268 | getSet ns Impossible = ns
271 | getRefs : (aTotal : Name) -> CaseTree vars -> NameMap Bool
272 | getRefs at = getNames (addRefs False at) empty
275 | addRefs : (aTotal : Name) -> NameMap Bool -> CaseTree vars -> NameMap Bool
276 | addRefs at ns = getNames (addRefs False at) ns
279 | getMetas : CaseTree vars -> NameMap Bool
280 | getMetas = getNames (addMetas False) empty
283 | mkTerm : (vars : Scope) -> Pat -> Term vars
284 | mkTerm vars (PAs fc x y) = mkTerm vars y
285 | mkTerm vars (PCon fc x tag arity xs)
286 | = apply fc (Ref fc (DataCon tag arity) x)
287 | (map (mkTerm vars) xs)
288 | mkTerm vars (PTyCon fc x arity xs)
289 | = apply fc (Ref fc (TyCon arity) x)
290 | (map (mkTerm vars) xs)
291 | mkTerm vars (PConst fc c) = PrimVal fc c
292 | mkTerm vars (PArrow fc x s t)
293 | = Bind fc x (Pi fc top Explicit (mkTerm vars s)) (mkTerm (x :: vars) t)
294 | mkTerm vars (PDelay fc r ty p)
295 | = TDelay fc r (mkTerm vars ty) (mkTerm vars p)
296 | mkTerm vars (PLoc fc n)
297 | = case isVar n vars of
298 | Just (MkVar prf) => Local fc Nothing _ prf
299 | _ => Ref fc Bound n
300 | mkTerm vars (PUnmatchable fc tm) = embed tm