0 | module Core.Termination.CallGraph
  1 |
  2 | import Core.Case.CaseTree
  3 | import Core.Context.Log
  4 | import Core.Env
  5 | import Core.Normalise
  6 | import Core.Options
  7 | import Core.Value
  8 |
  9 | import Libraries.Data.List.SizeOf
 10 | import Libraries.Data.SparseMatrix
 11 |
 12 | import Data.String
 13 |
 14 | %default covering
 15 |
 16 | data Guardedness = Toplevel | Unguarded | Guarded | InDelay
 17 |
 18 | Show Guardedness where
 19 |   show Toplevel = "Toplevel"
 20 |   show Unguarded = "Unguarded"
 21 |   show Guarded = "Guarded"
 22 |   show InDelay = "InDelay"
 23 |
 24 | sizeEq : {auto 0 cv : CompatibleVars rhsVars lhsVars} ->
 25 |          Term rhsVars -> -- RHS
 26 |          Term lhsVars -> -- LHS: may contain dot-patterns, try both sides of as patterns
 27 |          Bool
 28 | sizeEq (Local _ _ idx _) (Local _ _ idx' _) = idx == idx'
 29 | sizeEq (Ref _ _ n) (Ref _ _ n') = n == n'
 30 | sizeEq (Meta _ _ i args) (Meta _ _ i' args')
 31 |     = i == i' && assert_total (all (uncurry sizeEq) (zip args args'))
 32 | sizeEq (Bind _ _ b sc) (Bind _ _ b' sc') = eqBinderBy sizeEq b b' && sizeEq sc sc'
 33 | sizeEq (App _ f a) (App _ f' a') = sizeEq f f' && sizeEq a a'
 34 | sizeEq (As _ _ a p) p' = sizeEq p p'
 35 | sizeEq p (As _ _ a p') = sizeEq p a || sizeEq p p'
 36 | sizeEq (TDelayed _ _ t) (TDelayed _ _ t') = sizeEq t t'
 37 | sizeEq (TDelay _ _ t x) (TDelay _ _ t' x') = sizeEq t t' && sizeEq x x'
 38 | sizeEq (TForce _ _ t) (TForce _ _ t') = sizeEq t t'
 39 | sizeEq (PrimVal _ c) (PrimVal _ c') = c == c'
 40 | -- traverse dotted LHS terms
 41 | sizeEq t (Erased _ (Dotted t')) = eqTerm t t' -- t' is no longer a pattern
 42 | sizeEq (TType {}) (TType {}) = True
 43 | sizeEq _ _ = False
 44 |
 45 | -- Remove all force and delay annotations which are nothing to do with
 46 | -- coinduction meaning that all Delays left guard coinductive calls.
 47 | delazy : Defs -> Term vars -> Term vars
 48 | delazy defs (TDelayed fc r tm)
 49 |     = let tm' = delazy defs tm in
 50 |           case r of
 51 |                LInf => TDelayed fc r tm'
 52 |                _ => tm'
 53 | delazy defs (TDelay fc r ty tm)
 54 |     = let ty' = delazy defs ty
 55 |           tm' = delazy defs tm in
 56 |           case r of
 57 |                LInf => TDelay fc r ty' tm'
 58 |                _ => tm'
 59 | delazy defs (TForce fc r t)
 60 |     = case r of
 61 |            LInf => TForce fc r (delazy defs t)
 62 |            _ => delazy defs t
 63 | delazy defs (Meta fc n i args) = Meta fc n i (map (delazy defs) args)
 64 | delazy defs (Bind fc x b sc)
 65 |     = Bind fc x (map (delazy defs) b) (delazy defs sc)
 66 | delazy defs (App fc f a) = App fc (delazy defs f) (delazy defs a)
 67 | delazy defs (As fc s a p) = As fc s (delazy defs a) (delazy defs p)
 68 | delazy defs tm = tm
 69 |
 70 | mutual
 71 |   findSC : {vars : _} ->
 72 |            {auto c : Ref Ctxt Defs} ->
 73 |            Defs -> Env Term vars -> Guardedness ->
 74 |            List (Term vars) -> -- LHS args
 75 |            Term vars -> -- RHS
 76 |            Core (List SCCall)
 77 |   findSC {vars} defs env g pats (Bind fc n b sc)
 78 |        = pure $
 79 |             !(findSCbinder b) ++
 80 |             !(findSC defs (b :: env) g (map weaken pats) sc)
 81 |     where
 82 |       findSCbinder : Binder (Term vars) -> Core (List SCCall)
 83 |       findSCbinder (Let _ c val ty) = findSC defs env g pats val
 84 |       findSCbinder b = pure [] -- only types, no need to look
 85 |   -- If we're Guarded and find a Delay, continue with the argument as InDelay
 86 |   findSC defs env Guarded pats (TDelay _ _ _ tm)
 87 |       = findSC defs env InDelay pats tm
 88 |   findSC defs env g pats (TDelay _ _ _ tm)
 89 |       = findSC defs env g pats tm
 90 |   findSC defs env g pats (TForce _ _ tm)
 91 |       = findSC defs env Unguarded pats tm
 92 |   findSC defs env g pats tm
 93 |       = do let (fn, args) = getFnArgs tm
 94 |            False <- isAssertTotal fn
 95 |                | True => pure []
 96 |            -- if it's a 'case' or 'if' just go straight into the arguments
 97 |            Nothing <- handleCase fn args
 98 |                | Just res => pure res
 99 |            fn' <- conIfGuarded fn -- pretend it's a data constructor if
100 |                                   -- it has the AllGuarded flag
101 |            case (g, fn', args) of
102 |     -- If we're InDelay and find a constructor (or a function call which is
103 |     -- guaranteed to return a constructor; AllGuarded set), continue as InDelay
104 |              (InDelay, Ref fc (DataCon {}) cn, args) =>
105 |                  do scs <- traverse (findSC defs env InDelay pats) args
106 |                     pure (concat scs)
107 |              -- If we're InDelay otherwise, just check the arguments, the
108 |              -- function call is okay
109 |              (InDelay, _, args) =>
110 |                  do scs <- traverse (findSC defs env Unguarded pats) args
111 |                     pure (concat scs)
112 |              (Guarded, Ref fc (DataCon {}) cn, args) =>
113 |                     findSCcall defs env Guarded pats fc cn args
114 |              (Toplevel, Ref fc (DataCon {}) cn, args) =>
115 |                     findSCcall defs env Guarded pats fc cn args
116 |              (_, Ref fc Func fn, args) =>
117 |                  do logC "totality" 50 $
118 |                        pure $ "Looking up type of " ++ show !(toFullNames fn)
119 |                     findSCcall defs env Unguarded pats fc fn args
120 |              (_, f, args) =>
121 |                  do scs <- traverse (findSC defs env Unguarded pats) args
122 |                     pure (concat scs)
123 |       where
124 |         handleCase : Term vars -> List (Term vars) -> Core (Maybe (List SCCall))
125 |         handleCase (Ref fc nt n) args
126 |             = do n' <- toFullNames n
127 |                  if caseFn n'
128 |                     then Just <$> findSCcall defs env g pats fc n args
129 |                     else pure Nothing
130 |         handleCase _ _ = pure Nothing
131 |
132 |         isAssertTotal : Term vars -> Core Bool
133 |         isAssertTotal (Ref fc Func fn)
134 |             = pure $ !(toFullNames fn) == NS builtinNS (UN $ Basic "assert_total")
135 |         isAssertTotal tm = pure False
136 |
137 |         conIfGuarded : Term vars -> Core (Term vars)
138 |         conIfGuarded (Ref fc Func n)
139 |             = do defs <- get Ctxt
140 |                  Just gdef <- lookupCtxtExact n (gamma defs)
141 |                       | Nothing => pure $ Ref fc Func n
142 |                  if AllGuarded `elem` flags gdef
143 |                     then pure $ Ref fc (DataCon 0 0) n
144 |                     else pure $ Ref fc Func n
145 |         conIfGuarded tm = pure tm
146 |
147 |   knownOr : Core SizeChange -> Core SizeChange -> Core SizeChange
148 |   knownOr x y = case !x of Unknown => y_ => x
149 |
150 |   plusLazy : Core SizeChange -> Core SizeChange -> Core SizeChange
151 |   plusLazy x y = case !x of Smaller => pure Smallerx => pure $ x |+| !y
152 |
153 |   -- Return whether first argument is structurally smaller than the second.
154 |   sizeCompare : {auto defs : Defs} ->
155 |                 Nat -> -- backtracking fuel
156 |                 Term vars -> -- RHS: term we're checking
157 |                 Term vars -> -- LHS: argument it might be smaller than
158 |                 Core SizeChange
159 |
160 |   sizeCompareCon : {auto defs : Defs} -> Nat -> Term vars -> Term vars -> Core Bool
161 |   sizeCompareTyCon : {auto defs : Defs} -> Nat -> Term vars -> Term vars -> Core Bool
162 |   sizeCompareConArgs : {auto defs : Defs} -> Nat -> Term vars -> List (Term vars) -> Core Bool
163 |   sizeCompareApp : {auto defs : Defs} -> Nat -> Term vars -> Term vars -> Core SizeChange
164 |
165 |   sizeCompare fuel s (Erased _ (Dotted t)) = sizeCompare fuel s t
166 |   sizeCompare fuel _ (Erased {}) = pure Unknown -- incomparable!
167 |   -- for an as pattern, it's smaller if it's smaller than either part
168 |   sizeCompare fuel s (As _ _ p t)
169 |       = knownOr (sizeCompare fuel s p) (sizeCompare fuel s t)
170 |   sizeCompare fuel (As _ _ p s) t
171 |       = knownOr (sizeCompare fuel p t) (sizeCompare fuel s t)
172 |   -- if they're both metas, let sizeEq check if they're the same
173 |   sizeCompare fuel s@(Meta {}) t@(Meta {}) = pure (if sizeEq s t then Same else Unknown)
174 |   -- otherwise try to expand RHS meta
175 |   sizeCompare fuel s@(Meta n _ i args) t = do
176 |     Just gdef <- lookupCtxtExact (Resolved i) (gamma defs) | _ => pure Unknown
177 |     let (PMDef _ [] (STerm _ tm) _ _) = definition gdef | _ => pure Unknown
178 |     tm <- substMeta (embed tm) args zero Subst.empty
179 |     sizeCompare fuel tm t
180 |     where
181 |       substMeta : {0 drop, vs : _} ->
182 |                   Term (drop ++ vs) -> List (Term vs) ->
183 |                   SizeOf drop -> SubstEnv drop vs ->
184 |                   Core (Term vs)
185 |       substMeta (Bind bfc n (Lam _ c e ty) sc) (a :: as) drop env
186 |           = substMeta sc as (suc drop) (a :: env)
187 |       substMeta (Bind bfc n (Let _ c val ty) sc) as drop env
188 |           = substMeta (subst val sc) as drop env
189 |       substMeta rhs [] drop env = pure (substs drop env rhs)
190 |       substMeta rhs _ _ _ = throw (InternalError ("Badly formed metavar solution \{show n}"))
191 |
192 |   sizeCompare fuel s t
193 |      = if !(sizeCompareTyCon fuel s t) then pure Same
194 |        else if !(sizeCompareCon fuel s t)
195 |           then pure Smaller
196 |           else knownOr (sizeCompareApp fuel s t) (pure $ if sizeEq s t then Same else Unknown)
197 |
198 |   sizeCompareProdConArgs : {auto defs : Defs} -> Nat -> List (Term vars) -> List (Term vars) -> Core SizeChange
199 |   sizeCompareProdConArgs _ [] [] = pure Same
200 |   sizeCompareProdConArgs fuel (x :: xs) (y :: ys) =
201 |     case !(sizeCompare fuel x y) of
202 |       Unknown => pure Unknown
203 |       t => (t |*|) <$> sizeCompareProdConArgs fuel xs ys
204 |   sizeCompareProdConArgs _ _ _ = pure Unknown
205 |
206 |   sizeCompareTyCon fuel s t =
207 |     let (f, args) = getFnArgs t in
208 |     let (g, args') = getFnArgs s in
209 |     case f of
210 |       Ref _ (TyCon {}) cn => case g of
211 |         Ref _ (TyCon {}) cn' => if cn == cn'
212 |             then (Unknown /=) <$> sizeCompareProdConArgs fuel args' args
213 |             else pure False
214 |         _ => pure False
215 |       _ => pure False
216 |
217 |   sizeCompareCon fuel s t
218 |       = let (f, args) = getFnArgs t in
219 |         case f of
220 |              Ref _ (DataCon t a) cn =>
221 |                 -- if s is smaller or equal to an arg, then it is smaller than t
222 |                 if !(sizeCompareConArgs (minus fuel 1) s args) then pure True
223 |                 else let (g, args') = getFnArgs s in
224 |                     case (fuel, g) of
225 |                         (S k, Ref _ (DataCon t' a') cn') => do
226 |                                 -- if s is a matching DataCon, applied to same number of args,
227 |                                 -- no Unknown args, and at least one Smaller
228 |                                 if cn == cn' && length args == length args'
229 |                                   then (Smaller ==) <$> sizeCompareProdConArgs k args' args
230 |                                   else pure False
231 |                         _ => pure $ False
232 |              _ => pure False
233 |
234 |   sizeCompareConArgs _ s [] = pure False
235 |   sizeCompareConArgs fuel s (t :: ts)
236 |       = case !(sizeCompare fuel s t) of
237 |           Unknown => sizeCompareConArgs fuel s ts
238 |           _ => pure True
239 |
240 |   sizeCompareApp fuel (App _ f _) t = sizeCompare fuel f t
241 |   sizeCompareApp _ _ t = pure Unknown
242 |
243 |   sizeCompareAsserted : {auto defs : Defs} -> Nat -> Maybe (Term vars) -> Term vars -> Core SizeChange
244 |   sizeCompareAsserted fuel (Just s) t
245 |       = pure $ case !(sizeCompare fuel s t) of
246 |           Unknown => Unknown
247 |           _ => Smaller
248 |   sizeCompareAsserted _ Nothing _ = pure Unknown
249 |
250 |   -- if the argument is an 'assert_smaller', return the thing it's smaller than
251 |   asserted : Name -> Term vars -> Maybe (Term vars)
252 |   asserted aSmaller tm
253 |        = case getFnArgs tm of
254 |               (Ref _ nt fn, [_, _, b, _])
255 |                    => if fn == aSmaller
256 |                          then Just b
257 |                          else Nothing
258 |               _ => Nothing
259 |
260 |   -- Calculate the size change for the given argument.  i.e., return the
261 |   -- relative size of the given argument to each entry in 'pats'.
262 |   mkChange : Defs -> Name ->
263 |              (pats : List (Term vars)) ->
264 |              (arg : Term vars) ->
265 |              Core (List SizeChange)
266 |   mkChange defs aSmaller pats arg
267 |     = let fuel = defs.options.elabDirectives.totalLimit
268 |       in traverse (\p => plusLazy (sizeCompareAsserted fuel (asserted aSmaller arg) p) (sizeCompare fuel arg p)) pats
269 |
270 |   -- Given a name of a case function, and a list of the arguments being
271 |   -- passed to it, update the pattern list so that it's referring to the LHS
272 |   -- of the case block function and return the corresponding RHS.
273 |
274 |   -- This way, we can build case blocks directly into the size change graph
275 |   -- rather than treating the definitions separately.
276 |   getCasePats : {auto c : Ref Ctxt Defs} ->
277 |                 {vars : _} ->
278 |                 Defs -> Name -> List (Term vars) ->
279 |                 List (Term vars) ->
280 |                 Core (Maybe (List (vs ** (Env Term vs,
281 |                                          List (Term vs), Term vs))))
282 |
283 |   getCasePats {vars} defs n pats args
284 |       = do Just (PMDef _ _ _ _ pdefs) <- lookupDefExact n (gamma defs)
285 |              | _ => pure Nothing
286 |            log "totality" 20 $
287 |              unwords ["Looking at the", show (length pdefs), "cases of", show  n]
288 |            let pdefs' = map matchArgs pdefs
289 |            logC "totality" 20 $ do
290 |               old <- for pdefs $ \ (_ ** (_, lhs, rhs)) => do
291 |                        lhs <- toFullNames lhs
292 |                        rhs <- toFullNames rhs
293 |                        pure $ "    " ++ show lhs ++ " => " ++ show rhs
294 |               new <- for pdefs' $ \ (_ ** (_, lhs, rhs)) => do
295 |                        lhs <- traverse toFullNames lhs
296 |                        rhs <- toFullNames rhs
297 |                        pure $ "    " ++ show lhs ++ " => " ++ show rhs
298 |               pure $ unlines $ "Updated" :: old ++ "  to:" :: new
299 |            pure $ Just pdefs'
300 |
301 |     where
302 |       updateRHS : {vs, vs' : _} ->
303 |                   List (Term vs, Term vs') -> Term vs -> Term vs'
304 |       updateRHS {vs} {vs'} ms tm
305 |           = case lookupTm tm ms of
306 |                  Nothing => urhs tm
307 |                  Just t => t
308 |         where
309 |           urhs : Term vs -> Term vs'
310 |           urhs (Local fc _ _ _) = Erased fc Placeholder
311 |           urhs (Ref fc nt n) = Ref fc nt n
312 |           urhs (Meta fc m i margs) = Meta fc m i (map (updateRHS ms) margs)
313 |           urhs (App fc f a) = App fc (updateRHS ms f) (updateRHS ms a)
314 |           urhs (As fc s a p) = As fc s (updateRHS ms a) (updateRHS ms p)
315 |           urhs (TDelayed fc r ty) = TDelayed fc r (updateRHS ms ty)
316 |           urhs (TDelay fc r ty tm)
317 |               = TDelay fc r (updateRHS ms ty) (updateRHS ms tm)
318 |           urhs (TForce fc r tm) = TForce fc r (updateRHS ms tm)
319 |           urhs (Bind fc x b sc)
320 |               = Bind fc x (map (updateRHS ms) b)
321 |                   (updateRHS (map (\vt => (weaken (fst vt), weaken (snd vt))) ms) sc)
322 |           urhs (PrimVal fc c) = PrimVal fc c
323 |           urhs (Erased fc Impossible) = Erased fc Impossible
324 |           urhs (Erased fc Placeholder) = Erased fc Placeholder
325 |           urhs (Erased fc (Dotted t)) = Erased fc (Dotted (updateRHS ms t))
326 |           urhs (TType fc u) = TType fc u
327 |
328 |           lookupTm : Term vs -> List (Term vs, Term vs') -> Maybe (Term vs')
329 |           lookupTm tm [] = Nothing
330 |           lookupTm (As fc s p tm) tms -- Want to keep the pattern and the variable,
331 |                                       -- if there was an @ in the parent
332 |               = do tm' <- lookupTm tm tms
333 |                    Just $ As fc s tm' (urhs tm)
334 |           lookupTm tm ((As fc s p tm', v) :: tms)
335 |               = if tm == p
336 |                    then Just v
337 |                    else do tm' <- lookupTm tm ((tm', v) :: tms)
338 |                            Just $ As fc s (urhs p) tm'
339 |           lookupTm tm ((tm', v) :: tms)
340 |               = if tm == tm'
341 |                    then Just v
342 |                    else lookupTm tm tms
343 |
344 |       updatePat : {vs, vs' : _} ->
345 |                   List (Term vs, Term vs') -> Term vs -> Term vs'
346 |       updatePat ms tm = updateRHS ms tm
347 |
348 |       matchArgs : (vs ** (Env Term vs, Term vs, Term vs)->
349 |                   (vs ** (Env Term vs, List (Term vs), Term vs))
350 |       matchArgs (_ ** (env', lhs, rhs))
351 |          = let patMatch = reverse (zip args (getArgs lhs)) in
352 |                (_ ** (env', map (updatePat patMatch) pats, rhs))
353 |
354 |   findSCcall : {vars : _} ->
355 |                {auto c : Ref Ctxt Defs} ->
356 |                Defs -> Env Term vars -> Guardedness ->
357 |                List (Term vars) ->
358 |                FC -> Name -> List (Term vars) ->
359 |                Core (List SCCall)
360 |   findSCcall defs env g pats fc fn_in args
361 |         -- Under 'assert_total' we assume that all calls are fine, so leave
362 |         -- the size change list empty
363 |       = do fn <- getFullName fn_in
364 |            logC "totality.termination.sizechange" 10 $ do pure $ "Looking under " ++ show !(toFullNames fn)
365 |            aSmaller <- resolved (gamma defs) (NS builtinNS (UN $ Basic "assert_smaller"))
366 |            if caseFn fn
367 |               then do scs1 <- traverse (findSC defs env g pats) args
368 |                       mps  <- getCasePats defs fn pats args
369 |                       scs2 <- traverse (findInCase defs g) $ fromMaybe [] mps
370 |                       pure (concat (scs1 ++ scs2))
371 |               else do scs <- traverse (findSC defs env g pats) args
372 |                       pure $ [MkSCCall fn
373 |                                (fromListList
374 |                                     !(traverse (mkChange defs aSmaller pats) args))
375 |                                fc]
376 |                                ++ concat scs
377 |
378 |   findInCase : {auto c : Ref Ctxt Defs} ->
379 |                Defs -> Guardedness ->
380 |                (vs ** (Env Term vs, List (Term vs), Term vs)->
381 |                Core (List SCCall)
382 |   findInCase defs g (_ ** (env, pats, tm))
383 |      = do logC "totality" 10 $
384 |                    do ps <- traverse toFullNames pats
385 |                       pure ("Looking in case args " ++ show ps)
386 |           logTermNF "totality" 10 "        =" env tm
387 |           rhs <- normaliseOpts tcOnly defs env tm
388 |           findSC defs env g pats (delazy defs rhs)
389 |
390 | findCalls : {auto c : Ref Ctxt Defs} ->
391 |             Defs -> (vars ** (Env Term vars, Term vars, Term vars)->
392 |             Core (List SCCall)
393 | findCalls defs (_ ** (env, lhs, rhs_in))
394 |    = do let pargs = getArgs (delazy defs lhs)
395 |         rhs <- normaliseOpts tcOnly defs env rhs_in
396 |         findSC defs env Toplevel pargs (delazy defs rhs)
397 |
398 | getSC : {auto c : Ref Ctxt Defs} ->
399 |         Defs -> Def -> Core (List SCCall)
400 | getSC defs (PMDef _ args _ _ pats)
401 |    = do sc <- traverse (findCalls defs) pats
402 |         pure $ nub (concat sc)
403 | getSC defs _ = pure []
404 |
405 | export
406 | calculateSizeChange : {auto c : Ref Ctxt Defs} ->
407 |                       FC -> Name -> Core (List SCCall)
408 | calculateSizeChange loc n
409 |     = do logC "totality.termination.sizechange" 5 $ do pure $ "Calculating Size Change: " ++ show !(toFullNames n)
410 |          defs <- get Ctxt
411 |          Just def <- lookupCtxtExact n (gamma defs)
412 |               | Nothing => undefinedName loc n
413 |          getSC defs (definition def)
414 |