0 | module Network.TLS.Parsing
  1 |
  2 | import Data.List1
  3 | import Data.Vect
  4 | import Utils.Bytes
  5 | import public Utils.Parser
  6 |
  7 | namespace Parserializer
  8 |   ||| bidirectional serializer
  9 |   ||| `decode` is assumed to be the inverse of `encode` and vice versa
 10 |   public export
 11 |   record Parserializer (c : Type) (i : Type) (e : Type) (a : Type) where
 12 |     constructor MkParserializer
 13 |     encode : a -> List c
 14 |     decode : Parser i e a
 15 |
 16 |   export infixr 5 <*>>
 17 |
 18 |   export
 19 |   apair : (Semigroup e, Monoid i) => Parserializer c i e a -> Parserializer c i e b -> Parserializer c i e (a, b)
 20 |   apair ma mb = MkParserializer (\(a, b) => ma.encode a <+> mb.encode b) $ (,) <$> ma.decode <*> mb.decode
 21 |
 22 |   ||| infixr for `apair`
 23 |   export
 24 |   (<*>>) : (Semigroup e, Monoid i) => Parserializer c i e a -> Parserializer c i e b -> Parserializer c i e (a, b)
 25 |   (<*>>) = apair
 26 |
 27 |   export
 28 |   map : (to : a -> b) -> (from : b -> a) -> Parserializer c i e a -> Parserializer c i e b
 29 |   map to from pser = MkParserializer (pser.encode . from) (map to pser.decode)
 30 |
 31 |   export
 32 |   mapEither : (Semigroup e, Monoid i) => (to : a -> Either e b) -> (from : b -> a) -> Parserializer c i e a -> Parserializer c i e b
 33 |   mapEither to from pser = MkParserializer (pser.encode . from) $ do
 34 |     a <- pser.decode
 35 |     case to a of
 36 |       Right b => pure b
 37 |       Left e => fail e
 38 |
 39 |   export
 40 |   (*>) : (Semigroup e, Monoid i) => Parserializer c i e () -> Parserializer c i e a -> Parserializer c i e a
 41 |   ma *> mb = map snd ((),) (ma <*>> mb)
 42 |
 43 |   export
 44 |   (<*) : (Semigroup e, Monoid i) => Parserializer c i e a -> Parserializer c i e () -> Parserializer c i e a
 45 |   ma <* mb = map fst (,()) (ma <*>> mb)
 46 |
 47 |   export
 48 |   aeither : (Semigroup e, Monoid i) => Parserializer c i e a -> Parserializer c i e b -> Parserializer c i e (Either a b)
 49 |   aeither ma mb = MkParserializer (either ma.encode mb.encode) $ (map Left ma.decode) <|> (map Right mb.decode)
 50 |
 51 |   export
 52 |   (<|>) : (Semigroup e, Monoid i) => Parserializer c i e a -> Parserializer c i e b -> Parserializer c i e (Either a b)
 53 |   (<|>) = aeither
 54 |
 55 | ||| essentially (Nat, `a`), where Nat denotes the position, usually starts with 0
 56 | public export
 57 | record Posed (a : Type) where
 58 |   constructor MkPosed
 59 |   pos : Nat
 60 |   get : a
 61 |
 62 | ||| `Parser.token` but for `Posed`
 63 | export
 64 | p_get : Cons (Posed c) i => Parser i e c
 65 | p_get = map get token
 66 |
 67 | -- serializer utils
 68 |
 69 | ||| prepend the length of `body` into `n` bytes in big endian
 70 | export
 71 | prepend_length : (n : Nat) -> (body : List Bits8) -> List Bits8
 72 | prepend_length n body = (toList $ integer_to_be n $ cast $ length body) <+> body
 73 |
 74 | -- parser utils
 75 |
 76 | ||| parse the next `n` bytes as a natural number in big endian style
 77 | export
 78 | p_nat : (Semigroup e, Monoid i, Cons (Posed Bits8) i) => (n : Nat) -> Parser i e Nat
 79 | p_nat n = cast {to = Nat} . be_to_integer <$> count n p_get
 80 |
 81 | ||| make sure that `p` MUST consume at least `n` tokens, fails otherwise
 82 | public export
 83 | p_exact : (Cons c i, Monoid i) => (n : Nat) -> (p : Parser i (SimpleError String) a) -> Parser i (SimpleError String) a
 84 | p_exact Z (Pure leftover x) = pure x
 85 | p_exact (S i) (Pure leftover x) = fail $ msg $ "over fed, " <+> show (S i) <+> " bytes more to go"
 86 | p_exact i (Fail msg) = fail msg
 87 | p_exact Z parser = fail $ msg $ "under fed, wants more"
 88 | p_exact (S i) parser = do
 89 |   b <- token
 90 |   p_exact i (feed (singleton b) parser)
 91 |
 92 | --- parserializer utils
 93 |
 94 | ||| put parser error messages under another message
 95 | ||| used for creating a treeish error message
 96 | export
 97 | under : e -> Parserializer c i (SimpleError e) a -> Parserializer c i (SimpleError e) a
 98 | under msg pser = MkParserializer pser.encode (under msg pser.decode)
 99 |
