0 | ||| Tail-call optimization.
  1 | |||
  2 | ||| Here is a lengthy explanation how this works at the
  3 | ||| moment. Assume the following call graph of functions f1,f2,f3,f4 all
  4 | ||| calling each other in tail call position:
  5 | |||
  6 | ||| ```
  7 | |||       ------------ f2 ---- f4 (result)
  8 | |||      /          /     \
  9 | ||| f1 ---- f1     /       -- f1
 10 | |||      \        /
 11 | |||       -- f3 --
 12 | ||| ```
 13 | |||
 14 | ||| First, a directed graph of all toplevel function calls
 15 | ||| (incoming and outgoing) in tail-call position is created:
 16 | |||
 17 | ||| ```idris
 18 | ||| MkCallGraph $ fromList [(f1,[f1,f2,f3]),(f2,[f1,f4]),(f3,[f2])]
 19 | |||             $ fromList [(f1,[f1,f2]),(f2,[f1,f3]),(f3,[f1]),(f4,[f2])]
 20 | ||| ```
 21 | |||
 22 | ||| Mutually tail-recursive functions form a strongly connected
 23 | ||| component in such a call graph: There is a (directed) path from every function
 24 | ||| to every other function. Tarjan's algorithm is used to identify
 25 | ||| these strongly connected components and grouping them in
 26 | ||| a `List` of `List1`s.
 27 | |||
 28 | ||| A tail-recursive group of functions is now converted to an imperative
 29 | ||| loop as follows: Let `obj={h:_, a1:_, a2:_, ...}`
 30 | ||| be a Javascript object consisting
 31 | ||| of a tag `h` and arguments `a1`,`a2`,... . `h` indicates, whether `obj.a1`
 32 | ||| contains the result of the computation (`h = 0`) or describes
 33 | ||| a continuation indicating the next function to be invoked, in which
 34 | ||| case fields `a1`,`a2`,... are the function's arguments.
 35 | ||| together with the necessary arguments. The group of mutually
 36 | ||| recursive functions is now converted to a single switch statement
 37 | ||| where each branch corresponds to one of the function.
 38 | ||| Each function will be changed in such a way that instead of
 39 | ||| (recursively) calling another function in its group it will return
 40 | ||| a new object `{h:_, a1:_, ...}` with `h` indicating the next
 41 | ||| function to call (the next branch to choose in the `switch`
 42 | ||| statement and `a1`,`a2`,... being the next function's set of
 43 | ||| arguments. The function and initial argument object will then
 44 | ||| be passed to toplevel function `__tailRec`, which loops
 45 | ||| until the object signals that we have arrived at a result.
 46 | |||
 47 | ||| Here is an example of two mutually tail-recursive functions
 48 | ||| together with the generated tail-call optimized code.
 49 | |||
 50 | ||| Original version:
 51 | |||
 52 | ||| ```javascript
 53 | ||| function isEven(arg){
 54 | |||   switch (arg) {
 55 | |||     case 0  : return 1;
 56 | |||     default : return isOdd(arg - 1);
 57 | |||   }
 58 | ||| }
 59 | |||
 60 | ||| function isOdd(arg){
 61 | |||   switch (arg) {
 62 | |||     case 0  : return 0;
 63 | |||     default : return isEven(arg - 1);
 64 | |||   }
 65 | ||| }
 66 | ||| ```
 67 | |||
 68 | ||| The above gets converted to code similar to
 69 | ||| the following.
 70 | |||
 71 | ||| ```javascript
 72 | ||| function tcOpt(arg) {
 73 | |||   switch(arg.h) {
 74 | |||   // former function isEven
 75 | |||   case 1: {
 76 | |||     switch (arg.a1) {
 77 | |||       case 0  : return {h: 0, a1: 1};
 78 | |||       default : return {h: 2, a1: arg.a1 - 1};
 79 | |||     }
 80 | |||   }
 81 | |||   // former function isOdd
 82 | |||   case 2: {
 83 | |||     switch (a1) {
 84 | |||       case 0  : return {h: 0, a1: 0};
 85 | |||       default : return {h: 1, a1: arg.a1 - 1};
 86 | |||     }
 87 | |||   }
 88 | ||| }
 89 | |||
 90 | ||| function isEven(arg){
 91 | |||   return __tailRec(tcOpt,{h: 1, a1: arg})
 92 | ||| }
 93 | |||
 94 | ||| function isOdd(arg){
 95 | |||   return __tailRec(tcOpt,{h: 2, a1: arg})
 96 | ||| }
 97 | ||| ```
 98 | |||
 99 | ||| Finally, `__tailRec` is implemented as follows:
100 | |||
101 | ||| ```javascript
102 | |||   function __tailRec(f,ini) {
103 | |||     let obj = ini;
104 | |||     while(true){
105 | |||       switch(obj.h){
106 | |||         case 0: return obj.a1;
107 | |||         default: obj = f(obj);
108 | |||       }
109 | |||     }
110 | |||   }
111 | ||| ```
112 | |||
113 | ||| While the above example is in Javascript, this module operates
114 | ||| on `NamedCExp` exclusively, so it might be used with any backend
115 | ||| where the things described above can be expressed.
116 | module Compiler.ES.TailRec
117 |
118 | import Data.List1
119 | import Data.SortedSet
120 | import Data.SortedMap as M
121 | import Libraries.Data.Graph
122 | import Core.CompileExpr
123 | import Core.Context
124 |
125 | --------------------------------------------------------------------------------
126 | --          Utilities
127 | --------------------------------------------------------------------------------
128 |
129 | -- indices of a list starting at 1
130 | indices : List a -> List Int
131 | indices as = [1 .. cast (length as)]
132 |
133 | zipWithIndices : List a -> List (Int,a)
134 | zipWithIndices as = zip (indices as) as
135 |
136 | --------------------------------------------------------------------------------
137 | --          Tailcall Graph
138 | --------------------------------------------------------------------------------
139 |
140 | ||| A (toplevel) function in a group of mutually tail recursive functions.
141 | public export
142 | record TcFunction where
143 |   constructor MkTcFunction
144 |   ||| Function's name
145 |   name  : Name
146 |
147 |   ||| Function's index in its tail call group
148 |   ||| This is used to decide on which branch to choose in
149 |   ||| the next iteration
150 |   index : Int
151 |
152 |   ||| Argument list
153 |   args  : List Name
154 |
155 |   ||| Function's definition
156 |   exp   : NamedCExp
157 |
158 | ||| A group of mutually tail recursive toplevel functions.
159 | public export
160 | record TcGroup where
161 |   constructor MkTcGroup
162 |   ||| Index of the group. This is used to generate a uniquely
163 |   ||| named tail call optimized toplevel function.
164 |   index     : Int
165 |
166 |   ||| Set of mutually recursive functions.
167 |   functions : SortedMap Name TcFunction
168 |
169 | -- tail calls made from an expression
170 | tailCalls : NamedCExp -> SortedSet Name
171 | tailCalls (NmLet _ _ _ z) = tailCalls z
172 | tailCalls (NmApp _ (NmRef _ x) _) = singleton x
173 | tailCalls (NmConCase fc sc xs x) =
174 |   concatMap (\(MkNConAlt _ _ _ _ x) => tailCalls x) xs <+>
175 |   concatMap tailCalls x
176 | tailCalls (NmConstCase fc sc xs x) =
177 |   concatMap (\(MkNConstAlt _ x) => tailCalls x) xs <+>
178 |   concatMap tailCalls x
179 | tailCalls _ = empty
180 |
181 | -- Checks if a `List1` of functions actually has any tail recursive
182 | -- function calls and needs to be optimized.
183 | -- In case of a larger group (more than one element)
184 | -- the group contains tailcalls by construction. In case
185 | -- of a single function, we need to check that at least one
186 | -- outgoing tailcall goes back to the function itself.
187 | -- We use the given mapping from `Name` to set of names
188 | -- called in tail position to verify this.
189 | hasTailCalls : SortedMap Name (SortedSet Name) -> List1 Name -> Bool
190 | hasTailCalls g (x ::: Nil) = maybe False (contains x) $ lookup x g
191 | hasTailCalls _ _           = True
192 |
193 | -- Given a strongly connected group of functions, plus
194 | -- a unique index, convert them to the `TcGroup` they belong to.
195 | toGroup :  SortedMap Name (Name,List Name,NamedCExp)
196 |         -> (Int,List1 Name)
197 |         -> TcGroup
198 | toGroup funMap (groupIndex,functions) =
199 |   let ns    = zipWithIndices $ forget functions
200 |    in MkTcGroup groupIndex . fromList $ mapMaybe fun ns
201 |   where fun : (Int,Name) -> Maybe (Name,TcFunction)
202 |         fun (fx, n) = do
203 |           (_,args,exp) <- lookup n funMap
204 |           pure (n,MkTcFunction n fx args exp)
205 |
206 | -- Returns the connected components of the tail call graph
207 | -- of a set of toplevel function definitions.
208 | -- Every `TcGroup` consists of a set of mutually tail-recursive functions.
209 | tailCallGroups :  List (Name,List Name,NamedCExp) -> List TcGroup
210 | tailCallGroups funs =
211 |   let funMap = M.fromList $ map (\t => (fst t,t)) funs
212 |       graph  = map (\(_,_,x) => tailCalls x) funMap
213 |       groups = filter (hasTailCalls graph) $ tarjan graph
214 |    in map (toGroup funMap) (zipWithIndices groups)
215 |
216 | --------------------------------------------------------------------------------
217 | --          Converting tail call groups to expressions
218 | --------------------------------------------------------------------------------
219 |
220 | public export
221 | record Function where
222 |   constructor MkFunction
223 |   name : Name
224 |   args : List Name
225 |   body : NamedCExp
226 |
227 | tcFunction : Int -> Name
228 | tcFunction = MN "$tcOpt"
229 |
230 | tcArgName : Name
231 | tcArgName = MN "$a" 0
232 |
233 | tcContinueName : (groupIndex : Int) -> (functionIndex : Int) -> Name
234 | tcContinueName gi fi = MN ("TcContinue" ++ show gi) fi
235 |
236 | tcDoneName : (groupIndex : Int) -> Name
237 | tcDoneName gi = MN "TcDone" gi
238 |
239 | -- Converts a single function in a mutually tail-recursive
240 | -- group of functions to a single branch in a pattern match.
241 | -- Recursive function calls will be replaced with an
242 | -- applied data constructor whose tag indicates the
243 | -- branch in the pattern match to use next, and whose values
244 | -- will be used as the arguments for the next function.
245 | conAlt : TcGroup -> TcFunction -> NamedConAlt
246 | conAlt (MkTcGroup tcIx funs) (MkTcFunction n ix args exp) =
247 |   let name = tcContinueName tcIx ix
248 |    in MkNConAlt name DATACON (Just ix) args (toTc exp)
249 |
250 |    where
251 |      mutual
252 |
253 |        -- this is returned in case we arrived at a result
254 |        -- (an expression not corresponding to a recursive
255 |        -- call in tail position).
256 |        tcDone : NamedCExp -> NamedCExp
257 |        tcDone x = NmCon EmptyFC (tcDoneName tcIx) DATACON (Just 0) [x]
258 |
259 |        -- this is returned in case we arrived at a resursive call
260 |        -- in tail position. The index indicates the next "function"
261 |        -- to call, the list of expressions are the function's
262 |        -- arguments.
263 |        tcContinue : (index : Int) -> List NamedCExp -> NamedCExp
264 |        tcContinue ix = NmCon EmptyFC (tcContinueName tcIx ix) DATACON (Just ix)
265 |
266 |        -- recursively converts an expression. Only the `NmApp` case is
267 |        -- interesting, the rest is pretty much boiler plate.
268 |        toTc : NamedCExp -> NamedCExp
269 |        toTc (NmLet fc x y z) = NmLet fc x y $ toTc z
270 |        toTc x@(NmApp _ (NmRef _ n) xs) =
271 |          case lookup n funs of
272 |            Just v  => tcContinue v.index xs
273 |            Nothing => tcDone x
274 |        toTc (NmConCase fc sc a d) = NmConCase fc sc (map con a) (map toTc d)
275 |        toTc (NmConstCase fc sc a d) = NmConstCase fc sc (map const a) (map toTc d)
276 |        toTc x@(NmCrash {}) = x
277 |        toTc x = tcDone x
278 |
279 |        con : NamedConAlt -> NamedConAlt
280 |        con (MkNConAlt x y tag xs z) = MkNConAlt x y tag xs (toTc z)
281 |
282 |        const : NamedConstAlt -> NamedConstAlt
283 |        const (MkNConstAlt x y) = MkNConstAlt x (toTc y)
284 |
285 | -- Converts a group of mutually tail recursive functions
286 | -- to a list of toplevel function declarations. `tailRecLoopName`
287 | -- is the name of the toplevel function that does the
288 | -- infinite looping (function `__tailRec` in the example at
289 | -- the top of this module).
290 | convertTcGroup : (tailRecLoopName : Name) -> TcGroup -> List Function
291 | convertTcGroup loop g@(MkTcGroup gindex fs) =
292 |   let functions = sortBy (comparing index) $ values fs
293 |       branches  = map (conAlt g) functions
294 |       switch    = NmConCase EmptyFC (local tcArgName) branches Nothing
295 |    in MkFunction tcFun [tcArgName] switch :: map toFun functions
296 |
297 |   where tcFun : Name
298 |         tcFun = tcFunction gindex
299 |
300 |         local : Name -> NamedCExp
301 |         local = NmLocal EmptyFC
302 |
303 |         toFun : TcFunction -> Function
304 |         toFun (MkTcFunction n ix args x) =
305 |           let exps  = map local args
306 |               tcArg = NmCon EmptyFC (tcContinueName gindex ix)
307 |                         DATACON (Just ix) exps
308 |               tcFun = NmRef EmptyFC tcFun
309 |               body  = NmApp EmptyFC (NmRef EmptyFC loop) [tcFun,tcArg]
310 |            in MkFunction n args body
311 |
312 | -- Tail recursion optimizations: Converts all groups of
313 | -- mutually tail recursive functions to an imperative loop.
314 | tailRecOptim :  List TcGroup
315 |              -> (tcOptimized : SortedSet Name)
316 |              -> (tcLoopName : Name)
317 |              -> List (Name,List Name,NamedCExp)
318 |              -> List Function
319 | tailRecOptim groups names loop ts =
320 |   let regular = mapMaybe toFun ts
321 |       tailOpt = concatMap (convertTcGroup loop) groups
322 |    in tailOpt ++ regular
323 |
324 |   where toFun : (Name,List Name,NamedCExp) -> Maybe Function
325 |         toFun (n,args,exp) =
326 |           if contains n names
327 |              then Nothing
328 |              else Just $ MkFunction n args exp
329 |
330 | ||| Converts a list of toplevel definitions (potentially
331 | ||| several groups of mutually tail-recursive functions)
332 | ||| to a new set of tail-call optimized function definitions.
333 | ||| Only `MkNmFun`s are converted. Other constructors of `NamedDef`
334 | ||| are ignored and silently dropped.
335 | export
336 | functions :  (tcLoopName : Name)
337 |           -> List (Name,FC,NamedDef)
338 |           -> List Function
339 | functions loop dfs =
340 |   let ts     = mapMaybe def dfs
341 |       groups = tailCallGroups ts
342 |       names  = SortedSet.fromList $ concatMap (keys . functions) groups
343 |    in tailRecOptim groups names loop ts
344 |    where def : (Name,FC,NamedDef) -> Maybe (Name,List Name,NamedCExp)
345 |          def (n,_,MkNmFun args x) = Just (n,args,x)
346 |          def _                    = Nothing
347 |