0 | module Network.TLS.Parse.DER
  1 |
  2 | import Data.List1
  3 | import Data.Bits
  4 | import Data.Vect
  5 | import Utils.Bytes
  6 | import Utils.Parser
  7 | import Network.TLS.Parsing
  8 | import Utils.Misc
  9 | import public Utils.Time
 10 |
 11 | import Debug.Trace
 12 |
 13 | public export
 14 | data TagType : Type where
 15 |   Universal : TagType
 16 |   Application : TagType
 17 |   ContextSpecific : TagType
 18 |   Private : TagType
 19 |
 20 | public export
 21 | Show TagType where
 22 |   show Universal       = "Universal"
 23 |   show Application     = "Application"
 24 |   show ContextSpecific = "ContextSpecific"
 25 |   show Private         = "Private"
 26 |
 27 | public export
 28 | record Tag where
 29 |   constructor MkTag
 30 |   type : TagType
 31 |   tag_id : Nat
 32 |
 33 | public export
 34 | Show Tag where
 35 |   show tag = "(" <+> show tag.type <+> ", " <+> show tag.tag_id <+> ")"
 36 |
 37 | export
 38 | is_constructed : Bits8 -> Bool
 39 | is_constructed tag = testBit tag 5
 40 |
 41 | public export
 42 | record BitArray where
 43 |   constructor MkBitArray
 44 |   padding : Nat
 45 |   bytes : List Bits8
 46 |
 47 | public export
 48 | data ASN1 : TagType -> Nat -> Type where
 49 |   Boolean : Bool -> ASN1 Universal 0x01
 50 |   IntVal : Integer -> ASN1 Universal 0x02
 51 |   Bitstring : BitArray -> ASN1 Universal 0x03
 52 |   OctetString : List Bits8 -> ASN1 Universal 0x04
 53 |   Null : ASN1 Universal 0x05
 54 |   OID : List Nat -> ASN1 Universal 0x06
 55 |   PrintableString : String -> ASN1 Universal 0x13
 56 |   T61String : String -> ASN1 Universal 0x14
 57 |   IA5String : String -> ASN1 Universal 0x16
 58 |   UTF8String : String -> ASN1 Universal 0x0C
 59 |   Sequence : List (t ** n ** ASN1 t n-> ASN1 Universal 0x10 -- 0x30 & 31
 60 |   Set : List (t ** n ** ASN1 t n-> ASN1 Universal 0x11 -- 0x31 & 31
 61 |   UTCTime : DateTime -> ASN1 Universal 0x17
 62 |   GeneralizedTime : DateTime -> ASN1 Universal 0x18
 63 |   UnknownConstructed : (t : TagType) -> (n : Nat) -> List (t ** n ** ASN1 t n-> ASN1 t n
 64 |   UnknownPrimitive : (t : TagType) -> (n : Nat) -> List Bits8 -> ASN1 t n
 65 |
 66 | public export
 67 | Eq BitArray where
 68 |   (MkBitArray a b) == (MkBitArray c d) = (a == c) && (b == d)
 69 |
 70 | export
 71 | constraint_parse : (Cons (Posed Bits8) i, Monoid i) => (n : Nat) ->
 72 |                    (pser : Parser i (SimpleError String) a) ->
 73 |                    Parser i (SimpleError String) (List a)
 74 | constraint_parse Z pser = pure []
 75 | constraint_parse (S len) pser = go (S len) pser
 76 |   where
 77 |   go : Nat -> Parser i (SimpleError String) a -> Parser i (SimpleError String) (List a)
 78 |   go Z (Pure leftover x) = pure [x]
 79 |   go (S i) (Pure leftover x) = (x ::) <$> go (S i) pser
 80 |   go i (Fail msg) = fail msg
 81 |   go Z parser = fail $ msg $ "under fed, want more"
 82 |   go (S i) parser = do
 83 |     b <- token
 84 |     go i (feed (singleton b) parser)
 85 |
 86 | export
 87 | parse_length : (Monoid i, Cons (Posed Bits8) i) => Parser i (SimpleError String) Nat
 88 | parse_length = do
 89 |   b <- p_get
 90 |   let b' = b .&. 0x7F
 91 |   if b' == b
 92 |      then pure $ cast b
 93 |      else p_nat (cast b')
 94 |
 95 | extract_tag_type_bits : Bits8 -> TagType
 96 | extract_tag_type_bits x =
 97 |   case get_bits x of
 98 |     [ False, False ] => Universal
 99 |     [ False, True  ] => Application
100 |     [ True,  False ] => ContextSpecific
101 |     [ True,  True  ] => Private
102 |   where
103 |     get_bits : Bits8 -> Vect 2 Bool
104 |     get_bits x = [ (testBit x 7), (testBit x 6) ]
105 |
106 | export
107 | parse_tag_id : (Monoid i, Cons (Posed Bits8) i) => Parser i (SimpleError String) (Bool, Tag)
108 | parse_tag_id = do
109 |   b <- p_get
110 |   let construct = is_constructed b
111 |   let type = extract_tag_type_bits b
112 |   let id = b .&. 31
113 |   if id == 31
114 |     then (\x => (construct, MkTag type x)) <$> parse_length
115 |     else pure (construct, MkTag type $ cast id)
116 |
117 | export
118 | signed_be_to_integer : List1 Bits8 -> Integer
119 | signed_be_to_integer l@(x ::: xs) =
120 |   let is_neg = testBit x 7
121 |       v = be_to_integer l
122 |       m = (shiftL 1 (8 * (length l))) - 1 -- 2^n - 1
123 |   in if is_neg then v .|. (complement m) else v
124 |
125 | export
126 | decode_oid_nodes : Bits8 -> List Bits8 -> List Nat
127 | decode_oid_nodes first_node nodes =
128 |   let a = first_node `div` 40
129 |       b = first_node `mod` 40
130 |       (_, result) = foldl go (0, []) nodes
131 |       nodes = cast a :: cast b :: reverse result
132 |   in integerToNat <$> nodes
133 |   where
134 |     go : (Integer, List Integer) -> Bits8 -> (Integer, List Integer)
135 |     go (value, result) byte =
136 |       let value = (shiftL value 7) .|. cast (byte .&. 0x7F)
137 |       in if byte >= 0x80 then (value, result) else (0, value :: result)
138 |
139 | export
140 | parse_boolean : (Cons (Posed Bits8) i, Monoid i) => Nat -> Parser i (SimpleError String) Bool
141 | parse_boolean 1 = map (/= 0) p_get
142 | parse_boolean n = fail $ msg $ "boolean length should be 1, got: " <+> show n
143 |
144 | export
145 | parse_integer : (Cons (Posed Bits8) i, Monoid i) => Nat -> Parser i (SimpleError String) Integer
146 | parse_integer n = do
147 |   bits <- count n p_get
148 |   case fromList $ toList bits of
149 |     Just x => pure $ signed_be_to_integer x
150 |     Nothing => fail $ msg "integer length is 0"
151 |
152 | export
153 | parse_bitarray : (Cons (Posed Bits8) i, Monoid i) => Nat -> Parser i (SimpleError String) BitArray
154 | parse_bitarray Z = fail $ msg "bitarray length is 0"
155 | parse_bitarray (S n) = do
156 |   pad_len <- p_get
157 |   bits <- toList <$> count n p_get
158 |   pure $ MkBitArray (cast pad_len) bits
159 |
160 | export
161 | parse_null : (Cons (Posed Bits8) i, Monoid i) => Nat -> Parser i (SimpleError String) ()
162 | parse_null Z = pure ()
163 | parse_null n = fail $ msg $ "null length should be 0, got: " <+> show n
164 |
165 | export
166 | parse_oid : (Cons (Posed Bits8) i, Monoid i) => Nat -> Parser i (SimpleError String) (List Nat)
167 | parse_oid Z = fail $ msg "oid length is 0"
168 | parse_oid (S n) = do
169 |   first_node <- p_get
170 |   nodes <- toList <$> count n p_get
171 |   pure $ decode_oid_nodes first_node nodes
172 |
173 | export
174 | parse_time : (Cons (Posed Bits8) i, Monoid i) => Nat -> (String -> Either String DateTime) -> Parser i (SimpleError String) DateTime
175 | parse_time len f = do
176 |   str <- count len p_get
177 |   case f $ ascii_to_string $ toList str of
178 |     Right datetime => pure datetime
179 |     Left err => fail $ msg err
180 |
181 | export
182 | parse_utf8 : (Cons (Posed Bits8) i, Monoid i) => Nat -> Parser i (SimpleError String) String
183 | parse_utf8 len = do
184 |   str <- count len p_get
185 |   case utf8_decode $ toList str of
186 |     Just str => pure str
187 |     Nothing => fail $ msg "invalid utf8 string"
188 |
189 | public export
190 | ASN1Token : Type
191 | ASN1Token = (t ** n ** ASN1 t n)
192 |
193 | export
194 | extract_string : ASN1Token -> Maybe String
195 | extract_string (Universal ** 12 ** UTF8String b= Just b
196 | extract_string (Universal ** 19 ** PrintableString b= Just b
197 | extract_string (Universal ** 20 ** T61String b= Just b
198 | extract_string (Universal ** 22 ** IA5String b= Just b
199 | extract_string _ = Nothing
200 |
201 | export
202 | extract_epoch : ASN1Token -> Maybe Integer
203 | extract_epoch (Universal ** 23 ** UTCTime time= Just $ datetime_to_epoch time
204 | extract_epoch (Universal ** 24 ** GeneralizedTime time= Just $ datetime_to_epoch time
205 | extract_epoch _ = Nothing
206 |
207 | export          
208 | parse_asn1 : (Monoid i, Cons (Posed Bits8) i) => Parser i (SimpleError String) ASN1Token
209 | parse_asn1 = do
210 |   tag' <- parse_tag_id
211 |   len <- parse_length
212 |   case tag' of
213 |     (False, MkTag Universal 1) => (\b => (Universal ** 1 ** Boolean b)) <$> parse_boolean len
214 |     (False, MkTag Universal 2) => (\b => (Universal ** 2 ** IntVal b)) <$> parse_integer len
215 |     (False, MkTag Universal 3) => (\b => (Universal ** 3 ** Bitstring b)) <$> parse_bitarray len
216 |     (False, MkTag Universal 4) => (\b => (Universal ** 4 ** OctetString $ toList b)) <$> count len p_get
217 |     (False, MkTag Universal 5) => (\b => (Universal ** 5 ** Null)) <$> parse_null len
218 |     (False, MkTag Universal 6) => (\b => (Universal ** 6 ** OID b)) <$> parse_oid len
219 |     (False, MkTag Universal 12) => (\b => (Universal ** 12 ** UTF8String b)) <$> parse_utf8 len
220 |     (False, MkTag Universal 19) => (\b => (Universal ** 19 ** PrintableString $ ascii_to_string $ toList b)) <$> count len p_get
221 |     (False, MkTag Universal 20) => (\b => (Universal ** 20 ** T61String $ ascii_to_string $ toList b)) <$> count len p_get
222 |     (False, MkTag Universal 22) => (\b => (Universal ** 22 ** IA5String $ ascii_to_string $ toList b)) <$> count len p_get
223 |     (False, MkTag Universal 23) => (\b => (Universal ** 23 ** UTCTime b)) <$> parse_time len parse_utc_time
224 |     (False, MkTag Universal 24) => (\b => (Universal ** 24 ** GeneralizedTime b)) <$> parse_time len parse_generalized_time
225 |     (True,  MkTag Universal 16) => (\b => (Universal ** 16 ** Sequence b)) <$> constraint_parse len parse_asn1
226 |     (True,  MkTag Universal 17) => (\b => (Universal ** 17 ** Set b)) <$> constraint_parse len parse_asn1
227 |     (True,  MkTag t n) => (\b => (t ** n ** UnknownConstructed t n b)) <$> constraint_parse len parse_asn1
228 |     (False, MkTag t n) => (\b => (t ** n ** UnknownPrimitive t n $ toList b)) <$> count len p_get
229 |