0 | module Data.HashMap.Array
  1 |
  2 | import Data.Fin
  3 | import Data.IOArray.Prims
  4 | import Data.List
  5 | import Data.HashMap.Bits
  6 |
  7 | %default total
  8 |
  9 | %inline
 10 | prim__newArray : Bits32 -> a -> PrimIO (ArrayData a)
 11 | prim__newArray len x = Prims.prim__newArray (unsafeCast len) x
 12 |
 13 | %inline
 14 | prim__arrayGet : ArrayData a -> Bits32 -> PrimIO a
 15 | prim__arrayGet arr i = Prims.prim__arrayGet arr (unsafeCast i)
 16 |
 17 | %inline
 18 | prim__arraySet : ArrayData a -> Bits32 -> a -> PrimIO ()
 19 | prim__arraySet arr i x = Prims.prim__arraySet arr (unsafeCast i) x
 20 |
 21 | %inline
 22 | unsafeNewArray : Bits32 -> IO (ArrayData a)
 23 | unsafeNewArray len = fromPrim $ prim__newArray len (believe_me ())
 24 |
 25 | %inline
 26 | unsafeInlinePerformIO : IO a -> a
 27 | unsafeInlinePerformIO act =
 28 |     let MkIORes res %MkWorld = toPrim act %MkWorld
 29 |      in res
 30 |
 31 | -- Copy from[fromStart..fromStop) to to[toStart..)
 32 | -- %spec f
 33 | copyFromArrayBy :
 34 |     (f : a -> b) ->
 35 |     (from : ArrayData a) ->
 36 |     (to : ArrayData b) ->
 37 |     (fromStart : Bits32) ->
 38 |     (toStart : Bits32) ->
 39 |     (fromStop : Bits32) ->
 40 |     IO ()
 41 | copyFromArrayBy f from to fromStart toStart fromStop = case fromStart `prim__lt_Bits32` fromStop of
 42 |     0 => pure ()
 43 |     _ => do
 44 |         val <- fromPrim $ prim__arrayGet from fromStart
 45 |         fromPrim $ prim__arraySet to toStart (f val)
 46 |         copyFromArrayBy f from to (assert_smaller fromStart $ unsafeIncr fromStart) (unsafeIncr toStart) fromStop
 47 |
 48 | -- Hand specialised version of copyFromArrayBy id
 49 | copyFromArray :
 50 |     (from : ArrayData a) ->
 51 |     (to : ArrayData a) ->
 52 |     (fromStart : Bits32) ->
 53 |     (toStart : Bits32) ->
 54 |     (fromStop : Bits32) ->
 55 |     IO ()
 56 | copyFromArray from to fromStart toStart fromStop = case fromStart `prim__lt_Bits32` fromStop of
 57 |     0 => pure ()
 58 |     _ => do
 59 |         val <- fromPrim $ prim__arrayGet from fromStart
 60 |         fromPrim $ prim__arraySet to toStart val
 61 |         copyFromArray from to (assert_smaller fromStart $ unsafeIncr fromStart) (unsafeIncr toStart) fromStop
 62 |
 63 |
 64 | copyFromList : ArrayData a -> List a -> Bits32 -> IO ()
 65 | copyFromList arr [] idx = pure ()
 66 | copyFromList arr (x :: xs) idx = do
 67 |     fromPrim $ prim__arraySet arr idx x
 68 |     copyFromList arr xs (idx + 1)
 69 |
 70 | export
 71 | data Array : Type -> Type where
 72 |     MkArray : (len : Bits32) -> (arr : ArrayData a) -> Array a
 73 |
 74 | %name Array arr
 75 |
 76 | export
 77 | empty : Array a
 78 | empty = MkArray 0 $ unsafeInlinePerformIO $ unsafeNewArray 0
 79 |
 80 | export
 81 | singleton : a -> Array a
 82 | singleton x = unsafeInlinePerformIO $ do
 83 |     arr <- fromPrim $ Prims.prim__newArray 1 x
 84 |     pure $ MkArray 1 arr
 85 |
 86 | export
 87 | fromList : (xs : List a) -> Array a
 88 | fromList [] = empty
 89 | fromList (x :: xs) = unsafeInlinePerformIO $ do
 90 |     let len = 1 + cast (length xs)
 91 |     arr <- fromPrim $ prim__newArray len x
 92 |     copyFromList arr xs 1
 93 |     pure $ MkArray len arr
 94 |
 95 | toListOnto : Array a -> List a -> List a
 96 | toListOnto (MkArray 0 _) acc = acc
 97 | toListOnto xs@(MkArray len arr) acc =
 98 |     let last = unsafeInlinePerformIO $ fromPrim $ prim__arrayGet arr (len - 1)
 99 |      in case len of
