0 | module Data.Tensor.Axis
  1 |
  2 | import public Decidable.Equality
  3 | import Data.Vect.Elem
  4 | import Data.Vect.Quantifiers
  5 |
  6 | import Data.Container.Base
  7 | import Data.Unique.Vect
  8 | import Misc
  9 |
 10 | {-------------------------------------------------------------------------------
 11 | {-------------------------------------------------------------------------------
 12 |
 13 | ~~~~~~~~~~~~~~~
 14 | Design choices:
 15 | ~~~~~~~~~~~~~~~
 16 |
 17 | 1) Persistent axis names.
 18 |
 19 | Instead of transient axis names (bound within a function using the tensor, erased with the completion of the said function), axis names persist with the lifetime of the tensor.
 20 |
 21 | 2) Axis declarations persist globally, but are only checked for consistency at call sites.
 22 |
 23 | This means that axis names are checked for consistency at each call site, rather than at declaration sites. In a proper programming language we'd track names at declaration sites and raise errors if inconsistencies/duplicates are detected, here we opt for a more pragmatic approach.
 24 |
 25 | 3) Duplicate axis names within a tensor are allowed, as long as they refer to the same container.
 26 |
 27 | Otherwise it would not be clear how to take the diagonal/trace of a matrix while referring only to the axes: they'd have to have different names.
 28 |
 29 | 4) Does tensor contraction allow duplicate axis names?
 30 |
 31 |
 32 |
 33 | Does tensor contraction allow duplicate axis names
 34 |   * in the input (yes, this is what Einsum also allows)
 35 |   * in the output (no, because otherwise its not clear what should happen)
 36 |     * this means that we can't write `einsum("i,i->ii")`
 37 | 3) How does contraction work?
 38 |   3.1) Given `t : Tensor [BatchSize, BatchSize] Double`, what is `dotGeneral t`?
 39 |
 40 | Need to figure out how `reduce name t` acts when:
 41 | 1) `name="BatchSize"` and `t : Tensor [BatchSize, BatchSize] Double`
 42 |   - Should sum up the diagonal?
 43 | 2) `name="BatchSize"` and `t : Tensor [BatchSize] Double`
 44 |   - Should sum up the vector?
 45 | 3) `name="BatchSize"` and `t : Tensor [BatchSize, SeqLen, BatchSize] Double`
 46 |   - Should sum up the diagonal slices of SeqLen
 47 |
 48 | I suppose this is about iterators
 49 | iterating through 
 50 |
 51 |
 52 |
 53 |
 54 | --- Consistency checking: ----------------- 
 55 | We check consistency at each call site.
 56 | Alternatively if we were building a programming languge we'd check consistency with each declaration. That is, writing something like:
 57 | ```idris
 58 | BatchSize1 : Axis
 59 | BatchSize1 = "batchSize" ~> Vect 128
 60 |
 61 | BatchSize2 : Axis
 62 | BatchSize2 = "batchSize" ~> Vect 129
 63 | ```
 64 | would throw an error on the line `BatchSize2 = ...` because we're redeclaring "batchSize" which already exists.
 65 |
 66 | ------------------------------------------- 
 67 |
 68 | Similar projects/ideas:
 69 | * XArray: https://docs.xarray.dev/en/stable/ (persistent axis names)
 70 | * Haliax: https://github.com/marin-community/haliax
 71 |
 72 | -------------------------------------------------------------------------------}
 73 | -------------------------------------------------------------------------------}
 74 |
 75 | ||| The name for an axis is an arbitrary string
 76 | public export
 77 | AxisName : Type
 78 | AxisName = String
 79 |
 80 | public export infixr 0 ~> -- Constructor for container-based axes
 81 | public export infixr 0 ~~> -- 'Constructor' for cubical axes
 82 |
 83 | ||| An axis is a container (the "size" of the axis) together with its name
 84 | public export
 85 | record Axis where
 86 |   constructor (~>)
 87 |   name : AxisName
 88 |   cont : Cont
 89 |
 90 | public export
 91 | rename : Axis -> AxisName -> Axis
 92 | rename a str = str ~> a.cont
 93 |
 94 |
 95 | ||| In some cases we TensorType might need to assign a default name to an axis,
 96 | ||| one which is internal and will not be exposed to the user.
 97 | ||| This is the default name for such cases
 98 | public export
 99 | TTInternalName : AxisName
