0 | module Libraries.Data.NatSet
  1 |
  2 | import Data.Bits
  3 |
  4 | %default total
  5 |
  6 | export
  7 | NatSet : Type
  8 | NatSet = Integer
  9 |
 10 | export %inline
 11 | empty : NatSet
 12 | empty = 0
 13 |
 14 | export %inline
 15 | elem : Nat -> NatSet -> Bool
 16 | elem = flip testBit
 17 |
 18 | export
 19 | drop : NatSet -> List a -> List a
 20 | drop 0  xs = xs
 21 | drop ds xs = go 0 xs
 22 |   where
 23 |     go : Nat -> List a -> List a
 24 |     go _ [] = []
 25 |     go i (x :: xs)
 26 |         = if i `elem` ds
 27 |              then go (S i) xs
 28 |              else x :: go (S i) xs
 29 |
 30 | export %inline
 31 | take : NatSet -> List a -> List a
 32 | take = drop . complement
 33 |
 34 | export
 35 | isEmpty : NatSet -> Bool
 36 | isEmpty 0 = True
 37 | isEmpty _ = False
 38 |
 39 | export
 40 | size : NatSet -> Nat
 41 | size = go 0
 42 |   where
 43 |     go : Nat -> NatSet -> Nat
 44 |     go acc 0 = acc
 45 |     go acc n =
 46 |       -- cast is modulo, i.e. takes the lower bits
 47 |       let acc = acc + popCount (the Int64 (cast n)) in
 48 |       go acc (assert_smaller n (shiftR n 64))
 49 |
 50 | export %inline
 51 | Cast NatSet Integer where
 52 |   cast ns = ns
 53 |
 54 | export %inline
 55 | Cast Integer NatSet where
 56 |   cast n = n
 57 |
 58 | export
 59 | insert : Nat -> NatSet -> NatSet
 60 | insert = flip setBit
 61 |
 62 | export
 63 | delete : Nat -> NatSet -> NatSet
 64 | delete = flip clearBit
 65 |
 66 | export
 67 | toList : NatSet -> List Nat
 68 | toList = go 0
 69 |   where
 70 |     go : Nat -> NatSet -> List Nat
 71 |     go i 0 = []
 72 |     go i ns =
 73 |        let is = go (S i) (assert_smaller ns (shiftR ns 1)) in
 74 |        if 0 `elem` ns then i :: is else is
 75 |
 76 | fromList : List Nat -> NatSet
 77 | fromList = foldr insert empty
 78 |
 79 | export
 80 | Show NatSet where
 81 |   show ns = show (toList ns)
 82 |
 83 | export
 84 | partition : NatSet -> List a -> (List a , List a)
 85 | partition ps = go 0
 86 |   where
 87 |     go : Nat -> List a -> (List a , List a)
 88 |     go i [] = ([], [])
 89 |     go i (x :: xs)
 90 |       = let xys = go (S i) xs in
 91 |         if i `elem` ps
 92 |            then mapFst (x ::) xys
 93 |            else mapSnd (x ::) xys
 94 |
 95 | export
 96 | intersection : NatSet -> NatSet -> NatSet
 97 | intersection = (.&.)
 98 |
 99 | export
100 | union : NatSet -> NatSet -> NatSet
101 | union = (.|.)
102 |
103 | export
104 | intersectAll : List NatSet -> NatSet
105 | intersectAll [] = empty
106 | intersectAll (x::xs) = foldr intersection x xs
107 |
108 | export
109 | allLessThan : Nat -> NatSet
110 | allLessThan n = shiftL 1 n - 1
111 |
112 | 0 allLessThanSpecEmpty : toList (allLessThan 0) === []
113 | allLessThanSpecEmpty = Refl
114 |
115 | 0 allLessThanSpecNonEmpty : toList (allLessThan 10) === [0..9]
116 | allLessThanSpecNonEmpty = Refl
117 |
118 | export
119 | overwrite : a -> NatSet -> List a -> List a
120 | overwrite c 0  xs = xs
121 | overwrite c ds xs = go 0 xs
122 |   where
123 |     go : Nat -> List a -> List a
124 |     go _ [] = []
125 |     go i (x :: xs)
126 |         = if i `elem` ds
127 |              then c :: go (S i) xs
128 |              else x :: go (S i) xs
129 |
130 |
131 |
132 | -- Pop the zero (whether or not in the set) and shift all the
133 | -- other positions by -1 (useful when coming back from under
134 | -- a binder)
135 | export %inline
136 | popZ : NatSet -> NatSet
137 | popZ = (`shiftR` 1)
138 |
139 | export %inline
140 | popNs : Nat -> NatSet -> NatSet
141 | popNs = flip shiftR
142 |
143 | -- Add a 'new' Zero (not in the set) and shift all the
144 | -- other positions by +1 (useful when going under a binder)
145 | export %inline
146 | addZ : NatSet -> NatSet
147 | addZ = (`shiftL` 1)
148 |