0 | module IO.Async.Loop.Queue
  1 |
  2 | import Data.Linear.Ref1
  3 | import Data.Nat
  4 |
  5 | %default total
  6 |
  7 | export
  8 | inc : IORef Nat -> IO1 ()
  9 | inc r t =
 10 |   assert_total $
 11 |    let v    # t := read1 r t
 12 |        True # t := caswrite1 r v (S v) t | _ # t => inc r t
 13 |     in () # t
 14 |
 15 | export
 16 | dec : IORef Nat -> IO1 Bool
 17 | dec r t = assert_total $ let v # t := read1 r t in go v v t
 18 |   where
 19 |     go : Nat -> Nat -> IO1 Bool
 20 |     go x 0     t = False # t
 21 |     go x (S k) t =
 22 |       let True # t := caswrite1 r x k t | _ # t => dec r t
 23 |        in True # t
 24 |
 25 | --------------------------------------------------------------------------------
 26 | -- Task Queue and basic operations
 27 | --------------------------------------------------------------------------------
 28 |
 29 | ||| A specialize queue implementation enabling fast enqueue, dequeue,
 30 | ||| and work stealing.
 31 | export
 32 | record Queue a where
 33 |   constructor Q
 34 |   asleep : Bool
 35 |   head   : List a
 36 |   tail   : SnocList a
 37 |
 38 | export %inline
 39 | queueOf : (0 a : Type) -> Queue a
 40 | queueOf _ = Q False [] [<]
 41 |
 42 | export
 43 | isEmpty : Queue a -> Bool
 44 | isEmpty (Q _ [] [<]) = True
 45 | isEmpty _            = False
 46 |
 47 | export
 48 | enq : IORef (Queue a) -> a -> IO1 Bool
 49 | enq r v t = assert_total $ let q # t := read1 r t in go q q t
 50 |   where
 51 |     go : Queue a -> Queue a -> IO1 Bool
 52 |     go x (Q as [] [<]) t = case caswrite1 r x (Q as [v] [<]) t of
 53 |       True # t => as # t
 54 |       _    # t => enq r v t
 55 |     go x (Q as h tl) t = case caswrite1 r x (Q as h (tl:<v)) t of
 56 |       True # t => as # t
 57 |       _    # t => enq r v t
 58 |
 59 | export
 60 | enqall : IORef (Queue a) -> List a -> IO1 Bool
 61 | enqall r vs t = assert_total $ let q # t := read1 r t in go q q t
 62 |   where
 63 |     go : Queue a -> Queue a -> IO1 Bool
 64 |     go x (Q as [] [<]) t = case caswrite1 r x (Q as vs [<]) t of
 65 |       True # t => as # t
 66 |       _    # t => enqall r vs t
 67 |     go x (Q as h tl) t = case caswrite1 r x (Q as h (tl<><vs)) t of
 68 |       True # t => as # t
 69 |       _    # t => enqall r vs t
 70 |
 71 | export
 72 | deq : IORef (Queue a) -> IO1 (Maybe a)
 73 | deq r t = assert_total $ let q # t := read1 r t in go q q t
 74 |   where
 75 |     go : Queue a -> Queue a -> IO1 (Maybe a)
 76 |     go x (Q as h tl) t =
 77 |       case h of
 78 |         y::z => case caswrite1 r x (Q False z tl) t of
 79 |           True # t => Just y # t
 80 |           _    # t => deq r t
 81 |         []   => case tl <>> [] of
 82 |           y::z => case caswrite1 r x (Q False z [<]) t of
 83 |             True # t => Just y # t
 84 |             _    # t => deq r t
 85 |           []   => Nothing # t
 86 |
 87 | export
 88 | deqAndSleep : IORef (Queue a) -> IO1 (Maybe a)
 89 | deqAndSleep r t = assert_total $ let q # t := read1 r t in go q q t
 90 |   where
 91 |     go : Queue a -> Queue a -> IO1 (Maybe a)
 92 |     go x (Q as h tl) t =
 93 |       case h of
 94 |         y::z => case caswrite1 r x (Q False z tl) t of
 95 |           True # t => Just y # t
 96 |           _    # t => deq r t
 97 |         []   => case tl <>> [] of
 98 |           y::z => case caswrite1 r x (Q False z [<]) t of
 99 |             True # t => Just y # t
