0 | module Compiler.Opts.ToplevelConstants
  1 |
  2 | import Core.CompileExpr
  3 | import Core.Context
  4 |
  5 | import Data.List1
  6 | import Data.Vect
  7 | import Data.SortedSet
  8 | import Data.SortedMap
  9 | import Libraries.Data.Graph
 10 |
 11 | --------------------------------------------------------------------------------
 12 | --          Call Graph
 13 | --------------------------------------------------------------------------------
 14 |
 15 | -- direct calls from a top-level function's expression to other
 16 | -- top-level functions.
 17 | 0 CallGraph : Type
 18 | CallGraph = SortedMap Name (SortedSet Name)
 19 |
 20 | -- top-level functions called by an expression
 21 | calls : NamedCExp -> SortedSet Name
 22 | calls (NmLocal fc p) = empty
 23 | calls (NmRef fc n1) = singleton n1
 24 | calls (NmLam fc x y) = calls y
 25 | calls (NmLet fc x z w) = calls w <+> calls z
 26 | calls (NmApp fc x xs) = calls x <+> concatMap calls xs
 27 | calls (NmCon fc n1 x tag xs) = concatMap calls xs
 28 | calls (NmOp fc f xs) = concatMap calls xs
 29 | calls (NmExtPrim fc p xs) = concatMap calls xs
 30 | calls (NmForce fc lz x) = calls x
 31 | calls (NmDelay fc lz x) = calls x
 32 | calls (NmConCase fc sc xs x) =
 33 |   calls sc <+>
 34 |   concatMap (\(MkNConAlt _ _ _ _ y) => calls y) xs <+>
 35 |   concatMap calls x
 36 | calls (NmConstCase fc sc xs x) =
 37 |   calls sc <+>
 38 |   concatMap (\(MkNConstAlt _ y) => calls y) xs <+>
 39 |   concatMap calls x
 40 | calls (NmPrimVal fc cst) = empty
 41 | calls (NmErased fc) = empty
 42 | calls (NmCrash fc str) = empty
 43 |
 44 | defCalls : NamedDef -> SortedSet Name
 45 | defCalls (MkNmFun args x) = calls x
 46 | defCalls (MkNmCon tag arity nt) = empty
 47 | defCalls (MkNmForeign ccs fargs x) = empty
 48 | defCalls (MkNmError x) = calls x
 49 |
 50 | callGraph : List (Name, FC, NamedDef) -> CallGraph
 51 | callGraph = fromList . map (\(n,_,d) => (n, defCalls d))
 52 |
 53 | isRecursive : CallGraph -> List1 Name -> Bool
 54 | isRecursive g (x ::: Nil) = maybe False (contains x) $ lookup x g
 55 | isRecursive _ _           = True
 56 |
 57 | recursiveFunctions : CallGraph -> SortedSet Name
 58 | recursiveFunctions graph =
 59 |   let groups := filter (isRecursive graph) $ tarjan graph
 60 |    in concatMap (SortedSet.fromList . forget) groups
 61 |
 62 | --------------------------------------------------------------------------------
 63 | --          Sorting Functions
 64 | --------------------------------------------------------------------------------
 65 |
 66 | data SortTag : Type where
 67 |
 68 | record SortST where
 69 |   constructor SST
 70 |   processed : SortedSet Name
 71 |   nonconst  : SortedSet Name
 72 |   triples   : SnocList (Name, FC, NamedDef)
 73 |   map       : SortedMap Name (Name, FC, NamedDef)
 74 |   graph     : CallGraph
 75 |
 76 | appendDef : Ref SortTag SortST => (Name, FC, NamedDef) -> Core ()
 77 | appendDef t = do
 78 |   st <- get SortTag
 79 |   put SortTag $ {triples $= (:< t)} st
 80 |
 81 | getCalls : Ref SortTag SortST => Name -> Core (List Name)
 82 | getCalls n = map (maybe [] Prelude.toList . lookup n . graph) (get SortTag)
 83 |
 84 | getTriple : Ref SortTag SortST => Name -> Core (Maybe (Name,FC,NamedDef))
 85 | getTriple n = map (lookup n . map) (get SortTag)
 86 |
 87 | markProcessed : Ref SortTag SortST => Name -> Core ()
 88 | markProcessed n = do
 89 |   st <- get SortTag
 90 |   put SortTag $ {processed $= insert n} st
 91 |
 92 | isProcessed : Ref SortTag SortST => Name -> Core Bool
 93 | isProcessed n = map (contains n . processed) (get SortTag)
 94 |
 95 | checkCrash : Ref SortTag SortST => (Name, FC, NamedDef) -> Core ()
 96 | checkCrash (n, _, MkNmError _) = update SortTag $ { nonconst $= insert n }
 97 | checkCrash (n, _, MkNmFun args (NmCrash {})) = update SortTag $ { nonconst $= insert n }
 98 | checkCrash (n, _, MkNmFun args (NmOp _ Crash _)) = update SortTag $ { nonconst $= insert n }
 99 | checkCrash (n, _, def) = do
100 |   st <- get SortTag
101 |   when (any (flip contains st.nonconst) !(getCalls n)) $
102 |     put SortTag $ { nonconst $= insert n } st
103 |
104 | sortDef : Ref SortTag SortST => Name -> Core ()
105 | sortDef n = do
106 |   False  <- isProcessed n | True => pure ()
107 |   markProcessed n
108 |   cs     <- getCalls n
109 |   traverse_ sortDef cs
110 |   Just t <- getTriple n | Nothing => pure ()
111 |   appendDef t
112 |   checkCrash t
113 |
114 | isConstant : (recursiveFunctions : SortedSet Name) -> (Name,FC,NamedDef) -> Bool
115 | isConstant rec (n, _, MkNmFun [] _) = not $ contains n rec
116 | isConstant _   _                  = False
117 |
118 | export
119 | sortDefs : List (Name, FC, NamedDef) -> Core (List (Name, FC, NamedDef), SortedSet Name)
120 | sortDefs ts =
121 |   let graph  := callGraph ts
122 |       rec    := recursiveFunctions graph
123 |       consts := map fst $ filter (isConstant rec) ts
124 |       init   := SST {
125 |                     processed = empty
126 |                   , nonconst  = empty
127 |                   , triples   = Lin
128 |                   , map       = fromList (map (\t => (fst t, t)) ts)
129 |                   , graph     = graph
130 |                   }
131 |    in do
132 |      s       <- newRef SortTag init
133 |      traverse_ sortDef (map fst ts)
134 |      st <- get SortTag
135 |      let sorted = triples st <>> []
136 |      let consts = filter (not . flip contains (nonconst st)) consts
137 |      pure (sorted, fromList consts)
138 |