116 | module Compiler.ES.TailRec
119 | import Data.SortedSet
120 | import Data.SortedMap as M
121 | import Libraries.Data.Graph
122 | import Core.CompileExpr
123 | import Core.Context
130 | indices : List a -> List Int
131 | indices as = [1 .. cast (length as)]
133 | zipWithIndices : List a -> List (Int,a)
134 | zipWithIndices as = zip (indices as) as
142 | record TcFunction where
143 | constructor MkTcFunction
160 | record TcGroup where
161 | constructor MkTcGroup
167 | functions : SortedMap Name TcFunction
170 | tailCalls : NamedCExp -> SortedSet Name
171 | tailCalls (NmLet _ _ _ z) = tailCalls z
172 | tailCalls (NmApp _ (NmRef _ x) _) = singleton x
173 | tailCalls (NmConCase fc sc xs x) =
174 | concatMap (\(MkNConAlt _ _ _ _ x) => tailCalls x) xs <+>
175 | concatMap tailCalls x
176 | tailCalls (NmConstCase fc sc xs x) =
177 | concatMap (\(MkNConstAlt _ x) => tailCalls x) xs <+>
178 | concatMap tailCalls x
179 | tailCalls _ = empty
189 | hasTailCalls : SortedMap Name (SortedSet Name) -> List1 Name -> Bool
190 | hasTailCalls g (x ::: Nil) = maybe False (contains x) $
lookup x g
191 | hasTailCalls _ _ = True
195 | toGroup : SortedMap Name (Name,List Name,NamedCExp)
196 | -> (Int,List1 Name)
198 | toGroup funMap (groupIndex,functions) =
199 | let ns = zipWithIndices $
forget functions
200 | in MkTcGroup groupIndex . fromList $
mapMaybe fun ns
201 | where fun : (Int,Name) -> Maybe (Name,TcFunction)
203 | (_,args,exp) <- lookup n funMap
204 | pure (n,MkTcFunction n fx args exp)
209 | tailCallGroups : List (Name,List Name,NamedCExp) -> List TcGroup
210 | tailCallGroups funs =
211 | let funMap = M.fromList $
map (\t => (fst t,t)) funs
212 | graph = map (\(_,_,x) => tailCalls x) funMap
213 | groups = filter (hasTailCalls graph) $
tarjan graph
214 | in map (toGroup funMap) (zipWithIndices groups)
221 | record Function where
222 | constructor MkFunction
227 | tcFunction : Int -> Name
228 | tcFunction = MN "$tcOpt"
231 | tcArgName = MN "$a" 0
233 | tcContinueName : (groupIndex : Int) -> (functionIndex : Int) -> Name
234 | tcContinueName gi fi = MN ("TcContinue" ++ show gi) fi
236 | tcDoneName : (groupIndex : Int) -> Name
237 | tcDoneName gi = MN "TcDone" gi
245 | conAlt : TcGroup -> TcFunction -> NamedConAlt
246 | conAlt (MkTcGroup tcIx funs) (MkTcFunction n ix args exp) =
247 | let name = tcContinueName tcIx ix
248 | in MkNConAlt name DATACON (Just ix) args (toTc exp)
256 | tcDone : NamedCExp -> NamedCExp
257 | tcDone x = NmCon EmptyFC (tcDoneName tcIx) DATACON (Just 0) [x]
263 | tcContinue : (index : Int) -> List NamedCExp -> NamedCExp
264 | tcContinue ix = NmCon EmptyFC (tcContinueName tcIx ix) DATACON (Just ix)
268 | toTc : NamedCExp -> NamedCExp
269 | toTc (NmLet fc x y z) = NmLet fc x y $
toTc z
270 | toTc x@(NmApp _ (NmRef _ n) xs) =
271 | case lookup n funs of
272 | Just v => tcContinue v.index xs
273 | Nothing => tcDone x
274 | toTc (NmConCase fc sc a d) = NmConCase fc sc (map con a) (map toTc d)
275 | toTc (NmConstCase fc sc a d) = NmConstCase fc sc (map const a) (map toTc d)
276 | toTc x@(NmCrash {}) = x
279 | con : NamedConAlt -> NamedConAlt
280 | con (MkNConAlt x y tag xs z) = MkNConAlt x y tag xs (toTc z)
282 | const : NamedConstAlt -> NamedConstAlt
283 | const (MkNConstAlt x y) = MkNConstAlt x (toTc y)
290 | convertTcGroup : (tailRecLoopName : Name) -> TcGroup -> List Function
291 | convertTcGroup loop g@(MkTcGroup gindex fs) =
292 | let functions = sortBy (comparing index) $
values fs
293 | branches = map (conAlt g) functions
294 | switch = NmConCase EmptyFC (local tcArgName) branches Nothing
295 | in MkFunction tcFun [tcArgName] switch :: map toFun functions
298 | tcFun = tcFunction gindex
300 | local : Name -> NamedCExp
301 | local = NmLocal EmptyFC
303 | toFun : TcFunction -> Function
304 | toFun (MkTcFunction n ix args x) =
305 | let exps = map local args
306 | tcArg = NmCon EmptyFC (tcContinueName gindex ix)
307 | DATACON (Just ix) exps
308 | tcFun = NmRef EmptyFC tcFun
309 | body = NmApp EmptyFC (NmRef EmptyFC loop) [tcFun,tcArg]
310 | in MkFunction n args body
314 | tailRecOptim : List TcGroup
315 | -> (tcOptimized : SortedSet Name)
316 | -> (tcLoopName : Name)
317 | -> List (Name,List Name,NamedCExp)
319 | tailRecOptim groups names loop ts =
320 | let regular = mapMaybe toFun ts
321 | tailOpt = concatMap (convertTcGroup loop) groups
322 | in tailOpt ++ regular
324 | where toFun : (Name,List Name,NamedCExp) -> Maybe Function
325 | toFun (n,args,exp) =
326 | if contains n names
328 | else Just $
MkFunction n args exp
336 | functions : (tcLoopName : Name)
337 | -> List (Name,FC,NamedDef)
339 | functions loop dfs =
340 | let ts = mapMaybe def dfs
341 | groups = tailCallGroups ts
342 | names = SortedSet.fromList $
concatMap (keys . functions) groups
343 | in tailRecOptim groups names loop ts
344 | where def : (Name,FC,NamedDef) -> Maybe (Name,List Name,NamedCExp)
345 | def (n,_,MkNmFun args x) = Just (n,args,x)