0 | module Crypto.RSA
  1 |
  2 | import Data.List
  3 | import Data.Vect
  4 | import Data.Bits
  5 | import Utils.Misc
  6 | import Utils.Bytes
  7 | import Crypto.Random
  8 | import Data.Nat
  9 | import Crypto.Hash
 10 | import Data.List1
 11 | import Data.Fin
 12 | import Data.Stream
 13 | import Data.Fin.Extra
 14 | import Crypto.Hash.OID
 15 |
 16 | export
 17 | data RSAPublicKey : Type where
 18 |   MkRSAPublicKey : (n : Integer) -> (e : Integer) -> RSAPublicKey
 19 |
 20 | -- TODO: check if there are more constraints needed between n and e
 21 | -- also maybe use a proof instead of masking the constructor in the future
 22 | export
 23 | mk_rsa_publickey : Integer -> Integer -> Maybe RSAPublicKey
 24 | mk_rsa_publickey n e = guard (1 == gcd' n e) $> MkRSAPublicKey n e
 25 |
 26 | export
 27 | rsa_encrypt : RSAPublicKey -> Integer -> Integer
 28 | rsa_encrypt (MkRSAPublicKey n e) m = pow_mod m e n
 29 |
 30 | -- RFC 8017
 31 |
 32 | export
 33 | os2ip : Foldable t => t Bits8 -> Integer
 34 | os2ip = be_to_integer
 35 |
 36 | export
 37 | i2osp : Nat -> Integer -> Maybe (List Bits8)
 38 | i2osp b_len x =
 39 |   let mask = (shiftL 1 (8 * b_len)) - 1
 40 |       x' = x .&. mask
 41 |   in (guard $ x == x') $> (toList $ integer_to_be b_len x)
 42 |
 43 | export
 44 | rsavp1 : RSAPublicKey -> Integer -> Maybe Integer
 45 | rsavp1 pk@(MkRSAPublicKey n e) s = guard (s > 0 && s < (n - 1)) $> rsa_encrypt pk s
 46 |
 47 | record PSSEncodedMessage n where
 48 |   hash_digest : Vect n Bits8
 49 |   db : List Bits8
 50 |
 51 | public export
 52 | MaskGenerationFunction : Type
 53 | MaskGenerationFunction = (n : Nat) -> List Bits8 -> Vect n Bits8
 54 |
 55 | export
 56 | mgf1 : {algo : _} -> (h : Hash algo) => MaskGenerationFunction
 57 | mgf1 n seed = take n $ stream_concat $ map (\x => hash algo (seed <+> (toList $ integer_to_be 4 $ cast x))) nats
 58 |
 59 | export
 60 | modulus_bits : RSAPublicKey -> Nat
 61 | modulus_bits (MkRSAPublicKey n _) = if n > 0 then go Z n else 0
 62 |   where
 63 |     go : Nat -> Integer -> Nat
 64 |     go n x = if x == 0 then n else go (S n) (shiftR x 1)
 65 |
 66 | export
 67 | emsa_pss_verify : {algo : _} -> (h : Hash algo) => MaskGenerationFunction -> Nat -> List Bits8 -> List1 Bits8 -> Nat -> Maybe ()
 68 | emsa_pss_verify mgf sLen message em emBits = do
 69 |   let mHash = hash algo message
 70 |   let emLen = divCeilNZ emBits 8 SIsNonZero
 71 |   let (em, 0xbc) = uncons1 em
 72 |   | _ => Nothing -- Invalid padding
 73 |   (maskedDB, digest) <- splitLastAt1 (digest_nbyte {algo}) em
 74 |   -- check padding
 75 |   guard $ check_padding (modFinNZ emBits 8 SIsNonZero) (head maskedDB)
 76 |   let db = zipWith xor (toList maskedDB) (toList $ mgf (length maskedDB) (toList digest))
 77 |   (padding, salt) <- splitLastAt1 sLen db
 78 |   -- unset padding bits
 79 |   bits_to_be_cleared <- natToFin (minus (8 * emLen) emBits) _
 80 |   let mask = shiftR (the Bits8 oneBits) bits_to_be_cleared
 81 |   let (pxs, px) = uncons1 $ (mask .&. head padding) ::: (tail padding)
 82 |   -- check padding
 83 |   guard (0 == (foldr (.|.) 0 pxs))
 84 |   guard (1 == px)
 85 |   -- check salt length
 86 |   guard $ digest `s_eq` hash algo (replicate 8 0 <+> toList mHash <+> toList salt)
 87 |   where
 88 |     check_padding : Fin 8 -> Bits8 -> Bool
 89 |     check_padding FZ _ = True
 90 |     check_padding n b = 0 == shiftR b n
 91 |
 92 | export
 93 | rsassa_pss_verify' : {algo : _} -> (h : Hash algo) => MaskGenerationFunction -> Nat -> RSAPublicKey -> List Bits8 -> List Bits8 -> Bool
 94 | rsassa_pss_verify' mask_gen salt_len pk message signature = isJust $ do
 95 |   let modBits = modulus_bits pk
 96 |   let s = os2ip signature
 97 |   m <- rsavp1 pk s
 98 |   let emLen = divCeilNZ (pred modBits) 8 SIsNonZero
 99 |   em <- i2osp emLen m >>= fromList
100 |   emsa_pss_verify {algo} mask_gen salt_len message em (pred modBits)
101 |
102 | export
103 | rsassa_pss_verify : {algo : _} -> Hash algo => RSAPublicKey -> List Bits8 -> List Bits8 -> Bool
104 | rsassa_pss_verify = rsassa_pss_verify' {algo} (mgf1 {algo}) (digest_nbyte {algo})
105 |
106 | export
107 | emsa_pkcs1_v15_encode : {algo : _} -> RegisteredHash algo => List Bits8 -> Nat -> Maybe (List Bits8)
108 | emsa_pkcs1_v15_encode message emLen = do
109 |   let h = hashWithHeader {algo} message
110 |   let paddingLen = (emLen `minus` der_digest_n_byte {algo}) `minus` 3
111 |   guard (paddingLen >= 8)
112 |   let padding = replicate paddingLen 0xff
113 |   pure $ [ 0x00, 0x01 ] <+> padding <+> [ 0x00 ] <+> toList h
114 |
115 | export
116 | rsassa_pkcs1_v15_verify : {algo : _} -> RegisteredHash algo => RSAPublicKey -> List Bits8 -> List Bits8 -> Bool
117 | rsassa_pkcs1_v15_verify pk message signature = isJust $ do
118 |   let k = divCeilNZ (modulus_bits pk) 8 SIsNonZero
119 |   guard (k == length signature)
120 |
121 |   let s = os2ip signature
122 |   m <- rsavp1 pk s
123 |   em <- i2osp k m
124 |
125 |   em' <- emsa_pkcs1_v15_encode {algo} message k
126 |   guard (em `s_eq'` em')
127 |