0 | module Core.Case.CaseTree
  1 |
  2 | import Core.TT
  3 |
  4 | import Data.List
  5 | import Data.So
  6 | import Data.String
  7 | import Idris.Pretty.Annotations
  8 |
  9 | import Libraries.Data.NameMap
 10 | import Libraries.Text.PrettyPrint.Prettyprinter
 11 | import Libraries.Data.List.SizeOf
 12 |
 13 | %default covering
 14 |
 15 | mutual
 16 |   ||| Case trees in A-normal forms
 17 |   ||| i.e. we may only dispatch on variables, not expressions
 18 |   public export
 19 |   data CaseTree : Scoped where
 20 |        ||| case x return scTy of { p1 => e1 ; ... }
 21 |        Case : {name : _} ->
 22 |               (idx : Nat) ->
 23 |               (0 p : IsVar name idx vars) ->
 24 |               (scTy : Term vars) -> List (CaseAlt vars) ->
 25 |               CaseTree vars
 26 |        ||| RHS: no need for further inspection
 27 |        ||| The Int is a clause id that allows us to see which of the
 28 |        ||| initial clauses are reached in the tree
 29 |        STerm : Int -> Term vars -> CaseTree vars
 30 |        ||| error from a partial match
 31 |        Unmatched : (msg : String) -> CaseTree vars
 32 |        ||| Absurd context
 33 |        Impossible : CaseTree vars
 34 |
 35 |   ||| Case alternatives. Unlike arbitrary patterns, they can be at most
 36 |   ||| one constructor deep.
 37 |   public export
 38 |   data CaseAlt : Scoped where
 39 |        ||| Constructor for a data type; bind the arguments and subterms.
 40 |        ConCase : Name -> (tag : Int) -> (args : List Name) ->
 41 |                  CaseTree (Scope.addInner vars args) -> CaseAlt vars
 42 |        ||| Lazy match for the Delay type use for codata types
 43 |        DelayCase : (ty : Name) -> (arg : Name) ->
 44 |                    CaseTree (Scope.addInner vars [ty, arg]) -> CaseAlt vars
 45 |                    -- TODO `arg` and `ty` should be swapped, as in Yaffle
 46 |        ||| Match against a literal
 47 |        ConstCase : Constant -> CaseTree vars -> CaseAlt vars
 48 |        ||| Catch-all case
 49 |        DefaultCase : CaseTree vars -> CaseAlt vars
 50 |
 51 | export
 52 | FreelyEmbeddable CaseTree where
 53 |
 54 | mutual
 55 |   public export
 56 |   measure : CaseTree vars -> Nat
 57 |   measure (Case idx p scTy xs) = sum $ measureAlts <$> xs
 58 |   measure (STerm x y) = 0
 59 |   measure (Unmatched msg) = 0
 60 |   measure Impossible = 0
 61 |
 62 |   measureAlts : CaseAlt vars -> Nat
 63 |   measureAlts (ConCase x tag args y) = 1 + (measure y)
 64 |   measureAlts (DelayCase ty arg x) = 1 + (measure x)
 65 |   measureAlts (ConstCase x y) = 1 + (measure y)
 66 |   measureAlts (DefaultCase x) = 1 + (measure x)
 67 |
 68 | export
 69 | isDefault : CaseAlt vars -> Bool
 70 | isDefault (DefaultCase _) = True
 71 | isDefault _ = False
 72 |
 73 | mutual
 74 |   export
 75 |   StripNamespace (CaseTree vars) where
 76 |     trimNS ns (Case idx p scTy xs)
 77 |         = Case idx p (trimNS ns scTy) (map (trimNS ns) xs)
 78 |     trimNS ns (STerm x t) = STerm x (trimNS ns t)
 79 |     trimNS ns c = c
 80 |
 81 |     restoreNS ns (Case idx p scTy xs)
 82 |         = Case idx p (restoreNS ns scTy) (map (restoreNS ns) xs)
 83 |     restoreNS ns (STerm x t) = STerm x (restoreNS ns t)
 84 |     restoreNS ns c = c
 85 |
 86 |   export
 87 |   StripNamespace (CaseAlt vars) where
 88 |     trimNS ns (ConCase x tag args t) = ConCase x tag args (trimNS ns t)
 89 |     trimNS ns (DelayCase ty arg t) = DelayCase ty arg (trimNS ns t)
 90 |     trimNS ns (ConstCase x t) = ConstCase x (trimNS ns t)
 91 |     trimNS ns (DefaultCase t) = DefaultCase (trimNS ns t)
 92 |
 93 |     restoreNS ns (ConCase x tag args t) = ConCase x tag args (restoreNS ns t)
 94 |     restoreNS ns (DelayCase ty arg t) = DelayCase ty arg (restoreNS ns t)
 95 |     restoreNS ns (ConstCase x t) = ConstCase x (restoreNS ns t)
 96 |     restoreNS ns (DefaultCase t) = DefaultCase (restoreNS ns t)
 97 |
 98 |
 99 | public export
