0 | module IO.Async.Semaphore
 1 |
 2 | import Data.Linear.Ref1
 3 | import Data.Nat
 4 | import IO.Async.Loop
 5 | import IO.Async.Core
 6 |
 7 | %default total
 8 |
 9 | -- internal state of a `Semaphore` value
10 | data ST : Type where
11 |   Available : (available : Nat) -> ST
12 |   Requested : (requested : Nat) -> (cb : IO1 ()) -> ST
13 |
14 | ||| A semphore is a synchronization primitive that can
15 | ||| be observed by at most one observer.
16 | |||
17 | ||| It consists of an internal counter that is atomically
18 | ||| reduced every time `release` is invoked.
19 | |||
20 | ||| Calling `await` blocks the calling fiber until the
21 | ||| semaphore's counter has been reduced to 0.
22 | export
23 | record Semaphore where
24 |   constructor S
25 |   ref : IORef ST
26 |
27 | ||| Creates a new semaphore with an internal counter of `n`.
28 | export %inline
29 | semaphore : Lift1 World f => (n : Nat) -> f Semaphore
30 | semaphore n = S <$> newref (Available n)
31 |
32 | unobs : IORef ST -> IO1 ()
33 | unobs r t = assert_total $ let x # t := read1 r t in go x x t
34 |   where
35 |     go : ST -> ST -> IO1 ()
36 |     go x (Available n) t = () # t
37 |     go x _             t =
38 |       case caswrite1 r x (Available 0) t of
39 |         True # t => () # t
40 |         _    # t => unobs r t
41 |
42 | rel : IORef ST -> Nat -> IO1 ()
43 | rel r add t = assert_total $ let x # t := read1 r t in go x x t
44 |   where
45 |     go : ST -> ST -> IO1 ()
46 |     go x (Requested n cb) t =
47 |       case n `minus` add of
48 |         0 => case caswrite1 r x (Available (add `minus` n)) t of
49 |           True # t => cb t
50 |           _    # t => rel r add t
51 |         k => case caswrite1 r x (Requested k cb) t of
52 |           True # t => () # t
53 |           _    # t => rel r add t
54 |     go x (Available n) t =
55 |       case caswrite1 r x (Available $ n + add) t of
56 |         True # t => () # t
57 |         _    # t => rel r add t
58 |
59 | ||| Atomically reduces the internal counter of the semaphore by the
60 | ||| given number.
61 | export
62 | releaseN : HasIO io => Semaphore -> Nat -> io ()
63 | releaseN _       0   = pure ()
64 | releaseN (S ref) add = runIO (rel ref add)
65 |
66 | ||| Atomically reduces the internal counter of the semaphore by one.
67 | export %inline
68 | release : HasIO io => Semaphore -> io ()
69 | release s = releaseN s 1
70 |
71 | acq : IORef ST -> Nat -> IO1 () -> IO1 (IO1 ())
72 | acq r req cb t = assert_total $ let x # t := read1 r t in go x x t
73 |   where
74 |     go : ST -> ST -> IO1 (IO1 ())
75 |     go x (Available n) t =
76 |       case n >= req of
77 |         True  => case caswrite1 r x (Available (n `minus` req)) t of
78 |           True # t => let _ # t := cb t in unit1 # t
79 |           _    # t => acq r req cb t
80 |         False => case caswrite1 r x (Requested (req `minus` n) cb) t of
81 |           True # t => unobs r # t
82 |           _    # t => acq r req cb t
83 |     go x _ t = unit1 # t
84 |
85 | ||| Waits (possibly by semantically blocking the fiber)
86 | ||| until the given number of steps have been released.
87 | |||
88 | ||| Note: Currently, `Semaphore` values can only be observed
89 | |||       by one observer. The calling fiber will block until
90 | |||       canceled, if another fiber has called `await` already.
91 | export
92 | acquireN : Semaphore -> Nat -> Async e es ()
93 | acquireN (S ref) req = primAsync $ \cb => acq ref req (cb (Right ()))
94 |
95 | ||| Alias for `acquireN 1`
96 | export %inline
97 | acquire : Semaphore -> Async e es ()
98 | acquire s = acquireN s 1
99 |