0 | module Data.CT.DependentPara.Instances
  1 |
  2 | import Data.DPair
  3 | import Data.CT.Category.Definition
  4 | import Data.CT.Functor.Definition
  5 | import Data.CT.DependentAction.Definition
  6 | import Data.CT.DependentPara.Definition
  7 | import Data.CT.Category.Instances
  8 | import Data.CT.Functor.Instances
  9 | import Data.CT.DependentAction.Instances
 10 |
 11 | import Data.Container.Base
 12 | import Data.Container.Additive
 13 |
 14 | public export infixr 1 -\-> -- dependent parametric functions
 15 | public export infixr 1 -\--> -- non-dependent parametric functions
 16 | public export infixr 1 =\\=> -- dependent parametric (additive) dependent lenses
 17 | public export infixr 1 =\\==> -- non-dependent parametric (additive) dependent lenses
 18 |
 19 | {-------------------------------------------------------------------------------
 20 | {-------------------------------------------------------------------------------
 21 | Instead going in and defining full-blown definitions of dependent actegories with units and coherences we instead leverage the main definition in the background and only instantiate cases, manually:
 22 | one for parametric functions and one for parametric additive dependent lenses.
 23 | We instantiate them using same names in different namespaces, and leverage Idris' name resolution mechanisms to allow the user to use the same name and
 24 | reduce cognitive load
 25 | -------------------------------------------------------------------------------}
 26 | -------------------------------------------------------------------------------}
 27 |
 28 | namespace ParametricFunctions
 29 |   public export
 30 |   Para : (a, b : Type) -> Type
 31 |   Para = DepParaMor PairType
 32 |
 33 |   ||| Infix notation for non-dependent parametric functions
 34 |   ||| We interpret the extra "-" as a mental symbol for "flat",
 35 |   ||| i.e. "non-dependent"
 36 |   public export
 37 |   (-\-->) : (a, b : Type) -> Type
 38 |   a -\--> b = Para a b
 39 |
 40 |   ||| Parametric functions
 41 |   ||| "Usual" dependent para on sets and functions
 42 |   public export
 43 |   DPara : (a, b : Type) -> Type 
 44 |   DPara = DepParaMor DPairType
 45 |   
 46 |   ||| Infix notation for dependent parametric functions
 47 |   ||| We interpret the crossed line as a parameter coming in from the top
 48 |   public export
 49 |   (-\->) : (a, b : Type) -> Type
 50 |   a -\-> b = DPara a b
 51 |
 52 |   public export
 53 |   trivialParam : (a -> b) -> a -\-> b
 54 |   trivialParam f = MkPara 
 55 |     (\_ => Unit)
 56 |     (\(a ** ()=> f a)
 57 |
 58 |   public export
 59 |   id : a -\-> a
 60 |   id = trivialParam id
 61 |
 62 |   public export
 63 |   composePara : a -\-> b -> b -\-> c -> a -\-> c
 64 |   composePara (MkPara p f) (MkPara q g) = MkPara
 65 |     (\x => DPair (p x) (\p' => q (f (x ** p'))) )
 66 |     (\(x ** (p' ** q')) => g (f (x ** p'** q'))
 67 |   
 68 |   public export infixr 10 \>>
 69 |   
 70 |   public export
 71 |   (\>>) : a -\-> b -> b -\-> c -> a -\-> c
 72 |   (\>>) = composePara
 73 |
 74 |   public export
 75 |   reparam : (pf : a -\-> b) ->
 76 |     {q : a -> Type} ->
 77 |     (r : (x : a) -> q x -> pf.Param x) ->
 78 |     a -\-> b
 79 |   reparam (MkPara p f) r = MkPara q (\(x ** qq=> f (x ** (r x qq)))
 80 |
 81 |   public export
 82 |   Param : DPara a b -> a -> Type
 83 |   Param = DepParaMor.Param
 84 |   
 85 |   public export
 86 |   Run : (pf : DPara a b) -> (x : a) -> Param pf x -> b
 87 |   Run pf = DPair.curry (DepParaMor.Run pf)
 88 |
 89 |   public export
 90 |   data IsNotDependent : DPara a b -> Type where
 91 |     MkNonDep : (p : Type) -> (f : DPair a (const p) -> b) ->
 92 |       IsNotDependent (MkPara (\_ => p) f)
 93 |   
 94 |   public export
 95 |   GetNonDep : (pf : DPara a b) ->
 96 |     IsNotDependent pf => (p : Type ** DPair a (const p) -> b)
 97 |   GetNonDep _ @{MkNonDep p f} = (p ** f)
 98 |
 99 |   ||| Get the parameter of a non-dependent parametric function
100 |   public export
101 |   GetParam : (pf : DPara a b) ->
102 |     IsNotDependent pf => Type
103 |   GetParam _ @{MkNonDep p f} = p
104 |
105 |
106 |   public export
107 |   composeNTimes : Nat -> a -\-> a -> a -\-> a
108 |   composeNTimes 0 f = id
109 |   composeNTimes 1 f = f -- to get rid of the annoying Unit parameter
110 |   composeNTimes (S k) f = composePara f (composeNTimes k f)
111 |
112 |   public export
113 |   binaryOpToPara : {p : Type} -> (f : (a, p) -> b) -> a -\-> b
114 |   binaryOpToPara f = MkPara
115 |     (\_ => p)
116 |     (\(x ** p'=> f (x, p'))
117 |
118 | namespace ParametricDependentLenses
119 |   ||| DParametric dependent lenses
120 |   ||| Not really used in this repo
121 |   public export
122 |   DParaDLens : (a, b : Cont) -> Type
123 |   DParaDLens = DepParaMor DPairCont
124 |
125 |   public export
126 |   ParaDLens : (a, b : Cont) -> Type
127 |   ParaDLens = DepParaMor PairCont
128 |
129 |   public export
130 |   ParaAddDLens : (a, b : AddCont) -> Type
131 |   ParaAddDLens = DepParaMor PairAddCont
132 |
133 |   public export
134 |   (=\\==>) : (a, b : AddCont) -> Type
135 |   a =\\==> b = ParaAddDLens a b
136 |
137 |   public export
138 |   trivialParam : {0 a, b : AddCont} -> (a =%> b) -> a =\\==> b
139 |   trivialParam f = MkPara
140 |     Scalar
141 |     (!%+ \(x, ()) =>
142 |       let (y ** ky= (%!) f x
143 |       in (y ** \y' => (ky y', ())))
144 |
145 |   public export
146 |   id : a =\\==> a
147 |   id = trivialParam id
148 |
149 |   public export
150 |   GetParam : ParaAddDLens a b -> AddCont
151 |   GetParam (MkPara p _) = p
152 |
153 |   public export
154 |   toHomRepresentation : (f : ParaAddDLens a b) ->
155 |     (GetParam f) =%> InternalLensAdditive a b
156 |   toHomRepresentation (MkPara pType f) = !%+ \p =>
157 |     (!%+ \a => (f.fwd (a, p) ** \b' => fst (f.bwd (a, p) b')**
158 |       \l => foldr (\(a ** b'=> pType.Plus p (snd (f.bwd (a, p) b'))) (pType.Zero p) l)
159 |
160 |   public export
161 |   composePara : a =\\==> b -> b =\\==> c -> a =\\==> c
162 |   composePara (MkPara p f) (MkPara q g) = MkPara
163 |     (p >< q)
164 |     (!%+ \(x, (ps, qs)) => 
165 |       (g.fwd (f.fwd (x, ps), qs) ** \cPos =>
166 |         let (bPos, qPos) = g.bwd (f.fwd (x, ps), qs) cPos
167 |             (aPos, pPos) = f.bwd (x, ps) bPos
168 |         in (aPos, (pPos, qPos))))
169 |
170 |
171 | namespace DependentParametricDependentLenses
172 |
173 |   ||| Used in this repo, as all neural networks are additive dependent lenses
174 |   public export
175 |   DParaAddDLens : (a, b : AddCont) -> Type
176 |   DParaAddDLens = DepParaMor DPairAddCont
177 |
178 |   ||| Infix notation for additive parametric dependent lenses
179 |   public export
180 |   (=\\=>) : (a, b : AddCont) -> Type
181 |   a =\\=> b = DParaAddDLens a b
182 |   
183 |   public export
184 |   trivialParam : {0 a, b : AddCont} -> (a =%> b) -> a =\\=> b
185 |   trivialParam f = MkPara
186 |     (\_ => Scalar)
187 |     (!% !% \(x ** ()=> let (y ** ky= (%!) f x in (y ** \y' => (ky y', ())))
188 |
189 |   public export
190 |   id : a =\\=> a
191 |   id = trivialParam id
192 |   
193 |   public export
194 |   composePara : a =\\=> b -> b =\\=> c -> a =\\=> c
195 |   composePara (MkPara p f) (MkPara q g) = MkPara
196 |     (\x => DepHancockProduct (p x) (\ps => q (f.fwd (x ** ps))))
197 |     (!%+ \(x ** (ps ** qs)) =>
198 |       (g.fwd (f.fwd (x ** ps** qs** \cPos =>
199 |         let (bPos, qPos) = g.bwd (f.fwd (x ** ps** qscPos
200 |             (aPos, pPos) = f.bwd (x ** psbPos
201 |         in (aPos, (pPos, qPos))))
202 |
203 |   public export infixr 10 &>>
204 |
205 |   public export
206 |   (&>>) : a =\\=> b -> b =\\=> c -> a =\\=> c
207 |   (&>>) = composePara
208 |
209 |   ||| A predicate witnessing that a parametric additive dependent lens has
210 |   ||| a non-dependent (constant) parameter.
211 |   public export
212 |   data IsNotDependent : DParaAddDLens a b -> Type where
213 |     MkNonDep : (p : AddCont) -> (f : DepHancockProduct a (const p) =%> b) ->
214 |       IsNotDependent {a=a} {b=b} (MkPara (\_ => p) f)
215 |   
216 |   public export
217 |   GetNonDep : (pf : DParaAddDLens a b) ->
218 |     IsNotDependent pf => (pc : AddCont ** DepHancockProduct a (const pc) =%> b)
219 |   GetNonDep _ @{MkNonDep pc f} = (pc ** f)
220 |
221 |   public export
222 |   GetParam : (pf : DParaAddDLens a b) ->
223 |     IsNotDependent pf => AddCont
224 |   GetParam (MkPara (const p) f) @{MkNonDep p f} = p
225 |
226 |   public export
227 |   toHomRepresentation : (pf : DParaAddDLens a b) ->
228 |     IsNotDependent pf =>
229 |     GetParam pf =%> (InternalLensAdditive a b)
230 |   toHomRepresentation (MkPara (const pc) f) @{MkNonDep pc f}
231 |     = !%+ \p => (!%+ \x => (f.fwd (x ** p** \b' => fst (f.bwd (x ** pb')** \l => foldr (\(x ** b'=> pc.Plus p (snd (f.bwd (x ** pb'))) (pc.Zero p) l)
232 |
233 |   public export
234 |   composeNTimes : Nat -> a =\\=> a -> a =\\=> a
235 |   composeNTimes 0 f = id
236 |   composeNTimes 1 f = f -- to get rid of the annoying Unit parameter
237 |   composeNTimes (S k) f = composePara f (composeNTimes k f)
238 |
239 |   ||| Convert a morphism from product container to one from DepHancockProduct
240 |   ||| This witnesses the isomorphism (a >< p) ≅ DepHancockProduct a (const p)
241 |   public export
242 |   fromNonDepProduct : {0 a, p, b : AddCont} ->
243 |     (a >< p) =%> b -> DepHancockProduct a (const p) =%> b
244 |   fromNonDepProduct f = !%+ \(x ** p'=> (%!) f (x, p')
245 |
246 |   public export
247 |   binaryOpToPara : {p : AddCont} ->
248 |     (a >< p) =%> b -> a =\\==> b
249 |   binaryOpToPara f = MkPara p f
250 |
251 |   %hide Data.Container.Base.Morphism.Definition.DependentLenses.(=%>)
252 |
253 | -- public export
254 | -- dependentMap : {t : a -> Type} -> (f : (x : a) -> t x) ->
255 | --   Vect n a -> Vect n (x : a ** t x)
256 | -- dependentMap f [] = []
257 | -- dependentMap f (x :: xs) = (x ** f x) :: dependentMap f xs
258 | -- 
259 | -- public export infixr 10 <$^>
260 | -- public export
261 | -- (<$^>) : {t : a -> Type} -> (f : (x : a) -> t x) ->
262 | --   Vect n a -> Vect n (x : a ** t x)
263 | -- (<$^>) f xs = dependentMap f xs
264 |
265 |
266 | -- composePara_rhs_1 : (p : Vect n Type) -> (q : Vect m Type)
267 | --   -> (a -> All Prelude.id p -> b)
268 | --   -> (b -> All Prelude.id q -> c)
269 | --   -> (a -> All Prelude.id (p ++ q) -> c)
270 | -- composePara_rhs_1 [] [] f g a [] = ?composePara_rhs_1_rhs_2
271 | -- composePara_rhs_1 [] (q :: ws) f g a (pq :: pqs) = ?composePara_rhs_1_rhs_3
272 | -- composePara_rhs_1 (p :: ps) q f g a pq = ?composePara_rhs_1_rhs_1
273 | -- 
274 | -- composePara : Para a n b -> Para b m c -> Para a (n + m) c
275 | -- composePara (MkPara p f) (MkPara q g) = MkPara (p ++ q) (composePara_rhs_1 p q f g)
276 |
277 |