0 | {--
   1 | Copyright (C) 2021  Joel Berkeley
   2 |
   3 | This program is free software: you can redistribute it and/or modify
   4 | it under the terms of the GNU Affero General Public License as published
   5 | by the Free Software Foundation, either version 3 of the License, or
   6 | (at your option) any later version.
   7 |
   8 | This program is distributed in the hope that it will be useful,
   9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
  10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  11 | GNU Affero General Public License for more details.
  12 |
  13 | You should have received a copy of the GNU Affero General Public License
  14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
  15 | --}
  16 | ||| Defines `Tensor`, an array of numbers or booleans, along with a number of functions operating on
  17 | ||| `Tensor`s. `Tensor` operations typically leverage hardware acceleration and graph compilation.
  18 | ||| spidr tracks tensor shape and data type in the types, so you can be sure that if your tensor
  19 | ||| code compiles, these are consistent.
  20 | |||
  21 | ||| spidr achieves efficient reuse of tensor computations with `Tag`. See the tutorial
  22 | ||| _Nuisances in the Tensor API_ for a discussion of pitfalls to avoid when using `Tag`.
  23 | module Tensor
  24 |
  25 | import public Control.Monad.State
  26 | import Control.Monad.Error.Either
  27 | import Syntax.PreorderReasoning
  28 |
  29 | import Compiler.Eval
  30 | import Compiler.IR
  31 | import Compiler.Xla.Shape
  32 | import Compiler.Xla.ShapeUtil
  33 | import Compiler.LiteralRW
  34 | import Device
  35 | import public DType
  36 | import public Literal
  37 | import public Types
  38 | import public Util
  39 |
  40 | 0 XlaShape : Type
  41 | XlaShape = Xla.Shape
  42 |
  43 | %hide Xla.Shape
  44 |
  45 | ||| A scalar or array. Construct a `Tensor` with function `tensor`.
  46 | export
  47 | data Tensor : Shape -> DType -> Type where
  48 |   MkTensor : Value -> {shape : _} -> {dtype : _} -> Tensor shape dtype
  49 |
  50 | ||| The effect of tagging nodes in a computational graph.
  51 | export
  52 | data TagT : (Type -> Type) -> Type -> Type where
  53 |   MkTagT : StateT Env m a -> TagT m a
  54 |
  55 | public export 0
  56 | Tag : Type -> Type
  57 | Tag = TagT Identity
  58 |
  59 | export
  60 | Functor m => Functor (TagT m) where
  61 |   map f (MkTagT x) = MkTagT (map f x)
  62 |
  63 | export
  64 | Monad m => Applicative (TagT m) where
  65 |   pure x = MkTagT (pure x)
  66 |   (MkTagT f) <*> (MkTagT x) = MkTagT (f <*> x)
  67 |
  68 | export
  69 | Monad m => Monad (TagT m) where
  70 |   (MkTagT x) >>= f = MkTagT $ x >>= (\y => let MkTagT z = f y in z)
  71 |
  72 | export
  73 | MonadTrans TagT where
  74 |   lift = MkTagT . lift
  75 |
  76 | public export
  77 | interface Taggable a where
  78 |   ||| Mark an expression to be efficiently reused. For example, in
  79 |   ||| ```
  80 |   ||| bad : Tensor [9999999] F64
  81 |   ||| bad = let x = fill {shape = [9999999]} 1.0 in x + x
  82 |   |||
  83 |   ||| good : Tag $ Tensor [9999999] F64
  84 |   ||| good = do x <- tag $ fill {shape = [9999999]} 1.0
  85 |   |||           pure (x + x)
  86 |   ||| ```
  87 |   ||| the large vector `x` is calculated twice in `bad`, but once in `good`, as `tag` marks it for
  88 |   ||| sharing.
  89 |   |||
  90 |   ||| Types that implement this interface should `tag` constituent components it deems worth sharing.
  91 |   ||| For example, see the implementation for tuples.
  92 |   |||
  93 |   ||| See tutorial _Nuisances in the Tensor API_ for details.
  94 |   tag : Monad m => a -> TagT m a
  95 |
  96 | Taggable Op where
  97 |   -- `BoundSet` case is an optimization. Note this will mean you cannot re-bind a value
  98 |   -- to an inner scope, but I can't see why that would be useful
  99 |   tag (BoundSet x) = pure (BoundSet x)
 100 |   tag x = MkTagT $ tagOp x
 101 |
 102 | export
 103 | Taggable (Tensor shape dtype) where
 104 |   tag (MkTensor $ V idx op) = map (\op => MkTensor $ V idx op) (tag op)
 105 |
 106 | export
 107 | (Taggable a, Taggable b) => Taggable (a, b) where
 108 |   tag (a, b) = [| (tag a, tag b) |]
 109 |
 110 | ||| Construct a `Tensor` from `Literal` data. For example
 111 | ||| ```
 112 | ||| x : Tensor [2, 3] S32
 113 | ||| x = tensor [[1, 2, 3],
 114 | |||             [4, 5, 6]]
 115 | ||| ```
 116 | export
 117 | tensor : {shape : _} -> {dtype : _} -> Literal shape (idrisType dtype) -> Tensor shape dtype
 118 | tensor lit = MkTensor $ V 0 $ Lit shape dtype lit
 119 |
 120 | namespace F64
 121 |   export
 122 |   fromDouble : Double -> Tensor [] F64
 123 |   fromDouble = tensor . Scalar
 124 |
 125 | try : Show e => EitherT e IO a -> IO a
 126 | try = eitherT (\e => assert_total $ idris_crash $ show e) pure
 127 |
 128 | %hide Literal.All2.All2
 129 |
 130 | namespace List
 131 |   namespace Tag
 132 |     ||| Evaluate a list of `Tensor`s as a list of `Literal`s. Tensors in the list can have different
 133 |     ||| shapes and element types. For example,
 134 |     ||| ```
 135 |     ||| main : Device -> IO ()
 136 |     ||| main device = do [x, y] <- eval device $ do let x = tensor {dtype = F64} [1.2, 3.4]
 137 |     |||                                             y <- reduce @{Sum} [0] x
 138 |     |||                                             pure [x, y]
 139 |     |||                  printLn x
 140 |     |||                  printLn y
 141 |     ||| ```
 142 |     ||| In contrast to `Tensor.eval` when called on multiple tensors, this function constructs and
 143 |     ||| compiles the graph just once.
 144 |     export covering
 145 |     eval :
 146 |       Device ->
 147 |       Tag (All2 Tensor shapes dtypes) ->
 148 |       IO (All2 Literal shapes (DType.idrisType <$> dtypes))
 149 |     eval device (MkTagT xs) =
 150 |       let (env, xs) = runState empty xs
 151 |        in try $ do
 152 |             xlaShapes <- buildShapes xs
 153 |             let (outputs ** eq= lengthC xs
 154 |                 main = MkFn 0 [] (resultTypes xs) (results xs) env
 155 |             lits <- execute device main {outputs} (rewrite eq in xlaShapes)
 156 |             readAll xs $ rewrite sym eq in lits
 157 |
 158 |       where
 159 |
 160 |       lengthC : All2 Tensor s t -> (n ** n === length s)
 161 |       lengthC [] = (0 ** Refl)
 162 |       lengthC (_ :: xs) = let (n ** eq= lengthC xs in (S n ** cong S eq)
 163 |
 164 |       buildShapes : HasIO io => All2 Tensor s t -> io $ Vect (length s) XlaShape
 165 |       buildShapes [] = pure []
 166 |       buildShapes (MkTensor {shape, dtype} _ :: ts) = [| mkShape shape dtype :: buildShapes ts |]
 167 |
 168 |       results : All2 Tensor s t -> Vect (length s) Value
 169 |       results [] = []
 170 |       results (MkTensor x :: xs) = x :: results xs
 171 |
 172 |       resultTypes : All2 Tensor s t -> Vect (length s) ValueType
 173 |       resultTypes [] = []
 174 |       resultTypes (MkTensor {shape, dtype} _ :: xs) = TensorType shape dtype :: resultTypes xs
 175 |
 176 |       readAll :
 177 |         HasIO io =>
 178 |         All2 Tensor s t ->
 179 |         Vect (length s) Literal ->
 180 |         io $ All2 Literal s (DType.idrisType <$> t)
 181 |       readAll [] _ = pure []
 182 |       readAll (MkTensor {dtype} _ :: ts) (l :: ls) = [| read dtype l :: readAll ts ls |]
 183 |
 184 |   ||| A convenience wrapper for `List.Tag.eval`, for use with a bare list of `Tensor`s.
 185 |   export covering
 186 |   eval :
 187 |     Device ->
 188 |     All2 Tensor shapes dtypes ->
 189 |     IO (All2 Literal shapes (DType.idrisType <$> dtypes))
 190 |   eval device xs = eval device (pure xs)
 191 |
 192 | namespace Tag
 193 |   ||| Evaluate a `Tensor`, returning its value as a `Literal`. This function builds and executes the
 194 |   ||| computational graph.
 195 |   |||
 196 |   ||| **Note:** Each call to `eval` will rebuild and execute the graph; multiple calls to `eval` on
 197 |   ||| different tensors, even if they are in the same computation, will be treated independently.
 198 |   ||| To efficiently evaluate multiple tensors at once, use `List.Tag.eval`.
 199 |   export covering
 200 |   eval : Device -> Tag (Tensor shape dtype) -> IO (Literal shape (DType.idrisType dtype))
 201 |   eval device x = map (\[z] => z) $ List.Tag.eval device $ map (\z => [z]) x
 202 |
 203 | ||| A convenience wrapper for `Tag.eval`, for use with a bare `Tensor`.
 204 | export covering
 205 | eval : Device -> Tensor shape dtype -> IO (Literal shape (DType.idrisType dtype))
 206 | eval device x = eval device (pure x)
 207 |
 208 | ||| A string representation of a tensor graph.
 209 | |||
 210 | ||| There are no guarantees whatsoever as to the string structure and contents.
 211 | export
 212 | Show (Tag $ Tensor shape dtype) where
 213 |   show (MkTagT x) =
 214 |     let (env, MkTensor x) = runState empty x
 215 |      in show (MkFn 0 [] [TensorType shape dtype] [x] env)
 216 |
 217 | export
 218 | Show (Tensor shape dtype) where show = show . pure {f = Tag}
 219 |
 220 | ||| Positive infinity.
 221 | export
 222 | inf : Tensor [] F64
 223 |
 224 | ||| NaN (not a number).
 225 | export
 226 | nan : Tensor [] F64
 227 |
 228 | namespace Bound
 229 |   ||| Compares less than or equal to any other value (except NaN).
 230 |   export
 231 |   min : Ord dtype => Tensor [] dtype
 232 |   min @{OrdS32} = MkTensor $ V 0 $ MinValue S32
 233 |   min @{OrdS64} = MkTensor $ V 0 $ MinValue S64
 234 |   min @{OrdU32} = MkTensor $ V 0 $ MinValue U32
 235 |   min @{OrdU64} = MkTensor $ V 0 $ MinValue U64
 236 |   min @{OrdF64} = tensor $ Scalar $ -1.0 / 0.0
 237 |
 238 |   ||| Compares greater than or equal to any other value (except NaN).
 239 |   export
 240 |   max : Ord dtype => Tensor [] dtype
 241 |   max @{OrdS32} = MkTensor $ V 0 $ MaxValue S32
 242 |   max @{OrdS64} = MkTensor $ V 0 $ MaxValue S64
 243 |   max @{OrdU32} = MkTensor $ V 0 $ MaxValue U32
 244 |   max @{OrdU64} = MkTensor $ V 0 $ MaxValue U64
 245 |   max @{OrdF64} = tensor $ Scalar $ 1.0 / 0.0
 246 |
 247 | ||| The most negative possible finite float, approx. -1.8e308
 248 | export
 249 | minFinite : Tensor [] F64
 250 | minFinite = MkTensor $ V 0 $ MinFiniteFloat
 251 |
 252 | ||| The most positive possible finite float, approx. 1.8e308
 253 | export
 254 | maxFinite : Tensor [] F64
 255 | maxFinite = MkTensor $ V 0 $ MaxFiniteFloat
 256 |
 257 | ||| Cast the element type. For example, `castDtype (tensor {dtype = S32} [1, -2])` is
 258 | ||| `tensor {dtype = F64} [1.0, -2.0]`.
 259 | export
 260 | castDtype : Integral dtype => Tensor shape dtype -> Tensor shape F64
 261 | castDtype $ MkTensor {shape} x = MkTensor $ V 0 $ Convert F64 shape x
 262 |
 263 | fn0 : {a : _} -> Tag (Tensor s a) -> Tag $ Fn 0
 264 | fn0 f = MkTagT $ do
 265 |   addr <- reserve
 266 |
 267 |   let MkTagT res = f
 268 |       (env, MkTensor res) = runState (emptyFrom !get) res
 269 |       f = MkFn addr [] [TensorType s a] [res] env
 270 |
 271 |   updateCounterFrom env
 272 |   pure f
 273 |
 274 | fn1 : {a : _} -> {s : _} -> (Tensor s a -> Tag $ Tensor s' a') -> Tag $ Fn 1
 275 | fn1 f = MkTagT $ do
 276 |   addr <- reserve
 277 |
 278 |   let MkTagT res = f (MkTensor $ V 0 $ BoundSet addr)
 279 |       (env, MkTensor res) = runState (emptyFrom !get) res
 280 |       f = MkFn addr [TensorType s a] [TensorType s' a'] [res] env
 281 |
 282 |   updateCounterFrom env
 283 |   pure f
 284 |
 285 | fn2 :
 286 |   {a, a' : _} -> {s, s' : _} ->
 287 |   (Tensor s a -> Tensor s' a' -> Tag $ Tensor s'' a'') ->
 288 |   Tag $ Fn 2
 289 | fn2 f = MkTagT $ do
 290 |   addr <- reserve
 291 |
 292 |   let MkTagT res = f (MkTensor $ V 0 $ BoundSet addr) (MkTensor $ V 1 $ BoundSet addr)
 293 |       (env, MkTensor res) = runState (emptyFrom !get) res
 294 |       f = MkFn addr [TensorType s a, TensorType s' a'] [TensorType s'' a''] [res] env
 295 |
 296 |   updateCounterFrom env
 297 |   pure f
 298 |
 299 | fn22 :
 300 |   {a, a', a'' : _} -> {s, s' : _} ->
 301 |   (Tensor s a -> Tensor s' a' -> Tag $ (Tensor s'' a'', Tensor s''' a''')) ->
 302 |   Tag $ Fn 2
 303 | fn22 f = MkTagT $ do
 304 |   addr <- reserve
 305 |
 306 |   let MkTagT res = f (MkTensor $ V 0 $ BoundSet addr) (MkTensor $ V 1 $ BoundSet addr)
 307 |       (env, (MkTensor res0, MkTensor res1)) = runState (emptyFrom !get) res
 308 |       resTys = [TensorType s'' a'', TensorType s''' a''']
 309 |       f = MkFn addr [TensorType s a, TensorType s' a'] resTys [res0, res1] env
 310 |
 311 |   updateCounterFrom env
 312 |   pure f
 313 |
 314 | ||| Reverse-mode automatic differentiation.
 315 | |||
 316 | ||| `grad` can be applied repeatedly to obtain higher derivatives, though we do not yet support
 317 | ||| derivatives of vector-valued functions.
 318 | |||
 319 | ||| For example, for
 320 | ||| ```
 321 | ||| f : Tensor [2] F64 -> Tag $ Tensor [] F64
 322 | ||| f x = do
 323 | |||   x <- tag x
 324 | |||   let (x0, x1) = (slice [at 0] x, slice [at 1] x)
 325 | |||   pure $ x0 / x1
 326 | ||| ```
 327 | ||| `grad f (tensor [3.0, 2.0])` produces `tensor [0.5, -0.75]`.
 328 | |||
 329 | ||| **Warning:** `grad` is experimental, and only implemented for a subset of the tensor API.
 330 | export partial
 331 | grad : (Tensor shape F64 -> Tag $ Tensor [] F64) -> Tensor shape F64 -> Tag $ Tensor shape F64
 332 | grad f (MkTensor x) = pure $ MkTensor $ V 0 $ Grad shape !(fn1 f) x
 333 |
 334 | ||| Reshape a `Tensor`. For example, `reshape {to = [2, 1]} (tensor [3, 4])` is
 335 | ||| `tensor [[3], [4]]`. The output can have a different rank to the input.
 336 | export
 337 | reshape :
 338 |   {to : _} ->
 339 |   {auto 0 sizesEqual : product from = product to} ->
 340 |   Tensor from dtype ->
 341 |   Tensor to dtype
 342 | reshape $ MkTensor {shape} x = MkTensor $ V 0 $ Reshape dtype to x
 343 |
 344 | ||| Add a dimension of length one at the specified `axis`. The new dimension will be at the
 345 | ||| specified `axis` in the new `Tensor` (as opposed to the original `Tensor`). For example,
 346 | ||| `expand 1 $ tensor [[1, 2], [3, 4], [5, 6]]` is `tensor [[[1, 2]], [[3, 4]], [[5, 6]]]`.
 347 | export
 348 | expand :
 349 |   (axis : Nat) ->
 350 |   {auto 0 inBounds : axis `LTE` length shape} ->
 351 |   Tensor shape dtype ->
 352 |   Tensor (insertAt axis 1 shape) dtype
 353 | expand axis $ MkTensor {shape = _} x = MkTensor $ V 0 $ Reshape dtype (insertAt axis 1 shape) x
 354 |
 355 | namespace Squeezable
 356 |   ||| A `Squeezable from to` constitutes proof that the shape `from` can be squeezed to the
 357 |   ||| shape `to`. Squeezing is the process of removing any number of dimensions of length one.
 358 |   public export
 359 |   data Squeezable : (0 from : Shape) -> (0 to : Shape) -> Type where
 360 |     ||| Proof that a shape can be squeezed to itself. For example:
 361 |     |||
 362 |     ||| [] to []
 363 |     ||| [3, 4] to [3, 4]
 364 |     Same : Squeezable x x
 365 |
 366 |     ||| Proof that any dimensions (including those of length 1) can be preserved in the process of
 367 |     ||| squeezing. For example:
 368 |     |||
 369 |     ||| ...
 370 |     Match : Squeezable from to -> Squeezable (x :: from) (x :: to)
 371 |
 372 |     ||| Proof that any dimensions of length one can be squeezed out. For example:
 373 |     |||
 374 |     ||| [1, 3, 1, 1, 4] to [3, 4]
 375 |     Nest : Squeezable from to -> Squeezable (1 :: from) to
 376 |
 377 | ||| Remove dimensions of length one from a `Tensor` such that it has the desired shape. For example:
 378 | |||
 379 | ||| ```
 380 | ||| x : Tensor [2, 1, 3, 1] S32
 381 | ||| x = tensor [[[[4], [5], [6]]],
 382 | |||             [[[7], [8], [9]]]]
 383 | |||
 384 | ||| y : Tensor [2, 1, 3] S32
 385 | ||| y = squeeze x
 386 | ||| ```
 387 | ||| is
 388 | ||| ```
 389 | ||| y : Tensor [2, 1, 3] S32
 390 | ||| y = tensor [[[4, 5, 6]],
 391 | |||             [[7, 8, 9]]]
 392 | ||| ```
 393 | export
 394 | squeeze :
 395 |   {to : _} ->
 396 |   {auto 0 shapesSqueezable : Squeezable from to} ->
 397 |   Tensor from dtype ->
 398 |   Tensor to dtype
 399 | squeeze $ MkTensor {shape} x = MkTensor $ V 0 $ Reshape dtype to x
 400 |
 401 | ||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
 402 | ||| details.
 403 | export
 404 | data SliceOrIndex : Nat -> Type where
 405 |   Slice :
 406 |     (from, to : Nat) ->
 407 |     {size : _} ->
 408 |     {auto 0 fromTo : from + size = to} ->
 409 |     {auto 0 inDim : LTE to d} ->
 410 |     SliceOrIndex d
 411 |   Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
 412 |   DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
 413 |   DynamicIndex : Tensor [] U64 -> SliceOrIndex d
 414 |
 415 | ||| Index at `idx`. See `slice` for details.
 416 | public export
 417 | at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
 418 | at = Index
 419 |
 420 | namespace Dynamic
 421 |   ||| Index at the specified index. See `slice` for details.
 422 |   public export
 423 |   at : Tensor [] U64 -> SliceOrIndex d
 424 |   at = DynamicIndex
 425 |
 426 | ||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
 427 | public export
 428 | (.to) :
 429 |   (from, to : Nat) ->
 430 |   {size : _} ->
 431 |   {auto 0 fromTo : from + size = to} ->
 432 |   {auto 0 inDim : LTE to d} ->
 433 |   SliceOrIndex d
 434 | (.to) = Slice
 435 |
 436 | ||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
 437 | public export
 438 | (.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
 439 | (.size) = DynamicSlice
 440 |
 441 | ||| Slice across all indices along an axis. See `slice` for details.
 442 | public export
 443 | all : {d : _} -> SliceOrIndex d
 444 | all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
 445 |
 446 | ||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
 447 | ||| See `slice` for details.
 448 | public export
 449 | data MultiSlice : Shape -> Type where
 450 |   Nil : MultiSlice ds
 451 |   (::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
 452 |
 453 | namespace MultiSlice
 454 |   ||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
 455 |   ||| `Tensor.slice` for details.
 456 |   public export
 457 |   slice : {shape : _} -> MultiSlice shape -> Shape
 458 |   slice {shape} [] = shape
 459 |   slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
 460 |   slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
 461 |   slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
 462 |   slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
 463 |
 464 | ||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
 465 | ||| either static (`Nat`) or dynamic (scalar `U64`) indices.
 466 | |||
 467 | ||| **Static indices**
 468 | |||
 469 | ||| Static indices are `Nat`s. For example, for
 470 | ||| ```
 471 | ||| x : Tensor [5, 6] S32
 472 | ||| x = tensor [[ 0,  1,  2,  3,  4,  5],
 473 | |||             [ 6,  7,  8,  9, 10, 11],
 474 | |||             [12, 13, 14, 15, 16, 17],
 475 | |||             [18, 19, 20, 21, 22, 23],
 476 | |||             [24, 25, 26, 27, 28, 29]]
 477 | ||| ```
 478 | ||| we can index as `slice [at 1] x` to get
 479 | ||| ```
 480 | ||| x : Tensor [6] S32
 481 | ||| x = tensor [6, 7, 8, 9, 10, 11]
 482 | ||| ```
 483 | ||| or we can slice as `slice [2.to 4] x` to get
 484 | ||| ```
 485 | ||| x : Tensor [2, 6] S32
 486 | ||| x = tensor [[12, 13, 14, 15, 16, 17],
 487 | |||             [18, 19, 20, 21, 22, 23]]
 488 | ||| ```
 489 | ||| Note that in `2.to 4`, the 2 is inclusive, and the 4 exclusive, so we return indices 2 and 3.
 490 | |||
 491 | ||| **Dynamic indices**
 492 | |||
 493 | ||| Dynamic indices are scalar `U64` values, and the API works slightly differently because we
 494 | ||| can't know the value of dynamic indices until the graph is executed. For indexing, with scalar
 495 | ||| `U64` index `i` in `slice [at i] x`, `i` is clamped to be a valid index into that dimension.
 496 | ||| For example, for `i = tensor 1`, `slice [at i] x` is
 497 | ||| ```
 498 | ||| x : Tensor [6] S32
 499 | ||| x = tensor [6, 7, 8, 9, 10, 11]
 500 | ||| ```
 501 | ||| as in the static case. However, for `i = tensor 10`, `slice [at i] x` returns the last row
 502 | ||| ```
 503 | ||| x : Tensor [6] S32
 504 | ||| x = tensor [24, 25, 26, 27, 28, 29]
 505 | ||| ```
 506 | ||| We can also slice by specifying a scalar `U64` start index, and a static size, as
 507 | ||| `slice [i.size 2] x` with `i = tensor 2` to get
 508 | ||| ```
 509 | ||| x : Tensor [2, 6] S32
 510 | ||| x = tensor [[12, 13, 14, 15, 16, 17],
 511 | |||             [18, 19, 20, 21, 22, 23]]
 512 | ||| ```
 513 | ||| For a given slice `size`, the dynamic start index is clamped such that we always get `size`
 514 | ||| elements along that axis. For example, `slice [i.size 2] x` with `i = tensor 4` is
 515 | ||| ```
 516 | ||| x : Tensor [2, 6] S32
 517 | ||| x = tensor [[18, 19, 20, 21, 22, 23],
 518 | |||             [24, 25, 26, 27, 28, 29]]
 519 | ||| ```
 520 | ||| which starts at index 3 rather than index 4.
 521 | |||
 522 | ||| **Mixed static, dynamic, slicing and indexing**
 523 | |||
 524 | ||| Each axis can only be sliced or indexed, and must use only static or dynamic indices. However,
 525 | ||| across axes, we can mix these four arbitrarily. For example, with `slice [2.to 4, at 1] x` to
 526 | ||| get
 527 | ||| ```
 528 | ||| x : Tensor [2] S32
 529 | ||| x = tensor [13, 19]
 530 | ||| ```
 531 | ||| or with `i = tensor 2` in `slice [at 1, i.size 2] x` to get
 532 | ||| ```
 533 | ||| x : Tensor [2] S32
 534 | ||| x = tensor [7, 8]
 535 | ||| ```
 536 | |||
 537 | ||| Slices and indices apply to the leading axes of the tensor. For trailing axes omitted from the
 538 | ||| multi-dimensional slice, the whole of the axis is returned. If we want to slice or index over
 539 | ||| later axes and retain all indices in a leading axis, we can use the convenience function `all`,
 540 | ||| as `slice [all, at 3] x` to get
 541 | ||| ```
 542 | ||| x : Tensor [5] S32
 543 | ||| x = tensor [[3], [9], [15], [21], [27]]
 544 | ||| ```
 545 | ||| This is exactly the same as the more manual `slice [0.to 5, at 3] x` and
 546 | ||| `slice [(tensor 0).size 5, at 3] x`.
 547 | |||
 548 | ||| @at The multi-dimensional slices and indices at which to slice the tensor.
 549 | export
 550 | slice : (at : MultiSlice shape) -> Tensor shape dtype -> Tensor (slice at) dtype
 551 | slice at $ MkTensor x = MkTensor $ V 0 $
 552 |   let x = V 0 $ Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x
 553 |       -- we shortcut DynamicSlice to allow autodiff for static slicing
 554 |       x = if isDynamic at then V 0 $ DynamicSlice (dynStarts [] at) (mapd size id at) x else x
 555 |    in Reshape dtype (MultiSlice.slice at) x
 556 |
 557 |       where
 558 |       mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
 559 |              (Nat -> a) ->
 560 |              {shape : Shape} ->
 561 |              MultiSlice shape ->
 562 |              List a
 563 |       mapd _ dflt {shape} [] = Prelude.map dflt shape
 564 |       mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
 565 |
 566 |       start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
 567 |       start _ (Slice from _) = from
 568 |       start _ (Index idx) = idx
 569 |       start f {d} _ = f d
 570 |
 571 |       stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
 572 |       stop _ (Slice _ to) = to
 573 |       stop _ (Index idx) = S idx
 574 |       stop f {d} _ = f d
 575 |
 576 |       size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
 577 |       size _ (Slice {size = size'} _ _) = size'
 578 |       size _ (Index _) = 1
 579 |       size _ (DynamicSlice _ size') = size'
 580 |       size _ (DynamicIndex _) = 1
 581 |
 582 |       zero : Value
 583 |       zero = V 0 $ Lit [] U64 $ Scalar 0
 584 |
 585 |       isDynamic : {shape : _} -> MultiSlice shape -> Bool
 586 |       isDynamic [] = False
 587 |       isDynamic {shape = (_ :: _)} (DynamicSlice _ _ :: _) = True
 588 |       isDynamic {shape = (_ :: _)} (DynamicIndex _ :: _) = True
 589 |       isDynamic (_ :: ds) = isDynamic ds
 590 |
 591 |       dynStarts : List Value -> {shape : _} -> MultiSlice shape -> List Value
 592 |       dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
 593 |       dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
 594 |       dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
 595 |       dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
 596 |
 597 | ||| Concatenate two `Tensor`s along the specified `axis`. For example,
 598 | ||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and
 599 | ||| `concat 1 (tensor [[3], [6]]) (tensor [[4, 5], [7, 8]])` are both
 600 | ||| `tensor [[1, 2], [3, 4], [5, 6]]`.
 601 | export
 602 | concat :
 603 |   (axis : Nat) ->
 604 |   Tensor s dtype ->
 605 |   Tensor s' dtype ->
 606 |   {auto 0 inBounds : (InBounds axis s, InBounds axis s')} ->
 607 |   {auto 0 shapesConcatenable : deleteAt axis s = deleteAt axis s'} ->
 608 |   Tensor (replaceAt axis (index axis s + index axis s') s) dtype
 609 | concat axis (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ Concat axis [x, x']
 610 |
 611 | ||| Transpose a matrix. For example, `(tensor [[1, 2], [3, 4]]).T` is `tensor [[1, 3], [2, 4]]`.
 612 | export
 613 | (.T) : Tensor [m, n] dtype -> Tensor [n, m] dtype
 614 | (MkTensor x).T = MkTensor $ V 0 $ Transpose [1, 0] x
 615 |
 616 | ||| Transpose axes of a tensor. This is a more general version of `(.T)`, in which you can
 617 | ||| transpose any number of axes in a tensor of arbitrary rank. The i'th axis in the resulting
 618 | ||| tensor corresponds to the `index i ordering`'th axis in the input tensor. For example, for
 619 | ||| ```
 620 | ||| x : Tensor [2, 3, 4] S32
 621 | ||| x = tensor [[[ 0,  1,  2,  3],
 622 | |||              [ 4,  5,  6,  7],
 623 | |||              [ 8,  9, 10, 11]],
 624 | |||             [[12, 13, 14, 15],
 625 | |||              [16, 17, 18, 19],
 626 | |||              [20, 21, 22, 23]]]
 627 | ||| ```
 628 | ||| `transpose [0, 2, 1] x` is
 629 | ||| ```
 630 | ||| x : Tensor [2, 4, 3] S32
 631 | ||| x = tensor [[[ 0,  4,  8],
 632 | |||              [ 1,  5,  9],
 633 | |||              [ 2,  6, 10],
 634 | |||              [ 3,  7, 11]],
 635 | |||             [[12, 16, 20],
 636 | |||              [13, 17, 21],
 637 | |||              [14, 18, 22],
 638 | |||              [15, 19, 23]]]
 639 | ||| ```
 640 | ||| `transpose [2, 0, 1] x` is
 641 | ||| ```
 642 | ||| x : Tensor [4, 2, 3] S32
 643 | ||| x = tensor [[[ 0,  4,  8],
 644 | |||              [12, 16, 20]],
 645 | |||             [[ 1,  5,  9],
 646 | |||              [13, 17, 21]],
 647 | |||             [[ 2,  6, 10],
 648 | |||              [14, 18, 22]],
 649 | |||             [[ 3,  7, 11],
 650 | |||              [15, 19, 23]]]
 651 | ||| ```
 652 | |||
 653 | ||| In order to see what effect transposing a tensor has, it can help to bear in mind the following:
 654 | ||| * if an element can be found with `slice [at 3, at 4, at 5] x` in the original tensor,
 655 | |||   that same element can instead be found with `slice [at 5, at 3, at 4]` given a
 656 | |||   `transpose [2, 0, 1]`. That is, transposing axes re-orders indices when indexing.
 657 | ||| * with `transpose [2, 0, 1]`, traversing the first axis in the result is equivalent to
 658 | |||   traversing the last axis in the input. Similarly, traversing the last axis in the result is
 659 | |||   equivalent to traversing the second axis in the input.
 660 | export
 661 | transpose :
 662 |   (ordering : List Nat) ->
 663 |   Tensor shape dtype ->
 664 |   {auto 0 lengths : length ordering = length shape} ->
 665 |   {auto 0 axesUnique : unique ordering = True} ->
 666 |   {auto 0 inBounds : All (flip InBounds shape) ordering} ->
 667 |   Tensor (multiIndex ordering shape) dtype
 668 | transpose ordering $ MkTensor x = MkTensor $ V 0 $ Transpose ordering x
 669 |
 670 | ||| A `DimBroadcastable from to` proves that a dimension of size `from` can be broadcast to a
 671 | ||| dimension of size `to`.
 672 | public export
 673 | data DimBroadcastable : (0 from : Nat) -> (0 to : Nat) -> Type where
 674 |   ||| Proof that any dimension can be broadcast to itself. For example in shapes `[2, 3]` to
 675 |   ||| `[2, 3]`.
 676 |   Same : DimBroadcastable x x
 677 |
 678 |   ||| Proof that a dimension of length one can be broadcast to any size. For example in shapes
 679 |   ||| `[2, 1]` to `[2, 3]`
 680 |   Stack : DimBroadcastable 1 _
 681 |
 682 |   ||| Proof that any dimension can be broadcast to zero. For example in shapes `[2, 3]` to `[2, 0]`.
 683 |   Zero : DimBroadcastable _ 0
 684 |
 685 | namespace Broadcastable
 686 |   ||| A `Broadcastable from to` constitutes proof that the shape `from` can be broadcast to the
 687 |   ||| shape `to`.
 688 |   public export
 689 |   data Broadcastable : (0 from : Shape) -> (0 to : Shape) -> Type where
 690 |     ||| Proof that a shape can be broadcast to itself. For example:
 691 |     |||
 692 |     ||| [] to []
 693 |     ||| [3, 4] to [3, 4]
 694 |     |||
 695 |     ||| Implementation note: we could have used `Broadcast [] []`, which would have resulted in more
 696 |     ||| atomic constructors for `Broadcastable`, but the author guesses that this implementation helps
 697 |     ||| the type checker avoid applications of `Match`.
 698 |     Same : Broadcastable x x
 699 |
 700 |     ||| Proof that a dimension of size `f` can be broadcast to size `t` if these dimensions
 701 |     ||| are `DimBroadcastable f t`. For example:
 702 |     |||
 703 |     ||| [2, 3] to [2, 3]
 704 |     ||| [2, 1] to [2, 3]
 705 |     ||| [2, 1] to [2, 0]
 706 |     Match : forall from, to .
 707 |             {auto 0 ranksEq : length from = length to} ->
 708 |             {auto 0 dimBroadcastable : DimBroadcastable f t} ->
 709 |             Broadcastable from to ->
 710 |             Broadcastable (f :: from) (t :: to)
 711 |
 712 |     ||| Proof that broadcasting can add outer dimensions i.e. nesting. For example:
 713 |     |||
 714 |     ||| [3] to [1, 3]
 715 |     ||| [3] to [5, 3]
 716 |     Nest : Broadcastable f t -> Broadcastable f (_ :: t)
 717 |
 718 | ||| A shape can be extended with any number of leading dimensions.
 719 | |||
 720 | ||| @leading The leading dimensions.
 721 | export
 722 | broadcastableByLeading : (leading : List Nat) -> Broadcastable shape (leading ++ shape)
 723 | broadcastableByLeading [] = Same
 724 | broadcastableByLeading (l :: ls) = Nest (broadcastableByLeading ls)
 725 |
 726 | ||| A scalar can be broadcast to any shape.
 727 | %hint
 728 | export
 729 | scalarToAnyOk : (to : Shape) -> Broadcastable [] to
 730 | scalarToAnyOk to = rewrite sym $ appendNilRightNeutral to in broadcastableByLeading to
 731 |
 732 | ||| Broadcast a `Tensor` to a new compatible shape. For example,
 733 | ||| ```
 734 | ||| x : Tensor [2, 3] S32
 735 | ||| x = broadcast (tensor [4, 5, 6])
 736 | ||| ```
 737 | ||| is
 738 | ||| ```
 739 | ||| x : Tensor [2, 3] S32
 740 | ||| x = tensor [[4, 5, 6], [4, 5, 6]]
 741 | ||| ```
 742 | export
 743 | broadcast :
 744 |   {to : _} -> {dtype : _} ->
 745 |   {auto shapesOK : Broadcastable from to} ->
 746 |   Tensor from dtype ->
 747 |   Tensor to dtype
 748 | broadcast $ MkTensor {shape = _} x = MkTensor $ V 0 $ Broadcast dtype from to x
 749 |
 750 | ||| A `Tensor` where every element has the specified value. For example,
 751 | ||| ```
 752 | ||| fives : Tensor [2, 3] S32
 753 | ||| fives = fill 5
 754 | ||| ```
 755 | ||| is
 756 | ||| ```
 757 | ||| fives : Tensor [2, 3] S32
 758 | ||| fives = tensor [[5, 5, 5],
 759 | |||                 [5, 5, 5]]
 760 | ||| ```
 761 | export
 762 | fill : {shape : _} -> {dtype : _} -> idrisType dtype -> Tensor shape dtype
 763 | fill x = broadcast {shapesOK = scalarToAnyOk shape} (tensor (Scalar x))
 764 |
 765 | ||| A constant where values increment from zero along the specified `axis`. For example,
 766 | ||| ```
 767 | ||| x : Tensor [3, 5] S32
 768 | ||| x = iota 1
 769 | ||| ```
 770 | ||| is the same as
 771 | ||| ```
 772 | ||| x : Tensor [3, 5] S32
 773 | ||| x = tensor [[0, 1, 2, 3, 4],
 774 | |||             [0, 1, 2, 3, 4],
 775 | |||             [0, 1, 2, 3, 4]]
 776 | ||| ```
 777 | ||| and
 778 | ||| ```
 779 | ||| x : Tensor [3, 5] S32
 780 | ||| x = iota 0
 781 | ||| ```
 782 | ||| is the same as
 783 | ||| ```
 784 | ||| x : Tensor [3, 5] S32
 785 | ||| x = tensor [[0, 0, 0, 0, 0],
 786 | |||             [1, 1, 1, 1, 1],
 787 | |||             [2, 2, 2, 2, 2]]
 788 | ||| ```
 789 | export
 790 | iota :
 791 |   {shape : _} ->
 792 |   {dtype : _} ->
 793 |   {auto 0 _ : Num dtype} ->
 794 |   (axis : Nat) ->
 795 |   {auto 0 inBounds : InBounds axis shape} ->
 796 |   Tensor shape dtype
 797 | iota dimension = MkTensor $ V 0 $ Iota shape dtype dimension
 798 |
 799 | ||| A while-loop operating on a single tensor.
 800 | |||
 801 | ||| `while1` iteratively checks if the tensor satisfies `condition`, and if it does, updates it with
 802 | ||| `body`.
 803 | |||
 804 | ||| **Note:** For the XLA compiler, there are scoping restrictions on functions passed to `while1`.
 805 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 806 | |||
 807 | ||| @condition The guard condition for each iteration.
 808 | ||| @body The update step.
 809 | ||| @initial The initial tensor.
 810 | export covering
 811 | while1 :
 812 |   (condition : Tensor shape dtype -> Tag $ Tensor [] PRED) ->
 813 |   (body : Tensor shape dtype -> Tag $ Tensor shape dtype) ->
 814 |   (initial : Tensor shape dtype) ->
 815 |   Tag $ Tensor shape dtype
 816 | while1 condition body (MkTensor i0) =
 817 |   pure $ MkTensor $ V 0 $ While !(fn1 condition) !(fn1 body) [i0]
 818 |
 819 | ||| A while-loop operating on two tensors.
 820 | |||
 821 | ||| `while2` iteratively checks if the tensors together satisfy `condition`, and if so, updates them
 822 | ||| with `body`.
 823 | |||
 824 | ||| **Note:** For the XLA compiler, there are scoping restrictions on functions passed to `while2`.
 825 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 826 | |||
 827 | ||| @condition The guard condition for each iteration.
 828 | ||| @body The update step.
 829 | ||| @initial One initial tensor.
 830 | ||| @initial' The other initial tensor.
 831 | export covering
 832 | while2 :
 833 |   (condition : Tensor s a -> Tensor s' a' -> Tag $ Tensor [] PRED) ->
 834 |   (body : Tensor s a -> Tensor s' a' -> Tag (Tensor s a, Tensor s' a')) ->
 835 |   (initial : Tensor s a) -> (initial' : Tensor s' a') ->
 836 |   Tag (Tensor s a, Tensor s' a')
 837 | while2 condition body (MkTensor i) (MkTensor i') = do
 838 |   res <- tag $ While !(fn2 condition) !(fn22 body) [i, i']
 839 |   pure (MkTensor $ V 0 res, MkTensor $ V 1 res)
 840 |
 841 | ||| Lift a unary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
 842 | ||| For example,
 843 | ||| ```
 844 | ||| recip : Tensor [] F64 -> Tag $ Tensor [] F64
 845 | ||| recip x = pure $ 1.0 / x
 846 | ||| ```
 847 | ||| can be lifted to an element-wise reciprocal function as `map recip (tensor [-2, 0.4])`,
 848 | ||| which produces `tensor [-0.5, 2.5]`.
 849 | |||
 850 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the function passed to `map`.
 851 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 852 | export
 853 | map : {b : _} -> (Tensor [] a -> Tag $ Tensor [] b) -> Tensor shape a -> Tag $ Tensor shape b
 854 | map f $ MkTensor {shape = _} x =
 855 |   pure $ MkTensor $ V 0 $ Map !(fn1 f) [x] (TensorType shape b) (range $ length shape)
 856 |
 857 | ||| Lift a binary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
 858 | ||| For example,
 859 | ||| ```
 860 | ||| addRecip : Tensor [] F64 -> Tensor [] F64 -> Tag $ Tensor [] F64
 861 | ||| addRecip x y = pure $ x + 1.0 / y
 862 | ||| ```
 863 | ||| can be lifted to an element-wise function as
 864 | ||| `map2 addRecip (tensor [3.0, -3.0]) (tensor [-2.0, 0.4])`, which produces `tensor [2.5, -0.5]`.
 865 | |||
 866 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the function passed to `map2`.
 867 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 868 | export
 869 | map2 :
 870 |   {c : _} ->
 871 |   (Tensor [] a -> Tensor [] b -> Tag $ Tensor [] c) ->
 872 |   Tensor shape a -> Tensor shape b -> Tag $ Tensor shape c
 873 | map2 f (MkTensor {shape = _} x) (MkTensor x') =
 874 |   pure $ MkTensor $ V 0 $ Map !(fn2 f) [x, x'] (TensorType shape c) (range $ length shape)
 875 |
 876 | ||| Reduce elements along one `axis` of a `Tensor` according to a specified `reducer` `Monoid`.
 877 | ||| For example, if `x = tensor [[0, 1, 2], [3, 4, 5]]`, then reduce @{Sum} 0 x` produces
 878 | ||| `tensor [3, 5, 7]`, and `reduce @{Sum} 1 x` produces `tensor [3, 12]`.
 879 | |||
 880 | ||| **Note:** `Semigroup` doesn't use `Tag`, which limits the functions that can be used in
 881 | ||| `reduce`. However, the most commonly used semigroups don't need `Tag`, including `Sum`,
 882 | ||| `Prod`, `Min` and `Max`, so for ergonomics, we have opted to use `Monoid` as is. We can
 883 | ||| provide an overloaded variant if requested.
 884 | |||
 885 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the monoid's semigroup.
 886 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 887 | |||
 888 | ||| @reducer How to reduce elements along the given `axis`.
 889 | ||| @axis The axis along which to reduce elements.
 890 | export
 891 | reduce :
 892 |   (reducer : Monoid (Tensor [] dtype)) =>
 893 |   (axes : List Nat) ->
 894 |   {auto 0 axesUnique : Sorted LT axes} ->
 895 |   {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
 896 |   Tensor shape dtype ->
 897 |   Tag $ Tensor (deleteAt axes shape) dtype
 898 | reduce axes $ MkTensor x = do
 899 |   let semigroup : Monoid a -> Semigroup a
 900 |       semigroup _ = %search
 901 |
 902 |   g <- fn2 (pure .: (<+>) @{semigroup reducer})
 903 |   let MkTensor neutral' = neutral @{reducer}
 904 |   pure $ MkTensor $ V 0 $ Reduce g [neutral'] axes [x]
 905 |
 906 | ||| Sort the elements of a `Tensor` along a specified `dimension` according to a scalar-wise
 907 | ||| ordering. For sorting function `f`, elements are sorted such that for consecutive sorted
 908 | ||| elements `a` and `b`, either `f a b` is true, or `f a b` *and* `f b a` are false.
 909 | |||
 910 | ||| **Note:** Sorting is not stable, meaning elements that compare equal according the ordering may
 911 | ||| be sorted in a different order to the order they appear in the input.
 912 | |||
 913 | ||| **Note:** `sort` is limited to use comparison function without `Tag`. However, since the most
 914 | ||| commonly-used functions, including (>), (<), (>=), and (<=), don't use `Tag`, we have opted to
 915 | ||| omit it for ergonomics. We can trivially provide an overloaded variant if requested.
 916 | |||
 917 | ||| For example, for `x = tensor [[1, 6, 4], [3, 2, 5]]`, `sort (<) 0 x` produces
 918 | ||| `tensor [[1, 2, 4], [3, 6, 5]]`, while `sort (<) 1 x` produces
 919 | ||| `tensor [[1, 4, 6], [2, 3, 5]]`.
 920 | |||
 921 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the function passed to `sort`.
 922 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 923 | export
 924 | sort :
 925 |   (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
 926 |   (dimension : Nat) ->
 927 |   Tensor shape dtype ->
 928 |   {auto 0 dimInBounds : InBounds dimension shape} ->
 929 |   Tag $ Tensor shape dtype
 930 | sort comp dimension $ MkTensor x =
 931 |   pure $ MkTensor $ V 0 $ Sort !(fn2 $ pure .: comp) dimension False x
 932 |
 933 | ||| Reverse elements along the specified axes. For example, for
 934 | ||| ```
 935 | ||| x : Tensor [2, 3] S32
 936 | ||| x = tensor [[-2, -1,  0],
 937 | |||             [ 1,  2,  3]]
 938 | ||| ```
 939 | ||| `reverse [0] x` is
 940 | ||| ```
 941 | ||| x : Tensor [2, 3] S32
 942 | ||| x = tensor [[ 1,  2,  3],
 943 | |||             [-2, -1,  0]]
 944 | ||| ```
 945 | ||| `reverse [1] x` is
 946 | ||| ```
 947 | ||| x : Tensor [2, 3] S32
 948 | ||| x = tensor [[ 0, -1, -2],
 949 | |||             [ 3,  2,  1]]
 950 | ||| ```
 951 | ||| and `reverse [0, 1] x` is
 952 | ||| ```
 953 | ||| x : Tensor [2, 3] S32
 954 | ||| x = tensor [[ 3,  2,  1],
 955 | |||             [ 0, -1, -2]]
 956 | ||| ```
 957 | |||
 958 | ||| **Note:** This function requires `axes` is ordered simply so that elements are unique.
 959 | ||| The ordering itself is irrelevant to the implementation, but ensures uniqueness without using
 960 | ||| proofs of contradiction that can be difficult for Idris to construct.
 961 | export
 962 | reverse :
 963 |   (axes : List Nat) ->
 964 |   {auto 0 axesUnique : Sorted LT axes} ->
 965 |   {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
 966 |   Tensor shape dtype ->
 967 |   Tensor shape dtype
 968 | reverse axes $ MkTensor x = MkTensor $ V 0 $ Reverse axes x
 969 |
 970 | ewUnary : UnaryOp -> Tensor s a -> Tensor s a
 971 | ewUnary op $ MkTensor x = MkTensor $ V 0 $ UnaryElementwise op x
 972 |
 973 | ewBinary : BinaryOp -> Tensor s a -> Tensor s a -> Tensor s a
 974 | ewBinary op (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise op x x'
 975 |
 976 | ewBinary' : {out : _} -> BinaryOp -> Tensor s a -> Tensor s a -> Tensor s out
 977 | ewBinary' op (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise op x x'
 978 |
 979 | ||| Element-wise equality. For example, `tensor [1, 2] /= tensor [1, 3]` is
 980 | ||| `tensor [True, False]`.
 981 | export
 982 | (==) : Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
 983 | (==) = ewBinary' $ Compare Eq
 984 |
 985 | ||| Element-wise inequality. For example, `tensor [1, 2] /= tensor [1, 3]` is
 986 | ||| `tensor [False, True]`.
 987 | export
 988 | (/=) : Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
 989 | (/=) = ewBinary' $ Compare Ne
 990 |
 991 | ||| Element-wise less than. For example, `tensor [1, 2, 3] < tensor [2, 2, 2]` is
 992 | ||| `tensor [True, False, False]`.
 993 | export
 994 | (<) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
 995 | (<) = ewBinary' $ Compare Lt
 996 |
 997 | ||| Element-wise greater than. For example, `tensor [1, 2, 3] > tensor [2, 2, 2]` is
 998 | ||| `tensor [False, False, True]`.
 999 | export
1000 | (>) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1001 | (>) = ewBinary' $ Compare Gt
1002 |
1003 | ||| Element-wise less than or equal. For example, `tensor [1, 2, 3] <= tensor [2, 2, 2]`
1004 | ||| is `tensor [True, True, False]`.
1005 | export
1006 | (<=) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1007 | (<=) = ewBinary' $ Compare Le
1008 |
1009 | ||| Element-wise greater than or equal. For example,
1010 | ||| `tensor [1, 2, 3] >= tensor [2, 2, 2]` is `tensor [False, True, True]`.
1011 | export
1012 | (>=) : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1013 | (>=) = ewBinary' $ Compare Ge
1014 |
1015 | ||| Element-wise boolean and. For example,
1016 | ||| `tensor [True, True, False, False] && tensor [True, False, True, False]` is
1017 | ||| `tensor [True, False, False, False]`.
1018 | export
1019 | (&&) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1020 | (&&) = ewBinary And
1021 |
1022 | namespace Semigroup
1023 |   export
1024 |   [All] Semigroup (Tensor shape PRED) where
1025 |     (<+>) = (&&)
1026 |
1027 | namespace Monoid
1028 |   export
1029 |   [All] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.All where
1030 |     neutral = fill True
1031 |
1032 | ||| Element-wise boolean or. For example,
1033 | ||| `tensor [True, True, False, False] || tensor [True, False, True, False]` is
1034 | ||| `tensor [True, True, True, False]`.
1035 | export
1036 | (||) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1037 | (||) = ewBinary Or
1038 |
1039 | namespace Semigroup
1040 |   export
1041 |   [Any] Semigroup (Tensor shape PRED) where
1042 |     (<+>) = (||)
1043 |
1044 | namespace Monoid
1045 |   export
1046 |   [Any] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.Any where
1047 |     neutral = fill False
1048 |
1049 | ||| Element-wise boolean negation. For example, `not (tensor [True, False])` is
1050 | ||| `tensor [False, True]`.
1051 | export
1052 | not : Tensor shape PRED -> Tensor shape PRED
1053 | not = ewUnary Not
1054 |
1055 | ||| Choose elements from two `Tensor`s based on a `Tensor` of predicates. For each element in the
1056 | ||| predicates, the output will use the corresponding element from `onTrue` if the element is
1057 | ||| truthy, else the element from `onFalse`. For example, for
1058 | ||| ```
1059 | ||| preds : Tensor [3] PRED
1060 | ||| preds = tensor [False, True, False]
1061 | |||
1062 | ||| onTrue : Tensor [3] S32
1063 | ||| onTrue = tensor [1, 2, 3]
1064 | |||
1065 | ||| onFalse : Tensor [3] S32
1066 | ||| onFalse = tensor [4, 5, 6]
1067 | ||| ```
1068 | ||| `select preds onTrue onFalse` is `tensor [4, 2, 6]`.
1069 | |||
1070 | ||| @onTrue The elements to choose where the predicate elements are truthy.
1071 | ||| @onFalse The elements to choose where the predicate elements are falsy.
1072 | export
1073 | select :
1074 |   Tensor shape PRED ->
1075 |   (onTrue, onFalse : Tensor shape dtype) ->
1076 |   Tensor shape dtype
1077 | select (MkTensor p) (MkTensor t) (MkTensor f) = MkTensor $ V 0 $ Select p t f
1078 |
1079 | ||| Use a scalar predicate to evaluate one of two branches. If the predicate is truthy,
1080 | ||| evaluate `onTrue`, else `onFalse`. Each branch is evaluated lazily; only one will be
1081 | ||| evaluated.
1082 | |||
1083 | ||| For example, for
1084 | ||| ```
1085 | ||| f : Tensor [] F64 -> Tag $ Tensor [] F64
1086 | ||| f x = if_ (x < 1.0) (pure $ cos x) (do x <- tag x; x * x)
1087 | ||| ```
1088 | ||| `f 0.0` produces `1.0`, and `f 4.0` produces `8.0`.
1089 | |||
1090 | ||| **Note:** For the XLA compiler, there are scoping restrictions on branches used in `if_`, as
1091 | ||| StableHLO gives branches their own scope. See the tutorial _Nuisances in the Tensor API_
1092 | ||| for details.
1093 | |||
1094 | ||| @onTrue The branch to evaluate if the predicate is truthy.
1095 | ||| @onFalse The branch to evaluate if the predicate is falsy.
1096 | export
1097 | if_ :
1098 |   {shape : _} -> {dtype : _} ->
1099 |   Tensor [] PRED ->
1100 |   (onTrue, onFalse : Tag $ Tensor shape dtype) ->
1101 |   Tag $ Tensor shape dtype
1102 | if_ (MkTensor pred) onTrue onFalse =
1103 |   pure $ MkTensor $ V 0 $ If (TensorType shape dtype) pred !(fn0 onTrue) !(fn0 onFalse)
1104 |
1105 | ||| The identity tensor, with inferred shape and element type. For example,
1106 | ||| ```
1107 | ||| x : Tensor [2, 2] S32
1108 | ||| x = identity
1109 | ||| ```
1110 | ||| is
1111 | ||| ```
1112 | ||| x : Tensor [2, 2] S32
1113 | ||| x = tensor [[1, 0],
1114 | |||             [0, 1]]
1115 | ||| ```
1116 | export
1117 | identity : {n : _} -> {dtype : _} -> Num dtype => Tensor [n, n] dtype
1118 | identity =
1119 |   let MkTensor x = iota 0 {shape = [n, n], dtype = U64} == iota 1
1120 |    in MkTensor $ V 0 $ Convert dtype [n, n] x
1121 |
1122 | -- see https://www.python.org/dev/peps/pep-0465/#precedence-and-associativity
1123 | export infixl 9 @@
1124 |
1125 | namespace Vector
1126 |   ||| Vector dot product with a tensor of any rank. The vector dot product is with the first axis of
1127 |   ||| the right-hand side tensor. For example `tensor [0, 1, 2] @@ tensor [-1, -3, -1]` is
1128 |   ||| `-1`.
1129 |   export
1130 |   (@@) : Num dtype => Tensor [S m] dtype -> Tensor [S m] dtype -> Tensor [] dtype
1131 |   (MkTensor x) @@ (MkTensor x') =
1132 |     MkTensor $ V 0 $ DotGeneral [] [] [0] [0] (TensorType [] dtype) x x'
1133 |
1134 | namespace Matrix
1135 |   ||| Matrix multiplication with a matrix or vector. Contraction is along the last axis of the first
1136 |   ||| and the first axis of the last. For example,
1137 |   ||| ```
1138 |   ||| x : Tensor [2, 3] S32
1139 |   ||| x = tensor [[-1, -2, -3],
1140 |   |||             [ 0,  1,  2]]
1141 |   |||
1142 |   ||| y : Tensor [3, 1] S32
1143 |   ||| y = tensor [[4, 0, 5]]
1144 |   |||
1145 |   ||| z : Tensor [2, 1] S32
1146 |   ||| z = x @@ y
1147 |   ||| ```
1148 |   ||| is
1149 |   ||| ```
1150 |   ||| z : Tensor [2, 1] S32
1151 |   ||| z = tensor [-19, 10]
1152 |   ||| ```
1153 |   export
1154 |   (@@) : Num dtype =>
1155 |          Tensor [n, S m] dtype ->
1156 |          Tensor (S m :: tl) dtype ->
1157 |          {auto 0 vectorTail : length tl `LTE` 1} ->
1158 |          Tensor (n :: tl) dtype
1159 |   (MkTensor x) @@ (MkTensor x') =
1160 |     MkTensor $ V 0 $ DotGeneral [] [] [1] [0] (TensorType (n :: tl) dtype) x x'
1161 |
1162 | ||| The output shape of a `dotGeneral` operation.
1163 | public export
1164 | contract : (lBatch, rBatch, lContract, rContract : List Nat) ->
1165 |            (ls, rs : Shape) ->
1166 |            {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1167 |            {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1168 |            {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1169 |            {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1170 |            Shape
1171 | contract lBatch rBatch lContract rContract ls rs =
1172 |   let lResultDims = deleteAt {inBounds = lInBoundsBatch ++ lInBoundsContract}
1173 |                              (lBatch ++ lContract) ls
1174 |       rResultDims = deleteAt {inBounds = rInBoundsBatch ++ rInBoundsContract}
1175 |                              (rBatch ++ rContract) rs
1176 |    in multiIndex lBatch ls ++ lResultDims ++ rResultDims
1177 |
1178 | ||| Matrix multiplication.
1179 | |||
1180 | ||| This is a much more general version of `(@@)`, in which you can specify any number of batch
1181 | ||| and contracting axes. Matrix multiplication is done over each contracting axis.
1182 | ||| The operation is vectorized over batch axes. For each contracting axis on the left-hand
1183 | ||| operand, there is one contracting axis on the right-hand operand. These can be different axes
1184 | ||| in each operand. The same is true for each batch axis.
1185 | |||
1186 | ||| For example, we can vectorize over a typical rank-two matrix multiplication as follows: given
1187 | ||| two inputs tensors
1188 | ||| ```
1189 | ||| let x : Tensor [3, 4, 5, 6] F64
1190 | |||     y : Tensor [3, 4, 6, 7] F64
1191 | ||| ```
1192 | ||| we do
1193 | ||| ```
1194 | ||| let z : Tensor [3, 4, 5, 7] F64 = dotGeneral [0, 1] [0, 1] [3] [2] x y
1195 | ||| ```
1196 | ||| Here, we vectorized over the first two axes `[0, 1]`, and do standard matrix multiplication
1197 | ||| over the remaining axes by specifying the axes 3 and 2 respectively as contracting axes. Notice
1198 | ||| how the batch axes appear once each at the start of the output shape, and the contracting axis
1199 | ||| disappears. Remaining axes appear in order from left to right.
1200 | |||
1201 | ||| Note this API is somewhat of a quickfix to bring general matrix multiplication to the tensor
1202 | |||   API. It is not thoroughly tested. Expect it to change in the future.
1203 | export
1204 | dotGeneral :
1205 |   Num dtype =>
1206 |   (lBatch, rBatch, lContract, rContract : List Nat) ->
1207 |   {auto 0 lUnique : unique (lBatch ++ lContract) = True} ->
1208 |   {auto 0 rUnique : unique (rBatch ++ rContract) = True} ->
1209 |   {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1210 |   {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1211 |   {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1212 |   {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1213 |   {auto 0 batchDimsEq : multiIndex lBatch ls = multiIndex rBatch rs} ->
1214 |   {auto 0 contractDimsEq : multiIndex lContract ls = multiIndex rContract rs} ->
1215 |   Tensor ls dtype ->
1216 |   Tensor rs dtype ->
1217 |   Tensor (contract lBatch rBatch lContract rContract ls rs) dtype
1218 | dotGeneral lb rb lc rc (MkTensor x) (MkTensor y) =
1219 |   let resultType = TensorType (contract lb rb lc rc ls rs) dtype
1220 |    in MkTensor $ V 0 $ DotGeneral lb rb lc rc resultType x y
1221 |
1222 | ||| Element-wise addition. For example, `tensor [1, 2] + tensor [3, 4]` is
1223 | ||| `tensor [4, 6]`.
1224 | export
1225 | (+) : Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1226 | (+) = ewBinary Add
1227 |
1228 | namespace Semigroup
1229 |   export
1230 |   [Sum] Num dtype => Semigroup (Tensor shape dtype) where
1231 |     (<+>) = (+)
1232 |
1233 | namespace Monoid
1234 |   export
1235 |   [Sum] {shape : _} -> {dtype : _} -> Prelude.Num (idrisType dtype) => Num dtype =>
1236 |     Monoid (Tensor shape dtype) using Semigroup.Sum where
1237 |       neutral = fill 0
1238 |
1239 | ||| Element-wise negation. For example, `- tensor [1, -2]` is `tensor [-1, 2]`.
1240 | export
1241 | negate : Neg dtype => Tensor shape dtype -> Tensor shape dtype
1242 | negate $ MkTensor i = MkTensor $ V 0 $ UnaryElementwise Neg i
1243 |
1244 | ||| Element-wise subtraction. For example, `tensor [3, 4] - tensor [4, 2]` is
1245 | ||| `tensor [-1, 2]`.
1246 | export
1247 | (-) : Neg dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1248 | (-) = ewBinary Sub
1249 |
1250 | ||| Element-wise multiplication. For example, `tensor [2, 3] * tensor [4, 5]` is
1251 | ||| `tensor [8, 15]`.
1252 | export
1253 | (*) : Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1254 | (*) = ewBinary Mul
1255 |
1256 | namespace Scalarwise
1257 |   ||| Multiplication by a scalar. For example, `tensor 2 * tensor [3, 5]` is
1258 |   ||| `tensor [6, 10]`.
1259 |   |||
1260 |   ||| The RHS is required to be non-scalar simply to avoid ambiguities with element-wise `(*)`.
1261 |   export
1262 |   (*) : Num dtype => Tensor [] dtype -> Tensor (d :: ds) dtype -> Tensor (d :: ds) dtype
1263 |   l * r =
1264 |     let MkTensor {shape = _ :: _} _ = r
1265 |      in broadcast {shapesOK = scalarToAnyOk (d :: ds)} l * r
1266 |
1267 | namespace Semigroup
1268 |   export
1269 |   [Prod] Num dtype => Semigroup (Tensor shape dtype) where
1270 |     (<+>) = (*)
1271 |
1272 | namespace Monoid
1273 |   export
1274 |   [Prod] {shape : _} -> {dtype : _} -> Prelude.Num (idrisType dtype) => Num dtype =>
1275 |     Monoid (Tensor shape dtype) using Semigroup.Prod where
1276 |       neutral = fill 1
1277 |
1278 | ||| Element-wise floating point division. For example, `tensor [2, 3] / tensor [4, 5]` is
1279 | ||| `tensor [0.5, 0.6]`.
1280 | export
1281 | (/) : Fractional dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1282 | (/) = ewBinary Div
1283 |
1284 | namespace Scalarwise
1285 |   ||| Floating point division by a scalar. For example, `tensor [3.4, -5.6] / tensor 2` is
1286 |   ||| `tensor [1.7, -2.8]`.
1287 |   |||
1288 |   ||| The LHS is required to be non-scalar simply to avoid ambiguities with element-wise `(/)`.
1289 |   export
1290 |   (/) : Fractional dtype => Tensor (d :: ds) dtype -> Tensor [] dtype -> Tensor (d :: ds) dtype
1291 |   l / r =
1292 |     let MkTensor {shape = _ :: _} _ = l
1293 |      in l / broadcast {shapesOK = scalarToAnyOk (d :: ds)} r
1294 |
1295 | inf = tensor 1.0 / tensor 0.0
1296 | nan = tensor 0.0 / tensor 0.0
1297 |
1298 | ||| Element-wise division of natural numbers. For example,
1299 | ||| `div (tensor [13, 8]) [3, 4]` is `tensor [4, 2]`.
1300 | export
1301 | div : Tensor shape U64 ->
1302 |       (denom : Literal shape Nat) ->
1303 |       {auto 0 isSucc : All IsSucc denom} ->
1304 |       Tensor shape U64
1305 | div x y with (x)
1306 |   _ | (MkTensor {shape = _} _) = ewBinary Div x (tensor {dtype = U64} $ cast <$> y)
1307 |
1308 | namespace Integral
1309 |   ||| Element-wise remainder for natural numbers. For example,
1310 |   ||| `rem (tensor [13, 8]) [3, 4]` is `tensor [1, 0]`.
1311 |   export
1312 |   rem : Tensor shape U64 ->
1313 |         (denom : Literal shape Nat) ->
1314 |         {auto 0 isSucc : All IsSucc denom} ->
1315 |         Tensor shape U64
1316 |   rem x y with (x)
1317 |     _ | (MkTensor {shape = _} _) = ewBinary Rem x (tensor {dtype = U64} $ cast <$> y)
1318 |
1319 | export infixr 9 ^
1320 |
1321 | ||| Each element in `base` raised to the power of the corresponding element in `exponent`.
1322 | ||| example, `tensor [2, 25, -9] ^ tensor [3, -0.5, 0.5]` is `tensor [8, 0.2, nan]`.
1323 | |||
1324 | ||| Note: The behaviour of this function is not well-defined at negative or positive infinity, or
1325 | |||   NaN.
1326 | |||
1327 | ||| Note: The first root is used.
1328 | export
1329 | (^) : Tensor shape F64 -> Tensor shape F64 -> Tensor shape F64
1330 | (^) = ewBinary Pow
1331 |
1332 | (>>) : Tensor shape U64 -> Tensor shape U64 -> Tensor shape U64
1333 | (>>) = ewBinary ShiftRightLogical
1334 |
1335 | ||| Element-wise absolute value. For example, `abs (tensor [-2, 3])` is `tensor [2, 3]`.
1336 | export
1337 | abs : Abs dtype => Tensor shape dtype -> Tensor shape dtype
1338 | abs = ewUnary Abs
1339 |
1340 | ||| The element-wise natural exponential. For example, `exp (tensor [-1, 0, 2])` is
1341 | ||| `tensor [1 / euler, 1, pow euler 2]`.
1342 | export
1343 | exp : Tensor shape F64 -> Tensor shape F64
1344 | exp = ewUnary Exp
1345 |
1346 | ||| The element-wise floor function. For example,
1347 | ||| `floor (tensor [-1.6, -1.5, -1.4, -1.0, 1.0, 1.4, 1.5, 1.6])` is
1348 | ||| `tensor [-2.0, -2.0, -2.0, -1.0, 1.0, 1.0, 1.0, 1.0]`.
1349 | export
1350 | floor : Tensor shape F64 -> Tensor shape F64
1351 | floor = ewUnary Floor
1352 |
1353 | ||| The element-wise ceiling function. For example,
1354 | ||| `ceil (tensor [-1.6, -1.5, -1.4, -1.0, 1.0, 1.4, 1.5, 1.6])` is
1355 | ||| `tensor [-1.0, -1.0, -1.0, -1.0, 1.0, 2.0, 2.0, 2.0]`.
1356 | export
1357 | ceil : Tensor shape F64 -> Tensor shape F64
1358 | ceil = ewUnary Ceil
1359 |
1360 | ||| The element-wise natural logarithm. Negative inputs yield NaN output. For example,
1361 | ||| `log (tensor [1 / euler, 1, euler * euler])` is `tensor [-1, 0, 2]`.
1362 | export
1363 | log : Tensor shape F64 -> Tensor shape F64
1364 | log = ewUnary Log
1365 |
1366 | ||| The element-wise logistic function equivalent to `1 / 1 + exp (-x)`.
1367 | export
1368 | logistic : Tensor shape F64 -> Tensor shape F64
1369 | logistic = ewUnary Logistic
1370 |
1371 | ||| The element-wise sine.
1372 | export
1373 | sin : Tensor shape F64 -> Tensor shape F64
1374 | sin = ewUnary Sin
1375 |
1376 | ||| The element-wise cosine.
1377 | export
1378 | cos : Tensor shape F64 -> Tensor shape F64
1379 | cos = ewUnary Cos
1380 |
1381 | ||| The element-wise tangent.
1382 | export
1383 | tan : Tensor shape F64 -> Tensor shape F64
1384 | tan = ewUnary Tan
1385 |
1386 | ||| The element-wise inverse sine.
1387 | export
1388 | asin : Tensor shape F64 -> Tensor shape F64
1389 | asin = ewUnary Asin
1390 |
1391 | ||| The element-wise inverse cosine.
1392 | export
1393 | acos : Tensor shape F64 -> Tensor shape F64
1394 | acos = ewUnary Acos
1395 |
1396 | ||| The element-wise inverse tangent.
1397 | export
1398 | atan : Tensor shape F64 -> Tensor shape F64
1399 | atan = ewUnary Atan
1400 |
1401 | ||| The element-wise hyperbolic sine.
1402 | export
1403 | sinh : Tensor shape F64 -> Tensor shape F64
1404 | sinh = ewUnary Sinh
1405 |
1406 | ||| The element-wise hyperbolic cosine.
1407 | export
1408 | cosh : Tensor shape F64 -> Tensor shape F64
1409 | cosh = ewUnary Cosh
1410 |
1411 | ||| The element-wise hyperbolic tangent.
1412 | export
1413 | tanh : Tensor shape F64 -> Tensor shape F64
1414 | tanh = ewUnary Tanh
1415 |
1416 | ||| The element-wise inverse hyperbolic sine.
1417 | export
1418 | asinh : Tensor shape F64 -> Tensor shape F64
1419 | asinh = ewUnary Asinh
1420 |
1421 | ||| The element-wise inverse hyperbolic cosine.
1422 | export
1423 | acosh : Tensor shape F64 -> Tensor shape F64
1424 | acosh = ewUnary Acosh
1425 |
1426 | ||| The element-wise inverse hyperbolic tangent.
1427 | export
1428 | atanh : Tensor shape F64 -> Tensor shape F64
1429 | atanh = ewUnary Atanh
1430 |
1431 | ||| An approximation to the element-wise error function.
1432 | export
1433 | erf : Tensor shape F64 -> Tensor shape F64
1434 | erf = ewUnary Erf
1435 |
1436 | erfInv : Tensor shape F64 -> Tensor shape F64
1437 | erfInv = ewUnary ErfInv
1438 |
1439 | ||| The element-wise square. For example, `square (tensor [-2, 0, 3])`
1440 | ||| is `tensor [4, 0, 9]`.
1441 | export
1442 | square : Tensor shape F64 -> Tensor shape F64
1443 | square = ewUnary Square
1444 |
1445 | ||| The element-wise square root. The first root is used. Negative inputs yield NaN output.
1446 | ||| For example, `sqrt (tensor [0, 9])` is `tensor [0, 3]`.
1447 | export
1448 | sqrt : Tensor shape F64 -> Tensor shape F64
1449 | sqrt = ewUnary Sqrt
1450 |
1451 | ||| The element-wise minimum of the first argument compared to the second. For example,
1452 | ||| `min (tensor [-3, -1, 3]) (tensor [-1, 0, 1])` is `tensor [-3, -1, 1]`.
1453 | export
1454 | min : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1455 | min (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise Min x x'
1456 |
1457 | namespace Semigroup
1458 |   export
1459 |   [Min] {shape : _} -> Ord dtype => Semigroup (Tensor shape dtype) where
1460 |     (<+>) = min
1461 |
1462 | namespace Monoid
1463 |   export
1464 |   [Min] {shape : _} -> {dtype : _} -> Ord dtype =>
1465 |     Monoid (Tensor shape dtype) using Semigroup.Min where
1466 |       neutral = broadcast max
1467 |
1468 | ||| The element-wise maximum of the first argument compared to the second. For example,
1469 | ||| `max (tensor [-3, -1, 3]) (tensor [-1, 0, 1])` is `tensor [-1, 0, 3]`.
1470 | export
1471 | max : Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1472 | max (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise Max x x'
1473 |
1474 | namespace Semigroup
1475 |   export
1476 |   [Max] Ord dtype => Semigroup (Tensor shape dtype) where
1477 |     (<+>) = max
1478 |
1479 | namespace Monoid
1480 |   export
1481 |   [Max] {shape : _} -> {dtype : _} -> Ord dtype =>
1482 |     Monoid (Tensor shape dtype) using Semigroup.Max where
1483 |       neutral = broadcast min
1484 |
1485 | ||| The diagonal of a matrix as a vector. For example, for
1486 | ||| ```
1487 | ||| x : Tensor [3, 3] S32
1488 | ||| x = tensor [[0, 1, 2],
1489 | |||             [3, 4, 5],
1490 | |||             [6, 7, 8]]
1491 | ||| ```
1492 | ||| `diag x` is `tensor [0, 4, 8]`.
1493 | export
1494 | diag : Num dtype => Prelude.Num (idrisType dtype) => Tensor [n, n] dtype -> Tensor [n] dtype
1495 | diag {n = 0} x@(MkTensor {shape = [0, 0]} _) = reshape x
1496 | diag {n = S n} x@(MkTensor {shape = [S n, S n]} _) = (x * identity) @@ fill 1
1497 |
1498 | argmxx :
1499 |   Ord dtype =>
1500 |   (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
1501 |   Tensor [] dtype ->
1502 |   Tensor [S n] dtype ->
1503 |   Tag $ Tensor [] U64
1504 | argmxx cmp (MkTensor bound) x@(MkTensor {shape = _} _) = do
1505 |   let MkTensor idxs : Tensor [S n] U64 = iota 0
1506 |       MkTensor x = x
1507 |       MkTensor zero = tensor {dtype = U64} 0
1508 |
1509 |       mon :
1510 |         Tensor [] dtype -> Tensor [] U64 ->
1511 |         Tensor [] dtype -> Tensor [] U64 ->
1512 |         Tag (Tensor [] dtype, Tensor [] U64)
1513 |       mon x y x' y' =
1514 |         let useNext = x == x && (cmp x' x || x' /= x')
1515 |          in pure (select useNext x' x, select useNext y' y)
1516 |
1517 |   MkTagT $ do
1518 |     addr <- reserve
1519 |
1520 |     let MkTagT res = mon
1521 |           (MkTensor $ V 0 $ BoundSet addr)
1522 |           (MkTensor $ V 1 $ BoundSet addr)
1523 |           (MkTensor $ V 2 $ BoundSet addr)
1524 |           (MkTensor $ V 3 $ BoundSet addr)
1525 |         (env, (MkTensor m, MkTensor i)) = runState (emptyFrom !get) res
1526 |         argTys = [TensorType [] dtype, TensorType [] U64]
1527 |         f = MkFn addr (argTys ++ argTys) argTys [m, i] env
1528 |
1529 |     updateCounterFrom env
1530 |     pure $ MkTensor $ V 1 $ Reduce f [bound, zero] [0] [x, idxs]
1531 |
1532 | ||| The first index of the maximum value in a vector. For example,
1533 | ||| `argmax (tensor [-1, 3, -2, -2, 3])` produces `tensor 1`. If the vector contains NaN values,
1534 | ||| `argmax` returns the index of the first NaN.
1535 | export
1536 | argmax : Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64
1537 | argmax x@(MkTensor {dtype} _) = argmxx (>) min x
1538 |
1539 | ||| The first index of the minimum value in a vector. For example,
1540 | ||| `argmin (tensor [-1, 3, -2, -2, 3])` produces `tensor 2`. If the vector contains NaN values,
1541 | ||| `argmin` returns the index of the first NaN.
1542 | export
1543 | argmin : Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64
1544 | argmin x@(MkTensor {dtype} _) = argmxx (<) max x
1545 |
1546 | ||| Represents the upper- or lower-triangular component of a matrix.
1547 | public export
1548 | data Triangle = Upper | Lower
1549 |
1550 | ||| Get the upper- or lower-triangular component of a matrix, always including the matrix diagonal.
1551 | ||| Remaining elements will be zero. For example, for
1552 | ||| ```
1553 | ||| x : Tensor [3, 3] S32
1554 | ||| x = tensor [[1, 2, 3],
1555 | |||             [4, 5, 6],
1556 | |||             [7, 8, 9]]
1557 | ||| ```
1558 | ||| `triangle Lower x` produces
1559 | ||| ```
1560 | ||| x : Tensor [3, 3] S32
1561 | ||| x = tensor [[1, 0, 0],
1562 | |||             [4, 5, 0],
1563 | |||             [7, 8, 9]]
1564 | ||| ```
1565 | export
1566 | triangle :
1567 |   Prelude.Num (idrisType dtype) =>
1568 |   Triangle ->
1569 |   Tensor [n, n] dtype ->
1570 |   Tag $ Tensor [n, n] dtype
1571 | triangle tri (MkTensor x) = do
1572 |   let range : Tensor [n * n] U64 = iota 0
1573 |   indices <- tag $ reshape {to = [n, n], sizesEqual = productSquare n} range
1574 |   let op = case tri of
1575 |         Upper => Tensor.(>)
1576 |         Lower => Tensor.(<)
1577 |   pure $ select (op indices indices.T) (fill $ fromInteger 0) (MkTensor x)
1578 |
1579 |   where
1580 |
1581 |   productSquare : (m : Nat) -> product (the Shape [m * m]) = product (the Shape [m, m])
1582 |   productSquare m =
1583 |     Calc $
1584 |       |~ (m * m) + 0
1585 |       ~~ m * m ... plusZeroRightNeutral (m * m)
1586 |       ~~ (m + 0) * m ..< cong (* m) (plusZeroRightNeutral m)
1587 |
1588 | ||| Cholesky decomposition. Computes the lower triangular matrix `L` from the symmetric, positive
1589 | ||| semi-definite matrix `X` s.t. `X = L @@ L.T`. Values will be NaN if the input matrix is not
1590 | ||| positive semi-definite. The remaining matrix components - those not in the lower triangle or
1591 | ||| diagonal - will always be zero.
1592 | export
1593 | cholesky : Tensor [S n, S n] F64 -> Tag $ Tensor [S n, S n] F64
1594 | cholesky $ MkTensor x = triangle Lower (MkTensor $ V 0 $ Cholesky x)
1595 |
1596 | export infix 9 |\, \|
1597 |
1598 | namespace Matrix
1599 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is a lower-triangular matrix.
1600 |   ||| `a` is given by the lower-triangular elements of the first argument. Values in the
1601 |   ||| upper-triangular part are ignored. If `a` is lower-triangular already,
1602 |   ||| this is written `a |\ b`.
1603 |   |||
1604 |   ||| The operator is shaped like the lower-triangular portion of a matrix to signal that it uses
1605 |   ||| this portion of its argument. This is in contrast to `(\|)`.
1606 |   export
1607 |   (|\) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1608 |   (MkTensor a) |\ (MkTensor b) = MkTensor $ V 0 $ TriangularSolve a b True
1609 |
1610 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular
1611 |   ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the
1612 |   ||| lower-triangular part are ignored. If `a` is upper-triangular already, this is written
1613 |   ||| `a \| b`.
1614 |   |||
1615 |   ||| The operator is shaped like the upper-triangular portion of a matrix to signal that it uses
1616 |   ||| this portion of its argument. This is in contrast to `(|\)`.
1617 |   export
1618 |   (\|) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1619 |   (MkTensor a) \| (MkTensor b) = MkTensor $ V 0 $ TriangularSolve a b False
1620 |
1621 | namespace Vector
1622 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is a lower-triangular matrix.
1623 |   ||| `a` is given by the lower-triangular elements of the first argument. Values in the
1624 |   ||| upper-triangular part are ignored. If `a` is lower-triangular already,
1625 |   ||| this is written `a |\ b`.
1626 |   |||
1627 |   ||| The operator is shaped like the lower-triangular portion of a matrix to signal that it uses
1628 |   ||| this portion of its argument. This is in contrast to `(\|)`.
1629 |   export
1630 |   (|\) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1631 |   a |\ b = let (MkTensor {shape = [_]} _) = b in squeeze (a |\ expand 1 b)
1632 |
1633 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular
1634 |   ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the
1635 |   ||| lower-triangular part are ignored. If `a` is upper-triangular already, this is written
1636 |   ||| `a \| b`.
1637 |   |||
1638 |   ||| The operator is shaped like the upper-triangular portion of a matrix to signal that it uses
1639 |   ||| this portion of its argument. This is in contrast to `(|\)`.
1640 |   export
1641 |   (\|) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1642 |   a \| b = let (MkTensor {shape = [_]} _) = b in squeeze (a \| expand 1 b)
1643 |
1644 | ||| Sum the elements along the diagonal of the input. For example,
1645 | ||| `trace (tensor [[-1, 5], [1, 4]])` produces `3`.
1646 | export
1647 | trace : Num dtype => Prelude.Num (idrisType dtype) =>
1648 |         Tensor [S n, S n] dtype ->
1649 |         Tag $ Tensor [] dtype
1650 | trace x with (x)
1651 |   _ | MkTensor {shape = [_, _]} _ = reduce @{Sum} [0, 1] $ x * identity
1652 |
1653 | ||| A `Rand a` produces a pseudo-random value of type `a` from a `Tensor [2] U64` state.
1654 | ||| The state is updated every time a new value is generated.
1655 | public export 0
1656 | Rand : Type -> Type
1657 | Rand = StateT (Tensor [2] U64) Tag
1658 |
1659 | ||| Generate independent and identically distributed (IID) uniform samples.
1660 | |||
1661 | ||| The generated samples are a deterministic function of the input key and state, but may vary
1662 | ||| between PJRT plugin and library version.
1663 | |||
1664 | ||| Example usage, multiplying two uniform samples
1665 | ||| ```
1666 | ||| x : Tag $ Tensor [3] U64
1667 | ||| x = let seed = tensor [1, 1] in evalStateT seed [| rng * rng |]
1668 | ||| ```
1669 | export
1670 | rng : {shape : _} -> Rand $ Tensor shape U64
1671 | rng = ST $ \(MkTensor state) => do
1672 |   res <- tag $ Rng state (TensorType shape U64)
1673 |   pure (MkTensor $ V 0 res, MkTensor $ V 1 res)
1674 |
1675 | ||| Generate independent and identically distributed (IID) from the uniform distribution U(0, 1).
1676 | |||
1677 | ||| The generated samples are a deterministic function of the input key and state, but may vary
1678 | ||| between PJRT plugin and library version.
1679 | |||
1680 | ||| Example usage, multiplying two uniform samples
1681 | ||| ```
1682 | ||| x : Rand $ Tensor [3] F64
1683 | ||| x = [| uniform * uniform |]
1684 | ||| ```
1685 | export
1686 | uniform : {shape : _} -> Rand $ Tensor shape F64
1687 | uniform =
1688 |   let numMantissaBits : Bits64 = 52
1689 |       scale = broadcast $ 2.0 ^ tensor (Scalar $ - cast {to = Double} numMantissaBits)
1690 |       shift = fill $ 64 - numMantissaBits
1691 |    in rng {shape} <&> \x => castDtype (x >> shift) * scale
1692 |
1693 | ||| Generate independent and identically distributed (IID) samples from the standard normal
1694 | ||| distribution N(0, 1).
1695 | |||
1696 | ||| The generated samples are a deterministic function of the input key and state, but may vary
1697 | ||| between PJRT plugin and library version.
1698 | |||
1699 | ||| Example usage, multiplying two normal samples
1700 | ||| ```
1701 | ||| x : Rand $ Tensor [3] F64
1702 | ||| x = [| normal * normal |]
1703 | ||| ```
1704 | export
1705 | normal : {shape : _} -> Rand $ Tensor shape F64
1706 | normal = uniform <&> \x => sqrt (broadcast 2.0) * erfInv (broadcast 2.0 * x - broadcast 1.0)
1707 |