0 | module PG
  1 |
  2 | import UTF8 as UTF8
  3 | import BufConn
  4 | import Cont
  5 | import PGTypes
  6 | import BitsUtil
  7 | import Data.Cryptography.Hash
  8 | import Data.Cryptography.Hash.SHA
  9 | import Data.Cryptography.SCRAM
 10 |
 11 | import Control.Monad.Reader.Interface
 12 | import Control.Monad.Reader.Reader
 13 | import Control.Monad.State.State
 14 | import Control.Monad.Trans
 15 | import Data.Buffer as Buffer
 16 | import Data.List as List
 17 | import Data.List1
 18 | import Data.String
 19 | import Data.Vect as V
 20 | import Data.Buffer.Indexed as BufIndexed
 21 | import Debug.Trace
 22 |
 23 | utf8Encode : String -> List Bits8
 24 | utf8Encode str = BufIndexed.toList (UTF8.utf8Encode str)
 25 |
 26 | utf8Decode : List Bits8 -> String
 27 | utf8Decode list = UTF8.utf8Decode (BufIndexed.bufferL list)
 28 |
 29 | log : Monad m => String -> m ()
 30 | log = flip trace $ pure ()
 31 |
 32 | -- Note that these are not safe to use on non-ASCII. E.g. ord8 'æ' == 230, but ord '文' == 25991, which then gets masked when casting
 33 | chr8 : Bits8 -> Char
 34 | chr8 = chr . cast
 35 |
 36 | ord8 : Char -> Bits8
 37 | ord8 = cast . ord
 38 |
 39 | writeString : String -> List Bits8
 40 | writeString str = utf8Encode str ++ [0]
 41 |
 42 | public export
 43 | record PGAuthentication where
 44 |   constructor MkPGAuthentication
 45 |   {auto prf: LTE n 64}
 46 |   userName: String
 47 |   password: Vect n Bits8
 48 |   zeroPadPassword: Vect n Bits8 -> Vect 64 Bits8
 49 |
 50 | mkStartup : Monad m => ReaderT PGAuthentication m (List Bits8)
 51 | mkStartup = do
 52 |   st <- ask
 53 |   pure $
 54 |       [ 0, 3, 0, 0
 55 |       ]
 56 |       ++
 57 |       writeString "user"
 58 |       ++
 59 |       writeString st.userName
 60 |       ++
 61 |       [ 0
 62 |       ]
 63 |
 64 | export
 65 | data PgErr
 66 |   = UnexpectedMessage String
 67 |   | SCRAM2Failure Phase2Err
 68 |   | SCRAM3Failure Phase3Err
 69 |
 70 | export
 71 | Show PgErr where
 72 |   show (UnexpectedMessage msg) = "PgErr.UnexpectedMessage(" ++ msg ++ ")"
 73 |   show (SCRAM2Failure scram) = "PgErr.SCRAM2Failure(" ++ show scram ++ ")"
 74 |   show (SCRAM3Failure scram) = "PgErr.SCRAM3Failure(" ++ show scram ++ ")"
 75 |
 76 | logAscii : List Bits8 -> DBufferedConn PgErr Unit Unit altOut altIn
 77 | logAscii lol = do
 78 |   log (utf8Decode lol)
 79 |   log ""
 80 |
 81 | readBits8List : Int64 -> DBufferedConn PgErr Unit (List Bits8) altOut altIn
 82 | readBits8List budget = do
 83 |   if budget < 0
 84 |      then result $ UnexpectedMessage "List too long"
 85 |      else do
 86 |        bite::Nil <- read 1
 87 |        if bite == 0
 88 |          then pure []
 89 |          else do
 90 |            rest <- readBits8List (budget - 1)
 91 |            pure $ bite :: rest
 92 |
 93 | readString : DBufferedConn PgErr Unit String altOut altIn
 94 | readString = utf8Decode <$> readBits8List 100000
 95 |
 96 | readBits32 : DBufferedConn PgErr Unit Bits32 altOut altIn
 97 | readBits32 = bits32FromBigEndian <$> read 4
 98 |
 99 | readBits16 : DBufferedConn PgErr Unit Bits16 altOut altIn
