0 | ||| Finite Rose Trees
  1 | module Data.Tree
  2 |
  3 | import Data.List
  4 | import Data.String
  5 |
  6 | import Derive.Prelude
  7 |
  8 | %language ElabReflection
  9 | %default total
 10 |
 11 | --------------------------------------------------------------------------------
 12 | --          Finite Trees
 13 | --------------------------------------------------------------------------------
 14 |
 15 | ||| A finite rose tree
 16 | public export
 17 | record Tree (a : Type) where
 18 |   constructor T
 19 |   value : a
 20 |   forest : List (Tree a)
 21 |
 22 | ||| A finite forest of trees
 23 | public export
 24 | Forest : Type -> Type
 25 | Forest = List . Tree
 26 |
 27 | %runElab derive "Tree" [Show,Eq]
 28 |
 29 | --------------------------------------------------------------------------------
 30 | --          Creating Trees
 31 | --------------------------------------------------------------------------------
 32 |
 33 | ||| Wrap a single value in a tree
 34 | public export
 35 | singleton : a -> Tree a
 36 | singleton a = T a []
 37 |
 38 | ||| Create a regular tree of the given depth, where each branch has
 39 | ||| `width` children.
 40 | public export
 41 | replicate : (width : Nat) -> (depth : Nat) -> a -> Tree a
 42 | replicate _         0 x = T x []
 43 | replicate width (S k) x = T x $ replicate width (replicate width k x)
 44 |
 45 | ||| Unfold a tree up to the given depth.
 46 | public export
 47 | unfold : (depth : Nat) -> (f : s -> (a,List s)) -> s -> Tree a
 48 | unfold 0     f s = T (fst $ f s) []
 49 | unfold (S k) f s =
 50 |   let (a,ss) := f s
 51 |    in T a (map (unfold k f) ss)
 52 |
 53 | --------------------------------------------------------------------------------
 54 | --          Flattening Trees
 55 | --------------------------------------------------------------------------------
 56 |
 57 | zipWithKeep : (a -> a -> a) -> List a -> List a -> List a
 58 | zipWithKeep f [] ys = ys
 59 | zipWithKeep f xs [] = xs
 60 | zipWithKeep f (x :: xs) (y :: ys) = f x y :: zipWithKeep f xs ys
 61 |
 62 | ||| Flatten a tree into a list.
 63 | public export
 64 | flatten : Tree a -> List a
 65 | flatten (T v vs) = v :: flattenF vs
 66 |
 67 |   where
 68 |     flattenF : Forest a -> List a
 69 |     flattenF []        = Nil
 70 |     flattenF (x :: xs) = flatten x ++ flattenF xs
 71 |
 72 | ||| Convert a tree to a list of lists, so that all values at the same
 73 | ||| depth appear in the same list.
 74 | public export
 75 | layers : Tree a -> List (List a)
 76 | layers (T v vs) = [v] :: layersF vs
 77 |
 78 |   where
 79 |     layersF : Forest a -> List (List a)
 80 |     layersF []        = Nil
 81 |     layersF (x :: xs) = zipWithKeep (++) (layers x) (layersF xs)
 82 |
 83 | --------------------------------------------------------------------------------
 84 | --          Accessing Elements
 85 | --------------------------------------------------------------------------------
 86 |
 87 | ||| Try to look up a value in the tree by following the given path.
 88 | public export
 89 | index : List Nat -> Tree a -> Maybe a
 90 | index []        x = Just x.value
 91 | index (y :: ys) x = ix y x.forest >>= index ys
 92 |
 93 |   where
 94 |     ix : Nat -> List b -> Maybe b
 95 |     ix _ []            = Nothing
 96 |     ix 0     (z :: _)  = Just z
 97 |     ix (S k) (_ :: zs) = ix k zs
 98 |
 99 | --------------------------------------------------------------------------------