100 | data Pat : Type where
101 |      PAs : FC -> Name -> Pat -> Pat
102 |      PCon : FC -> Name -> (tag : Int) -> (arity : Nat) ->
103 |             List Pat -> Pat
104 |      PTyCon : FC -> Name -> (arity : Nat) -> List Pat -> Pat
105 |      PConst : FC -> (c : Constant) -> Pat
106 |      PArrow : FC -> (x : Name) -> Pat -> Pat -> Pat
107 |      PDelay : FC -> LazyReason -> Pat -> Pat -> Pat
108 |      -- TODO: Matching on lazy types
109 |      PLoc : FC -> Name -> Pat
110 |      PUnmatchable : FC -> ClosedTerm -> Pat
111 |
112 | export
113 | isPConst : Pat -> Maybe Constant
114 | isPConst (PConst _ c) = Just c
115 | isPConst _ = Nothing
116 |
117 | public export
118 | 0 isConPat : Pat -> Bool
119 | isConPat (PAs _ _ p) = isConPat p
120 | isConPat (PCon {}) = True
121 | isConPat (PTyCon {}) = True
122 | isConPat (PConst {}) = True
123 | isConPat (PArrow {}) = True
124 | isConPat (PDelay {}) = True
125 | isConPat _ = False
126 |
127 | public export
128 | 0 IsConPat : Pat -> Type
129 | IsConPat = So . isConPat
130 |
131 | showCT : {vars : _} -> (indent : String) -> CaseTree vars -> String
132 | showCA : {vars : _} -> (indent : String) -> CaseAlt vars  -> String
133 |
134 | showCT indent (Case {name} idx prf ty alts)
135 |   = "case " ++ show name ++ "[" ++ show idx ++ "] : " ++ show ty ++ " of"
136 |   ++ "\n" ++ indent ++ " { "
137 |   ++ showSep ("\n" ++ indent ++ " | ")
138 |              (assert_total (map (showCA ("  " ++ indent)) alts))
139 |   ++ "\n" ++ indent ++ " }"
140 | showCT indent (STerm i tm) = "[" ++ show i ++ "] " ++ show tm
141 | showCT indent (Unmatched msg) = "Error: " ++ show msg
142 | showCT indent Impossible = "Impossible"
143 |
144 | showCA indent (ConCase n tag args sc)
145 |         = showSep " " (map show (n :: args)) ++ " => " ++
146 |           showCT indent sc
147 | showCA indent (DelayCase _ arg sc)
148 |         = "Delay " ++ show arg ++ " => " ++ showCT indent sc
149 | showCA indent (ConstCase c sc)
150 |         = "Constant " ++ show c ++ " => " ++ showCT indent sc
151 | showCA indent (DefaultCase sc)
152 |         = "_ => " ++ showCT indent sc
153 |
154 | export
155 | covering
156 | {vars : _} -> Show (CaseTree vars) where
157 |   show = showCT ""
158 |
159 | export
160 | covering
161 | {vars : _} -> Show (CaseAlt vars) where
162 |   show = showCA ""
163 |
164 | mutual
165 |   export
166 |   eqTree : CaseTree vs -> CaseTree vs' -> Bool
167 |   eqTree (Case i _ _ alts) (Case i' _ _ alts')
168 |       = i == i'
169 |        && length alts == length alts'
170 |        && all (uncurry eqAlt) (zip alts alts')
171 |   eqTree (STerm _ t) (STerm _ t') = eqTerm t t'
172 |   eqTree (Unmatched _) (Unmatched _) = True
173 |   eqTree Impossible Impossible = True
174 |   eqTree _ _ = False
175 |
176 |   eqAlt : CaseAlt vs -> CaseAlt vs' -> Bool
177 |   eqAlt (ConCase n t args tree) (ConCase n' t' args' tree')
178 |       = n == n' && eqTree tree tree' -- assume arities match, since name does
179 |   eqAlt (DelayCase _ _ tree) (DelayCase _ _ tree')
180 |       = eqTree tree tree'
181 |   eqAlt (ConstCase c tree) (ConstCase c' tree')
182 |       = c == c' && eqTree tree tree'
183 |   eqAlt (DefaultCase tree) (DefaultCase tree')
184 |       = eqTree tree tree'
185 |   eqAlt _ _ = False
186 |
187 | export
188 | covering
189 | Show Pat where
190 |   show (PAs _ n p) = show n ++ "@(" ++ show p ++ ")"
191 |   show (PCon _ n i _ args) = show n ++ " " ++ show i ++ " " ++ assert_total (show args)
192 |   show (PTyCon _ n _ args) = "<TyCon>" ++ show n ++ " " ++ assert_total (show args)
193 |   show (PConst _ c) = show c
194 |   show (PArrow _ x s t) = "(" ++ show s ++ " -> " ++ show t ++ ")"
195 |   show (PDelay _ _ _ p) = "(Delay " ++ show p ++ ")"
196 |   show (PLoc _ n) = show n
197 |   show (PUnmatchable _ tm) = ".(" ++ show tm ++ ")"
198 |
199 | export
200 | Pretty IdrisSyntax Pat where
201 |   prettyPrec d (PAs _ n p) = pretty0 n <++> keyword "@" <+> parens (pretty p)
202 |   prettyPrec d (PCon _ n _ _ args) =
203 |     parenthesise (d > Open) $ hsep (pretty0 n :: map (prettyPrec App) args)
204 |   prettyPrec d (PTyCon _ n _ args) =
205 |     parenthesise (d > Open) $ hsep (pretty0 n :: map (prettyPrec App) args)
206 |   prettyPrec d (PConst _ c) = pretty c
207 |   prettyPrec d (PArrow _ _ p q) =
208 |     parenthesise (d > Open) $ pretty p <++> arrow <++> pretty q
209 |   prettyPrec d (PDelay _ _ _ p) = parens ("Delay" <++> pretty p)
210 |   prettyPrec d (PLoc _ n) = pretty0 n
211 |   prettyPrec d (PUnmatchable _ tm) = keyword "." <+> parens (byShow tm)
212 |
213 | mutual
214 |   insertCaseNames : SizeOf outer ->
215 |                     SizeOf ns ->
216 |                     CaseTree (outer ++ inner) ->
217 |                     CaseTree (outer ++ (ns ++ inner))
218 |   insertCaseNames outer ns (Case idx prf scTy alts)
219 |       = let MkNVar prf' = insertNVarNames outer ns (MkNVar prf) in
220 |             Case _ prf' (insertNames outer ns scTy)
221 |                 (map (insertCaseAltNames outer ns) alts)
222 |   insertCaseNames outer ns (STerm i x) = STerm i (insertNames outer ns x)
223 |   insertCaseNames _ _ (Unmatched msg) = Unmatched msg
224 |   insertCaseNames _ _ Impossible = Impossible
225 |
226 |   insertCaseAltNames : SizeOf outer ->
227 |                        SizeOf ns ->
228 |                        CaseAlt (outer ++ inner) ->
229 |                        CaseAlt (outer ++ (ns ++ inner))
230 |   insertCaseAltNames p q (ConCase x tag args ct)
231 |       = ConCase x tag args
232 |            (rewrite appendAssociative args outer (ns ++ inner) in
233 |                     insertCaseNames (mkSizeOf args + p) q {inner}
234 |                         (rewrite sym (appendAssociative args outer inner) in
235 |                                  ct))
236 |   insertCaseAltNames outer ns (DelayCase tyn valn ct)
237 |       = DelayCase tyn valn
238 |                   (insertCaseNames (suc (suc outer)) ns ct)
239 |   insertCaseAltNames outer ns (ConstCase x ct)
240 |       = ConstCase x (insertCaseNames outer ns ct)
241 |   insertCaseAltNames outer ns (DefaultCase ct)
242 |       = DefaultCase (insertCaseNames outer ns ct)
243 |
244 | export
245 | Weaken CaseTree where
246 |   weakenNs ns t = insertCaseNames zero ns t
247 |
248 | total
249 | getNames : (forall vs . NameMap Bool -> Term vs -> NameMap Bool) ->
250 |            NameMap Bool -> CaseTree vars -> NameMap Bool
251 | getNames add ns sc = getSet ns sc
252 |   where
253 |     mutual
254 |       getAltSet : NameMap Bool -> CaseAlt vs -> NameMap Bool
255 |       getAltSet ns (ConCase n t args sc) = getSet ns sc
256 |       getAltSet ns (DelayCase t a sc) = getSet ns sc
257 |       getAltSet ns (ConstCase i sc) = getSet ns sc
258 |       getAltSet ns (DefaultCase sc) = getSet ns sc
259 |
260 |       getAltSets : NameMap Bool -> List (CaseAlt vs) -> NameMap Bool
261 |       getAltSets ns [] = ns
262 |       getAltSets ns (a :: as) = getAltSets (getAltSet ns a) as
263 |
264 |       getSet : NameMap Bool -> CaseTree vs -> NameMap Bool
265 |       getSet ns (Case _ x ty xs) = getAltSets ns xs
266 |       getSet ns (STerm i tm) = add ns tm
267 |       getSet ns (Unmatched msg) = ns
268 |       getSet ns Impossible = ns
269 |
270 | export
271 | getRefs : (aTotal : Name) -> CaseTree vars -> NameMap Bool
272 | getRefs at = getNames (addRefs False at) empty
273 |
274 | export
275 | addRefs : (aTotal : Name) -> NameMap Bool -> CaseTree vars -> NameMap Bool
276 | addRefs at ns = getNames (addRefs False at) ns
277 |
278 | export
279 | getMetas : CaseTree vars -> NameMap Bool
280 | getMetas = getNames (addMetas False) empty
281 |
282 | export
283 | mkTerm : (vars : Scope) -> Pat -> Term vars
284 | mkTerm vars (PAs fc x y) = mkTerm vars y
285 | mkTerm vars (PCon fc x tag arity xs)
286 |     = apply fc (Ref fc (DataCon tag arity) x)
287 |                (map (mkTerm vars) xs)
288 | mkTerm vars (PTyCon fc x arity xs)
289 |     = apply fc (Ref fc (TyCon arity) x)
290 |                (map (mkTerm vars) xs)
291 | mkTerm vars (PConst fc c) = PrimVal fc c
292 | mkTerm vars (PArrow fc x s t)
293 |     = Bind fc x (Pi fc top Explicit (mkTerm vars s)) (mkTerm (x :: vars) t)
294 | mkTerm vars (PDelay fc r ty p)
295 |     = TDelay fc r (mkTerm vars ty) (mkTerm vars p)
296 | mkTerm vars (PLoc fc n)
297 |     = case isVar n vars of
298 |            Just (MkVar prf) => Local fc Nothing _ prf
299 |            _ => Ref fc Bound n
300 | mkTerm vars (PUnmatchable fc tm) = embed tm
301 |