0 | module Data.Tensor.Axis
2 | import public Decidable.Equality
3 | import Data.Vect.Elem
4 | import Data.Vect.Quantifiers
6 | import Data.Container.Base
7 | import Data.Unique.Vect
80 | public export infixr 0 ~>
81 | public export infixr 0 ~~>
91 | rename : Axis -> AxisName -> Axis
92 | rename a str = str ~> a.cont
99 | TTInternalName : AxisName
100 | TTInternalName = "__tensortype_tempaxis__"
106 | (~~>) : AxisName -> Nat -> Axis
107 | (~~>) axisName n = axisName ~> Vect n
111 | data IsCubical : Axis -> Type where
112 | MkIsCubical : (name : AxisName) -> (n : Nat) -> IsCubical (name ~~> n)
115 | dimHelper : {0 a : Axis} -> IsCubical a -> Nat
116 | dimHelper (MkIsCubical _ n) = n
119 | dim : (0 a : Axis) -> IsCubical a => Nat
120 | dim _ @{ic} = dimHelper ic
123 | data IsNaperian : Axis -> Type where
124 | MkIsNaperian : (name : AxisName) -> (pos : Type) ->
125 | IsNaperian (name ~> Nap pos)
128 | LogHelper : {0 a : Axis} -> IsNaperian a => Type
129 | LogHelper @{MkIsNaperian _ pos} = pos
132 | Log : (0 a : Axis) -> IsNaperian a => Type
133 | Log a @{inn} = LogHelper @{inn}
136 | toContNaperian : {0 a : Axis} -> IsNaperian a -> IsNaperian a.cont
137 | toContNaperian (MkIsNaperian name pos) = MkIsNaperian pos
140 | cubicalShapeHelper : {0 shape : Vect r Axis} ->
141 | All IsCubical shape -> List Nat
142 | cubicalShapeHelper [] = []
143 | cubicalShapeHelper (ic :: ns) = dimHelper ic :: cubicalShapeHelper ns
147 | cubicalShape : (0 shape : Vect r Axis) -> All IsCubical shape => List Nat
148 | cubicalShape _ @{ac} = cubicalShapeHelper ac
152 | size : (0 shape : Vect r Axis) -> (ac : All IsCubical shape) => Nat
153 | size ss = prod (cubicalShape ss)
156 | namespace TensorShape
159 | data TensorShape : (rank : Nat) ->Type where
160 | Nil : TensorShape 0
161 | (::) : (a : Axis) -> (as : TensorShape k) ->
162 | (ac : NewAxisConsistent a as) =>
166 | toVect : TensorShape k -> Vect k Axis
168 | toVect (a :: as) = a :: toVect as
171 | data NewAxisConsistent : Axis -> TensorShape k -> Type where
172 | NewAxis : {0 a : Axis} -> {0 as : TensorShape k} ->
173 | NotElem a.name (Axis.name <$> toVect as) ->
174 | NewAxisConsistent a as
175 | ExistingAxis : {0 a : Axis} -> {0 as : TensorShape k} ->
176 | (e : Elem a.name (Axis.name <$> toVect as)) ->
177 | (index (elemToFin e) (toVect as)).cont = a.cont ->
178 | NewAxisConsistent a as
181 | toList : TensorShape k -> List Axis
183 | toList (a :: as) = a :: toList as
188 | conts : TensorShape k -> List Cont
189 | conts ts = cont <$> toList ts
193 | axisNames : TensorShape k -> Vect k AxisName
194 | axisNames ts = name <$> toVect ts
198 | axisSizes : TensorShape k -> Vect k Cont
199 | axisSizes ts = cont <$> toVect ts
203 | size : (shape : TensorShape k) -> All IsCubical (conts shape) => Nat
204 | size shape = size (conts shape)
206 | test1 : TensorShape 2
207 | test1 = ["batchSize" ~> Vect 128, "seqLen" ~> List]
209 | test2 : TensorShape 3
210 | test2 = ["batchSize" ~> Vect 128, "seqLen" ~> List, "batchSize" ~> Vect 128]
213 | test3 : TensorShape 2
214 | test3 = ["batchSize" ~> Vect 128, "batchSize" ~> Vect 13]
231 | data InShape : AxisName -> TensorShape k -> Nat -> Type where
232 | Here : {as : TensorShape k} -> InShape axisName as n =>
233 | NewAxisConsistent (axisName ~> a) as =>
234 | InShape axisName ((axisName ~> a) :: as) (S n)
235 | There : {as : TensorShape k} -> InShape axisName as n =>
236 | NewAxisConsistent a as =>
237 | InShape axisName (a :: as) n
243 | (.getByName) : (shape : TensorShape k) ->
244 | (axisName : AxisName) ->
245 | (inShape : InShape axisName shape n) ->
248 | (.getByName) ((axisName ~> a) :: as) axisName Here = axisName ~> a
249 | (.getByName) (a :: as) axisName (There @{is}) = as.getByName axisName is
252 | removeAllOccurrences : {k, rank : Nat} ->(shape : TensorShape rank) ->
253 | (toDelete : AxisName) ->
254 | (inShape : InShape toDelete shape k) =>
255 | (m : Nat ** TensorShape m)
256 | removeAllOccurrences {k=0} shape toDelete = (
rank ** shape)
257 | removeAllOccurrences ((toDelete ~> a) :: ss) toDelete @{Here @{is}}
258 | = removeAllOccurrences ss toDelete @{is}
259 | removeAllOccurrences (s :: ss) toDelete @{There @{is}}
260 | = let (
m ** ss')
= removeAllOccurrences ss toDelete @{is}
261 | in (
S m ** (::) {ac=(believe_me ())} s ss')
268 | removeDuplicates : {k, rank : Nat} -> (shape : TensorShape rank) ->
269 | (axisName : AxisName) ->
270 | (inShape : InShape axisName shape k) =>
272 | (m : Nat ** TensorShape m)
273 | removeDuplicates shape axisName {inShape} {k = 1}
275 | removeDuplicates ((_ ~> a) :: as) axisName {inShape = Here @{is}} {k = (S (S k))}
276 | = removeDuplicates as axisName {inShape=is}
277 | removeDuplicates (s :: as) axisName {inShape = There @{is}} {k = (S (S k))}
278 | = let (
m ** as')
= removeDuplicates as axisName {inShape=is}
279 | in (
S m ** (::) {ac=(believe_me ())} s as')