100 | --          Functor and Monad Implementations
101 | --------------------------------------------------------------------------------
102 |
103 | -- All implementations are boilerplaty to satisfy the totality checker.
104 | foldlTree : (a -> e -> a) -> a -> Tree e -> a
105 | foldlTree f acc (T v vs) = foldlF (f acc v) vs
106 |
107 |   where
108 |     foldlF : a -> Forest e -> a
109 |     foldlF y []        = y
110 |     foldlF y (x :: xs) = foldlF (foldlTree f y x) xs
111 |
112 | foldrTree : (e -> a -> a) -> a -> Tree e -> a
113 | foldrTree f acc (T v vs) = f v (foldrF acc vs)
114 |
115 |   where
116 |     foldrF : a -> Forest e -> a
117 |     foldrF y []        = y
118 |     foldrF y (x :: xs) = foldrTree f (foldrF y xs) x
119 |
120 | traverseTree : Applicative f => (a -> f b) -> Tree a -> f (Tree b)
121 | traverseTree fun (T v vs) = [| T (fun v) (traverseF vs) |]
122 |
123 |   where
124 |     traverseF : Forest a -> f (Forest b)
125 |     traverseF []        = pure []
126 |     traverseF (x :: xs) = [| traverseTree fun x :: traverseF xs |]
127 |
128 | mapTree : (a -> b) -> Tree a -> Tree b
129 | mapTree f (T v vs) = T (f v) (mapF vs)
130 |
131 |   where
132 |     mapF : Forest a -> Forest b
133 |     mapF []       = []
134 |     mapF (h :: t) = mapTree f h :: mapF t
135 |
136 | bindTree : Tree a -> (a -> Tree b) -> Tree b
137 | bindTree (T va tas) f =
138 |   let T vb tbs := f va
139 |    in T vb (tbs ++ bindF tas)
140 |
141 |   where
142 |     bindF : Forest a -> Forest b
143 |     bindF []        = []
144 |     bindF (x :: xs) = bindTree x f :: bindF xs
145 |
146 | apTree : Tree (a -> b) -> Tree a -> Tree b
147 | apTree tf ta = bindTree tf $ \f => mapTree (apply f) ta
148 |
149 | joinTree : Tree (Tree a) -> Tree a
150 | joinTree (T (T va tas) ftas) =
151 |   T va $ tas ++ joinF ftas
152 |
153 |   where
154 |     joinF : Forest (Tree a) -> Forest a
155 |     joinF []        = []
156 |     joinF (x :: xs) = joinTree x :: joinF xs
157 |
158 | --------------------------------------------------------------------------------
159 | --          Visualizing Trees
160 | --------------------------------------------------------------------------------
161 |
162 | ||| Pretty print a tree.
163 | |||
164 | ||| Unlike `prettyTree`, this has support for multi-line strings and
165 | ||| will result in a vertically elongated representation.
166 | export
167 | drawTree : Tree String -> String
168 | drawTree  = unlines . draw
169 |
170 |   where
171 |     drawForest : Forest String -> String
172 |     drawForest  = unlines . map drawTree
173 |
174 |     draw : Tree String -> List String
175 |     draw (T x ts0) = lines x ++ subTrees ts0
176 |
177 |       where
178 |         shift : String -> String -> List String -> List String
179 |         shift first other tails =
180 |           zipWith (++) (first :: replicate (length tails) other) tails
181 |
182 |         subTrees : Forest String -> List String
183 |         subTrees []      = []
184 |         subTrees [t]     = "│" :: shift "└╼" "  " (draw t)
185 |         subTrees (t::ts) = "│" :: shift "├╼" "│ " (draw t) ++ subTrees ts
186 |
187 | parameters (rev : Bool)
188 |
189 |   lst : String
190 |   lst = if rev then "┌─" else "└─"
191 |
192 |   children : String -> List (Tree String) -> SnocList String -> SnocList String
193 |   children _   []       ss = ss
194 |   children pre [T l cs] ss =
195 |     let s := pre ++ lst ++ "\{l}"
196 |      in children (pre ++ "  ") cs (ss :< s)
197 |
198 |   children pre (T l cs :: xs) ss =
199 |     let s   := pre ++ "├─\{l}"
200 |         ss2 := children (pre ++ "│ ") cs (ss :< s)
201 |      in children pre xs ss2
202 |
203 |   ||| Pretty print a tree.
204 |   |||
205 |   ||| Unlike `drawTree`, this does not work with multi-line node labels,
206 |   ||| but will otherwise result in a vertically more compact representation.
207 |   ||| In addition, it is possible to print the tree "upside-down" by
208 |   ||| setting `rev` to `True`.
209 |   export
210 |   prettyTree : Tree String -> String
211 |   prettyTree (T l cs) =
212 |     let ss := children "" cs [<"\{l}"] <>> []
213 |      in if rev then unlines (reverse ss) else unlines ss
214 |
215 | --------------------------------------------------------------------------------
216 | --          Interfaces
217 | --------------------------------------------------------------------------------
218 |
219 | public export %inline
220 | Foldable Tree where
221 |   foldl  = foldlTree
222 |   foldr  = foldrTree
223 |   null _ = False
224 |
225 | public export %inline
226 | Functor Tree where
227 |   map = mapTree
228 |
229 | public export %inline
230 | Applicative Tree where
231 |   pure a = T a Nil
232 |   (<*>)  = apTree
233 |
234 | public export %inline
235 | Monad Tree where
236 |   (>>=) = bindTree
237 |   join  = joinTree
238 |
239 | public export %inline
240 | Traversable Tree where
241 |   traverse = traverseTree
242 |