100 |         1 => last :: acc
101 |         _ => toListOnto (assert_smaller xs $ MkArray (len - 1) arr) (last :: acc)
102 |
103 | export %inline
104 | length : Array a -> Bits32
105 | length (MkArray len x) = len
106 |
107 | %inline
108 | unsafeIndex : Array a -> Bits32 -> a
109 | unsafeIndex (MkArray _ arr) idx = unsafeInlinePerformIO $ fromPrim $ prim__arrayGet arr idx
110 |
111 | export
112 | index : Array a -> Bits32 -> Maybe a
113 | index arr idx =
114 |     if 0 <= idx && idx < length arr
115 |         then Just $ unsafeIndex arr idx
116 |         else Nothing
117 |
118 | export
119 | update : Array a -> List (Bits32, a) -> Array a
120 | update arr [] = arr
121 | update (MkArray len arr) xs = unsafeInlinePerformIO $ do
122 |     arr' <- unsafeNewArray len
123 |     copyFromArray arr arr' 0 0 len
124 |     updateFromList arr' xs len
125 |     pure $ MkArray len arr'
126 |   where
127 |     updateFromList :
128 |         (arr : ArrayData a) ->
129 |         (upds : List (Bits32, a)) ->
130 |         (len : Bits32) ->
131 |         IO ()
132 |     updateFromList arr [] len = pure ()
133 |     updateFromList arr ((idx, val) :: xs) len = do
134 |         when (idx < len) $ fromPrim $ prim__arraySet arr idx val
135 |         updateFromList arr xs len
136 |
137 | export
138 | insert : (idx : Bits32) -> (val : a) -> Array a -> Array a
139 | insert idx val arr@(MkArray len orig) = if idx <= len
140 |     then unsafeInlinePerformIO $ do
141 |         new <- unsafeNewArray (len + 1)
142 |         copyFromArray orig new 0 0 idx
143 |         fromPrim $ prim__arraySet new idx val
144 |         copyFromArray orig new idx (idx + 1) len
145 |         pure $ MkArray (len + 1) new
146 |     else arr
147 |
148 | export
149 | delete : (idx : Bits32) -> Array a -> Array a
150 | delete idx arr@(MkArray len orig) =
151 |     if idx >= len
152 |         then arr
153 |     else if len <= 1
154 |         then empty
155 |     else unsafeInlinePerformIO $ do
156 |             new <- unsafeNewArray (len - 1)
157 |             -- orig: 0 .. idx, new: 0 .. idx
158 |             copyFromArray orig new 0 0 idx
159 |             -- orig: idx + 1 .. len, new: idx .. len - 1
160 |             copyFromArray orig new (idx + 1) idx len
161 |             pure $ MkArray (len - 1) new
162 |
163 | export
164 | findIndex : (a -> Bool) -> Array a -> Maybe Bits32
165 | findIndex f arr = go 0 (length arr)
166 |   where
167 |     go : Bits32 -> Bits32 -> Maybe Bits32
168 |     go i len =
169 |         if i >= len
170 |             then Nothing
171 |         else if f (unsafeIndex arr i)
172 |             then Just i
173 |         else go (assert_smaller i $ i + 1) len
174 |
175 | export
176 | findWithIndex : (a -> Bool) -> Array a -> Maybe (Bits32, a)
177 | findWithIndex f arr = go 0 (length arr)
178 |   where
179 |     go : Bits32 -> Bits32 -> Maybe (Bits32, a)
180 |     go i len =
181 |         if i >= len
182 |             then Nothing
183 |         else
184 |             let elem = unsafeIndex arr i
185 |             in if f elem then Just (i, elem) else go (assert_smaller i $ i + 1) len
186 |
187 | export
188 | append : (val : a) -> Array a -> Array a
189 | append val arr = insert (length arr) val arr
190 |
191 | export
192 | Functor Array where
193 |     map f (MkArray len arr) = MkArray len $
194 |         unsafeInlinePerformIO $ do
195 |             arr' <- unsafeNewArray len
196 |             copyFromArrayBy f arr arr' 0 0 len
197 |             pure arr'
198 |
199 | foldrImpl :
200 |     {0 elem : _} ->
201 |     (f : elem -> acc -> acc) ->
202 |     acc -> Bits32 -> Array elem ->
203 |     acc
204 | foldrImpl f z i arr = if i == 0
205 |     then f (unsafeIndex arr i) z
206 |     else
207 |         let elem = unsafeIndex arr i
208 |          in foldrImpl f (f elem z) (assert_smaller i $ i - 1) arr
209 |
210 | foldlImpl :
211 |     {0 elem : _} ->
212 |     (f : acc -> elem -> acc) ->
213 |     acc ->
214 |     (index : Bits32) ->
215 |     (length : Bits32) ->
216 |     Array elem ->
217 |     acc
218 | foldlImpl f z i len arr = if i >= len
219 |     then z
220 |     else
221 |         let elem = unsafeIndex arr i
222 |          in foldlImpl f z (assert_smaller i $ i + 1) len arr
223 |
224 | export
225 | Foldable Array where
226 |     foldr f z arr = foldrImpl f z (length arr - 1) arr
227 |
228 |     foldl f z arr = foldlImpl f z 0 (length arr) arr
229 |
230 |     null arr = length arr == 0
231 |     toList arr = toListOnto arr []
232 |     foldMap f arr = foldr (\elem, acc => f elem <+> acc) neutral arr
233 |
234 | export
235 | Show a => Show (Array a) where
236 |     show = show . toList
237 |
238 | parameters (pred : a -> b -> Bool)
239 |     allFrom : Bits32 -> Array a -> Array b -> Bool
240 |     allFrom idx arr1 arr2 = if length arr1 <= idx || length arr2 <= idx
241 |         then True
242 |         else
243 |             let x = index arr1 idx
244 |                 y = index arr2 idx
245 |              in fromMaybe False [| pred x y |]
246 |                 && allFrom (assert_smaller idx $ idx + 1) arr1 arr2
247 |
248 |     export
249 |     all : Array a -> Array b -> Bool
250 |     all = allFrom 0
251 |
252 | export
253 | Eq a => Eq (Array a) where
254 |     x == y = length x == length y && all (==) x y
255 |