100 | ||| parserialize a single posed token
101 | export
102 | token : (Semigroup e, Cons (Posed c) i, Monoid i) => Parserializer c i e c
103 | token = MkParserializer pure p_get
104 |
105 | ||| parserialize `n` posed tokens
106 | export
107 | ntokens : (Semigroup e, Cons (Posed c) i, Monoid i) => (n : Nat) -> Parserializer c i e (Vect n c)
108 | ntokens n = MkParserializer (toList) (count n p_get)
109 |
110 | ||| parserialize the next `n` bytes in big endian style as a length describing the number of bytes of the following data to be fed to `pser`
111 | export
112 | lengthed : (Cons (Posed Bits8) i, Monoid i) => (n : Nat) -> (pser : Parserializer Bits8 i (SimpleError String) a) -> Parserializer Bits8 i (SimpleError String) a
113 | lengthed n pser = MkParserializer (prepend_length n . pser.encode) $ do
114 |   len <- p_nat n
115 |   p_exact len pser.decode
116 |
117 | ||| parserialize the next `n` bytes in big endian style as a length describing the number of bytes of the following data to be fed to `pser`
118 | ||| when `pser` completes, the result becomes an entry in the resulting list
119 | ||| when there are exactly zero bytes left, the list of results is returned
120 | ||| if under feeding `pser` for the last entry, the parser fails
121 | export
122 | lengthed_list : (Cons (Posed Bits8) i, Monoid i) => (n : Nat) -> (pser : Parserializer Bits8 i (SimpleError String) a) -> Parserializer Bits8 i (SimpleError String) (List a)
123 | lengthed_list youmu pser = MkParserializer (prepend_length youmu . concat . map pser.encode) $ do
124 |   S len <- p_nat youmu
125 |   | Z => pure []
126 |   go (S len) pser.decode
127 |   where
128 |   go : Nat -> Parser i (SimpleError String) a -> Parser i (SimpleError String) (List a)
129 |   go Z (Pure leftover x) = pure [x]
130 |   go (S i) (Pure leftover x) = (x ::) <$> go (S i) pser.decode
131 |   go i (Fail msg) = fail msg
132 |   go Z parser = fail $ msg $ "under fed, want more"
133 |   go (S i) parser = do
134 |     b <- token
135 |     go i (feed (singleton b) parser)
136 |
137 | ||| `lengthed_list` but `List1`
138 | export
139 | lengthed_list1 : (Cons (Posed Bits8) i, Monoid i) => (youmu : Nat) -> Parserializer Bits8 i (SimpleError String) a -> Parserializer Bits8 i (SimpleError String) (List1 a)
140 | lengthed_list1 youmu pser =
141 |   let
142 |     pser' = lengthed_list youmu pser
143 |   in
144 |     MkParserializer (pser'.encode . toList) $ do
145 |       (x :: xs) <- pser'.decode
146 |       | [] => fail $ msg $ "empty list"
147 |       pure (x ::: xs)
148 |
149 | ||| basically the parserializer version of `p_nat`
150 | export
151 | nat : Semigroup e => (Cons (Posed Bits8) i, Monoid i) => (n : Nat) -> Parserializer Bits8 i e Nat
152 | nat n = MkParserializer (toList . integer_to_be n . cast) (p_nat n)
153 |
154 | ||| parserialize a list of bytes with nice error messages specialized for displaying byte sequences
155 | export
156 | is : (Cons (Posed Bits8) i, Monoid i) => {k : Nat} -> Vect (S k) Bits8 -> Parserializer Bits8 i (SimpleError String) ()
157 | is cs = MkParserializer (const $ toList cs) $ do
158 |   bs <- count (S k) token
159 |   let cs' = map get bs
160 |   case cs == cs' of
161 |     True => pure ()
162 |     False =>
163 |       let
164 |         (begin, end) = mapHom pos (head bs, last bs)
165 |       in
166 |         fail $ msg $ "at position " <+> show begin <+> "-" <+> show end <+> ", expected " <+> xxd (toList cs) <+> " but got " <+> xxd (toList cs')
167 |