0 | module Data.Autodiff.Ops
  1 |
  2 | import Data.Tensor
  3 | import Data.Container.Additive
  4 | import Data.Para
  5 | import Control.Monad.Distribution
  6 | import Control.Monad.Identity
  7 | import Data.ComMonoid
  8 |
  9 | import Misc
 10 |
 11 | %hide Data.Container.Base.Morphism.Definition.DependentLenses.(=%>)
 12 | %hide Syntax.WithProof.prefix.(@@)
 13 |
 14 | -- This is here and not in Container.Additive because it uses `Tensor` 
 15 | public export
 16 | Simplex : Nat -> AddCont
 17 | Simplex n = MkAddCont $ (_ : Dist n) !> (Tensor [TTInternalName ~~> n] Double)
 18 |
 19 | public export
 20 | MulParametric : {a : Type} -> Num a => ParaAddDLens (Const a) (Const a)
 21 | MulParametric = binaryOpToPara {p=Const a} Mul
 22 |
 23 | public export
 24 | AddParametric : {a : Type} -> Num a => ParaAddDLens (Const a) (Const a)
 25 | AddParametric = binaryOpToPara {p=Const a} Sum
 26 |
 27 | public export
 28 | AffineParametric : {a : Type} -> Num a => ParaAddDLens (Const a) (Const a)
 29 | AffineParametric = composePara MulParametric AddParametric
 30 |
 31 | public export
 32 | LeakyReLU : {a : Type} -> Num a => Ord a => FromDouble a =>
 33 |   {default 0.01 alpha : a} ->
 34 |   ParaAddDLens (Const a) (Const a)
 35 | LeakyReLU = trivialParam (!%+ \x =>
 36 |   (if x > 0 then x else alpha * x ** \x' => if x > 0 then x' else alpha))
 37 |
 38 | public export
 39 | LeakyReLUTensor : {a : Type} -> Num a => Ord a => FromDouble a =>
 40 |   {default 0.01 alpha : a} ->
 41 |   {n : Axis} -> TensorMonoid n.cont =>
 42 |   ParaAddDLens (Const (Tensor [n] a)) (Const (Tensor [n] a))
 43 | LeakyReLUTensor = trivialParam (!%+ \x =>
 44 |   (x <&> (\xx => if xx > 0 then xx else alpha * xx) ** \dy =>
 45 |     (\(d, xx) => if xx > 0 then d else alpha * d) <$> liftA2Tensor dy x))
 46 |
 47 |
 48 | public export
 49 | coproductPair : {a, b, c, d : AddCont} ->
 50 |   ParaAddDLens a c ->
 51 |   ParaAddDLens b d ->
 52 |   ParaAddDLens (a >+< b) (c >+< d)
 53 | coproductPair (MkPara p f) (MkPara q g) = MkPara
 54 |   (p >< q)
 55 |   (coprodDistrOverTensor %>> (f >+< g))
 56 |
 57 | public export
 58 | parallelTensor2 : {a, b: Type} -> Num a => Num b => {axisName : String} ->
 59 |   ParaAddDLens (Const a) (Const b) ->
 60 |   ParaAddDLens (Const (Tensor [axisName ~~> 2] a))
 61 |                (Const (Tensor [axisName ~~> 2] b))
 62 | parallelTensor2 (MkPara pCont f) = MkPara
 63 |   (pCont >< pCont)
 64 |   (!%+ \(x, (p, q)) =>
 65 |     let (b1 ** kf= (%!) f (x @@ [0], p)
 66 |         (b2 ** kg= (%!) f (x @@ [1], q)
 67 |     in (># [b1, b2] ** \bs' =>
 68 |       let (x1', p') = kf (bs' @@ [0])
 69 |           (x2', q') = kg (bs' @@ [1])
 70 |       in (># [x1', x2'], (p', q'))))
 71 |
 72 | public export
 73 | parallelTensor3 : {a, b : Type} -> Num a => Num b => {axisName : AxisName} ->
 74 |   ParaAddDLens (Const a) (Const b) ->
 75 |   ParaAddDLens (Const (Tensor [axisName ~~> 3] a))
 76 |                (Const (Tensor [axisName ~~> 3] b))
 77 | parallelTensor3 (MkPara pCont f) = MkPara
 78 |   (pCont >< pCont >< pCont)
 79 |   (!%+ \(x, (p, q, r)) =>
 80 |     let (b1 ** kf= (%!) f (x @@ [0], p)
 81 |         (b2 ** kg= (%!) f (x @@ [1], q)
 82 |         (b3 ** kh= (%!) f (x @@ [2], r)
 83 |     in (># [b1, b2, b3] ** \bs' =>
 84 |       let (x1', p') = kf (bs' @@ [0])
 85 |           (x2', q') = kg (bs' @@ [1])
 86 |           (x3', r') = kh (bs' @@ [2])
 87 |       in (># [x1', x2', x3'], (p', (q', r')))))
 88 |
 89 | ||| Produces a parametric map that produces `n` copies of the output, instead
 90 | ||| of one, by using `n` different parameters
 91 | public export
 92 | sameFromTensor2 : {a, b : Type} -> Num a => Num b =>
 93 |   {axisName1, axisName2 : AxisName} ->
 94 |   ParaAddDLens (Const a) (Const b) ->
 95 |   ParaAddDLens (Const (Tensor [axisName1 ~~> 1] a))
 96 |                (Const (Tensor [axisName2 ~~> 2] b))
 97 | sameFromTensor2 (MkPara pCont f) = MkPara
 98 |   (pCont >< pCont)
 99 |   (!%+ \(x, (p, q)) =>
100 |     let val = x @@ [0]
101 |         (b1 ** kf= (%!) f (val, p)
102 |         (b2 ** kg= (%!) f (val, q)
103 |     in (># [b1, b2] ** \bs' =>
104 |       let (x1', p') = kf (bs' @@ [0])
105 |           (x2', q') = kg (bs' @@ [1])
106 |       in (># [x1' + x2'], (p', q'))))
107 |
108 | public export
109 | sameFromTensor3 : {a, b : Type} -> Num a => Num b =>
110 |   {axisName1, axisName2 : AxisName} ->
111 |   ParaAddDLens (Const a) (Const b) ->
112 |   ParaAddDLens (Const (Tensor [axisName1 ~~> 1] a))
113 |                (Const (Tensor [axisName2 ~~> 3] b))
114 | sameFromTensor3 (MkPara pCont f) = MkPara
115 |   (pCont >< pCont >< pCont)
116 |   (!%+ \(x, (p, q, r)) =>
117 |     let val = x @@ [0]
118 |         (b1 ** kf= (%!) f (val, p)
119 |         (b2 ** kg= (%!) f (val, q)
120 |         (b3 ** kh= (%!) f (val, r)
121 |     in (># [b1, b2, b3] ** \bs' =>
122 |       let (x1', p') = kf (bs' @@ [0])
123 |           (x2', q') = kg (bs' @@ [1])
124 |           (x3', r') = kh (bs' @@ [2])
125 |       in (># [x1' + x2' + x3'], (p', q', r'))))
126 |
127 | ||| Produces a parametric map that produces `n` copies of the output, instead
128 | ||| of one, by using `n` different parameters
129 | public export
130 | sameFromTensor : {a, b : Type} -> Num a => Num b => {n : Nat} -> 
131 |   {axisName1, axisName2 : AxisName} ->
132 |   ParaAddDLens (Const a) (Const b) ->
133 |   ParaAddDLens (Const (Tensor [axisName1 ~~> 1] a))
134 |                (Const (Tensor [axisName2 ~~> n] b))
135 | sameFromTensor (MkPara pCont f) = MkPara
136 |   (AllAll $ replicate n pCont)
137 |   (!%+ \(x, psShapes) =>
138 |     let val = x @@ [0]
139 |         outAndBw = runIdentity $ dTraverse
140 |             (\p => Id $ (%!) f (val, p))
141 |             (allToVect psShapes)
142 |         out = mapPropertyRelevant (\_, (y ** bw=> y) outAndBw
143 |         bw = mapPropertyRelevant (\_, (y ** bw=> bw) outAndBw
144 |     in (># constantToVect out ** \bs' =>
145 |       let tt = bw 
146 |       in ?bww))
147 |
148 | public export
149 | sameFrom : {a : AddCont} -> ParaAddDLens a b ->
150 |   ParaAddDLens a c ->
151 |   ParaAddDLens a (b >< c)
152 | sameFrom (MkPara p f) (MkPara q g) = MkPara
153 |   (p >< q)
154 |   (!%+ \(x, (p, q)) =>
155 |     let (b ** kf= (%!) f (x, p)
156 |         (c ** kg= (%!) g (x, q)
157 |     in ((b, c) ** \(b', c') =>
158 |       let (x'1, p') = kf b'
159 |           (x'2, q') = kg c'
160 |       in (a.Plus x x'1 x'2, (p', q'))))
161 |
162 | public export
163 | sameFromConst : {a, b, c : Type} -> Num a => Num b => Num c =>
164 |   ParaAddDLens (Const a) (Const b) ->
165 |   ParaAddDLens (Const a) (Const c) ->
166 |   ParaAddDLens (Const a) (Const (b, c))
167 | sameFromConst (MkPara p f) (MkPara q g) = MkPara
168 |   (p >< q)
169 |   (!%+ \(x, (p, q)) =>
170 |     let (b ** kf= (%!) f (x, p)
171 |         (c ** kg= (%!) g (x, q)
172 |     in ((b, c) ** \(b', c') =>
173 |       let (x'1, p') = kf b'
174 |           (x'2, q') = kg c'
175 |       in (x'1 + x'2, (p', q'))))
176 |
177 | public export
178 | sameFrom3 : {a : AddCont} -> ParaAddDLens a b ->
179 |   ParaAddDLens a c ->
180 |   ParaAddDLens a d ->
181 |   ParaAddDLens a (b >< c >< d)
182 | sameFrom3 (MkPara p f) (MkPara q g) (MkPara r h) = MkPara
183 |   (p >< q >< r)
184 |   (!%+ \(x, (p, q, r)) =>
185 |     let (b ** kf= (%!) f (x, p)
186 |         (c ** kg= (%!) g (x, q)
187 |         (d ** kh= (%!) h (x, r)
188 |     in ((b, c, d) ** \(b', c', d') =>
189 |       let (x'1, p') = kf b'
190 |           (x'2, q') = kg c'
191 |           (x'3, r') = kh d'
192 |       in (a.Plus x (a.Plus x x'1 x'2) x'3, (p', q', r'))))
193 |
194 | public export
195 | sameFromConst3 : {a, b, c, d : Type} -> Num a => Num b => Num c => Num d =>
196 |   ParaAddDLens (Const a) (Const b) ->
197 |   ParaAddDLens (Const a) (Const c) ->
198 |   ParaAddDLens (Const a) (Const d) ->
199 |   ParaAddDLens (Const a) (Const (b, c, d))
200 | sameFromConst3 (MkPara p f) (MkPara q g) (MkPara r h) = MkPara
201 |   (p >< q >< r)
202 |   (!%+ \(x, (p, q, r)) =>
203 |     let (b ** kf= (%!) f (x, p)
204 |         (c ** kg= (%!) g (x, q)
205 |         (d ** kh= (%!) h (x, r)
206 |     in ((b, c, d) ** \(b', c', d') =>
207 |       let (x'1, p') = kf b'
208 |           (x'2, q') = kg c'
209 |           (x'3, r') = kh d'
210 |       in (x'1 + x'2 + x'3, (p', q', r'))))
211 |
212 | -- ||| N-ary probability intro and elimination
213 | -- NProbIntro : {ef : EffectType} ->
214 | --   {i : Nat} -> IsSucc i =>
215 | --   {ts : Vect i Ty} ->
216 | --   All (\t => Term ef ctx t) ts ->  -- for now all the components need to run with the same effect
217 | --   -- Treating probability as logits
218 | --   Vect i (Term ef ctx Number) ->
219 | --   Term Prob ctx (NProb ts)
220 | -- NProbElim : {ef : EffectType} ->
221 | --   {i : Nat} -> IsSucc i =>
222 | --   {ts : Vect i Ty} ->
223 | --   Term ef ctx (NProb ts) ->
224 | --   All (\e => Term ef (e :: ctx) c) ts -> Term Prob ctx c