100 | TTInternalName = "__tensortype_tempaxis__"
101 |
102 |
103 | namespace Cubical
104 |   ||| A "constructor" for cubical axes
105 |   public export
106 |   (~~>) : AxisName -> Nat -> Axis
107 |   (~~>) axisName n = axisName ~> Vect n
108 |
109 |   ||| Follows the pattern of `IsCubical` from `Data.Container.Object.Instances`
110 |   public export
111 |   data IsCubical : Axis -> Type where
112 |     MkIsCubical : (name : AxisName) -> (n : Nat) -> IsCubical (name ~~> n)
113 |
114 |   public export
115 |   dimHelper : {0 a : Axis} -> IsCubical a -> Nat
116 |   dimHelper (MkIsCubical _ n) = n
117 |
118 |   public export
119 |   dim : (0 a : Axis) -> IsCubical a => Nat
120 |   dim _ @{ic} = dimHelper ic
121 |
122 |   public export
123 |   data IsNaperian : Axis -> Type where
124 |     MkIsNaperian : (name : AxisName) -> (pos : Type) ->
125 |       IsNaperian (name ~> Nap pos)
126 |
127 |   public export
128 |   LogHelper : {0 a : Axis} -> IsNaperian a => Type
129 |   LogHelper @{MkIsNaperian _ pos} = pos
130 |
131 |   public export
132 |   Log : (0 a : Axis) -> IsNaperian a => Type
133 |   Log a @{inn} = LogHelper @{inn}
134 |
135 |   public export
136 |   toContNaperian : {0 a : Axis} -> IsNaperian a ->  IsNaperian a.cont
137 |   toContNaperian (MkIsNaperian name pos) = MkIsNaperian pos
138 |
139 |   public export
140 |   cubicalShapeHelper : {0 shape : Vect r Axis} ->
141 |     All IsCubical shape -> List Nat
142 |   cubicalShapeHelper [] = []
143 |   cubicalShapeHelper (ic :: ns) = dimHelper ic :: cubicalShapeHelper ns
144 |
145 |   ||| Given a list of cubical axes, return the list of their dimensions
146 |   public export
147 |   cubicalShape : (0 shape : Vect r Axis) -> All IsCubical shape => List Nat
148 |   cubicalShape _ @{ac} = cubicalShapeHelper ac
149 |
150 |   ||| Size of a cubical tensor, i.e. its number of elements
151 |   public export
152 |   size : (0 shape : Vect r Axis) -> (ac : All IsCubical shape) => Nat
153 |   size ss = prod (cubicalShape ss)
154 |
155 |
156 | namespace TensorShape
157 |   mutual
158 |     public export
159 |     data TensorShape : (rank : Nat) ->Type where
160 |       Nil : TensorShape 0
161 |       (::) : (a : Axis) -> (as : TensorShape k) ->
162 |         (ac : NewAxisConsistent a as) =>
163 |         TensorShape (S k)
164 |
165 |     public export
166 |     toVect : TensorShape k -> Vect k Axis
167 |     toVect [] = []
168 |     toVect (a :: as) = a :: toVect as
169 |
170 |     public export
171 |     data NewAxisConsistent : Axis -> TensorShape k -> Type where
172 |       NewAxis : {0 a : Axis} -> {0 as : TensorShape k} ->
173 |         NotElem a.name (Axis.name <$> toVect as) ->
174 |         NewAxisConsistent a as
175 |       ExistingAxis : {0 a : Axis} -> {0 as : TensorShape k} ->
176 |         (e : Elem a.name (Axis.name <$> toVect as)) ->
177 |         (index (elemToFin e) (toVect as)).cont = a.cont ->
178 |         NewAxisConsistent a as
179 |
180 |   public export
181 |   toList : TensorShape k -> List Axis
182 |   toList [] = []
183 |   toList (a :: as) = a :: toList as
184 |
185 |   ||| Convenience function, turns it also into a list
186 |   ||| Because `Data.Container` uses lists with tensors
187 |   public export
188 |   conts : TensorShape k -> List Cont
189 |   conts ts = cont <$> toList ts
190 |
191 |   ||| Names of the axes in a tensor shape
192 |   public export
193 |   axisNames : TensorShape k -> Vect k AxisName
194 |   axisNames ts = name <$> toVect ts
195 |
196 |   ||| Sizes of the axes in a tensor shape
197 |   public export
198 |   axisSizes : TensorShape k -> Vect k Cont
199 |   axisSizes ts = cont <$> toVect ts
200 |
201 |   ||| Size of a tensor shape, i.e. its number of elements
202 |   public export
203 |   size : (shape : TensorShape k) -> All IsCubical (conts shape) => Nat
204 |   size shape = size (conts shape)
205 |
206 |   test1 : TensorShape 2
207 |   test1 = ["batchSize" ~> Vect 128, "seqLen" ~> List]
208 |
209 |   test2 : TensorShape 3
210 |   test2 = ["batchSize" ~> Vect 128, "seqLen" ~> List, "batchSize" ~> Vect 128]
211 |
212 |   failing
213 |     test3 : TensorShape 2
214 |     test3 = ["batchSize" ~> Vect 128, "batchSize" ~> Vect 13]
215 |
216 |   -- ||| If an axis `i` can be added into a singleton list `[j]`, then
217 |   -- ||| the axis `j` can be added into a singleton list `[i]`
218 |   -- public export
219 |   -- axisConsistentSym : {i, j : Axis} ->
220 |   --   NewAxisConsistent i [j] -> NewAxisConsistent j [i]
221 |   -- axisConsistentSym (NewAxis ne) = NewAxis (notElemSym ne)
222 |   -- -- For some reason we can't pattern match on `Here`? The proof should still 
223 |   -- -- be fine... 
224 |   -- axisConsistentSym (ExistingAxis (There Here) _) impossible
225 |   -- axisConsistentSym (ExistingAxis (There (There later)) _) impossible
226 |
227 |   ||| Proof that an axis name appears in a tensor shape n times
228 |   ||| The proof indirectly carries data of the exact indices where it appears
229 |   ||| Notably, can appear zero times, this case is needed for recursion
230 |   public export
231 |   data InShape : AxisName -> TensorShape k -> Nat -> Type where
232 |     Here : {as : TensorShape k} -> InShape axisName as n =>
233 |       NewAxisConsistent (axisName ~> a) as =>
234 |       InShape axisName ((axisName ~> a) :: as) (S n)
235 |     There : {as : TensorShape k} -> InShape axisName as n =>
236 |       NewAxisConsistent a as =>
237 |       InShape axisName (a :: as) n
238 |
239 |
240 |   ||| Recovers the axis from a shape given its name, and a prof that it is there
241 |   ||| Recovers the first occurence
242 |   public export
243 |   (.getByName) : (shape : TensorShape k) ->
244 |     (axisName : AxisName) -> 
245 |     (inShape : InShape axisName shape n) ->
246 |     IsSucc n =>
247 |     Axis
248 |   (.getByName) ((axisName ~> a) :: as) axisName Here = axisName ~> a
249 |   (.getByName) (a :: as) axisName (There @{is}) = as.getByName axisName is
250 |
251 |   public export
252 |   removeAllOccurrences : {k, rank : Nat} ->(shape : TensorShape rank) ->
253 |     (toDelete : AxisName) ->
254 |     (inShape : InShape toDelete shape k) =>
255 |     (m : Nat ** TensorShape m)
256 |   removeAllOccurrences {k=0} shape toDelete = (rank ** shape)
257 |   removeAllOccurrences ((toDelete ~> a) :: ss) toDelete @{Here @{is}}
258 |     = removeAllOccurrences ss toDelete @{is}
259 |   removeAllOccurrences (s :: ss) toDelete @{There @{is}}
260 |     = let (m ** ss'= removeAllOccurrences ss toDelete @{is}
261 |       in (S m ** (::) {ac=(believe_me ())} s ss'-- should write this later
262 |
263 |
264 |   ||| TODO rethink this function?
265 |   ||| In a tensor shape removes all but the first occurence of an axis
266 |   ||| removeDuplicates ["x" ~> 1, "y" ~> 3, "x" ~> 1] "x" = ["x" ~> 1, "y" ~> 1]
267 |   public export
268 |   removeDuplicates : {k, rank : Nat} -> (shape : TensorShape rank) ->
269 |     (axisName : AxisName) ->
270 |     (inShape : InShape axisName shape k) =>
271 |     IsSucc k =>
272 |     (m : Nat ** TensorShape m)
273 |   removeDuplicates shape axisName {inShape} {k = 1}
274 |     = (rank ** shape)
275 |   removeDuplicates ((_ ~> a) :: as) axisName {inShape = Here @{is}} {k = (S (S k))}
276 |     = removeDuplicates as axisName {inShape=is}
277 |   removeDuplicates (s :: as) axisName {inShape = There @{is}} {k = (S (S k))}
278 |     = let (m ** as'= removeDuplicates as axisName {inShape=is}
279 |       in (S m ** (::) {ac=(believe_me ())} s as')