100 | readBits16 = do
101 |   b16 <- read 2
102 |   pure $ cast $ bits32FromBigEndian $ [0, 0] `V.(++)` b16
103 |
104 | record Field where
105 |   constructor MkField
106 |   name : String
107 |   tblOid : Bits32
108 |   tblAid : Bits16
109 |   oid : Bits32
110 |   typlen : Bits16
111 |   atttypmod : Bits32
112 |   formatCode : Bits16
113 |
114 | Show Field where
115 |   show f =
116 |     "Field { name = " ++ show f.name
117 |        ++ ", tblOid = " ++ show f.tblOid
118 |        ++ ", tblAid = " ++ show f.tblAid
119 |        ++ ", oid = " ++ show f.oid
120 |        ++ ", typlen = " ++ show f.typlen
121 |        ++ ", atttypmod = " ++ show f.atttypmod
122 |        ++ ", formatCode = " ++ show f.formatCode
123 |        ++ "}"
124 |
125 | readField : DBufferedConn PgErr Unit Field altOut altIn
126 | readField = do
127 |   MkField
128 |     <$> readString
129 |     <*> readBits32
130 |     <*> readBits16
131 |     <*> readBits32
132 |     <*> readBits16
133 |     <*> readBits32
134 |     <*> readBits16
135 |
136 | readDataRowElement : DBufferedConn PgErr Unit DataRowElement altOut altIn
137 | readDataRowElement = do
138 |   len <- readBits32
139 |   content <- read (cast len)
140 |   pure $ MkDataRowElement $ toList content
141 |
142 |
143 | readElementCountAndElements : DBufferedConn PgErr Unit element altOut altIn -> DBufferedConn PgErr Unit (List element) altOut altIn
144 | readElementCountAndElements readField = do
145 |   len <- readBits32
146 |   if len < 6
147 |     then result $ UnexpectedMessage "Bad length for message with elements"
148 |     else do
149 |       numFields <- readBits16
150 |       readFields numFields
151 |   where
152 |     readFields : Bits16 -> DBufferedConn PgErr Unit (List element) altOut altIn
153 |     readFields 0 = pure []
154 |     readFields n = do
155 |       field <- readField
156 |       rest <- readFields (n - 1)
157 |       pure $ field :: rest
158 |
159 | natToInt64 : Nat -> Int64
160 | natToInt64 = cast
161 |
162 | bits32ToInt64 : Bits32 -> Int64
163 | bits32ToInt64 = cast
164 |
165 | readUntilEmpty : {typ: Type} -> Int64 -> (List Bits8 -> typ) -> DBufferedConn PgErr Unit (List typ) altOut altIn
166 | readUntilEmpty 0 decoder = pure []
167 | readUntilEmpty n decoder = do
168 |   list <- readBits8List n
169 |   let str = decoder list
170 |       remaining : Int64
171 |       remaining = n - natToInt64 (List.length list) - 1 -- 1 is the list terminator
172 |   case remaining `compare` 0 of
173 |     LT => result $ UnexpectedMessage "Used up all the bytes in the message without finishing. Error"
174 |     EQ => pure [str]
175 |     GT => (str ::) <$> readUntilEmpty remaining decoder
176 |
177 | decodeOneField: List Bits8 -> String
178 | decodeOneField [] = "Empty field type and msg"
179 | decodeOneField (fieldType :: msg) =
180 |   utf8Decode [fieldType] ++ ": " ++ utf8Decode msg
181 |
182 | liftR : Monad m => m a -> ReaderT PGAuthentication m a
183 | liftR = lift
184 |
185 | receiveRWithCode : Bits32 -> String -> (List Bits8 -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Unit) -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Unit
186 | receiveRWithCode expectedCode sentMessageName continueWithoutCode = do
187 |   typeByte :: Nil <- liftR $ read 1
188 |   let typeChar = chr8 typeByte
189 |   if typeChar /= 'R'
190 |     then liftR . result $ UnexpectedMessage $ "Expected R (auth) message, but got unknown response to " ++ sentMessageName
191 |     else do
192 |       len <- liftR readBits32
193 |       if len < 8
194 |          then liftR . result $ UnexpectedMessage $ "Expected code field for R message, but there are not enough bytes. Expected this be at least 8: " ++ show len
195 |          else do
196 |            code <- liftR readBits32
197 |            if code /= expectedCode
198 |               then liftR . result $ UnexpectedMessage $ "Unknown response received for " ++ sentMessageName ++ ", the response has code " ++ show expectedCode
199 |               else do
200 |                 -- The first 4 is the length itself, the second is the code.
201 |                 msg <- liftR $ toList <$> read (the (Int64 -> Nat) cast $ bits32ToInt64 len - 4 - 4)
202 |                 continueWithoutCode msg
203 |
204 | handleMessage: Char -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Unit
205 | handleMessage 'E' = MkReaderT $ \pw => do
206 |   errLen <- readBits32
207 |   shown <- readUntilEmpty (the (Bits32 -> Int64) cast errLen - 4) decodeOneField
208 |   result $ UnexpectedMessage $ joinBy ", " shown
209 |
210 | -- Authentication message (they all have R)
211 | handleMessage 'R' = do
212 |   len <- liftR readBits32
213 |   case len of
214 |     8 => do
215 |       method <- liftR readBits32
216 |       case method of
217 |         0 => pure ()
218 |         2 => liftR . result $ UnexpectedMessage "Received AuthenticationKerberosV5. Don't know how to handle."
219 |         3 => do
220 |           st <- ask
221 |           let
222 |             password: List Bits8
223 |             password = toList st.password ++ [0]
224 |             len = length password
225 |           liftR . send $ [ord8 'p'] ++ toList (bits32ToBigEndian (4 + cast len)) ++ password
226 |
227 |           typeByte::Nil <- liftR $ read 1
228 |           let typeChar = chr8 typeByte
229 |           if typeChar /= 'R'
230 |              then liftR . result $ UnexpectedMessage "Expected R response to auth"
231 |              else handleMessage typeChar
232 |         6 => liftR . result $ UnexpectedMessage "Received AuthenticationSCMCredential. Don't know how to handle."
233 |         7 => liftR . result $ UnexpectedMessage "Received AuthenticationGSS. Don't know how to handle."
234 |         9 => liftR . result $ UnexpectedMessage "Received AuthenticationSSPI. Don't know how to handle."
235 |         unknown => liftR . result $ UnexpectedMessage $ "Authentication message has type " ++ show unknown ++ ", which we don't know how to handle"
236 |     12 => liftR . result $ UnexpectedMessage "Received AuthenticationMD5Password. Don't know how to handle."
237 |     other =>
238 |       if other < 8
239 |          then liftR . result $ UnexpectedMessage "Expected SASLContinue, SASLFinal, SASL, GSSContinue. But got a shorter message."
240 |          else do
241 |            10 <- liftR readBits32 -- 10 is AuthenticationSASL
242 |            | 8 => liftR . result $ UnexpectedMessage "Received AuthenticationGSSContinue. Don't know how to handle."
243 |            | 11 => liftR . result $ UnexpectedMessage "Received AuthenticationSASLContinue too early. Don't know how to handle."
244 |            | 12 => liftR . result $ UnexpectedMessage "Received AuthenticationSASLFinal too early. Don't know how to handle."
245 |            | _ => liftR . result $ UnexpectedMessage "Received unknown Authentication message."
246 |            algos <- liftR $ readUntilEmpty
247 |              ( bits32ToInt64 len
248 |              - 4 -- the length itself
249 |              - 4 -- the code
250 |              - 1 -- the terminator at the very end
251 |              )
252 |              utf8Decode
253 |            if not ("SCRAM-SHA-256" `elem` algos)
254 |               then
255 |                 liftR . result $
256 |                   UnexpectedMessage "Could not find the only supported SCRAM algo SCRAM-SHA-256 in list of supported algos from server"
257 |               else do
258 |                 0::Nil <- liftR $ read 1
259 |                 | invalidTerminator => liftR . result $ UnexpectedMessage "Expected terminator"
260 |                 st <- ask
261 |                 let
262 |                   (scramInitial, st1) = genFirstMessageFromClient st.userName (utf8Encode "random nonce")
263 |                   selectedAlgoStr = writeString "SCRAM-SHA-256"
264 |                   len = length selectedAlgoStr + 4 + length scramInitial
265 |                 -- SASLInitialResponse
266 |                 liftR $ send $ [ord8 'p']
267 |                        ++ toList (bits32ToBigEndian (4 + cast len))
268 |                        ++ selectedAlgoStr
269 |                        ++ toList (bits32ToBigEndian (cast $ length scramInitial))
270 |                        ++ scramInitial
271 |                 receiveRWithCode 11 "SASLInitialResponse" $ \msg => do -- 11 is SASLContinue
272 |                   st <- ask
273 |                   let
274 |                     pw1 : Vect 64 Bits8
275 |                     pw1 = st.zeroPadPassword st.password
276 |                   case recvFirstMessageFromServer {hash=sha256} (++ replicate 32 0) pw1 st1 msg of
277 |                     Left err => liftR . result $ SCRAM2Failure err
278 |                     Right (scramClientFinalMsg, phase2State) => do
279 |                       let len = length scramClientFinalMsg
280 |                       liftR . send $ [ord8 'p']
281 |                                      ++ toList (bits32ToBigEndian (4 + cast len))
282 |                                      ++ scramClientFinalMsg
283 |                       receiveRWithCode 12 "SASLResponse" $ \msgFinal => do -- 12 is SASLFinal
284 |                         case recvSecondMessageFromServer phase2State msgFinal of
285 |                           Just err => liftR . result $ SCRAM3Failure err
286 |                           Nothing => do
287 |                             typeByte::Nil <- liftR $ read 1
288 |                             let typeChar = chr8 typeByte
289 |                             handleMessage typeChar
290 |
291 | -- ParameterStatus
292 | handleMessage 'S' =  MkReaderT $ \pw => do
293 |   len <- readBits32
294 |   let lenInt = cast (len - 4)
295 |   --log $ "ParameterStatus with length " ++ show lenInt
296 |   content <- read lenInt
297 |   if length (filter (== 0) $ toList content) /= 2
298 |     then result $ UnexpectedMessage "Wrong amount of zero bytes"
299 |     else do
300 |       let (keyName, rest) = break (== 0) $ toList content
301 |       log "ParameterStatus"
302 |       log (utf8Decode keyName)
303 |       pure ()
304 |
305 | -- BackendKeyData
306 | handleMessage 'K' =  MkReaderT $ \pw => do
307 |   len <- readBits32
308 |   if len /= 12
309 |     then result $ UnexpectedMessage $ "Wrong length of BackendKeyData: " ++ show len
310 |     else do
311 |       processIDSecretKey <- read 8
312 |       pure ()
313 |
314 | -- ReadyForQuery
315 | handleMessage 'Z' =  MkReaderT $ \pw => do
316 |   len <- readBits32
317 |   if len /= 5
318 |     then result $ UnexpectedMessage $ "Wrong length of ReadyForQuery: " ++ show len
319 |     else do
320 |       transactionStatus::Nil <- read 1
321 |       pure ()
322 |
323 | -- RowDescription
324 | handleMessage 'T' = MkReaderT $ \pw => do
325 |   fields <- readElementCountAndElements readField
326 |   pure ()
327 |
328 | -- DataRow
329 | handleMessage 'D' =  MkReaderT $ \pw => do
330 |   elements <- readElementCountAndElements readDataRowElement
331 |   pure ()
332 |
333 | -- CommandComplete
334 | handleMessage 'C' =  MkReaderT $ \pw => do
335 |   len <- readBits32
336 |   let lenInt = cast (len - 4)
337 |   content <- read lenInt
338 |   pure ()
339 |
340 | handleMessage typeByte = MkReaderT $ \pw =>
341 |   result $ UnexpectedMessage $ "Unknown message type byte: " <+> show typeByte
342 |
343 | readUntilNotParameterStatus : ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a altOut altIn) Char
344 | readUntilNotParameterStatus = do
345 |   typeByte::Nil <- liftR $ read 1
346 |   let typeChar = chr8 typeByte
347 |   if typeChar == 'S'
348 |     then do
349 |       handleMessage typeChar
350 |       readUntilNotParameterStatus
351 |     else pure typeChar
352 |
353 | readDataRowsUntilFinish : PgRows -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a PgInput b) PgRows
354 | readDataRowsUntilFinish soFar = do
355 |   typeByte::Nil <- liftR $ read 1
356 |   case chr8 typeByte of
357 |     'C' => do
358 |       handleMessage 'C'
359 |       pure soFar
360 |     'D' => do
361 |       thisRow <- liftR $ readElementCountAndElements readDataRowElement
362 |       let new = soFar <+> [thisRow]
363 |       readDataRowsUntilFinish new
364 |     unknown =>
365 |       liftR . result $ UnexpectedMessage $ "Unexpected message with type " <+> show unknown
366 |
367 | pgClient: ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a PgInput PgRows) PgErr
368 | pgClient = do
369 |   startup <- mkStartup
370 |   liftR $ send $ toList $ bits32ToBigEndian (cast $ length startup + 4)
371 |   liftR $ send startup
372 |   typeByte::Nil <- liftR $ read 1
373 |   handleMessage (chr8 typeByte)
374 |
375 |   typeChar@'K' <- readUntilNotParameterStatus
376 |   | unknown => liftR . result $ UnexpectedMessage $ "Expected BackendKeyData, got: " ++ show unknown
377 |   handleMessage typeChar -- BackendKeyData
378 |
379 |   typeByte::Nil <- liftR $ read 1
380 |   log $ "After BackendKeyData, got " ++ show (chr8 typeByte)
381 |   handleMessage (chr8 typeByte)
382 |
383 |   let
384 |     f : PgRows -> ReaderT PGAuthentication (\a => DBufferedConn PgErr Unit a PgInput PgRows) PgErr
385 |     f elements = do
386 |       pgInput <- liftR $ lift (yieldGet2 elements)
387 |       let sqlString = pgInputToString pgInput
388 |       log $ "pgClient received sql: " ++ sqlString
389 |       let sqlStringForWireProtocol = writeString sqlString
390 |
391 |       let len = length sqlStringForWireProtocol + 4
392 |       liftR $ send $ [ord8 'Q'] ++ toList (bits32ToBigEndian (cast len)) ++ sqlStringForWireProtocol
393 |
394 |       typeByte::Nil <- liftR $ read 1
395 |       rows <-
396 |         case chr8 typeByte of
397 |           'T' => do
398 |             handleMessage (chr8 typeByte)
399 |             readDataRowsUntilFinish []
400 |           'C' => do
401 |             handleMessage (chr8 typeByte)
402 |             pure []
403 |           _ => liftR . result $ UnexpectedMessage $ "Not a RowDescription or CommandComplete: " <+> utf8Decode [typeByte]
404 |
405 |       typeByte::Nil <- liftR $ read 1
406 |       if chr8 typeByte /= 'Z'
407 |          then pure $ UnexpectedMessage $ "Expected ready for query, got: " <+> utf8Decode [typeByte]
408 |          else do
409 |            handleMessage (chr8 typeByte)
410 |            f rows
411 |
412 |   f []
413 |
414 | export
415 | initialPgIter: PGAuthentication -> DIterator (List Bits8) (List Bits8) PgInput PgRows PgErr
416 | initialPgIter pw = iteratorFromBufConn (MkBufConSt [] ()) (runReaderT pw pgClient)
417 |