0 | module Compiler.Opts.ToplevelConstants
2 | import Core.CompileExpr
7 | import Data.SortedSet
8 | import Data.SortedMap
9 | import Libraries.Data.Graph
18 | CallGraph = SortedMap Name (SortedSet Name)
21 | calls : NamedCExp -> SortedSet Name
22 | calls (NmLocal fc p) = empty
23 | calls (NmRef fc n1) = singleton n1
24 | calls (NmLam fc x y) = calls y
25 | calls (NmLet fc x z w) = calls w <+> calls z
26 | calls (NmApp fc x xs) = calls x <+> concatMap calls xs
27 | calls (NmCon fc n1 x tag xs) = concatMap calls xs
28 | calls (NmOp fc f xs) = concatMap calls xs
29 | calls (NmExtPrim fc p xs) = concatMap calls xs
30 | calls (NmForce fc lz x) = calls x
31 | calls (NmDelay fc lz x) = calls x
32 | calls (NmConCase fc sc xs x) =
34 | concatMap (\(MkNConAlt _ _ _ _ y) => calls y) xs <+>
36 | calls (NmConstCase fc sc xs x) =
38 | concatMap (\(MkNConstAlt _ y) => calls y) xs <+>
40 | calls (NmPrimVal fc cst) = empty
41 | calls (NmErased fc) = empty
42 | calls (NmCrash fc str) = empty
44 | defCalls : NamedDef -> SortedSet Name
45 | defCalls (MkNmFun args x) = calls x
46 | defCalls (MkNmCon tag arity nt) = empty
47 | defCalls (MkNmForeign ccs fargs x) = empty
48 | defCalls (MkNmError x) = calls x
50 | callGraph : List (Name, FC, NamedDef) -> CallGraph
51 | callGraph = fromList . map (\(n,_,d) => (n, defCalls d))
53 | isRecursive : CallGraph -> List1 Name -> Bool
54 | isRecursive g (x ::: Nil) = maybe False (contains x) $
lookup x g
55 | isRecursive _ _ = True
57 | recursiveFunctions : CallGraph -> SortedSet Name
58 | recursiveFunctions graph =
59 | let groups := filter (isRecursive graph) $
tarjan graph
60 | in concatMap (SortedSet.fromList . forget) groups
66 | data SortTag : Type where
70 | processed : SortedSet Name
71 | nonconst : SortedSet Name
72 | triples : SnocList (Name, FC, NamedDef)
73 | map : SortedMap Name (Name, FC, NamedDef)
76 | appendDef : Ref SortTag SortST => (Name, FC, NamedDef) -> Core ()
79 | put SortTag $
{triples $= (:< t)} st
81 | getCalls : Ref SortTag SortST => Name -> Core (List Name)
82 | getCalls n = map (maybe [] Prelude.toList . lookup n . graph) (get SortTag)
84 | getTriple : Ref SortTag SortST => Name -> Core (Maybe (Name,FC,NamedDef))
85 | getTriple n = map (lookup n . map) (get SortTag)
87 | markProcessed : Ref SortTag SortST => Name -> Core ()
88 | markProcessed n = do
90 | put SortTag $
{processed $= insert n} st
92 | isProcessed : Ref SortTag SortST => Name -> Core Bool
93 | isProcessed n = map (contains n . processed) (get SortTag)
95 | checkCrash : Ref SortTag SortST => (Name, FC, NamedDef) -> Core ()
96 | checkCrash (n, _, MkNmError _) = update SortTag $
{ nonconst $= insert n }
97 | checkCrash (n, _, MkNmFun args (NmCrash {})) = update SortTag $
{ nonconst $= insert n }
98 | checkCrash (n, _, MkNmFun args (NmOp _ Crash _)) = update SortTag $
{ nonconst $= insert n }
99 | checkCrash (n, _, def) = do
101 | when (any (flip contains st.nonconst) !(getCalls n)) $
102 | put SortTag $
{ nonconst $= insert n } st
104 | sortDef : Ref SortTag SortST => Name -> Core ()
106 | False <- isProcessed n | True => pure ()
109 | traverse_ sortDef cs
110 | Just t <- getTriple n | Nothing => pure ()
114 | isConstant : (recursiveFunctions : SortedSet Name) -> (Name,FC,NamedDef) -> Bool
115 | isConstant rec (n, _, MkNmFun [] _) = not $
contains n rec
116 | isConstant _ _ = False
119 | sortDefs : List (Name, FC, NamedDef) -> Core (List (Name, FC, NamedDef), SortedSet Name)
121 | let graph := callGraph ts
122 | rec := recursiveFunctions graph
123 | consts := map fst $
filter (isConstant rec) ts
128 | , map = fromList (map (\t => (fst t, t)) ts)
132 | s <- newRef SortTag init
133 | traverse_ sortDef (map fst ts)
135 | let sorted = triples st <>> []
136 | let consts = filter (not . flip contains (nonconst st)) consts
137 | pure (sorted, fromList consts)