0 | module System.Random.Pure
  1 |
  2 | import Control.Monad.State.Interface
  3 |
  4 | import Data.Bits
  5 | import Data.Fin
  6 |
  7 | %default total
  8 |
  9 | --------------------------------
 10 | --- Interface for seed types ---
 11 | --------------------------------
 12 |
 13 | ||| Interface of algorithms of pseudo-random generation of `Bits64` values using a seed type `g`.
 14 | |||
 15 | ||| Those `Bits64` values must be generated uniformly.
 16 | ||| Splitting must give independent seeds.
 17 | public export
 18 | interface RandomGen g where
 19 |   next    : g -> (g, Bits64)
 20 |   split   : g -> (g, g)
 21 |   variant : Nat -> g -> g
 22 |
 23 | ||| Stream of independent seeds
 24 | public export
 25 | splitStream : RandomGen g => g -> Stream g
 26 | splitStream seed = do
 27 |   let (l, r) = split seed
 28 |   l :: splitStream r
 29 |
 30 | ||| Interface for getting the starting seed
 31 | public export
 32 | interface RandomGen g => CanInitSeed g m | m where
 33 |   initSeed : m g
 34 |
 35 | export %inline
 36 | ConstSeed : Applicative m => RandomGen g => g -> CanInitSeed g m
 37 | ConstSeed seed = S where [S] CanInitSeed g m where initSeed = pure seed
 38 |
 39 | ||| Stream of independent seeds in the current context of `CanInitSeed`
 40 | public export %inline
 41 | theSplitStream : CanInitSeed g m => Functor m => m $ Stream g
 42 | theSplitStream = splitStream <$> initSeed
 43 |
 44 | --------------------------------------------------------
 45 | --- Types for which values can be randomly generated ---
 46 | --------------------------------------------------------
 47 |
 48 | ||| Interface for generation of values of particular types using any appropriate `RandomGen` algorithm.
 49 | |||
 50 | ||| Contains a function for generation of uniform value sitting in the given range,
 51 | ||| and a function for generation of any value of the type.
 52 | ||| For inifinite types, when no range is given, implementation determines actual used range by itself.
 53 | public export
 54 | interface Random a where
 55 |   randomR : RandomGen g => (a, a) -> g -> (g, a)
 56 |   random  : RandomGen g => g -> (g, a)
 57 |
 58 | public export %inline
 59 | randomFor : RandomGen g => (0 a : _) -> Random a => g -> (g, a)
 60 | randomFor _ = random
 61 |
 62 | export
 63 | randomR' : Random a => RandomGen g => MonadState g m => (a, a) -> m a
 64 | randomR' bounds = let (g, x) = randomR bounds !get in put g $> x
 65 |
 66 | export
 67 | random' : Random a => RandomGen g => MonadState g m => m a
 68 | random' = let (g, x) = random !get in put g $> x
 69 |
 70 | public export %inline
 71 | randomFor' : RandomGen g => MonadState g m => (0 a : _) -> Random a => m a
 72 | randomFor' _ = random'
 73 |
 74 | export
 75 | randomThru : (0 thru : _) -> Random thru => (from : thru -> a) -> (to : a -> thru) -> Random a
 76 | randomThru thru from to = RandomThru where
 77 |   [RandomThru] Random a where
 78 |     randomR = map from .: randomR {a=thru} . mapHom to
 79 |     random  = map from . random {a=thru}
 80 |
 81 | --- Patricular implementations ---
 82 |
 83 | maxMask : Bits64 -> Bits64
 84 | maxMask max = case countLeadingZeros max of
 85 |                 Nothing  => zeroBits
 86 |                 Just off => oneBits `shiftR` off
 87 |   where
 88 |     countLeadingZeros : Bits64 -> Maybe $ Fin 64
 89 |     countLeadingZeros x = go 63 where
 90 |       go : Fin 64 -> Maybe $ Fin 64
 91 |       go i = if testBit x i then Just $ complement i else case i of
 92 |                FZ    => Nothing
 93 |                FS i' => go $ assert_smaller i $ weaken i'
 94 |
 95 | export
 96 | Random Bits64 where
 97 |   random = next
 98 |   randomR (lo, hi) = do
 99 |     let (lo, hi) = (lo `min` hi, lo `max` hi)
