25 | import Control.Monad.Error.Either
26 | import public Control.Monad.State
27 | import public Data.List
28 | import public Data.List.Elem
29 | import Data.List.Quantifiers
30 | import Decidable.Equality
31 | import Syntax.PreorderReasoning
33 | import Compiler.Eval
35 | import Compiler.Xla.Shape
36 | import Compiler.Xla.ShapeUtil
37 | import Compiler.LiteralRW
39 | import public Literal
40 | import public Primitive
45 | XlaShape = Xla.Shape
54 | data Tensor : (shape : Shape) -> (dtype : Type) -> Type where
55 | MkTensor : Value -> {shape : _} -> Tensor shape dtype
59 | data TagT : (Type -> Type) -> Type -> Type where
60 | MkTagT : StateT Env m a -> TagT m a
67 | Functor m => Functor (TagT m) where
68 | map f (MkTagT x) = MkTagT (map f x)
71 | Monad m => Applicative (TagT m) where
72 | pure x = MkTagT (pure x)
73 | (MkTagT f) <*> (MkTagT x) = MkTagT (f <*> x)
76 | Monad m => Monad (TagT m) where
77 | (MkTagT x) >>= f = MkTagT $
x >>= (\y => let MkTagT z = f y in z)
80 | MonadTrans TagT where
81 | lift = MkTagT . lift
84 | interface Taggable a where
101 | tag : Monad m => a -> TagT m a
106 | tag (BoundSet x) = pure (BoundSet x)
107 | tag x = MkTagT $
tagOp x
110 | Taggable (Tensor shape dtype) where
111 | tag (MkTensor $
V idx op) = map (\op => MkTensor $
V idx op) (tag op)
114 | (Taggable a, Taggable b) => Taggable (a, b) where
115 | tag (a, b) = [| (tag a, tag b) |]
124 | tensor : PrimitiveRW dtype a => {shape : _} -> Literal shape a -> Tensor shape dtype
125 | tensor lit = MkTensor $
V 0 $
Lit {dtype} {shape} lit
129 | fromDouble : Double -> Tensor [] F64
130 | fromDouble = tensor . Scalar
134 | fromInteger : Integer -> Tensor [] S32
135 | fromInteger = tensor . Scalar . fromInteger
137 | try : Show e => EitherT e IO a -> IO a
138 | try = eitherT (\e => assert_total $
idris_crash $
show e) pure
140 | namespace TensorList
145 | data TensorList : List Shape -> List Type -> Type where
146 | Nil : TensorList [] []
147 | (::) : PrimitiveRW dtype ty =>
148 | Tensor shape dtype ->
149 | TensorList shapes tys ->
150 | TensorList (shape :: shapes) (ty :: tys)
165 | eval : Device -> Tag (TensorList shapes tys) -> IO (All2 Literal shapes tys)
166 | eval device (MkTagT xs) =
167 | let (env, xs) = runState empty xs
169 | xlaShapes <- buildShapes xs
170 | let (
outputs ** eq)
= lengthC xs
171 | main = MkFn 0 [] (resultTypes xs) (results xs) env
172 | lits <- execute device main {outputs} (rewrite eq in xlaShapes)
173 | readAll xs $
rewrite sym eq in lits
177 | lengthC : TensorList s t -> (
n ** n === length s)
178 | lengthC [] = (
0 ** Refl)
179 | lengthC (_ :: xs) = let (
n ** eq)
= lengthC xs in (
S n ** cong S eq)
181 | buildShapes : HasIO io => TensorList s t -> io $
Vect (length s) XlaShape
182 | buildShapes [] = pure []
183 | buildShapes (MkTensor {shape, dtype} _ :: ts) = [| mkShape shape {dtype} :: buildShapes ts |]
185 | results : TensorList s t -> Vect (length s) Value
187 | results (MkTensor x :: xs) = x :: results xs
189 | resultTypes : TensorList s t -> Vect (length s) ValueType
190 | resultTypes [] = []
191 | resultTypes (MkTensor {shape, dtype} _ :: xs) = TensorType shape dtype :: resultTypes xs
193 | readAll : HasIO io => TensorList s t -> Vect (length s) Literal -> io $
All2 Literal s t
194 | readAll [] _ = pure []
195 | readAll (MkTensor {dtype} _ :: ts) (l :: ls) = [| read {dtype} l :: readAll ts ls |]
199 | eval : Device -> TensorList shapes tys -> IO (All2 Literal shapes tys)
200 | eval device xs = eval device (pure xs)
210 | eval : Device -> PrimitiveRW dtype ty => Tag (Tensor shape dtype) -> IO (Literal shape ty)
211 | eval device x = map (\[z] => z) $
eval device $
map (\z => [z]) x
215 | eval : Device -> PrimitiveRW dtype ty => Tensor shape dtype -> IO (Literal shape ty)
216 | eval device x = eval device (pure x)
222 | Primitive dtype => Show (Tag $
Tensor shape dtype) where
224 | let (env, MkTensor x) = runState empty x
225 | in show (MkFn 0 [] [TensorType shape dtype] [x] env)
229 | inf : Tensor [] F64
233 | nan : Tensor [] F64
235 | namespace Primitive
238 | interface Primitive.Eq dtype => Primitive.Ord dtype where
240 | min : Tensor [] dtype
243 | max : Tensor [] dtype
246 | Primitive.Ord U32 where
247 | min = MkTensor $
V 0 $
MinValue U32
248 | max = MkTensor $
V 0 $
MaxValue U32
251 | Primitive.Ord S32 where
252 | min = MkTensor $
V 0 $
MinValue S32
253 | max = MkTensor $
V 0 $
MaxValue S32
256 | Primitive.Ord U64 where
257 | min = MkTensor $
V 0 $
MinValue U64
258 | max = MkTensor $
V 0 $
MaxValue U64
261 | Primitive.Ord F64 where
262 | min = tensor $
Scalar $
-1.0 / 0.0
263 | max = tensor $
Scalar $
1.0 / 0.0
267 | minFinite : Tensor [] F64
268 | minFinite = MkTensor $
V 0 $
MinFiniteFloat
272 | maxFinite : Tensor [] F64
273 | maxFinite = MkTensor $
V 0 $
MaxFiniteFloat
278 | castDtype : Primitive.Integral a => Tensor shape a -> Tensor shape F64
279 | castDtype $
MkTensor {shape} x = MkTensor $
V 0 $
Convert {dtype = F64} shape x
281 | fn0 : Primitive a => Tag (Tensor s a) -> Tag $
Fn 0
282 | fn0 f = MkTagT $
do
286 | (env, MkTensor res) = runState (emptyFrom !get) res
287 | f = MkFn addr [] [TensorType s a] [res] env
289 | updateCounterFrom env
292 | fn1 : Primitive a => Primitive a' => {s : _} -> (Tensor s a -> Tag $
Tensor s' a') -> Tag $
Fn 1
293 | fn1 f = MkTagT $
do
296 | let MkTagT res = f (MkTensor $
V 0 $
BoundSet addr)
297 | (env, MkTensor res) = runState (emptyFrom !get) res
298 | f = MkFn addr [TensorType s a] [TensorType s' a'] [res] env
300 | updateCounterFrom env
304 | Primitive a => Primitive a' => Primitive a'' => {s, s' : _} ->
305 | (Tensor s a -> Tensor s' a' -> Tag $
Tensor s'' a'') ->
307 | fn2 f = MkTagT $
do
310 | let MkTagT res = f (MkTensor $
V 0 $
BoundSet addr) (MkTensor $
V 1 $
BoundSet addr)
311 | (env, MkTensor res) = runState (emptyFrom !get) res
312 | f = MkFn addr [TensorType s a, TensorType s' a'] [TensorType s'' a''] [res] env
314 | updateCounterFrom env
318 | Primitive a => Primitive a' => Primitive a'' => Primitive a''' => {s, s' : _} ->
319 | (Tensor s a -> Tensor s' a' -> Tag $
(Tensor s'' a'', Tensor s''' a''')) ->
321 | fn22 f = MkTagT $
do
324 | let MkTagT res = f (MkTensor $
V 0 $
BoundSet addr) (MkTensor $
V 1 $
BoundSet addr)
325 | (env, (MkTensor res0, MkTensor res1)) = runState (emptyFrom !get) res
326 | resTys = [TensorType s'' a'', TensorType s''' a''']
327 | f = MkFn addr [TensorType s a, TensorType s' a'] resTys [res0, res1] env
329 | updateCounterFrom env
349 | grad : (Tensor shape F64 -> Tag $
Tensor [] F64) -> Tensor shape F64 -> Tag $
Tensor shape F64
350 | grad f (MkTensor x) = pure $
MkTensor $
V 0 $
Grad shape !(fn1 f) x
358 | {auto 0 sizesEqual : product from = product to} ->
359 | Tensor from dtype ->
361 | reshape $
MkTensor {shape} x = MkTensor $
V 0 $
Reshape {dtype} to x
370 | {auto 0 inBounds : axis `LTE` length shape} ->
371 | Tensor shape dtype ->
372 | Tensor (insertAt axis 1 shape) dtype
373 | expand axis $
MkTensor {shape = _} x = MkTensor $
V 0 $
Reshape {dtype} (insertAt axis 1 shape) x
375 | namespace Squeezable
379 | data Squeezable : (0 from : Shape) -> (0 to : Shape) -> Type where
384 | Same : Squeezable x x
390 | Match : Squeezable from to -> Squeezable (x :: from) (x :: to)
395 | Nest : Squeezable from to -> Squeezable (1 :: from) to
417 | {auto 0 shapesSqueezable : Squeezable from to} ->
418 | Tensor from dtype ->
420 | squeeze $
MkTensor {shape} x = MkTensor $
V 0 $
Reshape {dtype} to x
425 | data SliceOrIndex : Nat -> Type where
427 | (from, to : Nat) ->
429 | {auto 0 fromTo : from + size = to} ->
430 | {auto 0 inDim : LTE to d} ->
432 | Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
433 | DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
434 | DynamicIndex : Tensor [] U64 -> SliceOrIndex d
438 | at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
444 | at : Tensor [] U64 -> SliceOrIndex d
450 | (from, to : Nat) ->
452 | {auto 0 fromTo : from + size = to} ->
453 | {auto 0 inDim : LTE to d} ->
459 | (.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
460 | (.size) = DynamicSlice
464 | all : {d : _} -> SliceOrIndex d
465 | all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
470 | data MultiSlice : Shape -> Type where
471 | Nil : MultiSlice ds
472 | (::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
474 | namespace MultiSlice
478 | slice : {shape : _} -> MultiSlice shape -> Shape
479 | slice {shape} [] = shape
480 | slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
481 | slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
482 | slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
483 | slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
573 | (at : MultiSlice shape) ->
574 | Tensor shape dtype ->
575 | Tensor (slice at) dtype
576 | slice at $
MkTensor x = MkTensor $
V 0 $
577 | let x = V 0 $
Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x
579 | x = if isDynamic at then V 0 $
DynamicSlice (dynStarts [] at) (mapd size id at) x else x
580 | in Reshape {dtype} (MultiSlice.slice at) x
583 | mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
586 | MultiSlice shape ->
588 | mapd _ dflt {shape} [] = Prelude.map dflt shape
589 | mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
591 | start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
592 | start _ (Slice from _) = from
593 | start _ (Index idx) = idx
594 | start f {d} _ = f d
596 | stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
597 | stop _ (Slice _ to) = to
598 | stop _ (Index idx) = S idx
601 | size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
602 | size _ (Slice {size = size'} _ _) = size'
603 | size _ (Index _) = 1
604 | size _ (DynamicSlice _ size') = size'
605 | size _ (DynamicIndex _) = 1
608 | zero = V 0 $
Lit {shape = []} {dtype = U64} 0
610 | isDynamic : {shape : _} -> MultiSlice shape -> Bool
611 | isDynamic [] = False
612 | isDynamic {shape = (_ :: _)} (DynamicSlice _ _ :: _) = True
613 | isDynamic {shape = (_ :: _)} (DynamicIndex _ :: _) = True
614 | isDynamic (_ :: ds) = isDynamic ds
616 | dynStarts : List Value -> {shape : _} -> MultiSlice shape -> List Value
617 | dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
618 | dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
619 | dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
620 | dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
632 | {auto 0 inBounds : (InBounds axis s, InBounds axis s')} ->
633 | {auto 0 shapesConcatenable : deleteAt axis s = deleteAt axis s'} ->
634 | Tensor (replaceAt axis (index axis s + index axis s') s) dtype
635 | concat axis (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
Concat axis [x, x']
639 | (.T) : Tensor [m, n] dtype -> Tensor [n, m] dtype
640 | (MkTensor x).T = MkTensor $
V 0 $
Transpose [1, 0] x
688 | (ordering : List Nat) ->
689 | Tensor shape dtype ->
690 | {auto 0 lengths : length ordering = length shape} ->
691 | {auto 0 axesUnique : unique ordering = True} ->
692 | {auto 0 inBounds : All (flip InBounds shape) ordering} ->
693 | Tensor (multiIndex ordering shape) dtype
694 | transpose ordering $
MkTensor x = MkTensor $
V 0 $
Transpose ordering x
699 | data DimBroadcastable : (0 from : Nat) -> (0 to : Nat) -> Type where
702 | Same : DimBroadcastable x x
706 | Stack : DimBroadcastable 1 _
709 | Zero : DimBroadcastable _ 0
711 | namespace Broadcastable
715 | data Broadcastable : (0 from : Shape) -> (0 to : Shape) -> Type where
724 | Same : Broadcastable x x
732 | Match : forall from, to .
733 | {auto 0 ranksEq : length from = length to} ->
734 | {auto 0 dimBroadcastable : DimBroadcastable f t} ->
735 | Broadcastable from to ->
736 | Broadcastable (f :: from) (t :: to)
742 | Nest : Broadcastable f t -> Broadcastable f (_ :: t)
748 | broadcastableByLeading : (leading : List Nat) -> Broadcastable shape (leading ++ shape)
749 | broadcastableByLeading [] = Same
750 | broadcastableByLeading (l :: ls) = Nest (broadcastableByLeading ls)
755 | scalarToAnyOk : (to : Shape) -> Broadcastable [] to
756 | scalarToAnyOk to = rewrite sym $
appendNilRightNeutral to in broadcastableByLeading to
772 | {auto shapesOK : Broadcastable from to} ->
773 | Tensor from dtype ->
775 | broadcast $
MkTensor {shape = _} x = MkTensor $
V 0 $
Broadcast {dtype} from to x
789 | fill : PrimitiveRW dtype ty => {shape : _} -> ty -> Tensor shape dtype
790 | fill x = broadcast {shapesOK = scalarToAnyOk shape} (tensor (Scalar x))
817 | iota : Primitive.Num dtype =>
820 | {auto 0 inBounds : InBounds axis shape} ->
822 | iota dimension = MkTensor $
V 0 $
Iota shape {dtype} dimension
838 | (condition : Tensor shape dtype -> Tag $
Tensor [] PRED) ->
839 | (body : Tensor shape dtype -> Tag $
Tensor shape dtype) ->
840 | (initial : Tensor shape dtype) ->
841 | Tag $
Tensor shape dtype
842 | while1 condition body (MkTensor i0) =
843 | pure $
MkTensor $
V 0 $
While !(fn1 condition) !(fn1 body) [i0]
859 | Primitive a => Primitive a' =>
860 | (condition : Tensor s a -> Tensor s' a' -> Tag $
Tensor [] PRED) ->
861 | (body : Tensor s a -> Tensor s' a' -> Tag (Tensor s a, Tensor s' a')) ->
862 | (initial : Tensor s a) -> (initial' : Tensor s' a') ->
863 | Tag (Tensor s a, Tensor s' a')
864 | while2 condition body (MkTensor i) (MkTensor i') = do
865 | res <- tag $
While !(fn2 condition) !(fn22 body) [i, i']
866 | pure (MkTensor $
V 0 res, MkTensor $
V 1 res)
880 | map : (Primitive a, Primitive b) =>
881 | (Tensor [] a -> Tag $
Tensor [] b) ->
882 | Tensor shape a -> Tag $
Tensor shape b
883 | map f $
MkTensor {shape = _} x =
884 | pure $
MkTensor $
V 0 $
Map !(fn1 f) [x] (TensorType shape b) (range $
length shape)
899 | (Primitive a, Primitive b, Primitive c) =>
900 | (Tensor [] a -> Tensor [] b -> Tag $
Tensor [] c) ->
901 | Tensor shape a -> Tensor shape b -> Tag $
Tensor shape c
902 | map2 f (MkTensor {shape = _} x) (MkTensor x') =
903 | pure $
MkTensor $
V 0 $
Map !(fn2 f) [x, x'] (TensorType shape c) (range $
length shape)
921 | (reducer : Monoid (Tensor [] dtype)) =>
923 | (axes : List Nat) ->
924 | {auto 0 axesUnique : Sorted LT axes} ->
925 | {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
926 | Tensor shape dtype ->
927 | Tag $
Tensor (deleteAt axes shape) dtype
928 | reduce axes $
MkTensor x = do
929 | let semigroup : Monoid a -> Semigroup a
930 | semigroup _ = %search
932 | g <- fn2 (pure .: (<+>) @{semigroup reducer})
933 | let MkTensor neutral' = neutral @{reducer}
934 | pure $
MkTensor $
V 0 $
Reduce g [neutral'] axes [x]
956 | (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
957 | (dimension : Nat) ->
958 | Tensor shape dtype ->
959 | {auto 0 dimInBounds : InBounds dimension shape} ->
960 | Tag $
Tensor shape dtype
961 | sort comp dimension $
MkTensor x =
962 | pure $
MkTensor $
V 0 $
Sort !(fn2 $
pure .: comp) dimension False x
994 | (axes : List Nat) ->
995 | {auto 0 axesUnique : Sorted LT axes} ->
996 | {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
997 | Tensor shape dtype ->
999 | reverse axes $
MkTensor x = MkTensor $
V 0 $
Reverse axes x
1001 | ewUnary : UnaryOp -> Tensor s a -> Tensor s a'
1002 | ewUnary op $
MkTensor x = MkTensor $
V 0 $
UnaryElementwise op x
1004 | ewBinary : BinaryOp -> Tensor s a -> Tensor s a' -> Tensor s a''
1005 | ewBinary op (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise op x x'
1010 | (==) : Primitive.Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1011 | (==) = ewBinary $
Compare Eq
1016 | (/=) : Primitive.Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1017 | (/=) = ewBinary $
Compare Ne
1022 | (<) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1023 | (<) = ewBinary $
Compare Lt
1028 | (>) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1029 | (>) = ewBinary $
Compare Gt
1034 | (<=) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1035 | (<=) = ewBinary $
Compare Le
1040 | (>=) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1041 | (>=) = ewBinary $
Compare Ge
1047 | (&&) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1052 | [All] Semigroup (Tensor shape PRED) where
1057 | [All] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.All where
1064 | (||) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1069 | [Any] Semigroup (Tensor shape PRED) where
1074 | [Any] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.Any where
1080 | not : Tensor shape PRED -> Tensor shape PRED
1104 | (onTrue, onFalse : Tensor shape dtype) ->
1106 | select (MkTensor p) (MkTensor t) (MkTensor f) = MkTensor $
V 0 $
Select p t f
1130 | (onTrue, onFalse : Tag $
Tensor shape dtype) ->
1131 | Tag $
Tensor shape dtype
1132 | if_ (MkTensor pred) onTrue onFalse =
1133 | pure $
MkTensor $
V 0 $
If (TensorType shape dtype) pred !(fn0 onTrue) !(fn0 onFalse)
1147 | identity : Primitive.Num dtype => {n : _} -> Tensor [n, n] dtype
1149 | let MkTensor x = iota 0 {shape = [n, n], dtype = U64} == iota 1
1150 | in MkTensor $
V 0 $
Convert {dtype} [n, n] x
1160 | (@@) : Primitive.Num dtype => Tensor [S m] dtype -> Tensor [S m] dtype -> Tensor [] dtype
1161 | (MkTensor x) @@ (MkTensor x') =
1162 | MkTensor $
V 0 $
DotGeneral [] [] [0] [0] (TensorType [] dtype) x x'
1184 | (@@) : (Primitive dtype, Primitive.Num dtype) =>
1185 | Tensor [n, S m] dtype ->
1186 | Tensor (S m :: tl) dtype ->
1187 | {auto 0 vectorTail : length tl `LTE` 1} ->
1189 | (MkTensor x) @@ (MkTensor x') =
1190 | MkTensor $
V 0 $
DotGeneral [] [] [1] [0] (TensorType (n :: tl) dtype) x x'
1194 | contract : (lBatch, rBatch, lContract, rContract : List Nat) ->
1196 | {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1197 | {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1198 | {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1199 | {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1201 | contract lBatch rBatch lContract rContract ls rs =
1202 | let lResultDims = deleteAt {inBounds = lInBoundsBatch ++ lInBoundsContract}
1203 | (lBatch ++ lContract) ls
1204 | rResultDims = deleteAt {inBounds = rInBoundsBatch ++ rInBoundsContract}
1205 | (rBatch ++ rContract) rs
1206 | in multiIndex lBatch ls ++ lResultDims ++ rResultDims
1236 | (lBatch, rBatch, lContract, rContract : List Nat) ->
1237 | {auto 0 lUnique : unique (lBatch ++ lContract) = True} ->
1238 | {auto 0 rUnique : unique (rBatch ++ rContract) = True} ->
1239 | {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1240 | {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1241 | {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1242 | {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1243 | {auto 0 batchDimsEq : multiIndex lBatch ls = multiIndex rBatch rs} ->
1244 | {auto 0 contractDimsEq : multiIndex lContract ls = multiIndex rContract rs} ->
1247 | Tensor (contract lBatch rBatch lContract rContract ls rs) dtype
1248 | dotGeneral lb rb lc rc (MkTensor x) (MkTensor y) =
1249 | let resultType = TensorType (contract lb rb lc rc ls rs) dtype
1250 | in MkTensor $
V 0 $
DotGeneral lb rb lc rc resultType x y
1255 | (+) : Primitive.Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1260 | [Sum] Primitive.Num dtype => Semigroup (Tensor shape dtype) where
1269 | Monoid (Tensor shape dtype) using Semigroup.Sum where
1274 | negate : Primitive.Neg dtype => Tensor shape dtype -> Tensor shape dtype
1275 | negate $
MkTensor i = MkTensor $
V 0 $
UnaryElementwise Neg i
1280 | (-) : Primitive.Neg dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1286 | (*) : Primitive.Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1295 | (*) : Primitive.Num dtype => Tensor [] dtype -> Tensor (d :: ds) dtype -> Tensor (d :: ds) dtype
1297 | let MkTensor {shape = _ :: _} _ = r
1298 | in broadcast {shapesOK = scalarToAnyOk (d :: ds)} l * r
1302 | [Prod] Primitive.Num dtype => Semigroup (Tensor shape dtype) where
1311 | Monoid (Tensor shape dtype) using Semigroup.Prod where
1317 | (/) : Primitive.Fractional dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1326 | (/) : Primitive.Fractional dtype =>
1327 | Tensor (d :: ds) dtype ->
1331 | let MkTensor {shape = _ :: _} _ = l
1332 | in l / broadcast {shapesOK = scalarToAnyOk (d :: ds)} r
1334 | inf = tensor 1.0 / tensor 0.0
1335 | nan = tensor 0.0 / tensor 0.0
1340 | div : Tensor shape U64 ->
1341 | (denom : Literal shape Nat) ->
1342 | {auto 0 isSucc : All IsSucc denom} ->
1345 | _ | (MkTensor {shape = _} _) = ewBinary Div x (tensor {dtype = U64} y)
1351 | rem : Tensor shape U64 ->
1352 | (denom : Literal shape Nat) ->
1353 | {auto 0 isSucc : All IsSucc denom} ->
1356 | _ | (MkTensor {shape = _} _) = ewBinary Rem x (tensor {dtype = U64} y)
1368 | (^) : Tensor shape F64 -> Tensor shape F64 -> Tensor shape F64
1371 | (>>) : Tensor shape U64 -> Tensor shape U64 -> Tensor shape U64
1372 | (>>) = ewBinary ShiftRightLogical
1376 | abs : Primitive.Abs dtype => Tensor shape dtype -> Tensor shape dtype
1382 | exp : Tensor shape F64 -> Tensor shape F64
1389 | floor : Tensor shape F64 -> Tensor shape F64
1396 | ceil : Tensor shape F64 -> Tensor shape F64
1402 | log : Tensor shape F64 -> Tensor shape F64
1407 | logistic : Tensor shape F64 -> Tensor shape F64
1408 | logistic = ewUnary Logistic
1412 | sin : Tensor shape F64 -> Tensor shape F64
1417 | cos : Tensor shape F64 -> Tensor shape F64
1422 | tan : Tensor shape F64 -> Tensor shape F64
1427 | asin : Tensor shape F64 -> Tensor shape F64
1432 | acos : Tensor shape F64 -> Tensor shape F64
1437 | atan : Tensor shape F64 -> Tensor shape F64
1442 | sinh : Tensor shape F64 -> Tensor shape F64
1447 | cosh : Tensor shape F64 -> Tensor shape F64
1452 | tanh : Tensor shape F64 -> Tensor shape F64
1457 | asinh : Tensor shape F64 -> Tensor shape F64
1462 | acosh : Tensor shape F64 -> Tensor shape F64
1467 | atanh : Tensor shape F64 -> Tensor shape F64
1472 | erf : Tensor shape F64 -> Tensor shape F64
1475 | erfInv : Tensor shape F64 -> Tensor shape F64
1476 | erfInv = ewUnary ErfInv
1481 | square : Tensor shape F64 -> Tensor shape F64
1482 | square = ewUnary Square
1487 | sqrt : Tensor shape F64 -> Tensor shape F64
1493 | min : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1494 | min (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise Min x x'
1498 | [Min] {shape : _} -> Primitive.Ord dtype => Semigroup (Tensor shape dtype) where
1503 | [Min] {shape : _} -> Primitive.Ord dtype => Monoid (Tensor shape dtype) using Semigroup.Min where
1504 | neutral = broadcast max
1509 | max : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1510 | max (MkTensor x) (MkTensor x') = MkTensor $
V 0 $
BinaryElementwise Max x x'
1514 | [Max] Primitive.Ord dtype => Semigroup (Tensor shape dtype) where
1519 | [Max] {shape : _} -> Primitive.Ord dtype => Monoid (Tensor shape dtype) using Semigroup.Max where
1520 | neutral = broadcast min
1533 | PrimitiveRW dtype ty =>
1537 | diag {n = 0} x@(MkTensor {shape = [0, 0]} _) = reshape x
1538 | diag {n = S n} x@(MkTensor {shape = [S n, S n]} _) = (x * identity) @@ fill 1
1542 | (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
1546 | argmxx cmp (MkTensor bound) x@(MkTensor {shape = _} _) = do
1547 | let MkTensor idxs : Tensor [S n] U64 = iota 0
1549 | MkTensor zero = tensor {dtype = U64} $
Scalar 0
1552 | Tensor [] dtype -> Tensor [] U64 ->
1553 | Tensor [] dtype -> Tensor [] U64 ->
1554 | Tag (Tensor [] dtype, Tensor [] U64)
1556 | let useNext = x == x && (cmp x' x || x' /= x')
1557 | in pure (select useNext x' x, select useNext y' y)
1563 | (MkTensor $
V 0 $
BoundSet addr)
1564 | (MkTensor $
V 1 $
BoundSet addr)
1565 | (MkTensor $
V 2 $
BoundSet addr)
1566 | (MkTensor $
V 3 $
BoundSet addr)
1567 | (env, (MkTensor m, MkTensor i)) = runState (emptyFrom !get) res
1568 | argTys = [TensorType [] dtype, TensorType [] U64]
1569 | f = MkFn addr (argTys ++ argTys) argTys [m, i] env
1572 | pure $
MkTensor $
V 1 $
Reduce f [bound, zero] [0] [x, idxs]
1578 | argmax : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $
Tensor [] U64
1579 | argmax = argmxx (>) min
1585 | argmin : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $
Tensor [] U64
1586 | argmin = argmxx (<) max
1590 | data Triangle = Upper | Lower
1609 | PrimitiveRW dtype ty =>
1613 | Tag $
Tensor [n, n] dtype
1614 | triangle tri (MkTensor x) = do
1615 | let range : Tensor [n * n] U64 = iota 0
1616 | indices <- tag $
reshape {to = [n, n], sizesEqual = productSquare n} range
1620 | pure $
select (op indices indices.T) (fill $
fromInteger 0) (MkTensor x)
1624 | productSquare : (m : Nat) -> product (the Shape [m * m]) = product (the Shape [m, m])
1628 | ~~ m * m ... plusZeroRightNeutral (m * m)
1629 | ~~ (m + 0) * m ..< cong (* m) (plusZeroRightNeutral m)
1636 | cholesky : Tensor [S n, S n] F64 -> Tag $
Tensor [S n, S n] F64
1637 | cholesky $
MkTensor x = triangle Lower (MkTensor $
V 0 $
Cholesky x)
1650 | (|\) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1651 | (MkTensor a) |\ (MkTensor b) = MkTensor $
V 0 $
TriangularSolve a b True
1661 | (\|) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1662 | (MkTensor a) \| (MkTensor b) = MkTensor $
V 0 $
TriangularSolve a b False
1673 | (|\) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1674 | a |\ b = let (MkTensor {shape = [_]} _) = b in squeeze (a |\ expand 1 b)
1684 | (\|) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1685 | a \| b = let (MkTensor {shape = [_]} _) = b in squeeze (a \| expand 1 b)
1690 | trace : (Primitive.Num dtype, Prelude.Num a) =>
1692 | Tensor [S n, S n] dtype ->
1695 | _ | MkTensor {shape = [_, _]} _ = reduce @{Sum} [0, 1] $
x * identity
1701 | Rand = StateT (Tensor [2] U64) Tag
1714 | rng : {shape : _} -> Rand $
Tensor shape U64
1715 | rng = ST $
\(MkTensor state) => do
1716 | res <- tag $
Rng state (TensorType shape U64)
1717 | pure (MkTensor $
V 0 res, MkTensor $
V 1 res)
1730 | uniform : {shape : _} -> Rand $
Tensor shape F64
1732 | let numMantissaBits = 52
1733 | scale = broadcast $
2.0 ^ tensor (Scalar $
- cast {to = Double} numMantissaBits)
1734 | shift = fill $
64 `minus` numMantissaBits
1735 | in rng {shape} <&> \x => castDtype (x >> shift) * scale
1749 | normal : {shape : _} -> Rand $
Tensor shape F64
1750 | normal = uniform <&> \x => sqrt (broadcast 2.0) * erfInv (broadcast 2.0 * x - broadcast 1.0)