0 | module TTImp.Elab.Utils
  1 |
  2 | import Core.Case.CaseTree
  3 | import Core.Context
  4 | import Core.Env
  5 | import Core.Normalise
  6 | import Core.Value
  7 |
  8 | import TTImp.Elab.Check
  9 | import TTImp.TTImp
 10 |
 11 | import Libraries.Data.NatSet
 12 | import Libraries.Data.VarSet
 13 |
 14 | import Libraries.Data.List.SizeOf
 15 |
 16 | %default covering
 17 |
 18 | detagSafe : {auto c : Ref Ctxt Defs} ->
 19 |             Defs -> ClosedNF -> Core Bool
 20 | detagSafe defs (NTCon _ n _ args)
 21 |     = do Just (TCon _ _ _ _ _ _ (Just detags)) <- lookupDefExact n (gamma defs)
 22 |               | _ => pure False
 23 |          args' <- traverse (evalClosure defs . snd) args
 24 |          pure $ NatSet.isEmpty detags || notErased 0 detags args'
 25 |   where
 26 |     -- if any argument positions are in the non-empty(!) detaggable set, and unerased, then
 27 |     -- detagging is safe
 28 |     notErased : Nat -> NatSet -> List ClosedNF -> Bool
 29 |     notErased i ns [] = False
 30 |     notErased i ns (NErased _ Impossible :: rest)
 31 |         = notErased (i + 1) ns rest -- Can't detag here, look elsewhere
 32 |     notErased i ns (_ :: rest) -- Safe to detag via this argument
 33 |         = elem i ns || notErased (i + 1) ns rest
 34 | detagSafe defs _ = pure False
 35 |
 36 | findErasedFrom : {auto c : Ref Ctxt Defs} ->
 37 |                  Defs -> Nat -> ClosedNF -> Core (NatSet, NatSet)
 38 | findErasedFrom defs pos (NBind fc x (Pi _ c _ aty) scf)
 39 |     = do -- In the scope, use 'Erased fc Impossible' to mean 'argument is erased'.
 40 |          -- It's handy here, because we can use it to tell if a detaggable
 41 |          -- argument position is available
 42 |          sc <- scf defs (toClosure defaultOpts Env.empty (Erased fc (ifThenElse (isErased c) Impossible Placeholder)))
 43 |          (erest, dtrest) <- findErasedFrom defs (1 + pos) sc
 44 |          let dt' = if !(detagSafe defs !(evalClosure defs aty))
 45 |                       then (insert pos dtrest) else dtrest
 46 |          pure $ if isErased c
 47 |                    then (insert pos erest, dt')
 48 |                    else (erest, dt')
 49 | findErasedFrom defs pos tm = pure (NatSet.empty, NatSet.empty)
 50 |
 51 | -- Find the argument positions in the given type which are guaranteed to be
 52 | -- erasable
 53 | export
 54 | findErased : {auto c : Ref Ctxt Defs} ->
 55 |              ClosedTerm -> Core (NatSet, NatSet)
 56 | findErased tm
 57 |     = do defs <- get Ctxt
 58 |          tmnf <- nf defs Env.empty tm
 59 |          findErasedFrom defs 0 tmnf
 60 |
 61 | export
 62 | updateErasable : {auto c : Ref Ctxt Defs} ->
 63 |                  Name -> Core ()
 64 | updateErasable n
 65 |     = do defs <- get Ctxt
 66 |          Just gdef <- lookupCtxtExact n (gamma defs)
 67 |               | Nothing => pure ()
 68 |          (es, dtes) <- findErased (type gdef)
 69 |          ignore $ addDef n $
 70 |                     { eraseArgs := es,
 71 |                       safeErase := dtes } gdef
 72 |
 73 | export
 74 | wrapErrorC : List ElabOpt -> (Error -> Error) -> Core a -> Core a
 75 | wrapErrorC opts err
 76 |     = if InCase `elem` opts
 77 |          then id
 78 |          else wrapError err
 79 |
 80 | plicit : Binder (Term vars) -> PiInfo RawImp
 81 | plicit (Pi _ _ p _) = forgetDef p
 82 | plicit (PVar _ _ p _) = forgetDef p
 83 | plicit _ = Explicit
 84 |
 85 | export
 86 | bindNotReq : {vs : _} ->
 87 |              FC -> Int -> Env Term vs -> (sub : Thin pre vs) ->
 88 |              List (PiInfo RawImp, Name) ->
 89 |              Term vs -> (List (PiInfo RawImp, Name), Term pre)
 90 | bindNotReq fc i [] Refl ns tm = (ns, embed tm)
 91 | bindNotReq fc i (b :: env) Refl ns tm
 92 |    = let tmptm = subst (Ref fc Bound (MN "arg" i)) tm
 93 |          (ns', btm) = bindNotReq fc (1 + i) env Refl ns tmptm in
 94 |          (ns', refToLocal (MN "arg" i) _ btm)
 95 | bindNotReq fc i (b :: env) (Keep p) ns tm
 96 |    = let tmptm = subst (Ref fc Bound (MN "arg" i)) tm
 97 |          (ns', btm) = bindNotReq fc (1 + i) env p ns tmptm in
 98 |          (ns', refToLocal (MN "arg" i) _ btm)
 99 | bindNotReq {vs = n :: _} fc i (b :: env) (Drop p) ns tm
100 |    = bindNotReq fc i env p ((plicit b, n) :: ns)
101 |        (Bind fc _ (Pi (binderLoc b) (multiplicity b) Explicit (binderType b)) tm)
102 |
103 | export
104 | bindReq : {vs : _} ->
105 |           FC -> Env Term vs -> (sub : Thin pre vs) ->
106 |           List (PiInfo RawImp, Name) ->
107 |           Term pre -> Maybe (List (PiInfo RawImp, Name), List Name, ClosedTerm)
108 | bindReq {vs} fc env Refl ns tm
109 |     = pure (ns, notLets [] _ env, abstractEnvType fc env tm)
110 |   where
111 |     notLets : List Name -> (vars : Scope) -> Env Term vars -> List Name
112 |     notLets acc [] _ = acc
113 |     notLets acc (v :: vs) (b :: env) = if isLet b then notLets acc vs env
114 |                                        else notLets (v :: acc) vs env
115 | bindReq {vs = n :: _} fc (b :: env) (Keep p) ns tm
116 |     = do b' <- shrinkBinder b p
117 |          bindReq fc env p ((plicit b, n) :: ns)
118 |             (Bind fc _ (Pi (binderLoc b) (multiplicity b) Explicit (binderType b')) tm)
119 | bindReq fc (b :: env) (Drop p) ns tm
120 |     = bindReq fc env p ns tm
121 |
122 | -- This machinery is to calculate whether any top level argument is used
123 | -- more than once in a case block, in which case inlining wouldn't be safe
124 | -- since it might duplicate work.
125 |
126 | data ArgUsed = Used1 -- been used
127 |              | Used0 -- not used
128 |              | LocalVar -- don't care if it's used
129 |
130 | record Usage (vs : Scope) where
131 |   constructor MkUsage
132 |   isUsedSet : VarSet vs -- whether it's been used
133 |   isLocalSet : VarSet vs -- don't care if it's used
134 |
135 | initUsed : Usage vs
136 | initUsed = MkUsage
137 |   { isUsedSet = VarSet.empty
138 |   , isLocalSet = VarSet.empty
139 |   }
140 |
141 | initUsedCase : SizeOf vs -> Usage vs
142 | initUsedCase p = MkUsage
143 |   { isUsedSet = VarSet.empty
144 |   , isLocalSet = maybe id VarSet.delete (last p) (VarSet.full p)
145 |   }
146 |
147 | setUsedVar : Var vs -> Usage vs -> Usage vs
148 | setUsedVar v us@(MkUsage isUsedSet isLocalSet)
149 |   = -- if we don't care then we don't change anything
150 |     if v `VarSet.elem` isLocalSet then us
151 |     -- otherwise we record the variable usage
152 |     else MkUsage { isUsedSet = VarSet.insert v isUsedSet
153 |                  , isLocalSet }
154 |
155 | isUsed : Var vs -> Usage vs -> Bool
156 | isUsed v us = v `VarSet.elem` isUsedSet us
157 |
158 | data Used : Type where
159 |
160 | setUsed : {auto u : Ref Used (Usage vars)} ->
161 |           Var vars -> Core ()
162 | setUsed p = update Used $ setUsedVar p
163 |
164 | extendUsed : ArgUsed -> SizeOf inner -> Usage vars -> Usage (inner ++ vars)
165 | extendUsed LocalVar p (MkUsage iu il)
166 |   = MkUsage (weakenNs {tm = VarSet} p iu) (append p (full p) il)
167 | extendUsed Used0 p (MkUsage iu il)
168 |   = MkUsage (weakenNs {tm = VarSet} p iu) (weakenNs {tm = VarSet} p il)
169 | extendUsed Used1 p (MkUsage iu il)
170 |   = MkUsage (append p (full p) iu) (weakenNs {tm = VarSet} p il)
171 |
172 | dropUsed : SizeOf inner -> Usage (inner ++ vars) -> Usage vars
173 | dropUsed p (MkUsage iu il) = MkUsage (VarSet.dropInner p iu) (dropInner p il)
174 |
175 | inExtended : ArgUsed -> SizeOf new ->
176 |              {auto u : Ref Used (Usage vars)} ->
177 |              (Ref Used (Usage (new ++ vars)) -> Core a) ->
178 |              Core a
179 | inExtended a new sc
180 |     = do used <- get Used
181 |          u' <- newRef Used (extendUsed a new used)
182 |          res <- sc u'
183 |          put Used (dropUsed new !(get Used @{u'}))
184 |          pure res
185 |
186 | 0 InlineSafe : Scoped -> Type
187 | InlineSafe tm
188 |   = {0 vars : Scope} -> {auto u : Ref Used (Usage vars)} ->
189 |     tm vars -> Core Bool
190 |
191 | termsInlineSafe : InlineSafe (List . Term)
192 |
193 | termInlineSafe : InlineSafe Term
194 | termInlineSafe (Local fc isLet idx p)
195 |    = let v := MkVar p in
196 |      if isUsed v !(get Used)
197 |         then pure False
198 |          else do setUsed v
199 |                  pure True
200 | termInlineSafe (Meta fc x y xs)
201 |     = termsInlineSafe xs
202 | termInlineSafe (Bind fc x b scope)
203 |    = do bok <- binderInlineSafe b
204 |         if bok
205 |            then inExtended LocalVar (suc zero) (\u' => termInlineSafe scope)
206 |            else pure False
207 |   where
208 |     binderInlineSafe : Binder (Term vars) -> Core Bool
209 |     binderInlineSafe (Let _ _ val _) = termInlineSafe val
210 |     binderInlineSafe _ = pure True
211 | termInlineSafe (App fc fn arg)
212 |     = do fok <- termInlineSafe fn
213 |          if fok
214 |             then termInlineSafe arg
215 |             else pure False
216 | termInlineSafe (As fc x as pat) = termInlineSafe pat
217 | termInlineSafe (TDelayed fc x ty) = termInlineSafe ty
218 | termInlineSafe (TDelay fc x ty arg) = termInlineSafe arg
219 | termInlineSafe (TForce fc x val) = termInlineSafe val
220 | termInlineSafe _ = pure True
221 |
222 | termsInlineSafe [] = pure True
223 | termsInlineSafe (x :: xs)
224 |     = do xok <- termInlineSafe x
225 |          if xok
226 |             then termsInlineSafe xs
227 |             else pure False
228 |
229 | mutual
230 |   caseInlineSafe : InlineSafe CaseTree
231 |   caseInlineSafe (Case idx p scTy xs)
232 |       = let v := MkVar p in
233 |         if isUsed v !(get Used)
234 |            then pure False
235 |            else do setUsed v
236 |                    caseAltsInlineSafe xs
237 |   caseInlineSafe (STerm x tm) = termInlineSafe tm
238 |   caseInlineSafe (Unmatched msg) = pure True
239 |   caseInlineSafe Impossible = pure True
240 |
241 |   caseAltsInlineSafe : InlineSafe (List . CaseAlt)
242 |   caseAltsInlineSafe [] = pure True
243 |   caseAltsInlineSafe (a :: as)
244 |       = do u <- get Used
245 |            True <- caseAltInlineSafe a
246 |              | False => pure False
247 |            -- We can reset the usage information, because we're
248 |            -- only going to use one alternative at a time
249 |            put Used u
250 |            caseAltsInlineSafe as
251 |
252 |   caseAltInlineSafe : InlineSafe CaseAlt
253 |   caseAltInlineSafe (ConCase x tag args sc)
254 |       -- should these be local vars?
255 |       = inExtended Used0 (mkSizeOf args) (\u' => caseInlineSafe sc)
256 |   caseAltInlineSafe (DelayCase ty arg sc)
257 |       -- should these be local vars?
258 |       = inExtended Used0 (mkSizeOf [ty, arg]) (\u' => caseInlineSafe sc)
259 |   caseAltInlineSafe (ConstCase x sc) = caseInlineSafe sc
260 |   caseAltInlineSafe (DefaultCase sc) = caseInlineSafe sc
261 |
262 | -- An inlining is safe if no variable is used more than once in the tree,
263 | -- which means that there's no risk of an input being evaluated more than
264 | -- once after the definition is expanded.
265 | export
266 | inlineSafe : CaseTree vars -> Core Bool
267 | inlineSafe t
268 |     = do u <- newRef Used initUsed
269 |          caseInlineSafe t
270 |
271 | export
272 | canInlineDef : {auto c : Ref Ctxt Defs} ->
273 |                Name -> Core Bool
274 | canInlineDef n
275 |     = do defs <- get Ctxt
276 |          Just (PMDef _ _ _ rtree _) <- lookupDefExact n (gamma defs)
277 |              | _ => pure False
278 |          inlineSafe rtree
279 |
280 | -- This is a special case because the only argument we actually care about
281 | -- is the last one, since the others are just variables passed through from
282 | -- the environment, and duplicating a variable doesn't cost anything.
283 | export
284 | canInlineCaseBlock : {auto c : Ref Ctxt Defs} ->
285 |                      Name -> Core Bool
286 | canInlineCaseBlock n
287 |     = do defs <- get Ctxt
288 |          Just (PMDef _ vars _ rtree _) <- lookupDefExact n (gamma defs)
289 |              | _ => pure False
290 |          u <- newRef Used (initUsedCase (mkSizeOf vars))
291 |          caseInlineSafe rtree
292 |