100 |     map (lo +) . nextMax (hi - lo) where
101 |
102 |       nextMax : (max : Bits64) -> g -> (g, Bits64)
103 |       nextMax max = assert_total go where
104 |
105 |         mask : Bits64
106 |         mask = maxMask max
107 |
108 |         covering
109 |         go : g -> (g, Bits64)
110 |         go g = do
111 |           let (g', x) = next g
112 |               x' = x .&. mask
113 |           if x' > max then go g' else (g', x')
114 |
115 | export %hint
116 | RandomInt64 : Random Int64
117 | RandomInt64 = randomThru Bits64 (fromInteger . (\x => x - diff) . cast) (fromInteger . (+ diff) . cast) where
118 |   diff : Integer
119 |   diff = 1 `shiftL` 63
120 |
121 | export %hint
122 | RandomBits32 : Random Bits32
123 | RandomBits32 = randomThru Bits64 cast cast
124 |
125 | export %hint
126 | RandomInt32 : Random Int32
127 | RandomInt32 = randomThru Int64 cast cast
128 |
129 | export %hint
130 | RandomBits16 : Random Bits16
131 | RandomBits16 = randomThru Bits64 cast cast
132 |
133 | export %hint
134 | RandomInt16 : Random Int16
135 | RandomInt16 = randomThru Int64 cast cast
136 |
137 | export %hint
138 | RandomBits8 : Random Bits8
139 | RandomBits8 = randomThru Bits64 cast cast
140 |
141 | export %hint
142 | RandomInt8 : Random Int8
143 | RandomInt8 = randomThru Int64 cast cast
144 |
145 | export %hint
146 | RandomInt : Random Int
147 | RandomInt = randomThru Int64 cast cast
148 |
149 | two64 : Integer
150 | two64 = 1 `shiftL` 64
151 |
152 | export
153 | Random Integer where
154 |   random = randomR (-two64, two64) -- This is more or less arbitrary anyway
155 |   randomR (lo, hi) = do
156 |     let (lo, hi) = (lo `min` hi, lo `max` hi)
157 |     map (lo +) . nextMax (hi - lo) where
158 |
159 |       nextMax : Integer -> g -> (g, Integer)
160 |       nextMax max = do
161 |         let goMask : Nat -> Integer -> (Nat, Bits64)
162 |             goMask n x = if x < two64
163 |               then (n, maxMask $ cast x)
164 |               else goMask (S n) (assert_smaller x $ x `shiftR` 64)
165 |
166 |         let (restDigits, leadMask) = goMask 0 max
167 |         let generate : g -> (g, Integer)
168 |             generate g0 = do
169 |               let (g', x) = next g0
170 |                   x' = x .&. leadMask
171 |               go (cast x') restDigits g'
172 |               where
173 |                 go : Integer -> Nat -> g -> (g, Integer)
174 |                 go acc Z     g = (g, acc)
175 |                 go acc (S n) g =
176 |                     let (g', x) = next g
177 |                     in go (acc * two64 + cast x) n g'
178 |
179 |         let covering loop : g -> (g, Integer)
180 |             loop g = do
181 |               let (g', x) = generate g
182 |               if x > max then loop g' else (g', x)
183 |
184 |         assert_total loop
185 |
186 | export %hint
187 | RandomNat : Random Nat
188 | RandomNat = randomThru Integer (cast . abs) cast
189 |
190 | export
191 | Random Double where
192 |   random = map (\w64 => cast (w64 `shiftR` 11) * doubleULP) . next where
193 |     doubleULP : Double
194 |     doubleULP =  1.0 / cast {from=Bits64} (1 `shiftL` 53)
195 |   randomR (lo, hi) =
196 |     if lo == hi then map (const lo) . next
197 |     else if lo == 1/0 || lo == -1/0 || hi == 1/0 || hi == -1/0 then map (const $ lo + hi) . next
198 |     else map (\x => x * lo + (1 - x) * hi) . random
199 |
200 | export
201 | Random Unit where
202 |   randomR ((), ()) gen = map (const ()) $ next gen
203 |   random gen = map (const ()) $ next gen
204 |
205 | export
206 | {n : Nat} -> Random (Fin $ S n) where
207 |   random  = map (\x => natToFinLt x @{believe_me Oh}) . randomR (0, n)
208 |   randomR = map (\x => natToFinLt x @{believe_me Oh}) .: randomR . mapHom finToNat
209 |
210 | export %hint
211 | RandomBool : Random Bool
212 | RandomBool = randomThru Bits64 (\x => testBit x 0) (\b => if b then 1 else 0)
213 |
214 | export
215 | Random Char where
216 |   random  = map cast . randomR {a=Bits64} (0, 0xfffff+0xffff+1)
217 |   randomR = map cast .: randomR {a=Bits64} . mapHom cast
218 |