0 | module Crypto.ChaCha
 1 |
 2 | import Control.Monad.State
 3 | import Data.Bits
 4 | import Data.DPair
 5 | import Data.Fin
 6 | import Data.Fin.Extra
 7 | import Data.Nat.Order.Properties
 8 | import Data.Vect
 9 | import Utils.Bytes
10 | import Utils.Misc
11 |
12 | public export
13 | Key : Type
14 | Key = Vect 8 Bits32 -- 32 * 8 = 256
15 |
16 | public export
17 | ChaChaState : Type
18 | ChaChaState = Vect 16 Bits32
19 |
20 | -- The first four words (0-3) are constants
21 | public export
22 | constants : Vect 4 Bits32
23 | constants = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574] -- ['expa', 'nd 3', '2-by', 'te k']
24 |
25 | public export
26 | quarter_rotate : Fin 16 -> Fin 16 -> Fin 16 -> Fin 16 -> State ChaChaState ()
27 | quarter_rotate a b c d = do
28 |   modify (\s => updateAt a (+ index b s) s)
29 |   modify (\s => updateAt d (`xor` index a s) s)
30 |   modify (\s => updateAt d (rotl 16) s)
31 |
32 |   modify (\s => updateAt c (+ index d s) s)
33 |   modify (\s => updateAt b (`xor` index c s) s)
34 |   modify (\s => updateAt b (rotl 12) s)
35 |
36 |   modify (\s => updateAt a (+ index b s) s)
37 |   modify (\s => updateAt d (`xor` index a s) s)
38 |   modify (\s => updateAt d (rotl 8) s)
39 |
40 |   modify (\s => updateAt c (+ index d s) s)
41 |   modify (\s => updateAt b (`xor` index c s) s)
42 |   modify (\s => updateAt b (rotl 7) s)
43 |
44 | public export
45 | double_rotate : State ChaChaState ()
46 | double_rotate = do
47 |   quarter_rotate 0 4  8 12
48 |   quarter_rotate 1 5  9 13
49 |   quarter_rotate 2 6 10 14
50 |   quarter_rotate 3 7 11 15
51 |   ------------------------
52 |   quarter_rotate 0 5 10 15
53 |   quarter_rotate 1 6 11 12
54 |   quarter_rotate 2 7  8 13
55 |   quarter_rotate 3 4  9 14
56 |
57 | public export
58 | run2x : (n_double_rounds : Nat) -> ChaChaState -> ChaChaState
59 | run2x n_double_rounds s =
60 |   execState s $ do
61 |     original <- get
62 |     go last
63 |     modify (zipWith (+) original)
64 |   where
65 |   go : Fin (S n_double_rounds) -> State ChaChaState ()
66 |   go FZ = pure ()
67 |   go (FS wit) = double_rotate *> go (weaken wit)
68 |
69 | ||| ChaCha construction with 4 octets counter and 12 octets nonce as per RFC8439
70 | public export
71 | chacha_rfc8439_block : Nat -> (counter : Bits32) -> Key -> Vect 3 Bits32 -> Vect 64 Bits8
72 | chacha_rfc8439_block rounds counter key nonce = concat $ map (to_le {n = 4}) $ run2x rounds $ constants ++ key ++ [counter] ++ nonce
73 |
74 | ||| ChaCha construction with 8 octets counter and 8 octets nonce as per the original ChaCha specification
75 | public export
76 | chacha_original_block : Nat -> (counter : Bits64) -> Key -> Vect 2 Bits32 -> Vect 64 Bits8
77 | chacha_original_block rounds counter key nonce = concat $ map (to_le {n = 4}) $ run2x rounds $ constants ++ key ++ split_word counter ++ nonce
78 |   where
79 |     split_word : Bits64 -> Vect 2 Bits32
80 |     split_word a = [ cast a, cast (shiftR a 32) ]
81 |