0 | module Control.MonadRec
  1 |
  2 | import public Control.WellFounded
  3 | import Control.Monad.Either
  4 | import Control.Monad.Identity
  5 | import Control.Monad.Maybe
  6 | import Control.Monad.Reader
  7 | import Control.Monad.RWS
  8 | import Control.Monad.State
  9 | import Control.Monad.Writer
 10 |
 11 | import Data.List
 12 | import Data.SnocList
 13 | import public Data.Fuel
 14 | import public Data.Nat
 15 |
 16 | %default total
 17 |
 18 | --------------------------------------------------------------------------------
 19 | --          Sized Implementations
 20 | --------------------------------------------------------------------------------
 21 |
 22 | public export
 23 | Sized Fuel where
 24 |   size Dry      = 0
 25 |   size (More f) = S $ size f
 26 |
 27 | public export
 28 | Sized (SnocList a) where
 29 |   size = length
 30 |
 31 | --------------------------------------------------------------------------------
 32 | --          Step
 33 | --------------------------------------------------------------------------------
 34 |
 35 | ||| Single step in a recursive computation.
 36 | |||
 37 | ||| A `Step` is either `Done`, in which case we return the
 38 | ||| final result, or `Cont`, in which case we continue
 39 | ||| iterating. In case of a `Cont`, we get a new seed for
 40 | ||| the next iteration plus an updated state. In addition
 41 | ||| we proof that the sequence of seeds is related via `rel`.
 42 | ||| If `rel` is well-founded, the recursion will provably
 43 | ||| come to an end in a finite number of steps.
 44 | public export
 45 | data Step :  (rel   : a -> a -> Type)
 46 |           -> (seed  : a)
 47 |           -> (accum : Type)
 48 |           -> (res   : Type)
 49 |           -> Type where
 50 |
 51 |   ||| Keep iterating with a new `seed2`, which is
 52 |   ||| related to the current `seed` via `rel`.
 53 |   ||| `vst` is the accumulated state of the iteration.
 54 |   Cont :  (seed2 : a)
 55 |        -> (0 prf : rel seed2 seed)
 56 |        -> (vst   : st)
 57 |        -> Step rel seed st res
 58 |
 59 |   ||| Stop iterating and return the given result.
 60 |   Done : (vres : res) -> Step rel v st res
 61 |
 62 | public export
 63 | Bifunctor (Step rel seed) where
 64 |   bimap f _ (Cont s2 prf st) = Cont s2 prf (f st)
 65 |   bimap _ g (Done res)       = Done (g res)
 66 |
 67 |   mapFst f (Cont s2 prf st) = Cont s2 prf (f st)
 68 |   mapFst _ (Done res)       = Done res
 69 |
 70 |   mapSnd _ (Cont s2 prf st) = Cont s2 prf st
 71 |   mapSnd g (Done res)       = Done (g res)
 72 |
 73 | --------------------------------------------------------------------------------
 74 | --          MonadRec
 75 | --------------------------------------------------------------------------------
 76 |
 77 | ||| Interface for tail-call optimized monadic recursion.
 78 | public export
 79 | interface Monad m => MonadRec m where
 80 |   ||| Implementers must make sure they implement this function
 81 |   ||| in a tail recursive manner.
 82 |   ||| The general idea is to loop using the given `step` function
 83 |   ||| until it returns a `Done`.
 84 |   |||
 85 |   ||| To convey to the totality checker that the sequence
 86 |   ||| of seeds generated during recursion must come to an
 87 |   ||| end after a finite number of steps, this function
 88 |   ||| requires an erased proof of accessibility.
 89 |   total
 90 |   tailRecM :
 91 |        {0 rel : a -> a -> Type}
 92 |     -> (seed  : a)
 93 |     -> (0 prf : Accessible rel seed)
 94 |     -> (ini   : st)
 95 |     -> (step  : (seed2 : a) -> st -> m (Step rel seed2 st b))
 96 |     -> m b
 97 |
 98 | ||| Monadic tail recursion over a well-founded structure.
 99 | public export %inline
