0 | module WebServerRacket
  1 |
  2 | import Cont
  3 | import ContAction
  4 | import PG
  5 | import UTF8 as UTF8
  6 | import RacketTCP
  7 |
  8 | import Http2
  9 | import Http2Responder
 10 |
 11 | import System as System
 12 | import System.Concurrency as Concurrency
 13 | import Data.Buffer as Buffer
 14 | import Data.Buffer.Core as BufCore
 15 | import Data.List as List
 16 | import Data.Maybe as Maybe
 17 |
 18 | -- Transitive import bug
 19 | import PGTypes
 20 |
 21 | setBytes : Buffer -> Int -> List Bits8 -> IO Unit
 22 | setBytes buf offset [] = pure ()
 23 | setBytes buf offset (x :: xs) = do
 24 |   Buffer.setBits8 buf offset x
 25 |   setBytes buf (offset + 1) xs
 26 |
 27 | record PgCtx where
 28 |   constructor MkPgCtx
 29 |   pgMutex : Concurrency.Mutex
 30 |   pgInputChan : Channel PgInput
 31 |   pgOutputChan : Channel PgRows
 32 |
 33 | SendRecvIterator : Type
 34 | SendRecvIterator = DIterator Response Act PgInput PgRows PgErr
 35 |
 36 | natToInt : Nat -> Int
 37 | natToInt = cast
 38 |
 39 | natToBits64 : Nat -> Bits64
 40 | natToBits64 = cast
 41 |
 42 | intToBits8 : Int -> Bits8
 43 | intToBits8 = cast
 44 |
 45 | newBuf : Nat -> IO Buffer
 46 | newBuf nat = do
 47 |   mbBuf <- Buffer.newBuffer (natToInt nat)
 48 |   case mbBuf of
 49 |     Nothing =>
 50 |       -- newBuffer only returns Nothing on negative values or, possibly on
 51 |       -- memory exhaustion.
 52 |       -- A Nat is always positive, so the Nothing should never happen.
 53 |       assert_total $ idris_crash "Impossible"
 54 |     Just buf => pure buf
 55 |
 56 | data ContState c d res
 57 |   = MkGotEof
 58 |   | MkGotWriteError
 59 |   | MkGotResult res
 60 |   | MkUnexpectedAmountWritten
 61 |   | MkGotCont (d, c -> DIterator Response Act c d res)
 62 |
 63 | runTillNeedsSql : InputPort -> OutputPort -> DIterator Response Act c d res -> IO (ContState c d res)
 64 | runTillNeedsSql pgInputPort pgOutputPort oldIter = do
 65 |   let
 66 |     sender : List Bits8 -> IO (Either WriteError Bits64)
 67 |     sender pgToSend = do
 68 |       writeBuf <- newBuf (length pgToSend)
 69 |       setBytes writeBuf 0 pgToSend
 70 |       writeBytes writeBuf pgOutputPort
 71 |   case oldIter of
 72 |     Susp (ActSend toSend) needPgInput => do
 73 |       Right written <- sender toSend
 74 |       | Left MkWriteError => pure MkGotWriteError
 75 |       if (written /= natToBits64 (length toSend))
 76 |          then pure MkUnexpectedAmountWritten
 77 |          else do
 78 |            putStrLn $ "Sent " <+> show written <+> " bytes"
 79 |            runTillNeedsSql pgInputPort pgOutputPort (needPgInput RSent)
 80 |     Susp ActReceive needPgInput => do
 81 |       syncEvt <- readBytesEvt 1 pgInputPort
 82 |       Right readBuf <- readSync syncEvt
 83 |       | Left MkEOF => pure MkGotEof
 84 |       readBufData <- map intToBits8 <$> Buffer.bufferData readBuf
 85 |       let newIter = needPgInput (RReceived readBufData)
 86 |
 87 |       runTillNeedsSql pgInputPort pgOutputPort newIter
 88 |     Susp2 rows needSql =>
 89 |       pure (MkGotCont (rows, needSql))
 90 |     Result err =>
 91 |       pure (MkGotResult err)
 92 |
 93 | client : DIterator Response Act PgRows PgInput Void -> InputPort -> OutputPort -> PgCtx -> IO Unit
 94 | client h2Iter inputPort outputPort pgCtx = do
 95 |   MkGotCont (pgInput, needsRows) <- runTillNeedsSql inputPort outputPort h2Iter
 96 |   | MkGotWriteError => putStrLn "Got error while writing to HTTP2"
 97 |   | MkUnexpectedAmountWritten => putStrLn "Got unexpected amount written to HTTP2"
 98 |   | MkGotEof => putStrLn "Got EOF from HTTP2"
 99 |   | MkGotResult void => absurd void
