0 | module Data.Compress.Huffman
  1 |
  2 | import Data.Fin
  3 | import Data.Vect
  4 | import Data.List
  5 | import Data.List1
  6 | import Data.Bits
  7 | import Data.Stream
  8 | import Data.Compress.Utils.Misc
  9 | import Data.Compress.Utils.Parser
 10 |
 11 | data Tree : Type -> Type where
 12 |   Node : Tree a -> Tree a -> Tree a
 13 |   Leaf : a -> Tree a
 14 |
 15 | mk : List Bool -> a -> Tree (Maybe a)
 16 | mk [] k = Leaf (Just k)
 17 | mk (False :: dirs) k = Node (mk dirs k) (Leaf Nothing)
 18 | mk (True :: dirs) k = Node (Leaf Nothing) (mk dirs k)
 19 |
 20 | insert : List Bool -> a -> Tree (Maybe a) -> Maybe (Tree (Maybe a))
 21 | insert [] k (Leaf Nothing) = pure $ Leaf (Just k)
 22 | insert [] k (Leaf (Just _)) = Nothing
 23 | insert [] k (Node _ _) = Nothing
 24 | insert (False :: xs) k (Node left right) = pure $ Node !(insert xs k left) right
 25 | insert (True :: xs) k (Node left right) = pure $ Node left !(insert xs k right)
 26 | insert (False :: xs) k (Leaf Nothing) = pure $ Node !(insert xs k (Leaf Nothing)) (Leaf Nothing)
 27 | insert (True :: xs) k (Leaf Nothing) = pure $ Node (Leaf Nothing) !(insert xs k (Leaf Nothing))
 28 | insert (_ :: _) k (Leaf (Just _)) = Nothing
 29 |
 30 | lookup_count : Eq a => List (a, Nat) -> a -> Nat
 31 | lookup_count list key =
 32 |   case lookup key list of
 33 |     Just x => x
 34 |     Nothing => Z
 35 |
 36 | max' : Ord a => List1 a -> a
 37 | max' (x ::: xs) = foldl max x xs
 38 |
 39 | smallest_codes : (Bits32 -> Nat) -> (n : Nat) -> Vect (S n) Bits32
 40 | smallest_codes bl_count max_code_length =
 41 |   take (S max_code_length) $ map fst $ iterate go (0,0)
 42 |   where
 43 |     go : (Bits32, Bits32) -> (Bits32, Bits32)
 44 |     go (prev_code, prev_index) = (shiftL (prev_code + cast (bl_count prev_index)) 1, prev_index + 1)
 45 |
 46 | make_tree_from_length : {n : Nat} -> List Bits32 -> Vect (S n) Bits32 -> List Bits32 -> Maybe (List Bits32)
 47 | make_tree_from_length acc next_code [] = Just $ reverse acc
 48 | make_tree_from_length acc next_code (x :: xs) = do
 49 |   x' <- natToFin (cast x) (S n)
 50 |   let v = index x' next_code
 51 |   make_tree_from_length (v :: acc) (updateAt x' (+1) next_code) xs
 52 |
 53 | decompose_bits32 : Bits32 -> Fin 32 -> List Bool
 54 | decompose_bits32 i FZ = [ testBit i FZ ]
 55 | decompose_bits32 i (FS n) = testBit i (FS n) :: decompose_bits32 i (weaken n)
 56 |
 57 | decompose_l_i : Bits32 -> Bits32 -> Maybe (List Bool)
 58 | decompose_l_i l i = decompose_bits32 i <$> natToFin (cast (l-1)) 32 
 59 |
 60 | public export
 61 | record HuffmanTree where
 62 |   constructor MkTree
 63 |   parse_literals : Parser Bitstream (SimpleError String) Bits32
 64 |   parse_distance : Parser Bitstream (SimpleError String) Bits32
 65 |
 66 | export
 67 | default_tree : HuffmanTree
 68 | default_tree = MkTree first (get_huff 5) where
 69 |   first : Parser Bitstream (SimpleError String) Bits32
 70 |   first = do
 71 |     x <- get_huff 7
 72 |     if x < 24 then pure (x + 256) else do
 73 |       x <- map ((shiftL x 1) .|.) get_bit
 74 |       if x < 192 then pure (x - 48)
 75 |         else if x < 200 then pure (x + 88)
 76 |         else map (\y => ((shiftL x 1) .|. y) - 256) get_bit
 77 |
 78 | tree_to_parser : Tree (Maybe a) -> Parser Bitstream (SimpleError String) a
 79 | tree_to_parser (Leaf Nothing) = fail $ msg "no value at leaf"
 80 | tree_to_parser (Leaf (Just v)) = pure v
 81 | tree_to_parser (Node false true) = do
 82 |   b <- token
 83 |   if b then tree_to_parser true else tree_to_parser false
 84 |
 85 | export
 86 | make_tree : Ord a => List (a, Bits32) -> Nat -> Maybe (Parser Bitstream (SimpleError String) a)
 87 | make_tree elem_code_length max_n_code = do
 88 |   let elem_code_length = filter (\(a,b) => b > 0) $ sortBy (\a,b => compare (fst a) (fst b)) elem_code_length
 89 |   let code_length = map snd elem_code_length
 90 |
 91 |   guard ((length code_length) <= max_n_code)
 92 |   let code_length_count = count code_length
 93 |   let bl_count = lookup_count code_length_count
 94 |   code_length_count1 <- fromList code_length_count
 95 |
 96 |   let max_code_length = cast $ max' $ map fst code_length_count1
 97 |   let next_code = smallest_codes bl_count max_code_length
 98 |
 99 |   tree <- make_tree_from_length [] next_code code_length
100 |   elem_code <- traverse (\(i,(v,l)) => (,v) <$> decompose_l_i l i) $ zip tree elem_code_length
101 |   (elem_code_head ::: elem_code_tail) <- fromList elem_code
102 |
103 |   b_tree <- foldlM (\t,(p,v) => insert p v t) (uncurry mk elem_code_head) elem_code_tail
104 |
105 |   pure $ tree_to_parser b_tree
106 |