0 | module Data.Tensor.Utils
6 | import Data.Tensor.Tensor
7 | import Data.Container.SubTerm
28 | namespace CommonNames
30 | Scalar : (a : Type) -> Type
31 | Scalar a = Tensor [] a
34 | Vector : (c : Axis) -> (a : Type) -> Type
35 | Vector c a = Tensor [c] a
38 | Matrix : (row, col : Axis) -> ConsistentWith row [col] =>
40 | Matrix row col a = Tensor [row, col] a
42 | namespace FillZerosOnes
44 | fill : Num a => {shape : TensorShape rank} ->
45 | AllC TensorMonoid shape =>
47 | fill x = tensorReplicate x
50 | zeros : Num a => {shape : TensorShape rank} ->
51 | AllC TensorMonoid shape =>
53 | zeros = fill (fromInteger 0)
56 | ones : Num a => {shape : TensorShape rank} ->
57 | AllC TensorMonoid shape =>
59 | ones = fill (fromInteger 1)
63 | identityBool : {0 c : Axis} -> IsCubical c =>
65 | identityBool @{MkIsCubical _ n}
66 | = outerWith (==) (positions {sh=()}) (positions {sh=()})
71 | identity : {0 c : Axis} -> IsCubical c =>
72 | Num a => Tensor [c, c] a
73 | identity @{MkIsCubical _ n} = fromBool <$> identityBool
88 | arange : {0 stop : Axis} -> IsCubical stop =>
89 | Cast Nat a => Tensor [stop] a
90 | arange @{MkIsCubical _ n} = cast . finToNat <$> positions {sh=()}
95 | arangeFromTo : {default (TTInternalName ~~> 0) start : Axis} ->
97 | (cStart : IsCubical start) => (cStop : IsCubical stop) =>
98 | Cast Nat a => Tensor [stop.name ~~> minus (dim stop) (dim start)] a
99 | arangeFromTo {cStart=(MkIsCubical _ n)} {cStop=(MkIsCubical _ m)}
100 | = cast . (+n) . finToNat <$> positions {sh=()}
105 | flip : {shape : TensorShape rank} ->
106 | (axis : Fin rank) ->
107 | IsCubical (index axis (toVect shape)) =>
108 | Tensor shape a -> Tensor shape a
111 | namespace Concatenate
115 | concat : {shape : TensorShape rank} -> {l : AxisName} ->
116 | {x, y : Axis} -> IsCubical x => IsCubical y =>
117 | ConsistentWith (l ~~> dim x + dim y) shape =>
118 | ConsistentWith x shape =>
119 | ConsistentWith y shape =>
120 | Tensor (x :: shape) a ->
121 | Tensor (y :: shape) a ->
122 | Tensor ((l ~~> dim x + dim y) :: shape) a
123 | concat @{MkIsCubical _ n} @{MkIsCubical _ m} t t'
124 | = embedTopExt $
extractTopExt t ++ extractTopExt t'
133 | size : {shape : TensorShape rank} ->
134 | Tensor shape a -> Nat
139 | size : {shape : TensorShape rank} ->
140 | All IsCubical (toVect shape) =>
141 | (0 _ : Tensor shape a) -> Nat
142 | size {shape} _ = size (toVect shape)
149 | flatten : {0 shape : TensorShape rank} ->
150 | Foldable (Tensor shape) =>
151 | Tensor shape a -> List a
158 | flatten : {shape : TensorShape rank} ->
159 | All IsCubical (conts shape) =>
160 | Tensor shape a -> Vect (size shape) a
161 | flatten = toVect . extMap (flattenCubical DefaultLayoutOrder) . GetT
167 | max : {0 shape : TensorShape rank} ->
168 | Foldable (Tensor shape) => Ord a =>
169 | Tensor shape a -> Maybe a
170 | max = max . flatten
174 | oneHot : {0 c : Axis} -> IsCubical c =>
175 | (i : Fin (dim c)) ->
176 | Num a => Tensor [c] a
177 | oneHot @{MkIsCubical _ n} i = set zeros [i] 1
179 | namespace Triangular
182 | cTriBool : {c : Axis} ->
183 | (ip : InterfaceOnPositions c.cont MOrd) =>
184 | TensorMonoid c.cont =>
185 | (sh : c.cont.Shp) -> Tensor [c, c] Bool
186 | cTriBool {ip = MkI p} sh
187 | = let cPositions = positions {sh=sh}
188 | pp : MOrd (c.cont.Pos sh) := p sh
189 | in outerWith (flip isSubTerm) cPositions cPositions
192 | triBool : {0 c : Axis} -> IsCubical c =>
194 | triBool @{MkIsCubical _ n} = cTriBool ()
200 | tri : {0 c : Axis} -> IsCubical c =>
201 | Num a => Tensor [c, c] a
202 | tri @{MkIsCubical _ n} = fromBool <$> triBool
207 | lowerTriangular : {0 c : Axis} -> IsCubical c =>
208 | Num a => Tensor [c, c] a -> Tensor [c, c] a
209 | lowerTriangular @{MkIsCubical _ n} = (* tri)
214 | upperTriangular : {0 c : Axis} -> IsCubical c =>
215 | Num a => Tensor [c, c] a -> Tensor [c, c] a
216 | upperTriangular @{MkIsCubical _ n} = (* ((fromBool . not) <$> triBool))
220 | maskedFill : {shape : TensorShape rank} ->
221 | Num a => AllC TensorMonoid shape =>
222 | (t : Tensor shape a) ->
223 | (mask : Tensor shape Bool) ->
226 | maskedFill t mask fill = liftA2Tensor mask t <&>
227 | (\(maskVal, tVal) => if maskVal then fill else tVal)
231 | sum : {shape : TensorShape rank} ->
232 | Algebra (Tensor shape) a =>
233 | Tensor shape a -> a
237 | mean : {shape : TensorShape rank} ->
238 | All IsCubical (toVect shape) =>
241 | Algebra (Tensor shape) a =>
242 | Tensor shape a -> a
243 | mean t = sum t / cast (Cubical.size t)
246 | variance : {c : Axis} -> IsCubical c =>
247 | Neg a => Fractional a => Cast Nat a =>
249 | variance @{MkIsCubical _ n} t =
250 | let inputMinusMean = t - pure (mean t)
251 | in mean (inputMinusMean * inputMinusMean)
254 | cumulativeSum : {c : Axis} -> Num a =>
255 | (isCubical : IsCubical c) =>
256 | Tensor [c] a -> Tensor [c] a
257 | cumulativeSum {isCubical=(MkIsCubical _ n)} t
258 | = (#>#) (scanl1 (+)) t
267 | namespace Traversals
269 | inorder : Tensor [b ~> BinTreeNode] a -> Tensor [l ~> List] a
270 | inorder = extToVector . extMap BinTreeNode.inorder . vectorToExt
274 | {shape : TensorShape rank} ->
276 | Applicative (Tensor shape) =>
277 | Traversable (Tensor shape) =>
278 | Random (Tensor shape a) where
279 | randomIO = sequence (pure randomIO)
280 | randomRIO = ?qhwhwh
283 | tta : Applicative (Tensor ["a" ~~> 1])
286 | ttt : Traversable (Tensor ["b" ~~> 1])
289 | ttd : Random Double
294 | random : Num a => Random a => HasIO io =>
295 | (shape : TensorShape rank) ->
296 | All IsCubical (toVect shape) =>
297 | Applicative (Tensor shape) =>
298 | Traversable (Tensor shape) =>
299 | io (Tensor shape a)
300 | random shape = sequence $
pure $
randomRIO (0, 1)
302 | tt : Traversable (Vect 2)
305 | ttt : Traversable (Ext (Vect 2))
308 | tttt : Traversable (Tensor ["i" ~~> 2])
317 | testRand2 : IO (Tensor ["i" ~~> 5] Double)
318 | testRand2 = random ["i" ~~> 5]
320 | testRand3 : IO Unit
321 | testRand3 = randomIO
324 | exMatrix : Ext (Vect 3 >< Vect 3) Double
325 | exMatrix = ((), ()) <| \case
337 | applMap : {n : Nat} -> Ext (Vect n >< Vect n) Double -> Ext (Vect n) Double
338 | applMap = extMap tensorM
340 | allPos : (BinTreePosLeaf (NodeS LeafS LeafS), BinTreePosLeaf (NodeS (NodeS LeafS LeafS) LeafS)) -> Double
341 | allPos ((GoLeft AtLeaf), (GoLeft (GoLeft AtLeaf))) = 0
342 | allPos ((GoRight AtLeaf), (GoLeft (GoLeft AtLeaf))) = 1
343 | allPos ((GoLeft AtLeaf), (GoLeft (GoRight AtLeaf))) = 2
344 | allPos ((GoRight AtLeaf), (GoLeft (GoRight AtLeaf))) = 3
345 | allPos ((GoLeft AtLeaf), (GoRight AtLeaf)) = 4
346 | allPos ((GoRight AtLeaf), (GoRight AtLeaf)) = 5
348 | exTree : Ext (BinTreeLeaf >< BinTreeLeaf) Double
349 | exTree = (NodeS LeafS LeafS, NodeS (NodeS LeafS LeafS) LeafS) <| allPos
351 | applMapTree : Ext (BinTreeLeaf >< BinTreeLeaf) Double -> Ext (BinTreeLeaf) Double
352 | applMapTree = extMap tensorM
354 | ff : Tensor ["v" ~~> 4, "v" ~~> 4] Double -> Tensor ["v" ~~> 4] Double
355 | ff t = let g = extMap {a=Double} (tensorM {c=Vect 4})
359 | t0 : Tensor ["j" ~~> 3, "k" ~~> 4] Double
360 | t0 = ># [ [0, 1, 2, 3]
364 | t1 : Tensor ["i" ~~> 6] Double
367 | exMatrix2 : Tensor ["v" ~~> 3, "v" ~~> 3] Double
368 | exMatrix2 = reshape $
arange {stop="l" ~~> 9}