0 | module Data.Cryptography.SCRAM
  1 |
  2 | import Data.Cryptography.Hash
  3 | import Data.Cryptography.HMAC
  4 | import UTF8
  5 | import Data.String.Search
  6 | import Data.List
  7 | import Data.Vect
  8 | import Data.Bits
  9 | import Data.String.Base64
 10 | import Data.SortedMap as SortedMap
 11 | import Control.Monad.State
 12 | import Data.Buffer -- for stringByteLength
 13 | import Data.Buffer.Indexed as BufIndexed
 14 |
 15 | export
 16 | record Phase1State where
 17 |   constructor MkPhase1State
 18 |   clientNonceBase64: List Bits8
 19 |   clientFirstMessageBare: List Bits8
 20 |
 21 | ||| @ userName: The user name to authenticate as.
 22 | ||| @ nonce: The nonce should be non-empty and random for SCRAM as intended.
 23 | ||| The first element of the returned tuple is the first client message to send to the server.
 24 | ||| The second element of the returned tuple is that opaque state that must be
 25 | ||| used when receiving the first message from the server, see 'recvFirstMessageFromServer'.
 26 | export
 27 | genFirstMessageFromClient: String -> List Bits8 -> (List Bits8, Phase1State)
 28 | genFirstMessageFromClient userName nonce =
 29 |   let bare = BufIndexed.toList (utf8Encode "n=") ++ BufIndexed.toList (utf8Encode userName) ++ BufIndexed.toList (utf8Encode ",r=") ++ base64Encode nonce
 30 |   in
 31 |   (BufIndexed.toList (utf8Encode "n,,") ++ bare, MkPhase1State{clientFirstMessageBare=bare, clientNonceBase64=base64Encode nonce})
 32 |
 33 | public export
 34 | data Phase2Err
 35 |   = MkRNotAtStart
 36 |   | MkRMissing
 37 |   | MkREmpty
 38 |   | MkRNotPrefixedByClientNonce
 39 |   | MkRHasEmptyServerNonce
 40 |   | MkIterationsMissing
 41 |   | MkIterationCountUnsupported
 42 |   | MkSaltMissing
 43 |   | MkSaltEmpty
 44 |   | MkSaltInvalidBase64
 45 |
 46 | export
 47 | Show Phase2Err where
 48 |   show MkRNotAtStart               = "RNotAtStart"
 49 |   show MkRMissing                  = "RMissing"
 50 |   show MkREmpty                    = "REmpty"
 51 |   show MkRNotPrefixedByClientNonce = "RNotPrefixedByClientNonce"
 52 |   show MkRHasEmptyServerNonce      = "RHasEmptyServerNonce"
 53 |   show MkIterationsMissing         = "IterationsMissing"
 54 |   show MkIterationCountUnsupported = "IterationCountUnsupported"
 55 |   show MkSaltMissing               = "SaltMissing"
 56 |   show MkSaltEmpty                 = "SaltEmpty"
 57 |   show MkSaltInvalidBase64         = "SaltInvalidBase64"
 58 |
 59 | record Phase2State (hash: HashAlgorithm) where
 60 |   constructor MkPhase2State
 61 |   expectedServerSignature: Vect hash.outputSize Bits8
 62 |
 63 | hmac : {hash: HashAlgorithm}
 64 |     -> {blockSize: Nat}
 65 |     -> Vect blockSize Bits8
 66 |     -> List Bits8
 67 |     -> Vect hash.outputSize Bits8
 68 | hmac key str = finalizeHmac $ appendHmac str $ mkHmacCtx {hash} {keyLength=blockSize} key
 69 |
 70 | bits32ToBigEndian : Bits32 -> Vect 4 Bits8
 71 | bits32ToBigEndian streamIdent =
 72 |   [ cast (streamIdent `shiftR` 24)
 73 |   , cast (streamIdent `shiftR` 16)
 74 |   , cast (streamIdent `shiftR` 8)
 75 |   , cast streamIdent
 76 |   ]
 77 |
 78 | hi : {hash: HashAlgorithm}
 79 |   -> {blockSize: Nat}
 80 |   -> Vect blockSize Bits8
 81 |   -> List Bits8
 82 |   -> Int
 83 |   -> Vect hash.outputSize Bits8
 84 | hi str salt n =
 85 |   let
 86 |     root: Vect hash.outputSize Bits8
 87 |     root = hmac {hash} {blockSize} str (salt ++ toList (bits32ToBigEndian 1))
 88 |     -- hi/ui is from the spec, but with caching added in the State
 89 |     ui : Int -> State (SortedMap Int (Vect hash.outputSize Bits8)) (Vect hash.outputSize Bits8)
 90 |     ui 1 = pure root
 91 |     ui n = do
 92 |       st <- get
 93 |       case lookup n st of
 94 |         Nothing => do
 95 |           val <- hmac {hash} {blockSize} str . toList <$> ui (n - 1)
 96 |           modify (insert n val)
 97 |           pure val
 98 |         Just found =>
 99 |           pure found
