25 | import public Control.Monad.State
26 | import Control.Monad.Error.Either
27 | import Syntax.PreorderReasoning
29 | import Compiler.Eval
31 | import Compiler.Xla.Shape
32 | import Compiler.Xla.ShapeUtil
33 | import Compiler.LiteralRW
36 | import public Literal
41 | XlaShape = Xla.Shape
47 | data Tensor : Shape -> DType -> Type where
48 | MkTensor : Value -> {shape : _} -> {dtype : _} -> Tensor shape dtype
52 | data TagT : (Type -> Type) -> Type -> Type where
53 | MkTagT : StateT Env m a -> TagT m a
60 | Functor m => Functor (TagT m) where
61 | map f (MkTagT x) = MkTagT (map f x)
64 | Monad m => Applicative (TagT m) where
65 | pure x = MkTagT (pure x)
66 | (MkTagT f) <*> (MkTagT x) = MkTagT (f <*> x)
69 | Monad m => Monad (TagT m) where
70 | (MkTagT x) >>= f = MkTagT $
x >>= (\y => let MkTagT z = f y in z)
73 | MonadTrans TagT where
74 | lift = MkTagT . lift
77 | interface Taggable a where
94 | tag : Monad m => a -> TagT m a
99 | tag (BoundSet x) = pure (BoundSet x)
100 | tag x = MkTagT $
tagOp x
103 | Taggable (Tensor shape dtype) where
104 | tag (MkTensor $
V idx op) = map (\op => MkTensor $
V idx op) (tag op)
107 | (Taggable a, Taggable b) => Taggable (a, b) where
108 | tag (a, b) = [| (tag a, tag b) |]
117 | tensor : {shape : _} -> {dtype : _} -> Literal shape (idrisType dtype) -> Tensor shape dtype
118 | tensor lit = MkTensor $
V 0 $
Lit shape dtype lit
122 | fromDouble : Double -> Tensor [] F64
123 | fromDouble = tensor . Scalar
125 | try : Show e => EitherT e IO a -> IO a
126 | try = eitherT (\e => assert_total $
idris_crash $
show e) pure
128 | %hide Literal.All2.All2
147 | Tag (All2 Tensor shapes dtypes) ->
148 | IO (All2 Literal shapes (DType.idrisType <$> dtypes))
149 | eval device (MkTagT xs) =
150 | let (env, xs) = runState empty xs
152 | xlaShapes <- buildShapes xs
153 | let (
outputs ** eq)
= lengthC xs
154 | main = MkFn 0 [] (resultTypes xs) (results xs) env
155 | lits <- execute device main {outputs} (rewrite eq in xlaShapes)
156 | readAll xs $
rewrite sym eq in lits
160 | lengthC : All2 Tensor s t -> (
n ** n === length s)
161 | lengthC [] = (
0 ** Refl)
162 | lengthC (_ :: xs) = let (
n ** eq)
= lengthC xs in (
S n ** cong S eq)
164 | buildShapes : HasIO io => All2 Tensor s t -> io $
Vect (length s) XlaShape
165 | buildShapes [] = pure []
166 | buildShapes (MkTensor {shape, dtype} _ :: ts) = [| mkShape shape dtype :: buildShapes ts |]
168 | results : All2 Tensor s t -> Vect (length s) Value
170 | results (MkTensor x :: xs) = x :: results xs
172 | resultTypes : All2 Tensor s t -> Vect (length s) ValueType
173 | resultTypes [] = []
174 | resultTypes (MkTensor {shape, dtype} _ :: xs) = TensorType shape dtype :: resultTypes xs
179 | Vect (length s) Literal ->
180 | io $
All2 Literal s (DType.idrisType <$> t)
181 | readAll [] _ = pure []
182 | readAll (MkTensor {dtype} _ :: ts) (l :: ls) = [| read dtype l :: readAll ts ls |]
188 | All2 Tensor shapes dtypes ->
189 | IO (All2 Literal shapes (DType.idrisType <$> dtypes))
190 | eval device xs = eval device (pure xs)
200 | eval : Device -> Tag (Tensor shape dtype) -> IO (Literal shape (DType.idrisType dtype))
201 | eval device x = map (\[z] => z) $
List.Tag.eval device $
map (\z => [z]) x
205 | eval : Device -> Tensor shape dtype -> IO (Literal shape (DType.idrisType dtype))
206 | eval device x = eval device (pure x)
212 | Show (Tag $
Tensor shape dtype) where
214 | let (env, MkTensor x) = runState empty x
215 | in show (MkFn 0 [] [TensorType shape dtype] [x] env)
218 | Show (Tensor shape dtype) where show = show . pure {f = Tag}
222 | inf : Tensor [] F64
226 | nan : Tensor [] F64
231 | min : Ord dtype => Tensor [] dtype
232 | min @{OrdS32} = MkTensor $
V 0 $
MinValue S32
233 | min @{OrdS64} = MkTensor $
V 0 $
MinValue S64
234 | min @{OrdU32} = MkTensor $
V 0 $
MinValue U32
235 | min @{OrdU64} = MkTensor $
V 0 $
MinValue U64
236 | min @{OrdF64} = tensor $
Scalar $
-1.0 / 0.0
240 | max : Ord dtype => Tensor [] dtype
241 | max @{OrdS32} = MkTensor $
V 0 $
MaxValue S32
242 | max @{OrdS64} = MkTensor $
V 0 $
MaxValue S64
243 | max @{OrdU32} = MkTensor $
V 0 $
MaxValue U32
244 | max @{OrdU64} = MkTensor $
V 0 $
MaxValue U64
245 | max @{OrdF64} = tensor $
Scalar $
1.0 / 0.0
249 | minFinite : Tensor [] F64
250 | minFinite = MkTensor $
V 0 $
MinFiniteFloat
254 | maxFinite : Tensor [] F64
255 | maxFinite = MkTensor $
V 0 $
MaxFiniteFloat
260 | castDtype : Integral dtype => Tensor shape dtype -> Tensor shape F64
261 | castDtype $
MkTensor {shape} x = MkTensor $
V 0 $
Convert F64 shape x
263 | fn0 : {a : _} -> Tag (Tensor s a) -> Tag $
Fn 0
264 | fn0 f = MkTagT $
do
268 | (env, MkTensor res) = runState (emptyFrom !get) res
269 | f = MkFn addr [] [TensorType s a] [res] env
271 | updateCounterFrom env
274 | fn1 : {a : _} -> {s : _} -> (Tensor s a -> Tag $
Tensor s' a') -> Tag $
Fn 1
275 | fn1 f = MkTagT $
do
278 | let MkTagT res = f (MkTensor $
V 0 $
BoundSet addr)
279 | (env, MkTensor res) = runState (emptyFrom !get) res
280 | f = MkFn addr [TensorType s a] [TensorType s' a'] [res] env
282 | updateCounterFrom env
286 | {a, a' : _} -> {s, s' : _} ->
287 | (Tensor s a -> Tensor s' a' -> Tag $
Tensor s'' a'') ->
289 | fn2 f = MkTagT $
do
292 | let MkTagT res = f (MkTensor $
V 0 $
BoundSet addr) (MkTensor $
V 1 $
BoundSet addr)
293 | (env, MkTensor res) = runState (emptyFrom !get) res
294 | f = MkFn addr [TensorType s a, TensorType s' a'] [TensorType s'' a''] [res] env
296 | updateCounterFrom env
300 | {a, a', a'' : _} -> {s, s' : _} ->
301 | (Tensor s a -> Tensor s' a' -> Tag $
(Tensor s'' a'', Tensor s''' a''')) ->
303 | fn22 f = MkTagT $
do
306 | let MkTagT res = f (MkTensor $
V 0 $
BoundSet addr) (MkTensor $
V 1 $
BoundSet addr)
307 | (env, (MkTensor res0, MkTensor res1)) = runState (emptyFrom !get) res
308 | resTys = [TensorType s'' a'', TensorType s''' a''']
309 | f = MkFn addr [TensorType s a, TensorType s' a'] resTys [res0, res1] env
311 | updateCounterFrom env
331 | grad : (Tensor shape F64 -> Tag $
Tensor [] F64) -> Tensor shape F64 -> Tag $
Tensor shape F64
332 | grad f (MkTensor x) = pure $
MkTensor $
V 0 $
Grad shape !(fn1 f) x
339 | {auto 0 sizesEqual : product from = product to} ->
340 | Tensor from dtype ->
342 | reshape $
MkTensor {shape} x = MkTensor $
V 0 $
Reshape dtype to x
350 | {auto 0 inBounds : axis `LTE` length shape} ->
351 | Tensor shape dtype ->
352 | Tensor (insertAt axis 1 shape) dtype
353 | expand axis $
MkTensor {shape = _} x = MkTensor $
V 0 $
Reshape dtype (insertAt axis 1 shape) x
355 | namespace Squeezable
359 | data Squeezable : (0 from : Shape) -> (0 to : Shape) -> Type where
364 | Same : Squeezable x x
370 | Match : Squeezable from to -> Squeezable (x :: from) (x :: to)
375 | Nest : Squeezable from to -> Squeezable (1 :: from) to
396 | {auto 0 shapesSqueezable : Squeezable from to} ->
397 | Tensor from dtype ->
399 | squeeze $
MkTensor {shape} x = MkTensor $
V 0 $
Reshape dtype to x
404 | data SliceOrIndex : Nat -> Type where
406 | (from, to : Nat) ->
408 | {auto 0 fromTo : from + size = to} ->
409 | {auto 0 inDim : LTE to d} ->
411 | Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
412 | DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
413 | DynamicIndex : Tensor [] U64 -> SliceOrIndex d
417 | at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
423 | at : Tensor [] U64 -> SliceOrIndex d
429 | (from, to : Nat) ->
431 | {auto 0 fromTo : from + size = to} ->
432 | {auto 0 inDim : LTE to d} ->
438 | (.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
439 | (.size) = DynamicSlice
443 | all : {d : _} -> SliceOrIndex d
444 | all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
449 | data MultiSlice : Shape -> Type where
450 | Nil : MultiSlice ds
451 | (::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
453 | namespace MultiSlice
457 | slice : {shape : _} -> MultiSlice shape -> Shape
458 | slice {shape} [] = shape
459 | slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
460 | slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
461 | slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
462 | slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
550 | slice : (at : MultiSlice shape) -> Tensor shape dtype -> Tensor (slice at) dtype
551 | slice at $
MkTensor x = MkTensor $
V 0 $
552 | let x = V 0 $
Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x
554 | x = if isDynamic at then V 0 $
DynamicSlice (dynStarts [] at) (mapd size id at) x else x
555 | in Reshape dtype (MultiSlice.slice at) x
558 | mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
561 | MultiSlice shape ->
563 | mapd _ dflt {shape} [] = Prelude.map dflt shape
564 | mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
566 | start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
567 | start _ (Slice from _) = from
568 | start _ (Index idx) = idx
569 | start f {d} _ = f d
571 | stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
572 | stop _ (Slice _ to) = to
573 | stop _ (Index idx) = S idx
576 | size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
577 | size _ (Slice {size = size'} _ _) = size'
578 | size _ (Index _) = 1
579 | size _ (DynamicSlice _ size') = size'
580 | size _ (DynamicIndex _) = 1
583 | zero = V 0 $
Lit [] U64 $
Scalar 0
585 | isDynamic : {shape : _} -> MultiSlice shape -> Bool
586 | isDynamic [] = False
587 | isDynamic {shape = (_ :: _)} (DynamicSlice _ _ :: _) = True
588 | isDynamic {shape = (_ :: _)} (DynamicIndex _ :: _) = True
589 | isDynamic (_ :: ds) = isDynamic ds
591 | dynStarts : List Value -> {shape : _} -> MultiSlice shape -> List Value
592 | dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
593 | dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
594 | dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
595 | dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
606 | {auto 0 inBounds : (InBounds axis s, InBounds axis s')} ->
607 | {auto 0 shapesConcatenable : deleteAt axis s = deleteAt axis s'} ->
608 | Tensor (replaceAt axis (index axis s + index axis s') s) dtype
609 | concat axis (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
Concat axis [x, x']
613 | (.T) : Tensor [m, n] dtype -> Tensor [n, m] dtype
614 | (MkTensor x).T = MkTensor $
V 0 $
Transpose [1, 0] x
662 | (ordering : List Nat) ->
663 | Tensor shape dtype ->
664 | {auto 0 lengths : length ordering = length shape} ->
665 | {auto 0 axesUnique : unique ordering = True} ->
666 | {auto 0 inBounds : All (flip InBounds shape) ordering} ->
667 | Tensor (multiIndex ordering shape) dtype
668 | transpose ordering $
MkTensor x = MkTensor $
V 0 $
Transpose ordering x
673 | data DimBroadcastable : (0 from : Nat) -> (0 to : Nat) -> Type where
676 | Same : DimBroadcastable x x
680 | Stack : DimBroadcastable 1 _
683 | Zero : DimBroadcastable _ 0
685 | namespace Broadcastable
689 | data Broadcastable : (0 from : Shape) -> (0 to : Shape) -> Type where
698 | Same : Broadcastable x x
706 | Match : forall from, to .
707 | {auto 0 ranksEq : length from = length to} ->
708 | {auto 0 dimBroadcastable : DimBroadcastable f t} ->
709 | Broadcastable from to ->
710 | Broadcastable (f :: from) (t :: to)
716 | Nest : Broadcastable f t -> Broadcastable f (_ :: t)
722 | broadcastableByLeading : (leading : List Nat) -> Broadcastable shape (leading ++ shape)
723 | broadcastableByLeading [] = Same
724 | broadcastableByLeading (l :: ls) = Nest (broadcastableByLeading ls)
729 | scalarToAnyOk : (to : Shape) -> Broadcastable [] to
730 | scalarToAnyOk to = rewrite sym $
appendNilRightNeutral to in broadcastableByLeading to
744 | {to : _} -> {dtype : _} ->
745 | {auto shapesOK : Broadcastable from to} ->
746 | Tensor from dtype ->
748 | broadcast $
MkTensor {shape = _} x = MkTensor $
V 0 $
Broadcast dtype from to x
762 | fill : {shape : _} -> {dtype : _} -> idrisType dtype -> Tensor shape dtype
763 | fill x = broadcast {shapesOK = scalarToAnyOk shape} (tensor (Scalar x))
793 | {auto 0 _ : Num dtype} ->
795 | {auto 0 inBounds : InBounds axis shape} ->
797 | iota dimension = MkTensor $
V 0 $
Iota shape dtype dimension
812 | (condition : Tensor shape dtype -> Tag $
Tensor [] PRED) ->
813 | (body : Tensor shape dtype -> Tag $
Tensor shape dtype) ->
814 | (initial : Tensor shape dtype) ->
815 | Tag $
Tensor shape dtype
816 | while1 condition body (MkTensor i0) =
817 | pure $
MkTensor $
V 0 $
While !(fn1 condition) !(fn1 body) [i0]
833 | (condition : Tensor s a -> Tensor s' a' -> Tag $
Tensor [] PRED) ->
834 | (body : Tensor s a -> Tensor s' a' -> Tag (Tensor s a, Tensor s' a')) ->
835 | (initial : Tensor s a) -> (initial' : Tensor s' a') ->
836 | Tag (Tensor s a, Tensor s' a')
837 | while2 condition body (MkTensor i) (MkTensor i') = do
838 | res <- tag $
While !(fn2 condition) !(fn22 body) [i, i']
839 | pure (MkTensor $
V 0 res, MkTensor $
V 1 res)
853 | map : {b : _} -> (Tensor [] a -> Tag $
Tensor [] b) -> Tensor shape a -> Tag $
Tensor shape b
854 | map f $
MkTensor {shape = _} x =
855 | pure $
MkTensor $
V 0 $
Map !(fn1 f) [x] (TensorType shape b) (range $
length shape)
871 | (Tensor [] a -> Tensor [] b -> Tag $
Tensor [] c) ->
872 | Tensor shape a -> Tensor shape b -> Tag $
Tensor shape c
873 | map2 f (MkTensor {shape = _} x) (MkTensor x') =
874 | pure $
MkTensor $
V 0 $
Map !(fn2 f) [x, x'] (TensorType shape c) (range $
length shape)
892 | (reducer : Monoid (Tensor [] dtype)) =>
893 | (axes : List Nat) ->
894 | {auto 0 axesUnique : Sorted LT axes} ->
895 | {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
896 | Tensor shape dtype ->
897 | Tag $
Tensor (deleteAt axes shape) dtype
898 | reduce axes $
MkTensor x = do
899 | let semigroup : Monoid a -> Semigroup a
900 | semigroup _ = %search
902 | g <- fn2 (pure .: (<+>) @{semigroup reducer})
903 | let MkTensor neutral' = neutral @{reducer}
904 | pure $
MkTensor $
V 0 $
Reduce g [neutral'] axes [x]
925 | (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
926 | (dimension : Nat) ->
927 | Tensor shape dtype ->
928 | {auto 0 dimInBounds : InBounds dimension shape} ->
929 | Tag $
Tensor shape dtype
930 | sort comp dimension $
MkTensor x =
931 | pure $
MkTensor $
V 0 $
Sort !(fn2 $
pure .: comp) dimension False x
963 | (axes : List Nat) ->
964 | {auto 0 axesUnique : Sorted LT axes} ->
965 | {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
966 | Tensor shape dtype ->
968 | reverse axes $
MkTensor x = MkTensor $
V 0 $
Reverse axes x
970 | ewUnary : UnaryOp -> Tensor s a -> Tensor s a
971 | ewUnary op $
MkTensor x = MkTensor $
V 0 $
UnaryElementwise op x
973 | ewBinary : BinaryOp -> Tensor s a -> Tensor s a -> Tensor s a
974 | ewBinary op (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise op x x'
976 | ewBinary' : {out : _} -> BinaryOp -> Tensor s a -> Tensor s a -> Tensor s out
977 | ewBinary' op (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise op x x'
982 | (==) : Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
983 | (==) = ewBinary' $
Compare Eq
988 | (/=) : Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
989 | (/=) = ewBinary' $
Compare Ne
994 | (<) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
995 | (<) = ewBinary' $
Compare Lt
1000 | (>) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1001 | (>) = ewBinary' $
Compare Gt
1006 | (<=) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1007 | (<=) = ewBinary' $
Compare Le
1012 | (>=) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1013 | (>=) = ewBinary' $
Compare Ge
1019 | (&&) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1024 | [All] Semigroup (Tensor shape PRED) where
1029 | [All] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.All where
1036 | (||) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1041 | [Any] Semigroup (Tensor shape PRED) where
1046 | [Any] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.Any where
1052 | not : Tensor shape PRED -> Tensor shape PRED
1075 | (onTrue, onFalse : Tensor shape dtype) ->
1077 | select (MkTensor p) (MkTensor t) (MkTensor f) = MkTensor $
V 0 $
Select p t f
1098 | {shape : _} -> {dtype : _} ->
1100 | (onTrue, onFalse : Tag $
Tensor shape dtype) ->
1101 | Tag $
Tensor shape dtype
1102 | if_ (MkTensor pred) onTrue onFalse =
1103 | pure $
MkTensor $
V 0 $
If (TensorType shape dtype) pred !(fn0 onTrue) !(fn0 onFalse)
1117 | identity : {n : _} -> {dtype : _} -> Num dtype => Tensor [n, n] dtype
1119 | let MkTensor x = iota 0 {shape = [n, n], dtype = U64} == iota 1
1120 | in MkTensor $
V 0 $
Convert dtype [n, n] x
1130 | (@@) : Num dtype => Tensor [S m] dtype -> Tensor [S m] dtype -> Tensor [] dtype
1131 | (MkTensor x) @@ (MkTensor x') =
1132 | MkTensor $
V 0 $
DotGeneral [] [] [0] [0] (TensorType [] dtype) x x'
1155 | Tensor [n, S m] dtype ->
1156 | Tensor (S m :: tl) dtype ->
1157 | {auto 0 vectorTail : length tl `LTE` 1} ->
1159 | (MkTensor x) @@ (MkTensor x') =
1160 | MkTensor $
V 0 $
DotGeneral [] [] [1] [0] (TensorType (n :: tl) dtype) x x'
1164 | contract : (lBatch, rBatch, lContract, rContract : List Nat) ->
1166 | {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1167 | {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1168 | {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1169 | {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1171 | contract lBatch rBatch lContract rContract ls rs =
1172 | let lResultDims = deleteAt {inBounds = lInBoundsBatch ++ lInBoundsContract}
1173 | (lBatch ++ lContract) ls
1174 | rResultDims = deleteAt {inBounds = rInBoundsBatch ++ rInBoundsContract}
1175 | (rBatch ++ rContract) rs
1176 | in multiIndex lBatch ls ++ lResultDims ++ rResultDims
1206 | (lBatch, rBatch, lContract, rContract : List Nat) ->
1207 | {auto 0 lUnique : unique (lBatch ++ lContract) = True} ->
1208 | {auto 0 rUnique : unique (rBatch ++ rContract) = True} ->
1209 | {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1210 | {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1211 | {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1212 | {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1213 | {auto 0 batchDimsEq : multiIndex lBatch ls = multiIndex rBatch rs} ->
1214 | {auto 0 contractDimsEq : multiIndex lContract ls = multiIndex rContract rs} ->
1217 | Tensor (contract lBatch rBatch lContract rContract ls rs) dtype
1218 | dotGeneral lb rb lc rc (MkTensor x) (MkTensor y) =
1219 | let resultType = TensorType (contract lb rb lc rc ls rs) dtype
1220 | in MkTensor $
V 0 $
DotGeneral lb rb lc rc resultType x y
1225 | (+) : Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1230 | [Sum] Num dtype => Semigroup (Tensor shape dtype) where
1235 | [Sum] {shape : _} -> {dtype : _} -> Prelude.Num (idrisType dtype) => Num dtype =>
1236 | Monoid (Tensor shape dtype) using Semigroup.Sum where
1241 | negate : Neg dtype => Tensor shape dtype -> Tensor shape dtype
1242 | negate $
MkTensor i = MkTensor $
V 0 $
UnaryElementwise Neg i
1247 | (-) : Neg dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1253 | (*) : Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1262 | (*) : Num dtype => Tensor [] dtype -> Tensor (d :: ds) dtype -> Tensor (d :: ds) dtype
1264 | let MkTensor {shape = _ :: _} _ = r
1265 | in broadcast {shapesOK = scalarToAnyOk (d :: ds)} l * r
1269 | [Prod] Num dtype => Semigroup (Tensor shape dtype) where
1274 | [Prod] {shape : _} -> {dtype : _} -> Prelude.Num (idrisType dtype) => Num dtype =>
1275 | Monoid (Tensor shape dtype) using Semigroup.Prod where
1281 | (/) : Fractional dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1290 | (/) : Fractional dtype => Tensor (d :: ds) dtype -> Tensor [] dtype -> Tensor (d :: ds) dtype
1292 | let MkTensor {shape = _ :: _} _ = l
1293 | in l / broadcast {shapesOK = scalarToAnyOk (d :: ds)} r
1295 | inf = tensor 1.0 / tensor 0.0
1296 | nan = tensor 0.0 / tensor 0.0
1301 | div : Tensor shape U64 ->
1302 | (denom : Literal shape Nat) ->
1303 | {auto 0 isSucc : All IsSucc denom} ->
1306 | _ | (MkTensor {shape = _} _) = ewBinary Div x (tensor {dtype = U64} $
cast <$> y)
1312 | rem : Tensor shape U64 ->
1313 | (denom : Literal shape Nat) ->
1314 | {auto 0 isSucc : All IsSucc denom} ->
1317 | _ | (MkTensor {shape = _} _) = ewBinary Rem x (tensor {dtype = U64} $
cast <$> y)
1329 | (^) : Tensor shape F64 -> Tensor shape F64 -> Tensor shape F64
1332 | (>>) : Tensor shape U64 -> Tensor shape U64 -> Tensor shape U64
1333 | (>>) = ewBinary ShiftRightLogical
1337 | abs : Abs dtype => Tensor shape dtype -> Tensor shape dtype
1343 | exp : Tensor shape F64 -> Tensor shape F64
1350 | floor : Tensor shape F64 -> Tensor shape F64
1357 | ceil : Tensor shape F64 -> Tensor shape F64
1363 | log : Tensor shape F64 -> Tensor shape F64
1368 | logistic : Tensor shape F64 -> Tensor shape F64
1369 | logistic = ewUnary Logistic
1373 | sin : Tensor shape F64 -> Tensor shape F64
1378 | cos : Tensor shape F64 -> Tensor shape F64
1383 | tan : Tensor shape F64 -> Tensor shape F64
1388 | asin : Tensor shape F64 -> Tensor shape F64
1393 | acos : Tensor shape F64 -> Tensor shape F64
1398 | atan : Tensor shape F64 -> Tensor shape F64
1403 | sinh : Tensor shape F64 -> Tensor shape F64
1408 | cosh : Tensor shape F64 -> Tensor shape F64
1413 | tanh : Tensor shape F64 -> Tensor shape F64
1418 | asinh : Tensor shape F64 -> Tensor shape F64
1423 | acosh : Tensor shape F64 -> Tensor shape F64
1428 | atanh : Tensor shape F64 -> Tensor shape F64
1433 | erf : Tensor shape F64 -> Tensor shape F64
1436 | erfInv : Tensor shape F64 -> Tensor shape F64
1437 | erfInv = ewUnary ErfInv
1442 | square : Tensor shape F64 -> Tensor shape F64
1443 | square = ewUnary Square
1448 | sqrt : Tensor shape F64 -> Tensor shape F64
1454 | min : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1455 | min (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise Min x x'
1459 | [Min] {shape : _} -> Ord dtype => Semigroup (Tensor shape dtype) where
1464 | [Min] {shape : _} -> {dtype : _} -> Ord dtype =>
1465 | Monoid (Tensor shape dtype) using Semigroup.Min where
1466 | neutral = broadcast max
1471 | max : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1472 | max (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise Max x x'
1476 | [Max] Ord dtype => Semigroup (Tensor shape dtype) where
1481 | [Max] {shape : _} -> {dtype : _} -> Ord dtype =>
1482 | Monoid (Tensor shape dtype) using Semigroup.Max where
1483 | neutral = broadcast min
1494 | diag : Num dtype => Prelude.Num (idrisType dtype) => Tensor [n, n] dtype -> Tensor [n] dtype
1495 | diag {n = 0} x@(MkTensor {shape = [0, 0]} _) = reshape x
1496 | diag {n = S n} x@(MkTensor {shape = [S n, S n]} _) = (x * identity) @@ fill 1
1500 | (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
1504 | argmxx cmp (MkTensor bound) x@(MkTensor {shape = _} _) = do
1505 | let MkTensor idxs : Tensor [S n] U64 = iota 0
1507 | MkTensor zero = tensor {dtype = U64} 0
1510 | Tensor [] dtype -> Tensor [] U64 ->
1511 | Tensor [] dtype -> Tensor [] U64 ->
1512 | Tag (Tensor [] dtype, Tensor [] U64)
1514 | let useNext = x == x && (cmp x' x || x' /= x')
1515 | in pure (select useNext x' x, select useNext y' y)
1521 | (MkTensor $
V 0 $
BoundSet addr)
1522 | (MkTensor $
V 1 $
BoundSet addr)
1523 | (MkTensor $
V 2 $
BoundSet addr)
1524 | (MkTensor $
V 3 $
BoundSet addr)
1525 | (env, (MkTensor m, MkTensor i)) = runState (emptyFrom !get) res
1526 | argTys = [TensorType [] dtype, TensorType [] U64]
1527 | f = MkFn addr (argTys ++ argTys) argTys [m, i] env
1530 | pure $
MkTensor $
V 1 $
Reduce f [bound, zero] [0] [x, idxs]
1536 | argmax : Ord dtype => Tensor [S n] dtype -> Tag $
Tensor [] U64
1537 | argmax x@(MkTensor {dtype} _) = argmxx (>) min x
1543 | argmin : Ord dtype => Tensor [S n] dtype -> Tag $
Tensor [] U64
1544 | argmin x@(MkTensor {dtype} _) = argmxx (<) max x
1548 | data Triangle = Upper | Lower
1567 | Prelude.Num (idrisType dtype) =>
1570 | Tag $
Tensor [n, n] dtype
1571 | triangle tri (MkTensor x) = do
1572 | let range : Tensor [n * n] U64 = iota 0
1573 | indices <- tag $
reshape {to = [n, n], sizesEqual = productSquare n} range
1577 | pure $
select (op indices indices.T) (fill $
fromInteger 0) (MkTensor x)
1581 | productSquare : (m : Nat) -> product (the Shape [m * m]) = product (the Shape [m, m])
1585 | ~~ m * m ... plusZeroRightNeutral (m * m)
1586 | ~~ (m + 0) * m ..< cong (* m) (plusZeroRightNeutral m)
1593 | cholesky : Tensor [S n, S n] F64 -> Tag $
Tensor [S n, S n] F64
1594 | cholesky $
MkTensor x = triangle Lower (MkTensor $
V 0 $
Cholesky x)
1607 | (|\) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1608 | (MkTensor a) |\ (MkTensor b) = MkTensor $
V 0 $
TriangularSolve a b True
1618 | (\|) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1619 | (MkTensor a) \| (MkTensor b) = MkTensor $
V 0 $
TriangularSolve a b False
1630 | (|\) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1631 | a |\ b = let (MkTensor {shape = [_]} _) = b in squeeze (a |\ expand 1 b)
1641 | (\|) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1642 | a \| b = let (MkTensor {shape = [_]} _) = b in squeeze (a \| expand 1 b)
1647 | trace : Num dtype => Prelude.Num (idrisType dtype) =>
1648 | Tensor [S n, S n] dtype ->
1651 | _ | MkTensor {shape = [_, _]} _ = reduce @{Sum} [0, 1] $
x * identity
1657 | Rand = StateT (Tensor [2] U64) Tag
1670 | rng : {shape : _} -> Rand $
Tensor shape U64
1671 | rng = ST $
\(MkTensor state) => do
1672 | res <- tag $
Rng state (TensorType shape U64)
1673 | pure (MkTensor $
V 0 res, MkTensor $
V 1 res)
1686 | uniform : {shape : _} -> Rand $
Tensor shape F64
1688 | let numMantissaBits : Bits64 = 52
1689 | scale = broadcast $
2.0 ^ tensor (Scalar $
- cast {to = Double} numMantissaBits)
1690 | shift = fill $
64 - numMantissaBits
1691 | in rng {shape} <&> \x => castDtype (x >> shift) * scale
1705 | normal : {shape : _} -> Rand $
Tensor shape F64
1706 | normal = uniform <&> \x => sqrt (broadcast 2.0) * erfInv (broadcast 2.0 * x - broadcast 1.0)