0 | module Data.Compress.Inflate
  1 |
  2 | import Data.Compress.Utils.Parser
  3 | import Data.Compress.Utils.Bytes
  4 | import Data.Compress.Utils.Misc
  5 | import Data.Compress.Huffman
  6 | import Data.Compress.Interface
  7 | import Data.Compress.Utils.FiniteBuffer
  8 | import Data.Vect
  9 | import Data.Bits
 10 | import Data.List
 11 | import Data.SnocList
 12 | import Data.Stream
 13 | import Control.Monad.Error.Either
 14 |
 15 | public export
 16 | data InflateParserState'
 17 |   = InflateInit
 18 |   | InflateHuffman
 19 |   | InflateUncompressed
 20 |   | InflateEnd
 21 |
 22 | public export
 23 | data InflateParserState : InflateParserState' -> Type where
 24 |   AtHeader : InflateParserState InflateInit
 25 |   AtHuffman : Bool -> HuffmanTree -> InflateParserState InflateHuffman
 26 |   AtUncompressed : Bool -> Nat -> InflateParserState InflateUncompressed
 27 |   AtEnd : SnocList Bits8 -> InflateParserState InflateEnd
 28 |
 29 | public export
 30 | data InflateState
 31 |   = MkState Bitstream (FiniteBuffer Bits8) (DPair InflateParserState' InflateParserState)
 32 |
 33 | match_off : List Nat
 34 | match_off = [ 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258 ]
 35 |
 36 | match_extra : List (Fin 32)
 37 | match_extra = [ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0 ]
 38 |
 39 | dist_off : List Nat
 40 | dist_off = [ 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577 ]
 41 |
 42 | dist_extra : List (Fin 32)
 43 | dist_extra = [ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 ]
 44 |
 45 | clen_alphabets : List (Fin 19)
 46 | clen_alphabets = [ 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 ]
 47 |
 48 | length_lookup : Nat -> Maybe (Nat, Fin 32)
 49 | length_lookup n = Just (!(index_may n match_off), !(index_may n match_extra))
 50 |
 51 | distance_lookup : Nat -> Maybe (Nat, Fin 32)
 52 | distance_lookup n = Just (!(index_may n dist_off), !(index_may n dist_extra))
 53 |
 54 | parse_inflate_uncompressed_len : Parser Bitstream (SimpleError String) Nat
 55 | parse_inflate_uncompressed_len = do
 56 |   len <- le_nat 2
 57 |   nlen <- cast <$> le_nat 2
 58 |   let True = (cast {to=Bits16} len) == (complement nlen)
 59 |   | False => fail $ msg "invalid length header"
 60 |   pure len
 61 |
 62 | -- List (literal, code length)
 63 | parse_inflate_code_lengths : Bits32 -> Maybe Bits32 -> Parser Bitstream (SimpleError String) (Fin 19) -> ?
 64 | parse_inflate_code_lengths n_lit_code supplied_prev_length parser = loop [] 0 where
 65 |   loop : List (Bits32, Bits32) -> Bits32 -> Parser Bitstream (SimpleError String) (List (Bits32, Bits32))
 66 |   loop acc current = if current >= n_lit_code then pure acc else parser >>= \case
 67 |     16 => do
 68 |       let Just prev_length = map snd (head' acc) <|> supplied_prev_length
 69 |       | Nothing => fail $ msg "asked for previous code length, but buffer is empty"
 70 |       n <- (3 +) <$> get_bits 2
 71 |       let literals = zip [current..(current + n - 1)] (take (cast n) $ repeat prev_length)
 72 |       loop (literals <+> acc) (current + n)
 73 |     17 => do
 74 |       n <- (3 +) <$> get_bits 3
 75 |       loop acc (current + n)
 76 |     18 => do
 77 |       n <- (11 +) <$> get_bits 7
 78 |       loop acc (current + n)
 79 |     n => do
 80 |       let len = cast $ finToNat n
 81 |       loop ((current, len) :: acc) (current + 1)
 82 |
 83 | parse_inflate_dynamic : Parser Bitstream (SimpleError String) HuffmanTree
 84 | parse_inflate_dynamic = do
 85 |   n_lit_code <- (257 +) <$> get_bits 5
 86 |   n_dist_code <- (1 +) <$> get_bits 5
 87 |   n_len_code <- (cast . (4 +)) <$> get_bits 4
 88 |   let True = n_len_code <= 19
 89 |   | False => fail $ msg "n_len_code exceeds 19"
 90 |   alphabets <- for (take (cast n_len_code) clen_alphabets) (\k => (k,) <$> get_bits 3)
 91 |   let Just code_length_parser = make_tree alphabets 19
 92 |   | Nothing => fail $ msg "failed to generate code length tree"
 93 |
 94 |   literals <- parse_inflate_code_lengths n_lit_code Nothing code_length_parser
 95 |   let Just prev_length = map snd $ head' literals
 96 |   | Nothing => fail $ msg "literals tree parser did nothing"
 97 |   let Just literals_parser = make_tree literals (cast n_lit_code)
 98 |   | Nothing => fail $ msg "failed to generate code literals tree"
 99 |
100 |   distances <- parse_inflate_code_lengths n_dist_code (Just prev_length) code_length_parser
101 |   let Just distances_parser = make_tree distances (cast n_dist_code)
102 |   | Nothing => fail $ msg "failed to generate code distances tree"
103 |
104 |   pure (MkTree literals_parser distances_parser)
105 |
106 | data HuffmanOutput = End | Literal Bits8 | Copy Nat Nat
107 |
108 | parse_inflate_huffman : HuffmanTree -> Parser Bitstream (SimpleError String) HuffmanOutput
109 | parse_inflate_huffman tree = do
110 |   x <- tree.parse_literals
111 |   if x < 256 then pure $ Literal (cast x)
112 |     else if x == 256 then pure End
113 |     else if x < 286 then do
114 |       let Just (off, extra) = length_lookup (cast (x - 257))
115 |       | Nothing => fail $ msg "length symbol out of bound"
116 |       length <- map (\b => off + cast b) (get_bits extra)
117 |       dcode <- tree.parse_distance
118 |
119 |       let Just (off, extra) = distance_lookup (cast dcode)
120 |       | Nothing => fail $ msg "distance symbol out of bound"
121 |       distance <- map (\b => off + cast b) (get_bits extra)
122 |       pure $ Copy (cast length) (cast distance)
123 |     else fail $ msg "invalid code \{show x} encountered"
124 |
125 | parse_inflate_header : Parser Bitstream (SimpleError String) (DPair InflateParserState' InflateParserState)
126 | parse_inflate_header = do
127 |   final <- token
128 |   case !(count 2 token) of
129 |     -- Uncompressed
130 |     [False, False] => do
131 |       len <- parse_inflate_uncompressed_len
132 |       pure (_ ** AtUncompressed final len)
133 |     -- Fixed Huffman
134 |     [True , False] => do
135 |       pure (_ ** AtHuffman final default_tree)
136 |     -- Dynamic Huffman
137 |     [False, True ] => do
138 |       tree <- parse_inflate_dynamic
139 |       pure (_ ** AtHuffman final tree)
140 |     -- Invalid
141 |     [True , True ] => do
142 |       fail $ msg "invalid compression method"
143 |
144 | next_state : (is_final : Bool) -> DPair InflateParserState' InflateParserState
145 | next_state True  = (_ ** AtEnd [<])
146 | next_state False = (_ ** AtHeader)
147 |
148 | feed_inflate' : SnocList Bits8 -> FiniteBuffer Bits8 -> DPair InflateParserState' InflateParserState ->
149 |                 Bitstream -> Either String (SnocList Bits8, InflateState)
150 |
151 | feed_inflate' acc ob (_ ** AtEnd leftovercontent =
152 |   Right (acc, MkState neutral (empty 0) (_ ** AtEnd (leftover <>< toBits8 content))) -- terminates
153 |
154 | feed_inflate' acc ob (InflateInit ** statecontent =
155 |   case feed content parse_inflate_header of
156 |     Pure leftover state =>
157 |       feed_inflate' acc ob state leftover
158 |     Fail err =>
159 |       Left (show err)
160 |     _ => -- underfed, need more input
161 |       Right (acc, MkState content ob (_ ** AtHeader))
162 |
163 | feed_inflate' acc ob (_ ** (AtUncompressed final remaining)content =
164 |   let (output, leftover) = fromBits8 <$> splitAt remaining (toBits8 content)
165 |       ob = ob +<>< output
166 |       acc = acc <>< output
167 |   in case minus remaining (length output) of
168 |        S n => Right (acc, MkState leftover ob (_ ** AtUncompressed final (S n))) -- underfed
169 |        Z   => feed_inflate' acc ob (next_state final) leftover
170 |
171 | feed_inflate' acc' ob (_ ** (AtHuffman final tree)content = go acc' ob tree content where
172 |   go : SnocList Bits8 -> FiniteBuffer Bits8 -> HuffmanTree -> Bitstream -> Either String (SnocList Bits8, InflateState)
173 |   go acc ob tree input =
174 |     case feed input (parse_inflate_huffman tree) of
175 |       Fail err =>
176 |         Left (show err)
177 |       Pure leftover End =>
178 |         feed_inflate' acc ob (next_state final) leftover
179 |       Pure leftover (Literal literal) =>
180 |         go (acc :< literal) (ob +< literal) tree leftover
181 |       Pure leftover (Copy len distance) =>
182 |         case take_last distance ob of
183 |           Just copied_chunk =>
184 |             let appended = take len $ stream_concat $ repeat copied_chunk
185 |             in go (acc <>< appended) (ob +<>< appended) tree leftover
186 |           Nothing => Left "asked for distance \{show distance} but only \{show (length ob)} in buffer"
187 |       _ => -- underfed, need more input
188 |         Right (acc, MkState input ob (_ ** (AtHuffman final tree)))
189 |
190 | export
191 | Decompressor InflateState where
192 |   feed (MkState ib ob state) content = mapFst toList <$> feed_inflate' Lin ob state (ib <+> fromBits8 content)
193 |   done (MkState _ _ (_ ** AtEnd leftover)) = Right (toList leftover)
194 |   done _ = Left "inflate: underfed"
195 |   init = MkState neutral (empty 32768) (_ ** AtHeader)
196 |