0 | module Data.Graph.Indexed.Util.DisjointSet
 1 |
 2 | import Data.Array
 3 | import Data.SortedMap
 4 | import Data.Linear.Token
 5 | import Syntax.T1
 6 |
 7 | %default total
 8 |
 9 | data DSNode : Nat -> Type where
10 |   R : (size   : Nat) -> DSNode k -- root of set
11 |   N : (parent : Fin k) -> DSNode k -- child node of set
12 |
13 | ||| A simple [disjoint set](https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
14 | ||| implementation.
15 | |||
16 | ||| This allows us to efficiently partition the values from 0 to `k-1`
17 | ||| into disjoint sets. Operations like `root`, `size`, and `union`
18 | ||| are de-facto amortized O(1).
19 | export
20 | record DisjointSet (s : Type) (k : Nat) where
21 |   constructor DSF
22 |   arr : MArray s k (DSNode k)
23 |
24 | ||| Allocates a fresh `DisjointSet` data type with all
25 | ||| values from `0` to `k-1` in their own partition.
26 | export
27 | ds : (k : Nat) -> F1 s (DisjointSet s k)
28 | ds k t = let m # t := marray1 k (R 1) t in DSF m # t
29 |
30 | dspair : DisjointSet s k -> Fin k -> F1 s (Fin k,Nat)
31 | dspair ds x t =
32 |   case get ds.arr x t of
33 |     R s # t => (x,s) # t
34 |     N p # t =>
35 |      let r # t := dspair ds (assert_smaller x p) t
36 |          _ # t := set ds.arr x (N $ fst r) t
37 |       in r # t
38 |
39 | ||| Returns the current root node of `x`, used for identifying
40 | ||| the partition, to which `x` currently belongs.
41 | export %inline
42 | root : DisjointSet s k -> (x : Fin k) -> F1 s (Fin k)
43 | root ds x t = let p # t := dspair ds x t in fst p # t
44 |
45 | ||| Returns the size of the partition `x` currently belongs to.
46 | export %inline
47 | size : DisjointSet s k -> (x : Fin k) -> F1 s Nat
48 | size ds x t = let p # t := dspair ds x t in snd p # t
49 |
50 | ||| Returns `True` if `x` and `y` belong to the same partition,
51 | ||| `False` otherwise.
52 | export
53 | samePartition : DisjointSet s k -> (x,y : Fin k) -> F1 s Bool
54 | samePartition ds x y t =
55 |  let rx # t := root ds x t
56 |      ry # t := root ds y t
57 |   in (rx == ry) # t
58 |
59 | ||| Computes the set union of the partitions of `x` and `y`.
60 | export
61 | union : DisjointSet s k -> (x,y : Fin k) -> F1' s
62 | union ds x y = T1.do
63 |   (rx,sx) <- dspair ds x
64 |   (ry,sy) <- dspair ds y
65 |   case rx == ry of
66 |     True  => pure ()
67 |     False => case sx > sy of
68 |       True  => set ds.arr rx (R $ sx + sy) >> set ds.arr ry (N rx)
69 |       False => set ds.arr ry (R $ sx + sy) >> set ds.arr rx (N ry)
70 |
71 | export
72 | sets : {k : _} -> DisjointSet s k -> F1 s (List $ List (Fin k))
73 | sets ds = go empty k
74 |   where
75 |     go :
76 |          SortedMap (Fin k) (List $ Fin k)
77 |       -> (n : Nat)
78 |       -> {auto 0 lte : LTE n k}
79 |       -> F1 s (List $ List (Fin k))
80 |     go m 0     t = values m # t
81 |     go m (S j) t =
82 |      let x     := natToFinLT j
83 |          r # t := root ds x t
84 |       in go (update (Just . maybe [x] (x::)) r m) j t
85 |