100 |     uis : List (Vect hash.outputSize Bits8)
101 |     uis = evalState SortedMap.empty $ traverse ui [1..n]
102 |   in
103 |   foldr (zipWith xor) (replicate hash.outputSize 0) uis
104 |
105 | mkProof : {hash: HashAlgorithm}
106 |        -> {blockSize: Nat}
107 |        -> (Vect hash.outputSize Bits8 -> Vect blockSize Bits8)
108 |        -> Vect blockSize Bits8
109 |        -> List Bits8
110 |        -> List Bits8
111 | mkProof zeroPad saltedPassword authMessage =
112 |   let
113 |     clientKey = hmac {hash} saltedPassword (BufIndexed.toList (utf8Encode "Client Key"))
114 |     storedKey: Vect hash.outputSize Bits8
115 |     storedKey = hash.finalizeHash $ hash.appendHash (toList clientKey) hash.mkHashCtx
116 |     clientSignature = hmac {hash} (zeroPad storedKey) authMessage
117 |     clientProof: Vect hash.outputSize Bits8
118 |     clientProof = zipWith xor clientKey clientSignature
119 |   in toList clientProof
120 |
121 | ||| @ zeroPad: Zero pad hash output size to block size
122 | ||| @ normalizedPassword: Normalized password according to https://datatracker.ietf.org/doc/html/rfc5802#section-2.2
123 | |||                       and zero padded according to https://datatracker.ietf.org/doc/html/rfc2104#section-2 .
124 | |||                       Note that (quote):
125 | |||                          Applications that use keys longer
126 | |||                          than B bytes will first hash the key using H and then use the
127 | |||                          resultant L byte string as the actual key to HMAC.
128 | |||                       This extra hashing is not done as part of this function, and the caller will need to do it
129 | |||                       conditionally if strict compatibility is needed and long passwords are used.
130 | |||                       Also note that non-ASCII passwords must be rejected if normalization is not supported.
131 | export
132 | recvFirstMessageFromServer : {hash: HashAlgorithm}
133 |                           -> {blockSize: Nat}
134 |                           -> (Vect hash.outputSize Bits8 -> Vect blockSize Bits8)
135 |                           -> Vect blockSize Bits8
136 |                           -> Phase1State
137 |                           -> List Bits8
138 |                           -> Either Phase2Err (List Bits8, Phase2State hash)
139 | recvFirstMessageFromServer zeroPad pw1 st1 msg = do
140 |   Just ([], rBareAndAfter) <- pure $ splitBits8 (BufIndexed.toList (utf8Encode "r=")) msg
141 |   | Nothing => Left MkRMissing
142 |   | _ => Left MkRNotAtStart
143 |   Just (fullNonceBase64, sBareAndAfter) <- pure $ splitBits8 (BufIndexed.toList (utf8Encode ",s=")) rBareAndAfter
144 |   | Nothing => Left MkSaltMissing
145 |   S _ <- pure $ length fullNonceBase64
146 |   | Z => Left MkREmpty
147 |   True <- pure $ st1.clientNonceBase64 `isPrefixOf` fullNonceBase64
148 |   | False => Left MkRNotPrefixedByClientNonce
149 |   False <- pure $ length st1.clientNonceBase64 == length fullNonceBase64
150 |   | True => Left MkRHasEmptyServerNonce
151 |   Just (saltBase64, iterationsDecimal) <- pure $ splitBits8 (BufIndexed.toList (utf8Encode ",i=")) sBareAndAfter
152 |   | Nothing => Left MkIterationsMissing
153 |   case base64DecodeBits8 saltBase64 of
154 |     Just [] => Left MkSaltEmpty
155 |     Nothing => Left MkSaltInvalidBase64
156 |     Just nonEmptySalt =>
157 |       case utf8Decode (BufIndexed.bufferL iterationsDecimal) of
158 |         "4096" => do
159 |           let
160 |             clientFinalMessageWithoutProof =
161 |               BufIndexed.toList (utf8Encode "c=biws,r=") ++ fullNonceBase64
162 |             authMessage =
163 |               st1.clientFirstMessageBare
164 |                 ++ BufIndexed.toList (utf8Encode ",") ++ msg
165 |                 ++ BufIndexed.toList (utf8Encode ",") ++ clientFinalMessageWithoutProof
166 |             saltedPassword : Vect blockSize Bits8
167 |             saltedPassword = zeroPad (hi pw1 nonEmptySalt 4096)
168 |             clientProofBase64 = base64Encode $ mkProof zeroPad saltedPassword authMessage
169 |             serverKey = hmac {hash} saltedPassword (BufIndexed.toList (utf8Encode "Server Key"))
170 |             expectedServerSignature = hmac {hash} (zeroPad serverKey) authMessage
171 |           Right
172 |             (clientFinalMessageWithoutProof ++ BufIndexed.toList (utf8Encode ",p=") ++ clientProofBase64
173 |             , MkPhase2State
174 |                 { expectedServerSignature =
175 |                     expectedServerSignature
176 |                 }
177 |             )
178 |         _ => Left MkIterationCountUnsupported
179 |
180 | public export
181 | data Phase3Err
182 |   = MkServerSignatureInvalidBase64
183 |   | MkServerSignatureMissing
184 |   | MkServerSignatureMismatch
185 |
186 | export
187 | Show Phase3Err where
188 |   show MkServerSignatureInvalidBase64 = "ServerSignatureInvalidBase64"
189 |   show MkServerSignatureMissing       = "ServerSignatureMissing"
190 |   show MkServerSignatureMismatch      = "ServerSignatureMismatch"
191 |
192 | ||| If Nothing is returned, authentication was successful.
193 | export
194 | recvSecondMessageFromServer: {hash: HashAlgorithm} -> Phase2State hash -> List Bits8 -> Maybe Phase3Err
195 | recvSecondMessageFromServer st2 receivedASCII =
196 |   if BufIndexed.toList (utf8Encode "v=") `isPrefixOf` receivedASCII
197 |      then
198 |        case base64DecodeBits8 (drop 2 receivedASCII) of
199 |          Nothing => Just MkServerSignatureInvalidBase64
200 |          Just decodedServerSignature =>
201 |             if decodedServerSignature == toList st2.expectedServerSignature
202 |                then Nothing
203 |                else Just MkServerSignatureMismatch
204 |      else Just MkServerSignatureMissing
205 |