100 |             _    # t => deq r t
101 |           []   => case caswrite1 r x (Q True [] [<]) t of
102 |             True # t => Nothing # t
103 |             _    # t => deq r t
104 |
105 | --------------------------------------------------------------------------------
106 | -- Work stealing
107 | --------------------------------------------------------------------------------
108 |
109 | STEAL_MAX : Nat
110 | STEAL_MAX = 10
111 |
112 | -- we want to steal at most half + 1 tasks from the head of a
113 | -- queue, so this counts elements from the head two at a time.
114 | -- we also don't want to steal more than `STEAL_MAX` tasks
115 | count : List a -> Nat -> Nat
116 | count []         k     = 0
117 | count _          0     = 0
118 | count [_]        _     = 1
119 | count (_::_::xs) (S k) = S (count xs k)
120 |
121 | -- like `count` but for the tail
122 | countsl : SnocList a -> Nat -> Nat
123 | countsl [<]        k     = 0
124 | countsl _          0     = 0
125 | countsl [<_]       _     = 1
126 | countsl (sx:<_:<_) (S k) = S (countsl sx k)
127 |
128 | splitList : SnocList a -> Nat -> List a -> (List a, List a)
129 | splitList sx (S k) (x::xs) = splitList (sx:<x) k xs
130 | splitList sx _     xs      =(xs, sx <>> [])
131 |
132 | splitSnoc : SnocList a -> Nat -> List a -> (SnocList a, List a)
133 | splitSnoc (sx:<x) (S k) xs = splitSnoc sx k (x::xs)
134 | splitSnoc sx      _     xs = (sx, xs)
135 |
136 | splitHead : List a -> (List a, List a)
137 | splitHead xs = splitList [<] (count xs STEAL_MAX) xs 
138 |
139 | splitTail : Nat -> List a -> List a -> SnocList a -> (SnocList a, List a)
140 | splitTail (S n) res (_::t) (i:<v) = splitTail n (v::res) t i
141 | splitTail 0     res _      sx     = (sx, res)
142 | splitTail n     res _      [<]    = ([<], res)
143 | splitTail n     res []     sx     =
144 |   splitSnoc sx (countsl sx n) res
145 |
146 | ||| Steals up to `STEAL_MAX` tasks from a queue but not more than half
147 | ||| the enqueued tasks (rounded up).
148 | export
149 | steal : IORef (Queue a) -> IO1 (Maybe a)
150 | steal r t = assert_total $ let q # t := read1 r t in go q q t
151 |   where
152 |     go : Queue a -> Queue a -> IO1 (Maybe a)
153 |     go x (Q a hs ts) t =
154 |       case ts of
155 |         (rem:<v) => case caswrite1 r x (Q a hs rem) t of
156 |           True # t => Just v # t
157 |           _    # t => steal r t
158 |         [<] => case hs of
159 |           v::rem => case caswrite1 r x (Q a rem [<]) t of
160 |             True # t => Just v # t
161 |             _    # t => steal r t
162 |           [] => Nothing # t
163 |   -- where
164 |   --   go : Queue a -> Queue a -> IO1 (List a)
165 |   --   go x (Q _ [] [<]) t = [] # t
166 |   --   go x (Q a h  [<]) t =
167 |   --     let (h2,res) := splitHead h
168 |   --      in case caswrite1 r x (Q a h2 [<]) t of
169 |   --           True # t => res # t
170 |   --           _    # t => steal r t
171 |   --   go x (Q a h  tl)  t =
172 |   --     let (tl2,res) := splitTail STEAL_MAX [] h tl
173 |   --      in case caswrite1 r x (Q a h tl2) t of
174 |   --           True # t => res # t
175 |   --           _    # t => steal r t
176 |
177 | --------------------------------------------------------------------------------
178 | -- Tests and Proofs
179 | --------------------------------------------------------------------------------
180 |
181 | 0 count1 : (n : Nat) -> count [] n === 0
182 | count1 _ = Refl
183 |
184 | 0 count2 : (n : Nat) -> (xs : List a) -> LTE (count xs n) n
185 | count2 0 []            = LTEZero
186 | count2 0 (x :: xs)     = LTEZero
187 | count2 (S k) []        = LTEZero
188 | count2 (S k) [x]       = LTESucc LTEZero
189 | count2 (S k) (_::_::xs)= LTESucc (count2 k xs)
190 |
191 | 0 count3 : (n : Nat) -> (xs : List a) -> LTE (count xs n) (length xs)
192 | count3 0 []            = LTEZero
193 | count3 0 (x :: xs)     = LTEZero
194 | count3 (S k) []        = LTEZero
195 | count3 (S k) [x]       = LTESucc LTEZero
196 | count3 (S k) (_::_::xs)= lteSuccRight (LTESucc $ count3 k xs)
197 |
198 | 0 count4 : count [1,2,3,4,5] 2 === 2
199 | count4 = Refl
200 |
201 | 0 count5 : count [1,2,3,4] 2 === 2
202 | count5 = Refl
203 |
204 | 0 count6 : count [1,2,3] 2 === 2
205 | count6 = Refl
206 |
207 | 0 count7 : count [1,2] 2 === 1
208 | count7 = Refl
209 |
210 | 0 countsl1 : (n : Nat) -> countsl [<] n === 0
211 | countsl1 _ = Refl
212 |
213 | 0 countsl2 : (n : Nat) -> (sx : SnocList a) -> LTE (countsl sx n) n
214 | countsl2 0 [<]           = LTEZero
215 | countsl2 0 (sx :< x)     = LTEZero
216 | countsl2 (S k) [<]       = LTEZero
217 | countsl2 (S k) [<x]      = LTESucc LTEZero
218 | countsl2 (S k) (sx:<_:<_)= LTESucc (countsl2 k sx)
219 |
220 | 0 countsl3 : (n : Nat) -> (sx : SnocList a) -> LTE (countsl sx n) (length sx)
221 | countsl3 0 [<]           = LTEZero
222 | countsl3 0 (sx:<x)       = LTEZero
223 | countsl3 (S k) [<]       = LTEZero
224 | countsl3 (S k) [<x]      = LTESucc LTEZero
225 | countsl3 (S k) (sx:<_:<_)= lteSuccRight (LTESucc $ countsl3 k sx)
226 |
227 | 0 countsl4 : countsl [<1,2,3,4,5] 2 === 2
228 | countsl4 = Refl
229 |
230 | 0 countsl5 : countsl [<1,2,3,4] 2 === 2
231 | countsl5 = Refl
232 |
233 | 0 countsl6 : countsl [<1,2,3] 2 === 2
234 | countsl6 = Refl
235 |
236 | 0 countsl7 : countsl [<1,2] 2 === 1
237 | countsl7 = Refl
238 |