0 | -- Implementation based on https://www.newspipe.org/article/public/309714
  1 |
  2 | module Crypto.Hash.Poly1305
  3 |
  4 | import Data.Vect
  5 | import Utils.Misc
  6 | import Utils.Bytes
  7 | import Data.Bits
  8 | import Crypto.Hash
  9 |
 10 | export
 11 | record Poly1305 where
 12 |   constructor MkPoly1305
 13 |   buffer : List Bits8
 14 |   h : Vect 3 Bits64
 15 |   r : Vect 2 Bits64
 16 |   s : Vect 2 Bits64
 17 |
 18 | record Bits128 where
 19 |   constructor MkBits128
 20 |   lo : Bits64
 21 |   hi : Bits64
 22 |
 23 | mask_low_2_bits : Bits64
 24 | mask_low_2_bits  = 0x0000000000000003
 25 |
 26 | mask_not_low_2_bits : Bits64
 27 | mask_not_low_2_bits = complement mask_low_2_bits
 28 |
 29 | p0 : Bits64
 30 | p0 = 0xFFFFFFFFFFFFFFFB
 31 |
 32 | p1 : Bits64
 33 | p1 = 0xFFFFFFFFFFFFFFFF
 34 |
 35 | p2 : Bits64
 36 | p2 = 0x0000000000000003
 37 |
 38 | add64 : Bits64 -> Bits64 -> Bits64 -> (Bits64, Bits64)
 39 | add64 x y carry =
 40 |   let sum = x + y + carry
 41 |       carry = shiftR ((x .&. y) .|. ((x .|. y) .&. (complement sum))) 63
 42 |   in (sum, carry)
 43 |
 44 | sub64 : Bits64 -> Bits64 -> Bits64 -> (Bits64, Bits64)
 45 | sub64 x y borrow =
 46 |   let diff = x - y - borrow
 47 |       borrow_out = shiftR ((complement x .&. y) .|. (diff .&. complement (x `xor` y))) 63
 48 |   in (diff, borrow_out)
 49 |
 50 | mul64_mask32 : Bits64
 51 | mul64_mask32 = cast (the Bits32 oneBits)
 52 |
 53 | mul64 : Bits64 -> Bits64 -> Bits128
 54 | mul64 x y =
 55 |   let x0 = x .&. mul64_mask32
 56 |       x1 = shiftR x 32
 57 |       y0 = y .&. mul64_mask32
 58 |       y1 = shiftR y 32
 59 |       w0 = x0 * y0
 60 |       t = (x1 * y0) + (shiftR w0 32)
 61 |       w1 = t .&. mul64_mask32
 62 |       w2 = shiftR t 32
 63 |       w1 = w1 + (x0 * y1)
 64 |   in MkBits128 (x * y) (x1 * y1 + w2 + (shiftR w1 32))
 65 |
 66 | add128 : Bits128 -> Bits128 -> Bits128
 67 | add128 a b =
 68 |   let (lo, c) = add64 a.lo b.lo 0
 69 |       (hi, _) = add64 a.hi b.hi c
 70 |   in MkBits128 lo hi
 71 |
 72 | shiftr2 : Bits128 -> Bits128
 73 | shiftr2 a = MkBits128 ((shiftR a.lo 2) .|. (shiftL (a.hi .&. 3) 62)) (shiftR a.hi 2)
 74 |
 75 | -- select64 returns x if v == 1 and y if v == 0, in constant time
 76 | select64 : Bits64 -> Bits64 -> Bits64 -> Bits64
 77 | select64 v x y = (x .&. complement (v - 1)) .|. (y .&. (v - 1))
 78 |
 79 | core : Poly1305 -> Bits64 -> Bits64 -> Bits64 -> (Bits64, Bits64, Bits64)
 80 | core state h0 h1 h2 =
 81 |   let [r0, r1] = state.r
 82 |
 83 |       h0r0 = mul64 h0 r0
 84 |       h1r0 = mul64 h1 r0
 85 |       h2r0 = mul64 h2 r0
 86 |       h0r1 = mul64 h0 r1
 87 |       h1r1 = mul64 h1 r1
 88 |       h2r1 = mul64 h2 r1
 89 |
 90 |       m0 = h0r0
 91 |       m1 = add128 h1r0 h0r1
 92 |       m2 = add128 h2r0 h1r1
 93 |       m3 = h2r1
 94 |
 95 |       t0 = m0.lo
 96 |       (t1, c) = add64 m1.lo m0.hi 0
 97 |       (t2, c) = add64 m2.lo m1.hi c
 98 |       (t3, _) = add64 m3.lo m2.hi c
 99 |
100 |       h0 = t0
101 |       h1 = t1
102 |       h2 = t2 .&. mask_low_2_bits
103 |       cc = MkBits128 (t2 .&. mask_not_low_2_bits) t3
104 |
105 |       (h0, c) = add64 h0 cc.lo 0
106 |       (h1, c) = add64 h1 cc.hi c
107 |       h2 = h2 + c
108 |
109 |       cc = shiftr2 cc
110 |
111 |       (h0, c) = add64 h0 cc.lo 0
112 |       (h1, c) = add64 h1 cc.hi c
113 |       h2 = h2 + c
114 |   in (h0, h1, h2)
115 |
116 | update' : Poly1305 -> Poly1305
117 | update' state =
118 |   case splitAtExact 16 state.buffer of
119 |     Just (buf, rest) =>
120 |       let (a, b) = bimap from_le from_le (splitAt 8 buf)
121 |           [h0, h1, h2] = state.h
122 |           (h0, c) = add64 h0 a 0
123 |           (h1, c) = add64 h1 b c
124 |           h2 = h2 + c + 1
125 |           (h0, h1, h2) = core state h0 h1 h2
126 |       in {h := [h0, h1, h2]} state
127 |     Nothing => state
128 |
129 | finalize' : Poly1305 -> Vect 16 Bits8
130 | finalize' state =
131 |   case exactLength 16 (fromList state.buffer) of
132 |     Just buf =>
133 |       let (a, b) = bimap from_le from_le (splitAt 8 buf)
134 |           [h0, h1, h2] = state.h
135 |           (h0, c) = add64 h0 a 0
136 |           (h1, c) = add64 h1 b c
137 |           h2 = h2 + c
138 |           (h0, h1, h2) = core state h0 h1 h2
139 |
140 |           (t0, b) = sub64 h0 p0 0
141 |           (t1, b) = sub64 h1 p1 b
142 |           (_ , b) = sub64 h2 p2 b
143 |
144 |           h0 = select64 b h0 t0
145 |           h1 = select64 b h1 t1
146 |
147 |           [s0, s1] = state.s
148 |           (h0, c) = add64 h0 s0 0
149 |           (h1, _) = add64 h1 s1 c
150 |       in to_le {n=8} h0 ++ to_le {n=8} h1
151 |     Nothing =>
152 |       finalize' $ { buffer := pad_zero 16 (state.buffer <+> [ 0x01 ]) } (update' state)
153 |
154 | export
155 | Digest Poly1305 where
156 |   digest_nbyte = 16
157 |   update message state = update' $ {buffer := state.buffer <+> message} state
158 |   finalize = finalize'
159 |
160 | export
161 | MAC (Vect 32 Bits8) Poly1305 where
162 |   initialize_mac key =
163 |     let ([r0, r1], s) = splitAt 2 $ map from_le $ group 4 8 key
164 |         r0 = r0 .&. 0x0FFFFFFC0FFFFFFF
165 |         r1 = r1 .&. 0x0FFFFFFC0FFFFFFC
166 |     in MkPoly1305 [] [0,0,0] [r0, r1] s
167 |