0 | module Network.TLS.Handle
  1 |
  2 | import Control.Linear.LIO
  3 | import Control.Monad.Error.Either
  4 | import Control.Monad.State
  5 | import Crypto.ECDH
  6 | import Crypto.Random
  7 | import Data.List
  8 | import Data.List1
  9 | import Data.Vect
 10 | import Data.Void
 11 | import Network.TLS.Core
 12 | import Network.TLS.Magic
 13 | import Network.TLS.Parsing
 14 | import Network.TLS.Record
 15 | import Utils.Bytes
 16 | import Utils.Handle
 17 | import Utils.Misc
 18 | import Utils.Parser
 19 |
 20 | public export
 21 | tls_version_to_state_type : TLSVersion -> Type
 22 | tls_version_to_state_type TLS12 = TLSState Application2
 23 | tls_version_to_state_type TLS13 = TLSState Application3
 24 | tls_version_to_state_type _ = Void
 25 |
 26 | export
 27 | record TLSHandle (version : TLSVersion) t_ok t_closed where
 28 |   constructor MkTLSHandle
 29 |   1 handle : Handle t_ok t_closed (Res String $ const t_closed) (Res String $ const t_closed)
 30 |   state : tls_version_to_state_type version
 31 |   buffer : List Bits8
 32 |
 33 | public export
 34 | Uninhabited (TLSHandle TLS10 t_ok t_closed) where
 35 |   uninhabited = state
 36 |
 37 | public export
 38 | Uninhabited (TLSHandle TLS11 t_ok t_closed) where
 39 |   uninhabited = state
 40 |
 41 | OkOrError : TLSVersion -> Type -> Type -> Type
 42 | OkOrError tls_version t_ok t_closed = Res Bool $ \ok => if ok then TLSHandle tls_version t_ok t_closed else Res String (const t_closed)
 43 |
 44 | read_record : LinearIO m => (1 _ : Handle' t_ok t_closed) -> L1 m $ Res Bool $ \ok => if ok then Res (List Bits8) (const $ Handle' t_ok t_closed) else Res String (const t_closed)
 45 | read_record handle = do
 46 |   -- read header
 47 |   (True # (b_header # handle)) <- read handle 5
 48 |   | (False # (error # handle)) => pure1 (False # ("read (record header / alert) failed: " <+> error # handle))
 49 |   let (Pure [] (Right (_, TLS12, len))) =
 50 |     feed {i = List (Posed Bits8)} (map (uncurry MkPosed) $ enumerate 0 b_header) (alert <|> record_type_with_version_with_length).decode
 51 |   | Pure [] (Left x) => (close handle) >>= (\s => pure1 (False # (("ALERT: " <+> show x) # s)))
 52 |   | _ => (close handle) >>= (\s => pure1 (False # (("unable to parse header: " <+> xxd b_header) # s)))
 53 |
 54 |   -- read record content
 55 |   (True # (b_body # handle)) <- read handle (cast len)
 56 |   | (False # (error # handle)) => pure1 (False # ("read (record body) failed: " <+> error # handle))
 57 |   if length b_body == cast len
 58 |     then pure1 (True # (b_header <+> b_body # handle))
 59 |     else let err = "length does not match header: " <+> xxd b_body
 60 |                <+> "\nexpected length: " <+> show len
 61 |                <+> "\nactual length: " <+> (show $ length b_body)
 62 |          in (close handle) >>= (\s => pure1 (False # (err # s)))
 63 |
 64 | gen_key : MonadRandom m => (g : SupportedGroup) -> m (DPair SupportedGroup (\g => Pair (curve_group_to_scalar_type g) (curve_group_to_element_type g)))
 65 | gen_key group = do
 66 |   keypair <- generate_key_pair @{snd $ curve_group_to_type group}
 67 |   pure (group ** keypair)
 68 |
 69 | tls2_handshake : LinearIO io => TLSState ServerHello2 -> (1 _ : Handle' t_ok t_closed) -> CertificateCheck IO -> L1 io (OkOrError TLS12 t_ok t_closed)
 70 | tls2_handshake state handle cert_ok = do
 71 |   (True # (b_cert # handle)) <- read_record handle
 72 |   | (False # other) => pure1 (False # other)
 73 |
 74 |   let Right state = serverhello2_to_servercert state b_cert
 75 |   | Left error => (close handle) >>= (\s => pure1 (False # (error # s)))
 76 |
 77 |   (True # (b_skex # handle)) <- read_record handle
 78 |   | (False # other) => pure1 (False # other)
 79 |
 80 |   Right state <- liftIO1 $ servercert_to_serverkex state b_skex cert_ok
 81 |   | Left error => (close handle) >>= (\s => pure1 (False # (error # s)))
 82 |
 83 |   (True # (b_s_hello_done # handle)) <- read_record handle
 84 |   | (False # other) => pure1 (False # other)
 85 |
 86 |   let Right (state, handshake_data) = serverkex_process_serverhellodone state b_s_hello_done
 87 |   | Left error => (close handle) >>= (\s => pure1 (False # (error # s)))
 88 |
 89 |   (True # handle) <- write handle handshake_data
 90 |   | (False # (error # handle)) => pure1 (False # ("send_byte (handshake data) failed: " <+> error # handle))
 91 |
 92 |   (True # (b_ccs # handle)) <- read_record handle
 93 |   | (False # other) => pure1 (False # other)
 94 |
 95 |   let Right state = serverhellodone_to_applicationready2 state b_ccs
 96 |   | Left error => (close handle) >>= (\s => pure1 (False # (error # s)))
 97 |
 98 |   (True # (b_fin # handle)) <- read_record handle
 99 |   | (False # other) => pure1 (False # other)
100 |
101 |   case applicationready2_to_application2 state b_fin of
102 |     Right state => pure1 (True # MkTLSHandle handle state [])
103 |     Left error => (close handle) >>= (\s => pure1 (False # (error # s)))
104 |
105 | tls3_handshake : LinearIO io => TLSState ServerHello3 -> (1 _ : Handle' t_ok t_closed) -> CertificateCheck IO -> L1 io (OkOrError TLS13 t_ok t_closed)
106 | tls3_handshake state handle cert_ok = do
107 |   (True # (b_response # handle)) <- read_record handle
108 |   | (False # other) => pure1 (False # other)
109 |   parsed <- liftIO1 $ tls3_serverhello_to_application state b_response cert_ok
110 |   case parsed of
111 |     Right (Right (client_verify, state)) => do
112 |       (True # handle) <- write handle client_verify
113 |       | (False # (error # handle)) => pure1 (False # ("send_byte (client verify data) failed: " <+> error # handle))
114 |       pure1 (True # MkTLSHandle handle state [])
115 |     Right (Left state) =>
116 |       tls3_handshake state handle cert_ok
117 |     Left error =>
118 |       (close handle) >>= (\s => pure1 (False # (error # s)))
119 |
120 | DecryptFunction : Type -> Type
121 | DecryptFunction state = state -> List Bits8 -> Either String (state, List Bits8)
122 |
123 | EncryptFunction : Type -> Type
124 | EncryptFunction state = state -> List Bits8 -> (state, List Bits8)
125 |
126 | tlshandle_read : {version : _} -> LinearIO io => (wanted : Nat) -> (1 _ : TLSHandle version t_ok t_closed) -> DecryptFunction (tls_version_to_state_type version) -> L1 io (Res Bool $ ReadHack (TLSHandle version t_ok t_closed) (Res String (const t_closed)))
127 | tlshandle_read wanted (MkTLSHandle handle state buffer) decrypt =
128 |   let (a, b) = splitAt wanted buffer
129 |   in if (length a) == wanted
130 |         then pure1 (True # (a # MkTLSHandle handle state b))
131 |         else do
132 |           (True # (b_record # handle)) <- read_record handle
133 |           | (False # other) => pure1 (False # other)
134 |           case decrypt state b_record of
135 |             Right (state, plaintext) => tlshandle_read wanted (MkTLSHandle handle state $ buffer <+> plaintext) decrypt
136 |             Left error => (close handle) >>= (\s => pure1 (False # (error # s)))
137 |
138 | tlshandle_write : {tls_version : TLSVersion} -> LinearIO io => List (List Bits8) -> (1 _ : TLSHandle tls_version t_ok t_closed) -> EncryptFunction (tls_version_to_state_type tls_version) -> L1 io (Res Bool $ WriteHack (TLSHandle tls_version t_ok t_closed) (Res String (const t_closed)))
139 | tlshandle_write [] sock encrypt = pure1 (True # sock)
140 | tlshandle_write (x :: xs) (MkTLSHandle handle state buffer) encrypt = do
141 |   let (state, b_record) = encrypt state x
142 |   (True # handle) <- write handle b_record
143 |   | (False # (error # handle)) => pure1 (False # ("write (application data) failed: " <+> error # handle))
144 |   tlshandle_write xs (MkTLSHandle handle state buffer) encrypt
145 |
146 | ||| Reference: OpenSSL
147 | chunk_size : Nat
148 | chunk_size = 0x2000
149 |
150 | tlshandle_to_handle : {version : _} -> (1 _ : TLSHandle version t_ok t_closed) -> Handle' (TLSHandle version t_ok t_closed) t_closed
151 | tlshandle_to_handle {version=TLS10} (MkTLSHandle handle state buffer) = (kill_linear state) handle
152 | tlshandle_to_handle {version=TLS11} (MkTLSHandle handle state buffer) = (kill_linear state) handle
153 | tlshandle_to_handle {version=TLS12} handle = MkHandle
154 |   handle
155 |   ( \sock, len => tlshandle_read len sock decrypt_from_record2 )
156 |   ( \sock, input => tlshandle_write (chunk chunk_size input) sock encrypt_to_record2 )
157 |   ( \(MkTLSHandle handle state buffer) => close handle )
158 | tlshandle_to_handle {version=TLS13} handle = MkHandle
159 |   handle
160 |   ( \sock, len => tlshandle_read len sock decrypt_from_record )
161 |   ( \sock, input => tlshandle_write (chunk chunk_size input) sock encrypt_to_record )
162 |   ( \(MkTLSHandle handle state buffer) => close handle )
163 |
164 | TLSHandle' : Type -> Type -> Type
165 | TLSHandle' t_ok t_closed = Res TLSVersion $ \version => TLSHandle version t_ok t_closed
166 |
167 | abstract_tlshandle : (1 _ : TLSHandle' t_ok t_closed) -> Handle' (TLSHandle' t_ok t_closed) t_closed
168 | abstract_tlshandle x = MkHandle
169 |   x
170 |   ( \(v # h), wanted => do
171 |       (True # (output # MkHandle h _ _ _)) <- read (tlshandle_to_handle h) wanted
172 |       | (False # (err # x)) => pure1 $ False # (err # x)
173 |       pure1 $ True # (output # (_ # h))
174 |   )
175 |   ( \(v # h), input => do
176 |       (True # MkHandle h _ _ _) <- write (tlshandle_to_handle h) input
177 |       | (False # (err # x)) => pure1 $ False # (err # x)
178 |       pure1 $ True # (_ # h)
179 |   )
180 |   ( \(v # h) => close $ tlshandle_to_handle h
181 |   )
182 |
183 | export
184 | tls_handshake : (MonadRandom IO, LinearIO io) => 
185 |                 String ->
186 |                 List1 SupportedGroup ->
187 |                 List1 SignatureAlgorithm ->
188 |                 List1 CipherSuite ->
189 |                 (1 _ : Handle' t_ok t_closed) ->
190 |                 CertificateCheck IO ->
191 |                 L1 io (Res Bool $ \ok => if ok then Handle' (TLSHandle' t_ok t_closed) t_closed else Res String (const t_closed))
192 | tls_handshake target_hostname supported_groups signature_algos cipher_suites handle cert_ok = do
193 |   random <- liftIO1 $ random_bytes _
194 |   keypairs <- liftIO1 $ traverse gen_key supported_groups
195 |   let
196 |     (client_hello, state) =
197 |       tls_init_to_clienthello $ TLS_Init $ MkTLSInitialState
198 |         target_hostname
199 |         random
200 |         []
201 |         cipher_suites
202 |         signature_algos
203 |         keypairs
204 |
205 |   (True # handle) <- write handle client_hello
206 |   | (False # (error # handle)) => pure1 (False # ("send client_hello failed: " <+> error # handle))
207 |
208 |   (True # (b_server_hello # handle)) <- read_record handle
209 |   | (False # other) => pure1 (False # other)
210 |
211 |   case tls_clienthello_to_serverhello state b_server_hello of
212 |     Right (Left state) => do
213 |       (True # ok) <- tls2_handshake state handle cert_ok
214 |       | (False # no) => pure1 (False # no)
215 |       pure1 $ True # abstract_tlshandle (_ # ok)
216 |     Right (Right state) => do
217 |       (True # ok) <- tls3_handshake state handle cert_ok
218 |       | (False # no) => pure1 (False # no)
219 |       pure1 $ True # abstract_tlshandle (_ # ok)
220 |     Left error => do
221 |       h <- close handle
222 |       pure1 $ False # (error # h)
223 |