0 | module Data.Tensor.Shape.Shape
2 | import public Decidable.Equality
3 | import Data.Vect.Elem
4 | import Data.Vect.Quantifiers
6 | import Data.Container.Base
7 | import Data.Tensor.Shape.Axis
74 | data TensorShape : (rank : Nat) ->Type where
76 | (::) : (a : Axis) -> (as : TensorShape k) ->
77 | (axisConsistent : a `ConsistentWith` as) =>
84 | data ConsistentWith : Axis -> TensorShape k -> Type where
85 | NewAxis : {0 a : Axis} -> {0 as : TensorShape k} ->
86 | NotElem a.name as ->
87 | a `ConsistentWith` as
88 | ExistingAxis : {0 a : Axis} -> {0 as : TensorShape k} ->
89 | (e : Elem a.name as) ->
90 | index as a.name = a.cont ->
91 | a `ConsistentWith` as
98 | data Elem : (axisName : AxisName) -> (as : TensorShape rank) -> Type where
99 | Here : {0 as : TensorShape rank} ->
100 | a `ConsistentWith` as =>
101 | axisName = a.name =>
102 | Elem axisName (a :: as)
103 | There : {0 a : Axis} -> {0 as : TensorShape rank} ->
104 | a `ConsistentWith` as =>
105 | (elem : Elem axisName as) =>
106 | Elem axisName (a :: as)
110 | data NotElem : (axisElem : AxisName) -> (as : TensorShape rank) -> Type where
111 | NotInEmpty : NotElem axisElem []
112 | NotInNonEmpty : {0 axisElem : AxisName} -> {0 a : Axis} ->
113 | {0 as : TensorShape rank} ->
114 | (neq : IsNo (decEq axisElem a.name)) ->
115 | (notElem : NotElem axisElem as) =>
116 | (a `ConsistentWith` as) ->
117 | NotElem axisElem (a :: as)
123 | index : (shape : TensorShape rank) ->
124 | (axisName : AxisName) ->
125 | (isElem : Elem axisName shape) =>
127 | index (a :: _) axisName @{Here} = a.cont
128 | index (_ :: as) axisName @{There} = index as axisName
132 | rename : (shape : TensorShape rank) ->
133 | (axisName : AxisName) ->
134 | (newAxisName : AxisName) ->
137 | rename (a :: as) axisName newAxisName
138 | = (::) (applyWhen (axisName == a.name) (flip rename newAxisName) a) (rename as axisName newAxisName) @{believe_me "consistentAfterRenaming"}
140 | namespace RenameByIndex
143 | rename : (shape : TensorShape rank) ->
144 | (axisIndex : Fin rank) ->
145 | (newAxisName : AxisName) ->
147 | rename (a :: as) FZ newAxisName
148 | = (::) (newAxisName ~> a.cont) as @{believe_me "consistentAfterRenamingByIndex"}
149 | rename (a :: as) (FS axisIndex) newAxisName
150 | = (::) a (rename as axisIndex newAxisName) @{believe_me "consistentAfterRenamingByIndex"}
154 | namespace Quantifiers
156 | data All : (p : Axis -> Type) -> TensorShape k -> Type where
158 | (::) : {0 a : Axis} -> {0 as : TensorShape k} ->
160 | a `ConsistentWith` as =>
164 | data Any : (p : Axis -> Type) -> TensorShape k -> Type where
165 | Here : {0 a : Axis} -> {0 as : TensorShape k} ->
166 | a `ConsistentWith` as =>
167 | p a -> Any p (a :: as)
168 | There : {0 a : Axis} -> {0 as : TensorShape k} ->
169 | a `ConsistentWith` as =>
170 | Any p as -> Any p (a :: as)
176 | namespace QuantifierOnContainers
178 | data AllC : (p : Cont -> Type) -> TensorShape k -> Type where
180 | (::) : {0 a : Axis} -> {0 as : TensorShape k} ->
181 | p a.cont -> AllC p as ->
182 | a `ConsistentWith` as =>
186 | data AnyC : (p : Cont -> Type) -> TensorShape k -> Type where
187 | Here : {0 a : Axis} -> {0 as : TensorShape k} ->
188 | a `ConsistentWith` as =>
189 | p a.cont -> AnyC p (a :: as)
190 | There : {0 a : Axis} -> {0 as : TensorShape k} ->
191 | a `ConsistentWith` as =>
192 | AnyC p as -> AnyC p (a :: as)
195 | tensorShapesConsistent : TensorShape k -> TensorShape k' -> Type
196 | tensorShapesConsistent s1 s2 = All (\a => a `ConsistentWith` s2) s1
200 | toVect : TensorShape k -> Vect k Axis
202 | toVect (a :: as) = a :: toVect as
205 | toList : TensorShape k -> List Axis
207 | toList (a :: as) = a :: toList as
212 | conts : TensorShape k -> List Cont
213 | conts ts = cont <$> toList ts
217 | renamePreservesConts : (shape : TensorShape rank) ->
218 | (axisName : AxisName) ->
219 | (newAxisName : AxisName) ->
220 | conts (rename shape axisName newAxisName) = conts shape
221 | renamePreservesConts [] _ _ = Refl
222 | renamePreservesConts (a :: as) axisName newAxisName with (axisName == a.name)
223 | _ | True = cong (a.cont ::) (renamePreservesConts as axisName newAxisName)
224 | _ | False = cong (a.cont ::) (renamePreservesConts as axisName newAxisName)
226 | namespace RenameByIndex
229 | renamePreservesConts : (shape : TensorShape rank) ->
230 | (axisIndex : Fin rank) ->
231 | (newAxisName : AxisName) ->
232 | conts (rename shape axisIndex newAxisName) = conts shape
233 | renamePreservesConts (a :: as) FZ newAxisName = Refl
234 | renamePreservesConts (a :: as) (FS axisIndex) newAxisName
235 | = cong (a.cont ::) (renamePreservesConts as axisIndex newAxisName)
239 | axisNames : TensorShape k -> Vect k AxisName
240 | axisNames ts = name <$> toVect ts
244 | axisSizes : TensorShape k -> Vect k Cont
245 | axisSizes ts = cont <$> toVect ts
249 | size : (shape : TensorShape k) -> All IsCubical (conts shape) => Nat
250 | size shape = size (conts shape)
257 | 0 TensorCubEvidence : TensorShape k -> Type
258 | TensorCubEvidence shape = Either (All IsCubical shape) ()
267 | data UniqueElem : AxisName -> TensorShape rank -> Type where
268 | Here : {0 as : TensorShape rank} ->
269 | axisName = ax.name =>
270 | NotElem axisName as =>
271 | ax `ConsistentWith` as =>
272 | UniqueElem axisName (ax :: as)
273 | There : {0 ax : Axis} -> {0 as : TensorShape rank} ->
275 | (uniqueElem : UniqueElem axisName as) =>
276 | (neq : IsNo (decEq axisName ax.name)) =>
277 | ax `ConsistentWith` as =>
278 | UniqueElem axisName (ax :: as)
282 | forgetUnique : {as : TensorShape rank} ->
283 | UniqueElem axisName as ->
285 | forgetUnique {as = (a :: as)} Here = Here
286 | forgetUnique {as = (a :: as)} (There {uniqueElem=elem})
287 | = There {elem=forgetUnique elem}
290 | index : (shape : TensorShape rank) ->
291 | (axisName : AxisName) ->
292 | (uniqueElem : UniqueElem axisName shape) =>
294 | index (a :: _) axisName @{Here} = a.cont
295 | index (_ :: as) axisName @{There} = index as axisName
299 | removeAxis : {rank : Nat} ->
300 | (toRemove : AxisName) ->
301 | (shape : TensorShape (S rank)) ->
302 | (is : UniqueElem toRemove shape) =>
304 | removeAxis toRemove (_ :: as) @{Here} = as
305 | removeAxis toRemove (a :: as) @{There @{ItIsSucc}}
306 | = let cProof = consistentAfterRemoving a as toRemove
307 | in a :: removeAxis toRemove as
310 | consistentAfterRemoving : {rank : Nat} ->
311 | (a : Axis) -> (as : TensorShape (S rank)) ->
312 | a `ConsistentWith` as =>
313 | (toRemove : AxisName) ->
314 | (uElem : UniqueElem toRemove as) =>
315 | a `ConsistentWith` (removeAxis toRemove as)
316 | consistentAfterRemoving = believe_me "consistentAfterRemoving"
318 | notElemExample1 : NotElem "i" ["g" ~> List, "j" ~> BinTree]
319 | notElemExample1 = %search
321 | tensorShapeTest1 : TensorShape 2
322 | tensorShapeTest1 = ["batchSize" ~> Vect 128, "seqLen" ~> List]
324 | tensorShapeTest2 : TensorShape 3
326 | = ["batchSize" ~> Vect 128, "seqLen" ~> List, "batchSize" ~> Vect 128]
329 | tensorShapeTest3 : TensorShape 2
330 | tensorShapeTest3 = ["batchSize" ~> Vect 128, "batchSize" ~> Vect 13]
332 | uniqueElemExample1 : UniqueElem "j" ["i" ~> List, "j" ~> BinTree, "i" ~> List]
333 | uniqueElemExample1 = %search
336 | uniqueElemExampleFail : UniqueElem "x" ["i" ~> List, "j" ~> BinTree]
337 | uniqueElemExampleFail = %search
339 | uniqueElemExampleFail2 : UniqueElem
"i" [
"i" ~> List
342 | uniqueElemExampleFail2
= %search
345 | TensorTest1 : TensorShape 3
346 | TensorTest1 = ["batchSize" ~> Vect 128, "seqLen" ~> List, "feat" ~> Vect 64]
350 | TensorTest2 : (i : Axis) -> ConsistentWith i [i]
351 | TensorTest2 i = %search
354 | TensorElemTest2 : Elem "asdf" TensorTest1
355 | TensorElemTest2 = %search