0 | module Data.Num
  1 |
  2 | import Data.Vect
  3 |
  4 | ||| Interface for the Exponential
  5 | ||| We also include minus infinity because of the necessity to compute
  6 | ||| causal masks within the attention mechanism.
  7 | ||| For rules that `exp` should satisfy, see https://arxiv.org/abs/1911.04790
  8 | ||| We also have
  9 | ||| `exp . log = id`, `log . exp = id`, `exp minusInfinity = 0`...
 10 | public export
 11 | interface Num a => Exp a where
 12 |   exp : a -> a
 13 |   log : a -> a
 14 |   minusInfinity : a
 15 |
 16 | public export
 17 | Exp Double where
 18 |   exp = Prelude.exp
 19 |   log = Prelude.log
 20 |   minusInfinity = cast "-inf.0"
 21 |
 22 | public export
 23 | interface Num a => Sqrt a where
 24 |   sqrt : a -> a
 25 |
 26 | public export
 27 | Sqrt Double where
 28 |   sqrt = Prelude.sqrt
 29 |
 30 |
 31 | namespace Num
 32 |   public export
 33 |   {n : Nat} -> Num a => Num (Vect n a) where
 34 |     xs + ys = zipWith (+) xs ys
 35 |     xs * ys = zipWith (*) xs ys
 36 |     fromInteger x = replicate n (fromInteger x)
 37 |
 38 |   public export
 39 |   Num Unit where
 40 |     () + () = ()
 41 |     () * () = ()
 42 |     fromInteger x = ()
 43 |   
 44 |   public export
 45 |   Num a => Num b => Num (a, b) where
 46 |     (lFst, lSnd) * (rFst, rSnd) = (lFst * rFst, lSnd * rSnd)
 47 |     (+) (lFst, lSnd) (rFst, rSnd) = (lFst + rFst, lSnd + rSnd)
 48 |     fromInteger x = (fromInteger x, fromInteger x)
 49 |
 50 |   public export
 51 |   Num a => Num b => Num (DPair a (const b)) where
 52 |     (lFst ** lSnd* (rFst ** rSnd= (lFst * rFst ** lSnd * rSnd)
 53 |     (+) (lFst ** lSnd) (rFst ** rSnd= (lFst + rFst ** lSnd + rSnd)
 54 |     fromInteger x = (fromInteger x ** fromInteger x)
 55 |
 56 |   %hint
 57 |   public export
 58 |   depFunNum : {k : Fin n -> Type} ->
 59 |     {ss : (i : Fin n) -> Num (k i)} ->
 60 |      Num ((i : Fin n) -> k i)
 61 |   depFunNum = MkNum
 62 |     (\s, t => \i => s i + t i)
 63 |     (\f, g => \i => f i * g i)
 64 |     (\n => \i => fromInteger n)
 65 |
 66 |
 67 |
 68 | namespace Neg
 69 |   public export
 70 |   Neg Unit where
 71 |     negate () = ()
 72 |     () - () = ()
 73 |
 74 |   public export
 75 |   Neg a => Neg b => Neg (a, b) where
 76 |     negate (lFst, lSnd) = (negate lFst, negate lSnd)
 77 |     (lFst, lSnd) - (rFst, rSnd) = (lFst - rFst, lSnd - rSnd)
 78 |
 79 |   public export
 80 |   Neg a => Neg b => Neg (DPair a (const b)) where
 81 |     negate (fst ** snd= (negate fst ** negate snd)
 82 |     (fst ** snd- (rFst ** rSnd= (fst - rFst ** snd - rSnd)
 83 |
 84 |   %hint
 85 |   public export
 86 |   depFunNeg : {k : Fin n -> Type} ->
 87 |     {ss : (i : Fin n) -> Neg (k i)} ->
 88 |      Neg ((i : Fin n) -> k i)
 89 |   depFunNeg = MkNeg
 90 |     (\s => \i => negate (s i))
 91 |     (\f, g => \i => f i - g i)
 92 |
 93 | namespace FromDouble
 94 |   public export
 95 |   FromDouble Unit where
 96 |     fromDouble x = ()
 97 |
 98 |   public export
 99 |   FromDouble a => FromDouble b => FromDouble (a, b) where
100 |     fromDouble x = (fromDouble x, fromDouble x)
101 |   
102 |   public export
103 |   FromDouble a => FromDouble b => FromDouble (DPair a (const b)) where
104 |     fromDouble x = (fromDouble x ** fromDouble x)
105 |
106 |   %hint
107 |   public export
108 |   depFunFromDouble : {k : Fin n -> Type} ->
109 |     {ss : (i : Fin n) -> FromDouble (k i)} ->
110 |      FromDouble ((i : Fin n) -> k i)
111 |   depFunFromDouble = MkFromDouble
112 |     (\s => \i => fromDouble s)
113 |
114 | namespace Fractional
115 |   public export
116 |   Fractional Unit where
117 |     () / () = ()
118 |     
119 |   public export 
120 |   Fractional a => Fractional b => Fractional (a, b) where
121 |     (lFst, lSnd) / (rFst, rSnd) = (lFst / rFst, lSnd / rSnd)
122 |
123 |   public export
124 |   Fractional a => Fractional b => Fractional (DPair a (const b)) where
125 |     (fst ** snd/ (rFst ** rSnd= (fst / rFst ** snd / rSnd)
126 |
127 | namespace Sqrt
128 |   public export
129 |   Sqrt Unit where
130 |     sqrt () = ()
131 |
132 |   public export
133 |   Sqrt a => Sqrt b => Sqrt (a, b) where
134 |     sqrt (lFst, lSnd) = (sqrt lFst, sqrt lSnd)
135 |
136 |   public export
137 |   Sqrt a => Sqrt b => Sqrt (DPair a (const b)) where
138 |     sqrt (fst ** snd= (sqrt fst ** sqrt snd)