7 | import Data.Cryptography.Hash
8 | import Data.Cryptography.Hash.SHA
9 | import Data.Cryptography.SCRAM
11 | import Control.Monad.Reader.Interface
12 | import Control.Monad.Reader.Reader
13 | import Control.Monad.State.State
14 | import Control.Monad.Trans
15 | import Data.Buffer as Buffer
16 | import Data.List as List
19 | import Data.Vect as V
20 | import Data.Buffer.Indexed as BufIndexed
23 | utf8Encode : String -> List Bits8
24 | utf8Encode str = BufIndexed.toList (UTF8.utf8Encode str)
26 | utf8Decode : List Bits8 -> String
27 | utf8Decode list = UTF8.utf8Decode (BufIndexed.bufferL list)
29 | log : Monad m => String -> m ()
30 | log = flip trace $
pure ()
33 | chr8 : Bits8 -> Char
36 | ord8 : Char -> Bits8
39 | writeString : String -> List Bits8
40 | writeString str = utf8Encode str ++ [0]
43 | record PGAuthentication where
44 | constructor MkPGAuthentication
45 | {auto prf: LTE n 64}
47 | password: Vect n Bits8
48 | zeroPadPassword: Vect n Bits8 -> Vect 64 Bits8
50 | mkStartup : Monad m => ReaderT PGAuthentication m (List Bits8)
59 | writeString st.userName
66 | = UnexpectedMessage String
67 | | SCRAM2Failure Phase2Err
68 | | SCRAM3Failure Phase3Err
72 | show (UnexpectedMessage msg) = "PgErr.UnexpectedMessage(" ++ msg ++ ")"
73 | show (SCRAM2Failure scram) = "PgErr.SCRAM2Failure(" ++ show scram ++ ")"
74 | show (SCRAM3Failure scram) = "PgErr.SCRAM3Failure(" ++ show scram ++ ")"
76 | logAscii : List Bits8 -> DBufferedConn PgErr Unit Unit altOut altIn
78 | log (utf8Decode lol)
81 | readBits8List : Int64 -> DBufferedConn PgErr Unit (List Bits8) altOut altIn
82 | readBits8List budget = do
84 | then result $
UnexpectedMessage "List too long"
90 | rest <- readBits8List (budget - 1)
93 | readString : DBufferedConn PgErr Unit String altOut altIn
94 | readString = utf8Decode <$> readBits8List 100000
96 | readBits32 : DBufferedConn PgErr Unit Bits32 altOut altIn
97 | readBits32 = bits32FromBigEndian <$> read 4
99 | readBits16 : DBufferedConn PgErr Unit Bits16 altOut altIn
102 | pure $
cast $
bits32FromBigEndian $
[0, 0] `V.(++)` b16
105 | constructor MkField
112 | formatCode : Bits16
116 | "Field { name = " ++ show f.name
117 | ++ ", tblOid = " ++ show f.tblOid
118 | ++ ", tblAid = " ++ show f.tblAid
119 | ++ ", oid = " ++ show f.oid
120 | ++ ", typlen = " ++ show f.typlen
121 | ++ ", atttypmod = " ++ show f.atttypmod
122 | ++ ", formatCode = " ++ show f.formatCode
125 | readField : DBufferedConn PgErr Unit Field altOut altIn
136 | readDataRowElement : DBufferedConn PgErr Unit DataRowElement altOut altIn
137 | readDataRowElement = do
139 | content <- read (cast len)
140 | pure $
MkDataRowElement $
toList content
143 | readElementCountAndElements : DBufferedConn PgErr Unit element altOut altIn -> DBufferedConn PgErr Unit (List element) altOut altIn
144 | readElementCountAndElements readField = do
147 | then result $
UnexpectedMessage "Bad length for message with elements"
149 | numFields <- readBits16
150 | readFields numFields
152 | readFields : Bits16 -> DBufferedConn PgErr Unit (List element) altOut altIn
153 | readFields 0 = pure []
156 | rest <- readFields (n - 1)
157 | pure $
field :: rest
159 | natToInt64 : Nat -> Int64
162 | bits32ToInt64 : Bits32 -> Int64
163 | bits32ToInt64 = cast
165 | readUntilEmpty : {typ: Type} -> Int64 -> (List Bits8 -> typ) -> DBufferedConn PgErr Unit (List typ) altOut altIn
166 | readUntilEmpty 0 decoder = pure []
167 | readUntilEmpty n decoder = do
168 | list <- readBits8List n
169 | let str = decoder list
171 | remaining = n - natToInt64 (List.length list) - 1
172 | case remaining `compare` 0 of
173 | LT => result $
UnexpectedMessage "Used up all the bytes in the message without finishing. Error"
175 | GT => (str ::) <$> readUntilEmpty remaining decoder
177 | decodeOneField: List Bits8 -> String
178 | decodeOneField [] = "Empty field type and msg"
179 | decodeOneField (fieldType :: msg) =
180 | utf8Decode [fieldType] ++ ": " ++ utf8Decode msg
182 | liftR : Monad m => m a -> ReaderT PGAuthentication m a
185 | receiveRWithCode : Bits32 -> String -> (List Bits8 -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Unit) -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Unit
186 | receiveRWithCode expectedCode sentMessageName continueWithoutCode = do
187 | typeByte :: Nil <- liftR $
read 1
188 | let typeChar = chr8 typeByte
190 | then liftR . result $
UnexpectedMessage $
"Expected R (auth) message, but got unknown response to " ++ sentMessageName
192 | len <- liftR readBits32
194 | then liftR . result $
UnexpectedMessage $
"Expected code field for R message, but there are not enough bytes. Expected this be at least 8: " ++ show len
196 | code <- liftR readBits32
197 | if code /= expectedCode
198 | then liftR . result $
UnexpectedMessage $
"Unknown response received for " ++ sentMessageName ++ ", the response has code " ++ show expectedCode
201 | msg <- liftR $
toList <$> read (the (Int64 -> Nat) cast $
bits32ToInt64 len - 4 - 4)
202 | continueWithoutCode msg
204 | handleMessage: Char -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Unit
205 | handleMessage 'E' = MkReaderT $
\pw => do
206 | errLen <- readBits32
207 | shown <- readUntilEmpty (the (Bits32 -> Int64) cast errLen - 4) decodeOneField
208 | result $
UnexpectedMessage $
joinBy ", " shown
211 | handleMessage 'R' = do
212 | len <- liftR readBits32
215 | method <- liftR readBits32
218 | 2 => liftR . result $
UnexpectedMessage "Received AuthenticationKerberosV5. Don't know how to handle."
222 | password: List Bits8
223 | password = toList st.password ++ [0]
224 | len = length password
225 | liftR . send $
[ord8 'p'] ++ toList (bits32ToBigEndian (4 + cast len)) ++ password
227 | typeByte::Nil <- liftR $
read 1
228 | let typeChar = chr8 typeByte
230 | then liftR . result $
UnexpectedMessage "Expected R response to auth"
231 | else handleMessage typeChar
232 | 6 => liftR . result $
UnexpectedMessage "Received AuthenticationSCMCredential. Don't know how to handle."
233 | 7 => liftR . result $
UnexpectedMessage "Received AuthenticationGSS. Don't know how to handle."
234 | 9 => liftR . result $
UnexpectedMessage "Received AuthenticationSSPI. Don't know how to handle."
235 | unknown => liftR . result $
UnexpectedMessage $
"Authentication message has type " ++ show unknown ++ ", which we don't know how to handle"
236 | 12 => liftR . result $
UnexpectedMessage "Received AuthenticationMD5Password. Don't know how to handle."
239 | then liftR . result $
UnexpectedMessage "Expected SASLContinue, SASLFinal, SASL, GSSContinue. But got a shorter message."
241 | 10 <- liftR readBits32
242 | | 8 => liftR . result $
UnexpectedMessage "Received AuthenticationGSSContinue. Don't know how to handle."
243 | | 11 => liftR . result $
UnexpectedMessage "Received AuthenticationSASLContinue too early. Don't know how to handle."
244 | | 12 => liftR . result $
UnexpectedMessage "Received AuthenticationSASLFinal too early. Don't know how to handle."
245 | | _ => liftR . result $
UnexpectedMessage "Received unknown Authentication message."
246 | algos <- liftR $
readUntilEmpty
247 | ( bits32ToInt64 len
253 | if not ("SCRAM-SHA-256" `elem` algos)
256 | UnexpectedMessage "Could not find the only supported SCRAM algo SCRAM-SHA-256 in list of supported algos from server"
258 | 0::Nil <- liftR $
read 1
259 | | invalidTerminator => liftR . result $
UnexpectedMessage "Expected terminator"
262 | (scramInitial, st1) = genFirstMessageFromClient st.userName (utf8Encode "random nonce")
263 | selectedAlgoStr = writeString "SCRAM-SHA-256"
264 | len = length selectedAlgoStr + 4 + length scramInitial
266 | liftR $
send $
[ord8 'p']
267 | ++ toList (bits32ToBigEndian (4 + cast len))
269 | ++ toList (bits32ToBigEndian (cast $
length scramInitial))
271 | receiveRWithCode 11 "SASLInitialResponse" $
\msg => do
274 | pw1 : Vect 64 Bits8
275 | pw1 = st.zeroPadPassword st.password
276 | case recvFirstMessageFromServer {hash=sha256} (++ replicate 32 0) pw1 st1 msg of
277 | Left err => liftR . result $
SCRAM2Failure err
278 | Right (scramClientFinalMsg, phase2State) => do
279 | let len = length scramClientFinalMsg
280 | liftR . send $
[ord8 'p']
281 | ++ toList (bits32ToBigEndian (4 + cast len))
282 | ++ scramClientFinalMsg
283 | receiveRWithCode 12 "SASLResponse" $
\msgFinal => do
284 | case recvSecondMessageFromServer phase2State msgFinal of
285 | Just err => liftR . result $
SCRAM3Failure err
287 | typeByte::Nil <- liftR $
read 1
288 | let typeChar = chr8 typeByte
289 | handleMessage typeChar
292 | handleMessage 'S' = MkReaderT $
\pw => do
294 | let lenInt = cast (len - 4)
296 | content <- read lenInt
297 | if length (filter (== 0) $
toList content) /= 2
298 | then result $
UnexpectedMessage "Wrong amount of zero bytes"
300 | let (keyName, rest) = break (== 0) $
toList content
301 | log "ParameterStatus"
302 | log (utf8Decode keyName)
306 | handleMessage 'K' = MkReaderT $
\pw => do
309 | then result $
UnexpectedMessage $
"Wrong length of BackendKeyData: " ++ show len
311 | processIDSecretKey <- read 8
315 | handleMessage 'Z' = MkReaderT $
\pw => do
318 | then result $
UnexpectedMessage $
"Wrong length of ReadyForQuery: " ++ show len
320 | transactionStatus::Nil <- read 1
324 | handleMessage 'T' = MkReaderT $
\pw => do
325 | fields <- readElementCountAndElements readField
329 | handleMessage 'D' = MkReaderT $
\pw => do
330 | elements <- readElementCountAndElements readDataRowElement
334 | handleMessage 'C' = MkReaderT $
\pw => do
336 | let lenInt = cast (len - 4)
337 | content <- read lenInt
340 | handleMessage typeByte = MkReaderT $
\pw =>
341 | result $
UnexpectedMessage $
"Unknown message type byte: " <+> show typeByte
343 | readUntilNotParameterStatus : ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Char
344 | readUntilNotParameterStatus = do
345 | typeByte::Nil <- liftR $
read 1
346 | let typeChar = chr8 typeByte
349 | handleMessage typeChar
350 | readUntilNotParameterStatus
353 | readDataRowsUntilFinish : PgRows -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a PgInput b) PgRows
354 | readDataRowsUntilFinish soFar = do
355 | typeByte::Nil <- liftR $
read 1
356 | case chr8 typeByte of
361 | thisRow <- liftR $
readElementCountAndElements readDataRowElement
362 | let new = soFar <+> [thisRow]
363 | readDataRowsUntilFinish new
365 | liftR . result $
UnexpectedMessage $
"Unexpected message with type " <+> show unknown
367 | pgClient: ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a PgInput PgRows) PgErr
369 | startup <- mkStartup
370 | liftR $
send $
toList $
bits32ToBigEndian (cast $
length startup + 4)
371 | liftR $
send startup
372 | typeByte::Nil <- liftR $
read 1
373 | handleMessage (chr8 typeByte)
375 | typeChar@'K' <- readUntilNotParameterStatus
376 | | unknown => liftR . result $
UnexpectedMessage $
"Expected BackendKeyData, got: " ++ show unknown
377 | handleMessage typeChar
379 | typeByte::Nil <- liftR $
read 1
380 | log $
"After BackendKeyData, got " ++ show (chr8 typeByte)
381 | handleMessage (chr8 typeByte)
384 | f : PgRows -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a PgInput PgRows) PgErr
386 | pgInput <- liftR $
lift (yieldGet2 elements)
387 | let sqlString = pgInputToString pgInput
388 | log $
"pgClient received sql: " ++ sqlString
389 | let sqlStringForWireProtocol = writeString sqlString
391 | let len = length sqlStringForWireProtocol + 4
392 | liftR $
send $
[ord8 'Q'] ++ toList (bits32ToBigEndian (cast len)) ++ sqlStringForWireProtocol
394 | typeByte::Nil <- liftR $
read 1
396 | case chr8 typeByte of
398 | handleMessage (chr8 typeByte)
399 | readDataRowsUntilFinish []
401 | handleMessage (chr8 typeByte)
403 | _ => liftR . result $
UnexpectedMessage $
"Not a RowDescription or CommandComplete: " <+> utf8Decode [typeByte]
405 | typeByte::Nil <- liftR $
read 1
406 | if chr8 typeByte /= 'Z'
407 | then pure $
UnexpectedMessage $
"Expected ready for query, got: " <+> utf8Decode [typeByte]
409 | handleMessage (chr8 typeByte)
415 | initialPgIter: PGAuthentication -> DIterator (List Bits8) (List Bits8) PgInput PgRows PgErr
416 | initialPgIter pw = iteratorFromBufConn (MkBufConSt [] ()) (runReaderT pw pgClient)