0 | module Data.Compress.Inflate
2 | import Data.Compress.Utils.Parser
3 | import Data.Compress.Utils.Bytes
4 | import Data.Compress.Utils.Misc
5 | import Data.Compress.Huffman
6 | import Data.Compress.Interface
7 | import Data.Compress.Utils.FiniteBuffer
11 | import Data.SnocList
13 | import Control.Monad.Error.Either
16 | data InflateParserState'
19 | | InflateUncompressed
23 | data InflateParserState : InflateParserState' -> Type where
24 | AtHeader : InflateParserState InflateInit
25 | AtHuffman : Bool -> HuffmanTree -> InflateParserState InflateHuffman
26 | AtUncompressed : Bool -> Nat -> InflateParserState InflateUncompressed
27 | AtEnd : SnocList Bits8 -> InflateParserState InflateEnd
31 | = MkState Bitstream (FiniteBuffer Bits8) (DPair InflateParserState' InflateParserState)
33 | match_off : List Nat
34 | match_off = [ 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258 ]
36 | match_extra : List (Fin 32)
37 | match_extra = [ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0 ]
40 | dist_off = [ 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577 ]
42 | dist_extra : List (Fin 32)
43 | dist_extra = [ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 ]
45 | clen_alphabets : List (Fin 19)
46 | clen_alphabets = [ 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 ]
48 | length_lookup : Nat -> Maybe (Nat, Fin 32)
49 | length_lookup n = Just (!(index_may n match_off), !(index_may n match_extra))
51 | distance_lookup : Nat -> Maybe (Nat, Fin 32)
52 | distance_lookup n = Just (!(index_may n dist_off), !(index_may n dist_extra))
54 | parse_inflate_uncompressed_len : Parser Bitstream (SimpleError String) Nat
55 | parse_inflate_uncompressed_len = do
57 | nlen <- cast <$> le_nat 2
58 | let True = (cast {to=Bits16} len) == (complement nlen)
59 | | False => fail $
msg "invalid length header"
63 | parse_inflate_code_lengths : Bits32 -> Maybe Bits32 -> Parser Bitstream (SimpleError String) (Fin 19) -> ?
64 | parse_inflate_code_lengths n_lit_code supplied_prev_length parser = loop [] 0 where
65 | loop : List (Bits32, Bits32) -> Bits32 -> Parser Bitstream (SimpleError String) (List (Bits32, Bits32))
66 | loop acc current = if current >= n_lit_code then pure acc else parser >>= \case
68 | let Just prev_length = map snd (head' acc) <|> supplied_prev_length
69 | | Nothing => fail $
msg "asked for previous code length, but buffer is empty"
70 | n <- (3 +) <$> get_bits 2
71 | let literals = zip [current..(current + n - 1)] (take (cast n) $
repeat prev_length)
72 | loop (literals <+> acc) (current + n)
74 | n <- (3 +) <$> get_bits 3
75 | loop acc (current + n)
77 | n <- (11 +) <$> get_bits 7
78 | loop acc (current + n)
80 | let len = cast $
finToNat n
81 | loop ((current, len) :: acc) (current + 1)
83 | parse_inflate_dynamic : Parser Bitstream (SimpleError String) HuffmanTree
84 | parse_inflate_dynamic = do
85 | n_lit_code <- (257 +) <$> get_bits 5
86 | n_dist_code <- (1 +) <$> get_bits 5
87 | n_len_code <- (cast . (4 +)) <$> get_bits 4
88 | let True = n_len_code <= 19
89 | | False => fail $
msg "n_len_code exceeds 19"
90 | alphabets <- for (take (cast n_len_code) clen_alphabets) (\k => (k,) <$> get_bits 3)
91 | let Just code_length_parser = make_tree alphabets 19
92 | | Nothing => fail $
msg "failed to generate code length tree"
94 | literals <- parse_inflate_code_lengths n_lit_code Nothing code_length_parser
95 | let Just prev_length = map snd $
head' literals
96 | | Nothing => fail $
msg "literals tree parser did nothing"
97 | let Just literals_parser = make_tree literals (cast n_lit_code)
98 | | Nothing => fail $
msg "failed to generate code literals tree"
100 | distances <- parse_inflate_code_lengths n_dist_code (Just prev_length) code_length_parser
101 | let Just distances_parser = make_tree distances (cast n_dist_code)
102 | | Nothing => fail $
msg "failed to generate code distances tree"
104 | pure (MkTree literals_parser distances_parser)
106 | data HuffmanOutput = End | Literal Bits8 | Copy Nat Nat
108 | parse_inflate_huffman : HuffmanTree -> Parser Bitstream (SimpleError String) HuffmanOutput
109 | parse_inflate_huffman tree = do
110 | x <- tree.parse_literals
111 | if x < 256 then pure $
Literal (cast x)
112 | else if x == 256 then pure End
113 | else if x < 286 then do
114 | let Just (off, extra) = length_lookup (cast (x - 257))
115 | | Nothing => fail $
msg "length symbol out of bound"
116 | length <- map (\b => off + cast b) (get_bits extra)
117 | dcode <- tree.parse_distance
119 | let Just (off, extra) = distance_lookup (cast dcode)
120 | | Nothing => fail $
msg "distance symbol out of bound"
121 | distance <- map (\b => off + cast b) (get_bits extra)
122 | pure $
Copy (cast length) (cast distance)
123 | else fail $
msg "invalid code \{show x} encountered"
125 | parse_inflate_header : Parser Bitstream (SimpleError String) (DPair InflateParserState' InflateParserState)
126 | parse_inflate_header = do
128 | case !(count 2 token) of
130 | [False, False] => do
131 | len <- parse_inflate_uncompressed_len
132 | pure (
_ ** AtUncompressed final len)
134 | [True , False] => do
135 | pure (
_ ** AtHuffman final default_tree)
137 | [False, True ] => do
138 | tree <- parse_inflate_dynamic
139 | pure (
_ ** AtHuffman final tree)
141 | [True , True ] => do
142 | fail $
msg "invalid compression method"
144 | next_state : (is_final : Bool) -> DPair InflateParserState' InflateParserState
145 | next_state True = (
_ ** AtEnd [<])
146 | next_state False = (
_ ** AtHeader)
148 | feed_inflate' : SnocList Bits8 -> FiniteBuffer Bits8 -> DPair InflateParserState' InflateParserState ->
149 | Bitstream -> Either String (SnocList Bits8, InflateState)
151 | feed_inflate' acc ob (
_ ** AtEnd leftover)
content =
152 | Right (acc, MkState neutral (empty 0) (
_ ** AtEnd (leftover <>< toBits8 content))
)
154 | feed_inflate' acc ob (
InflateInit ** state)
content =
155 | case feed content parse_inflate_header of
156 | Pure leftover state =>
157 | feed_inflate' acc ob state leftover
161 | Right (acc, MkState content ob (
_ ** AtHeader)
)
163 | feed_inflate' acc ob (
_ ** (AtUncompressed final remaining))
content =
164 | let (output, leftover) = fromBits8 <$> splitAt remaining (toBits8 content)
165 | ob = ob +<>< output
166 | acc = acc <>< output
167 | in case minus remaining (length output) of
168 | S n => Right (acc, MkState leftover ob (
_ ** AtUncompressed final (S n))
)
169 | Z => feed_inflate' acc ob (next_state final) leftover
171 | feed_inflate' acc' ob (
_ ** (AtHuffman final tree))
content = go acc' ob tree content where
172 | go : SnocList Bits8 -> FiniteBuffer Bits8 -> HuffmanTree -> Bitstream -> Either String (SnocList Bits8, InflateState)
173 | go acc ob tree input =
174 | case feed input (parse_inflate_huffman tree) of
177 | Pure leftover End =>
178 | feed_inflate' acc ob (next_state final) leftover
179 | Pure leftover (Literal literal) =>
180 | go (acc :< literal) (ob +< literal) tree leftover
181 | Pure leftover (Copy len distance) =>
182 | case take_last distance ob of
183 | Just copied_chunk =>
184 | let appended = take len $
stream_concat $
repeat copied_chunk
185 | in go (acc <>< appended) (ob +<>< appended) tree leftover
186 | Nothing => Left "asked for distance \{show distance} but only \{show (length ob)} in buffer"
188 | Right (acc, MkState input ob (
_ ** (AtHuffman final tree))
)
191 | Decompressor InflateState where
192 | feed (MkState ib ob state) content = mapFst toList <$> feed_inflate' Lin ob state (ib <+> fromBits8 content)
193 | done (MkState _ _ (
_ ** AtEnd leftover)
) = Right (toList leftover)
194 | done _ = Left "inflate: underfed"
195 | init = MkState neutral (empty 32768) (
_ ** AtHeader)