100 | trWellFounded :
101 |      {auto _ : MonadRec m}
102 |   -> {auto 0 _ : WellFounded a rel}
103 |   -> (seed  : a)
104 |   -> (ini   : st)
105 |   -> (step  : (seed2 : a) -> st -> m (Step rel seed2 st b))
106 |   -> m b
107 | trWellFounded seed = tailRecM seed (wellFounded seed)
108 |
109 | public export %inline
110 | ||| Monadic tail recursion over a sized structure.
111 | trSized :
112 |      {auto _ : MonadRec m}
113 |   -> {auto 0 _ : Sized a}
114 |   -> (seed : a)
115 |   -> (ini  : st)
116 |   -> (step : (v : a) -> st -> m (Step Smaller v st b))
117 |   -> m b
118 | trSized x ini = tailRecM x (sizeAccessible x) ini
119 |
120 | ||| This is NOT a tail-recursive implementation, allowing any monad be used
121 | ||| with the same API as if it is tail-recursive. Avoid using it at all costs!
122 | export
123 | [NonStackSafe] Monad m => MonadRec m where
124 |   tailRecM seed (Access acc) init step = step seed init >>= \case
125 |     Cont seed2 prf vst => tailRecM seed2 (acc seed2 prf) vst step
126 |     Done vres          => pure vres
127 |
128 | --------------------------------------------------------------------------------
129 | --          Base Implementations
130 | --------------------------------------------------------------------------------
131 |
132 | public export
133 | MonadRec Identity where
134 |   tailRecM seed (Access rec) st1 f = case f seed st1 of
135 |     Id (Done b)         => Id b
136 |     Id (Cont y prf st2) => tailRecM y (rec y prf) st2 f
137 |
138 | public export
139 | MonadRec Maybe where
140 |   tailRecM seed (Access rec) st1 f = case f seed st1 of
141 |     Nothing               => Nothing
142 |     Just (Done b)         => Just b
143 |     Just (Cont y prf st2) => tailRecM y (rec y prf) st2 f
144 |
145 | public export
146 | MonadRec (Either e) where
147 |   tailRecM seed (Access rec) st1 f = case f seed st1 of
148 |     Left e                 => Left e
149 |     Right (Done b)         => Right b
150 |     Right (Cont y prf st2) => tailRecM y (rec y prf) st2 f
151 |
152 | trIO :
153 |      (x : a)
154 |   -> (0 _ : Accessible rel x)
155 |   -> (ini : st)
156 |   -> (f : (v : a) -> st -> IO (Step rel v st b))
157 |   -> IO b
158 | trIO x acc ini f = fromPrim $ run x acc ini
159 |
160 |   where
161 |     run :
162 |          (y : a)
163 |       -> (0 _ : Accessible rel y)
164 |       -> (st1 : st)
165 |       -> (1 w : %World)
166 |       -> IORes b
167 |     run y (Access rec) st1 w = case toPrim (f y st1) w of
168 |       MkIORes (Done b) w2          => MkIORes b w2
169 |       MkIORes (Cont y2 prf st2) w2 => run y2 (rec y2 prf) st2 w2
170 |
171 | public export %inline
172 | MonadRec IO where
173 |   tailRecM = trIO
174 |
175 | --------------------------------------------------------------------------------
176 | --          Transformer Implementations
177 | --------------------------------------------------------------------------------
178 |
179 | ---------------------------
180 | -- StateT
181 |
182 | %inline
183 | convST :
184 |      {auto _ : Functor m}
185 |   -> (f : (v : a) -> st -> StateT s m (Step rel v st b))
186 |   -> (v : a)
187 |   -> (st,s)
188 |   -> m (Step rel v (st,s) (s,b))
189 | convST f v (st1,s1) =
190 |   (\(s2,stp) => bimap (,s2) (s2,) stp) <$> runStateT s1 (f v st1)
191 |
192 | public export
193 | MonadRec m => MonadRec (StateT s m) where
194 |   tailRecM x acc ini f =
195 |     ST $ \s1 => tailRecM x acc (ini,s1) (convST f)
196 |
197 | ---------------------------
198 | -- EitherT
199 |
200 | convE :
201 |      {auto _ : Functor m}
202 |   -> (f : (v : a) -> st -> EitherT e m (Step rel v st b))
203 |   -> (v : a)
204 |   -> (ini : st)
205 |   -> m (Step rel v st (Either e b))
206 | convE f v s1 = map conv $ runEitherT (f v s1)
207 |
208 |   where
209 |     conv : Either e (Step rel v st b) -> Step rel v st (Either e b)
210 |     conv (Left err)                = Done (Left err)
211 |     conv (Right $ Done b)          = Done (Right b)
212 |     conv (Right $ Cont v2 prf st2) = Cont v2 prf st2
213 |
214 | public export
215 | MonadRec m => MonadRec (EitherT e m) where
216 |   tailRecM x acc ini f =
217 |     MkEitherT $ tailRecM x acc ini (convE f)
218 |
219 | ---------------------------
220 | -- MaybeT
221 |
222 | convM :
223 |      {auto _ : Functor m}
224 |   -> (f : (v : a) -> st -> MaybeT m (Step rel v st b))
225 |   -> (v : a)
226 |   -> (ini : st)
227 |   -> m (Step rel v st (Maybe b))
228 | convM f v s1 = map conv $ runMaybeT (f v s1)
229 |
230 |   where
231 |     conv : Maybe (Step rel v st b) -> Step rel v st (Maybe b)
232 |     conv Nothing                  = Done Nothing
233 |     conv (Just $ Done b)          = Done (Just b)
234 |     conv (Just $ Cont v2 prf st2) = Cont v2 prf st2
235 |
236 | public export
237 | MonadRec m => MonadRec (MaybeT m) where
238 |   tailRecM x acc ini f =
239 |     MkMaybeT $ tailRecM x acc ini (convM f)
240 |
241 | ---------------------------
242 | -- ReaderT
243 |
244 | convR :
245 |      (f : (v : a) -> st -> ReaderT e m (Step rel v st b))
246 |   -> (env : e)
247 |   -> (v : a)
248 |   -> (ini : st)
249 |   -> m (Step rel v st b)
250 | convR f env v s1 = runReaderT env (f v s1)
251 |
252 | public export
253 | MonadRec m => MonadRec (ReaderT e m) where
254 |   tailRecM x acc ini f =
255 |     MkReaderT $ \env => tailRecM x acc ini (convR f env)
256 |
257 | ---------------------------
258 | -- WriterT
259 |
260 | convW :
261 |      {auto _ : Functor m}
262 |   -> (f : (v : a) -> st -> WriterT w m (Step rel v st b))
263 |   -> (v : a)
264 |   -> (st,w)
265 |   -> m (Step rel v (st,w) (b,w))
266 | convW f v (s1,w1) =
267 |   (\(stp,w2) => bimap (,w2) (,w2) stp) <$> unWriterT (f v s1) w1
268 |
269 | public export
270 | MonadRec m => MonadRec (WriterT w m) where
271 |   tailRecM x acc ini f =
272 |     MkWriterT $ \w1 => tailRecM x acc (ini,w1) (convW f)
273 |
274 | ---------------------------
275 | -- RWST
276 |
277 | convRWS :
278 |      {auto _ : Functor m}
279 |   -> (f : (v : a) -> st -> RWST r w s m (Step rel v st b))
280 |   -> (env : r)
281 |   -> (v : a)
282 |   -> (st,s,w)
283 |   -> m (Step rel v (st,s,w) (b,s,w))
284 | convRWS f env v (st1,s1,w1) =
285 |   (\(stp,s2,w2) => bimap (,s2,w2) (,s2,w2) stp) <$> unRWST (f v st1) env s1 w1
286 |
287 | public export
288 | MonadRec m => MonadRec (RWST r w s m) where
289 |   tailRecM x acc ini f =
290 |     MkRWST $ \r1,s1,w1 => tailRecM x acc (ini,s1,w1) (convRWS f r1)
291 |