0 | module Data.Tensor.Utils
5 | import Data.Tensor.Tensor
6 | import Data.Container.SubTerm
27 | namespace CommonNames
29 | Scalar : (a : Type) -> Type
30 | Scalar a = Tensor [] a
33 | Vector : (c : Axis) -> (a : Type) -> Type
34 | Vector c a = Tensor [c] a
37 | Matrix : (row, col : Axis) -> NewAxisConsistent row [col] => (a : Type) -> Type
38 | Matrix row col a = Tensor [row, col] a
40 | namespace FillZerosOnes
42 | fill : Num a => {shape : TensorShape rank} ->
43 | All TensorMonoid (conts shape) =>
45 | fill x = tensorReplicate x
48 | zeros : Num a => {shape : TensorShape rank} ->
49 | All TensorMonoid (conts shape) =>
51 | zeros = fill (fromInteger 0)
54 | ones : Num a => {shape : TensorShape rank} ->
55 | All TensorMonoid (conts shape) =>
57 | ones = fill (fromInteger 1)
61 | identityBool : {0 c : Axis} -> IsCubical c =>
63 | identityBool @{MkIsCubical _ n}
64 | = outerWith (==) (positions {sh=()}) (positions {sh=()})
69 | identity : {0 c : Axis} -> IsCubical c =>
70 | Num a => Tensor [c, c] a
71 | identity @{MkIsCubical _ n} = fromBool <$> identityBool
86 | arange : {0 stop : Axis} -> IsCubical stop =>
87 | Cast Nat a => Tensor [stop] a
88 | arange @{MkIsCubical _ n} = cast . finToNat <$> positions {sh=()}
93 | arangeFromTo : {default (TTInternalName ~~> 0) start : Axis} ->
95 | (cStart : IsCubical start) => (cStop : IsCubical stop) =>
96 | Cast Nat a => Tensor [stop.name ~~> minus (dim stop) (dim start)] a
97 | arangeFromTo {cStart=(MkIsCubical _ n)} {cStop=(MkIsCubical _ m)}
98 | = cast . (+n) . finToNat <$> positions {sh=()}
103 | flip : {shape : TensorShape rank} ->
104 | (axis : Fin rank) ->
105 | IsCubical (index axis (toVect shape)) =>
106 | Tensor shape a -> Tensor shape a
109 | namespace Concatenate
113 | concat : {shape : TensorShape rank} -> {l : AxisName} ->
114 | {x, y : Axis} -> IsCubical x => IsCubical y =>
115 | NewAxisConsistent (l ~~> dim x + dim y) shape =>
116 | NewAxisConsistent x shape =>
117 | NewAxisConsistent y shape =>
118 | Tensor (x :: shape) a ->
119 | Tensor (y :: shape) a ->
120 | Tensor ((l ~~> dim x + dim y) :: shape) a
121 | concat @{MkIsCubical _ n} @{MkIsCubical _ m} t t'
122 | = embedTopExt $
extractTopExt t ++ extractTopExt t'
131 | size : {shape : TensorShape rank} ->
132 | Tensor shape a -> Nat
137 | size : {shape : TensorShape rank} ->
138 | All IsCubical (toVect shape) =>
139 | (0 _ : Tensor shape a) -> Nat
140 | size {shape} _ = size (toVect shape)
147 | flatten : {0 shape : TensorShape rank} ->
148 | Foldable (Tensor shape) =>
149 | Tensor shape a -> List a
156 | flatten : {shape : TensorShape rank} ->
157 | All IsCubical (conts shape) =>
158 | Tensor shape a -> Vect (size shape) a
159 | flatten = toVect . extMap (flattenCubical DefaultLayoutOrder) . GetT
165 | max : {0 shape : TensorShape rank} ->
166 | Foldable (Tensor shape) => Ord a =>
167 | Tensor shape a -> Maybe a
168 | max = maxInList . flatten
172 | oneHot : {0 c : Axis} -> IsCubical c =>
173 | (i : Fin (dim c)) ->
174 | Num a => Tensor [c] a
175 | oneHot @{MkIsCubical _ n} i = set zeros [i] 1
177 | namespace Triangular
180 | cTriBool : {c : Axis} ->
181 | (ip : InterfaceOnPositions c.cont MOrd) =>
182 | TensorMonoid c.cont =>
183 | (sh : c.cont.Shp) -> Tensor [c, c] Bool
184 | cTriBool {ip = MkI {p}} sh
185 | = let cPositions = positions {sh=sh}
186 | pp : MOrd (c.cont.Pos sh) := p sh
187 | in outerWith (flip isSubTerm) cPositions cPositions
190 | triBool : {0 c : Axis} -> IsCubical c =>
192 | triBool @{MkIsCubical _ n} = cTriBool ()
198 | tri : {0 c : Axis} -> IsCubical c =>
199 | Num a => Tensor [c, c] a
200 | tri @{MkIsCubical _ n} = fromBool <$> triBool
205 | lowerTriangular : {0 c : Axis} -> IsCubical c =>
206 | Num a => Tensor [c, c] a -> Tensor [c, c] a
207 | lowerTriangular @{MkIsCubical _ n} = (* tri)
212 | upperTriangular : {0 c : Axis} -> IsCubical c =>
213 | Num a => Tensor [c, c] a -> Tensor [c, c] a
214 | upperTriangular @{MkIsCubical _ n} = (* ((fromBool . not) <$> triBool))
218 | maskedFill : {shape : TensorShape rank} ->
219 | Num a => All TensorMonoid (conts shape) =>
220 | (t : Tensor shape a) ->
221 | (mask : Tensor shape Bool) ->
224 | maskedFill t mask fill = liftA2Tensor mask t <&>
225 | (\(maskVal, tVal) => if maskVal then fill else tVal)
229 | sum : {shape : TensorShape rank} ->
230 | Algebra (Tensor shape) a =>
231 | Tensor shape a -> a
235 | mean : {shape : TensorShape rank} ->
236 | All IsCubical (toVect shape) =>
239 | Algebra (Tensor shape) a =>
240 | Tensor shape a -> a
241 | mean t = sum t / cast (Cubical.size t)
244 | variance : {c : Axis} -> IsCubical c =>
245 | Neg a => Fractional a => Cast Nat a =>
247 | variance @{MkIsCubical _ n} t =
248 | let inputMinusMean = t - pure (mean t)
249 | in mean (inputMinusMean * inputMinusMean)
252 | cumulativeSum : {c : Axis} -> Num a =>
253 | (isCubical : IsCubical c) =>
254 | Tensor [c] a -> Tensor [c] a
263 | namespace Traversals
265 | inorder : Tensor [b ~> BinTreeNode] a -> Tensor [l ~> List] a
266 | inorder = extToVector . extMap BinTreeNode.inorder . vectorToExt
270 | {shape : TensorShape rank} ->
272 | Applicative (Tensor shape) =>
273 | Traversable (Tensor shape) =>
274 | Random (Tensor shape a) where
275 | randomIO = sequence (pure randomIO)
276 | randomRIO = ?qhwhwh
279 | tta : Applicative (Tensor ["a" ~~> 1])
282 | ttt : Traversable (Tensor ["b" ~~> 1])
285 | ttd : Random Double
290 | random : Num a => Random a => HasIO io =>
291 | (shape : TensorShape rank) ->
292 | All IsCubical (toVect shape) =>
293 | Applicative (Tensor shape) =>
294 | Traversable (Tensor shape) =>
295 | io (Tensor shape a)
296 | random shape = sequence $
pure $
randomRIO (0, 1)
298 | tt : Traversable (Vect 2)
301 | ttt : Traversable (Ext (Vect 2))
304 | tttt : Traversable (Tensor ["i" ~~> 2])
307 | testRand : IO (Tensor ["i" ~~> 2, "j" ~~> 3] Double)
309 | t <- random ["i" ~~> 2, "j" ~~> 3]
313 | testRand2 : IO (Tensor ["i" ~~> 5] Double)
314 | testRand2 = random ["i" ~~> 5]
316 | testRand3 : IO Unit
317 | testRand3 = randomIO
320 | exMatrix : Ext (Vect 3 >< Vect 3) Double
321 | exMatrix = ((), ()) <| \case
333 | applMap : {n : Nat} -> Ext (Vect n >< Vect n) Double -> Ext (Vect n) Double
334 | applMap = extMap tensorM
336 | allPos : (BinTreePosLeaf (NodeS LeafS LeafS), BinTreePosLeaf (NodeS (NodeS LeafS LeafS) LeafS)) -> Double
337 | allPos ((GoLeft AtLeaf), (GoLeft (GoLeft AtLeaf))) = 0
338 | allPos ((GoRight AtLeaf), (GoLeft (GoLeft AtLeaf))) = 1
339 | allPos ((GoLeft AtLeaf), (GoLeft (GoRight AtLeaf))) = 2
340 | allPos ((GoRight AtLeaf), (GoLeft (GoRight AtLeaf))) = 3
341 | allPos ((GoLeft AtLeaf), (GoRight AtLeaf)) = 4
342 | allPos ((GoRight AtLeaf), (GoRight AtLeaf)) = 5
344 | exTree : Ext (BinTreeLeaf >< BinTreeLeaf) Double
345 | exTree = (NodeS LeafS LeafS, NodeS (NodeS LeafS LeafS) LeafS) <| allPos
347 | applMapTree : Ext (BinTreeLeaf >< BinTreeLeaf) Double -> Ext (BinTreeLeaf) Double
348 | applMapTree = extMap tensorM
350 | ff : Tensor ["v" ~~> 4, "v" ~~> 4] Double -> Tensor ["v" ~~> 4] Double
351 | ff t = let g = extMap {a=Double} (tensorM {c=Vect 4})
355 | t0 : Tensor ["j" ~~> 3, "k" ~~> 4] Double
356 | t0 = ># [ [0, 1, 2, 3]
360 | t1 : Tensor ["i" ~~> 6] Double
363 | exMatrix2 : Tensor ["v" ~~> 3, "v" ~~> 3] Double
364 | exMatrix2 = reshape $
arange {stop="v" ~~> 9}
369 | tTest : Tensor ["i" ~~> 800] Double
373 | tRes : Tensor ["i" ~~> 2, "j" ~~> 400] Double
374 | tRes = reshape tTest