100 |   Concurrency.mutexAcquire pgCtx.pgMutex
101 |   Concurrency.channelPut pgCtx.pgInputChan pgInput
102 |   rows <- Concurrency.channelGet pgCtx.pgOutputChan
103 |   Concurrency.mutexRelease pgCtx.pgMutex
104 |   client (needsRows rows) inputPort outputPort pgCtx
105 |
106 | export
107 | runResponder : PGAuthentication -> (Request -> InnerCont (List Frame)) -> IO Unit
108 | runResponder pgAuth responder = do
109 |   pgCtx <- MkPgCtx <$> Concurrency.makeMutex <*> Concurrency.makeChannel <*> Concurrency.makeChannel
110 |
111 |   Right (pgIn, pgOut) <- tcpConnect "localhost" 5432
112 |   | Left MkCouldn'tConnect => putStrLn "Couldn't connect to PostgreSQL at localhost:5432"
113 |
114 |   MkGotCont ([], needsSql) <- runTillNeedsSql pgIn pgOut (mkSendRecvIterator $ initialPgIter pgAuth)
115 |   | MkGotCont (rows, needsSql) => putStrLn "Expected no rows"
116 |   | MkGotEof => putStrLn "Err in pg handshake: Got End-Of-File"
117 |   | MkGotResult pgErr => putStrLn $ "Err in pg handshake: PgErr: " <+> show pgErr
118 |   | MkGotWriteError => putStrLn "Err in pg handshake: WriteError"
119 |   | MkUnexpectedAmountWritten => putStrLn "Err in pg handshake: UnexpectedAmountWritten"
120 |   putStrLn "Done with PostgreSQL handshake"
121 |   let
122 |     pgForever : (PgInput -> SendRecvIterator) -> IO Unit
123 |     pgForever iter = do
124 |       pgInput <- Concurrency.channelGet pgCtx.pgInputChan
125 |       MkGotCont (result, needsSql2) <- runTillNeedsSql pgIn pgOut $ iter pgInput
126 |       | MkGotEof => putStrLn "Err in pg sql execution: Got End-Of-File on PostgreSQL connection"
127 |       | MkGotResult pgErr => putStrLn $ "Err in pg sql execution: PgErr: " <+> show pgErr
128 |       | MkGotWriteError => putStrLn "Err in pg sql execution: WriteError"
129 |       | MkUnexpectedAmountWritten => putStrLn "Err in pg sql execution: UnexpectedAmountWritten"
130 |       Concurrency.channelPut pgCtx.pgOutputChan result
131 |       pgForever needsSql2
132 |
133 |   pgThreadId <- fork (pgForever needsSql)
134 |
135 |   Right listener <- tcpListen 8000
136 |   | Left MkListenErr => putStrLn "Couldn't listen on port 8000"
137 |   putStrLn "Listening on port 8000..."
138 |   let
139 |     initialHttpIter: DIterator (List Bits8) (List Bits8) PgRows PgInput Void
140 |     initialHttpIter = Http2.mkInitialHttp2Iter responder
141 |     forever : IO Unit
142 |     forever = do
143 |       (inputPortRaw, outputPortRaw) <- tcpAccept listener
144 |       let h2Buf = BufCore.unsafeGetBuffer $ UTF8.utf8Encode "h2"
145 |       errOrPorts <- portsToSslPorts inputPortRaw outputPortRaw "key.pem" "cert.pem" h2Buf
146 |       case errOrPorts of
147 |         Left MkCouldn'tSetUpTLS => do
148 |           putStrLn "Couldn't set up TLS for client, dropping..."
149 |           forever
150 |         Right (inputPort, outputPort) => do
151 |           h2ThreadId <- fork $ client (mkSendRecvIterator initialHttpIter) inputPort outputPort pgCtx
152 |           forever
153 |   forever
154 |