0 | {--
   1 | Copyright (C) 2021  Joel Berkeley
   2 |
   3 | This program is free software: you can redistribute it and/or modify
   4 | it under the terms of the GNU Affero General Public License as published
   5 | by the Free Software Foundation, either version 3 of the License, or
   6 | (at your option) any later version.
   7 |
   8 | This program is distributed in the hope that it will be useful,
   9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
  10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  11 | GNU Affero General Public License for more details.
  12 |
  13 | You should have received a copy of the GNU Affero General Public License
  14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
  15 | --}
  16 | ||| Defines `Tensor`, an array of numbers or booleans, along with a number of functions operating on
  17 | ||| `Tensor`s. `Tensor` operations typically leverage hardware acceleration and graph compilation.
  18 | ||| spidr tracks tensor shape and data type in the types, so you can be sure that if your tensor
  19 | ||| code compiles, these are consistent.
  20 | |||
  21 | ||| spidr achieves efficient reuse of tensor computations with `Tag`. See the tutorial
  22 | ||| _Nuisances in the Tensor API_ for a discussion of pitfalls to avoid when using `Tag`.
  23 | module Tensor
  24 |
  25 | import Control.Monad.Error.Either
  26 | import public Control.Monad.State
  27 | import public Data.List
  28 | import public Data.List.Elem
  29 | import Data.List.Quantifiers
  30 | import Decidable.Equality
  31 | import Syntax.PreorderReasoning
  32 |
  33 | import Compiler.Eval
  34 | import Compiler.IR
  35 | import Compiler.Xla.Shape
  36 | import Compiler.Xla.ShapeUtil
  37 | import Compiler.LiteralRW
  38 | import Device
  39 | import public Literal
  40 | import public Primitive
  41 | import public Types
  42 | import public Util
  43 |
  44 | 0 XlaShape : Type
  45 | XlaShape = Xla.Shape
  46 |
  47 | %hide Xla.Shape
  48 |
  49 | ||| A scalar or array. Construct a `Tensor` with the function `tensor`.
  50 | |||
  51 | ||| @shape The `Tensor` shape.
  52 | ||| @dtype The element type.
  53 | export
  54 | data Tensor : (shape : Shape) -> (dtype : Type) -> Type where
  55 |   MkTensor : Value -> {shape : _} -> Tensor shape dtype
  56 |
  57 | ||| The effect of tagging nodes in a computational graph.
  58 | export
  59 | data TagT : (Type -> Type) -> Type -> Type where
  60 |   MkTagT : StateT Env m a -> TagT m a
  61 |
  62 | public export 0
  63 | Tag : Type -> Type
  64 | Tag = TagT Identity
  65 |
  66 | export
  67 | Functor m => Functor (TagT m) where
  68 |   map f (MkTagT x) = MkTagT (map f x)
  69 |
  70 | export
  71 | Monad m => Applicative (TagT m) where
  72 |   pure x = MkTagT (pure x)
  73 |   (MkTagT f) <*> (MkTagT x) = MkTagT (f <*> x)
  74 |
  75 | export
  76 | Monad m => Monad (TagT m) where
  77 |   (MkTagT x) >>= f = MkTagT $ x >>= (\y => let MkTagT z = f y in z)
  78 |
  79 | export
  80 | MonadTrans TagT where
  81 |   lift = MkTagT . lift
  82 |
  83 | public export
  84 | interface Taggable a where
  85 |   ||| Mark an expression to be efficiently reused. For example, in
  86 |   ||| ```
  87 |   ||| bad : Tensor [9999999] F64
  88 |   ||| bad = let x = fill {shape = [9999999]} 1.0 in x + x
  89 |   |||
  90 |   ||| good : Tag $ Tensor [9999999] F64
  91 |   ||| good = do x <- tag $ fill {shape = [9999999]} 1.0
  92 |   |||           pure (x + x)
  93 |   ||| ```
  94 |   ||| the large vector `x` is calculated twice in `bad`, but once in `good`, as `tag` marks it for
  95 |   ||| sharing.
  96 |   |||
  97 |   ||| Types that implement this interface should `tag` constituent components it deems worth sharing.
  98 |   ||| For example, see the implementation for tuples.
  99 |   |||
 100 |   ||| See tutorial _Nuisances in the Tensor API_ for details.
 101 |   tag : Monad m => a -> TagT m a
 102 |
 103 | Taggable Op where
 104 |   -- `BoundSet` case is an optimization. Note this will mean you cannot re-bind a value
 105 |   -- to an inner scope, but I can't see why that would be useful
 106 |   tag (BoundSet x) = pure (BoundSet x)
 107 |   tag x = MkTagT $ tagOp x
 108 |
 109 | export
 110 | Taggable (Tensor shape dtype) where
 111 |   tag (MkTensor $ V idx op) = map (\op => MkTensor $ V idx op) (tag op)
 112 |
 113 | export
 114 | (Taggable a, Taggable b) => Taggable (a, b) where
 115 |   tag (a, b) = [| (tag a, tag b) |]
 116 |
 117 | ||| Construct a `Tensor` from `Literal` data. For example
 118 | ||| ```
 119 | ||| x : Tensor [2, 3] S32
 120 | ||| x = tensor [[1, 2, 3],
 121 | |||             [4, 5, 6]]
 122 | ||| ```
 123 | export
 124 | tensor : PrimitiveRW dtype a => {shape : _} -> Literal shape a -> Tensor shape dtype
 125 | tensor lit = MkTensor $ V 0 $ Lit {dtype} {shape} lit
 126 |
 127 | namespace F64
 128 |   export
 129 |   fromDouble : Double -> Tensor [] F64
 130 |   fromDouble = tensor . Scalar
 131 |
 132 | namespace S32
 133 |   export
 134 |   fromInteger : Integer -> Tensor [] S32
 135 |   fromInteger = tensor . Scalar . fromInteger
 136 |
 137 | try : Show e => EitherT e IO a -> IO a
 138 | try = eitherT (\e => assert_total $ idris_crash $ show e) pure
 139 |
 140 | namespace TensorList
 141 |   namespace Tag
 142 |     ||| A list of `Tensor`s, along with the conversions needed to evaluate them to `Literal`s.
 143 |     ||| The list is parametrized by the shapes and types of the resulting `Literal`s.
 144 |     public export
 145 |     data TensorList : List Shape -> List Type -> Type where
 146 |       Nil : TensorList [] []
 147 |       (::) : PrimitiveRW dtype ty =>
 148 |              Tensor shape dtype ->
 149 |              TensorList shapes tys ->
 150 |              TensorList (shape :: shapes) (ty :: tys)
 151 |
 152 |     ||| Evaluate a list of `Tensor`s as a list of `Literal`s. Tensors in the list can have different
 153 |     ||| shapes and element types. For example,
 154 |     ||| ```
 155 |     ||| main : Device -> IO ()
 156 |     ||| main device = do [x, y] <- eval device $ do let x = tensor {dtype = F64} [1.2, 3.4]
 157 |     |||                                             y <- reduce @{Sum} [0] x
 158 |     |||                                             pure [x, y]
 159 |     |||                  printLn x
 160 |     |||                  printLn y
 161 |     ||| ```
 162 |     ||| In contrast to `Tensor.eval` when called on multiple tensors, this function constructs and
 163 |     ||| compiles the graph just once.
 164 |     export covering
 165 |     eval : Device -> Tag (TensorList shapes tys) -> IO (All2 Literal shapes tys)
 166 |     eval device (MkTagT xs) =
 167 |       let (env, xs) = runState empty xs
 168 |        in try $ do
 169 |             xlaShapes <- buildShapes xs
 170 |             let (outputs ** eq= lengthC xs
 171 |                 main = MkFn 0 [] (resultTypes xs) (results xs) env
 172 |             lits <- execute device main {outputs} (rewrite eq in xlaShapes)
 173 |             readAll xs $ rewrite sym eq in lits
 174 |
 175 |       where
 176 |
 177 |       lengthC : TensorList s t -> (n ** n === length s)
 178 |       lengthC [] = (0 ** Refl)
 179 |       lengthC (_ :: xs) = let (n ** eq= lengthC xs in (S n ** cong S eq)
 180 |
 181 |       buildShapes : HasIO io => TensorList s t -> io $ Vect (length s) XlaShape
 182 |       buildShapes [] = pure []
 183 |       buildShapes (MkTensor {shape, dtype} _ :: ts) = [| mkShape shape {dtype} :: buildShapes ts |]
 184 |
 185 |       results : TensorList s t -> Vect (length s) Value
 186 |       results [] = []
 187 |       results (MkTensor x :: xs) = x :: results xs
 188 |
 189 |       resultTypes : TensorList s t -> Vect (length s) ValueType
 190 |       resultTypes [] = []
 191 |       resultTypes (MkTensor {shape, dtype} _ :: xs) = TensorType shape dtype :: resultTypes xs
 192 |
 193 |       readAll : HasIO io => TensorList s t -> Vect (length s) Literal -> io $ All2 Literal s t
 194 |       readAll [] _ = pure []
 195 |       readAll (MkTensor {dtype} _ :: ts) (l :: ls) = [| read {dtype} l :: readAll ts ls |]
 196 |
 197 |   ||| A convenience wrapper for `TensorList.Tag.eval`, for use with a bare `TensorList`.
 198 |   export covering
 199 |   eval : Device -> TensorList shapes tys -> IO (All2 Literal shapes tys)
 200 |   eval device xs = eval device (pure xs)
 201 |
 202 | namespace Tag
 203 |   ||| Evaluate a `Tensor`, returning its value as a `Literal`. This function builds and executes the
 204 |   ||| computational graph.
 205 |   |||
 206 |   ||| **Note:** Each call to `eval` will rebuild and execute the graph; multiple calls to `eval` on
 207 |   ||| different tensors, even if they are in the same computation, will be treated independently.
 208 |   ||| To efficiently evaluate multiple tensors at once, use `TensorList.Tag.eval`.
 209 |   export covering
 210 |   eval : Device -> PrimitiveRW dtype ty => Tag (Tensor shape dtype) -> IO (Literal shape ty)
 211 |   eval device x = map (\[z] => z) $ eval device $ map (\z => [z]) x
 212 |
 213 | ||| A convenience wrapper for `Tag.eval`, for use with a bare `Tensor`.
 214 | export covering
 215 | eval : Device -> PrimitiveRW dtype ty => Tensor shape dtype -> IO (Literal shape ty)
 216 | eval device x = eval device (pure x)
 217 |
 218 | ||| A string representation of a tensor graph.
 219 | |||
 220 | ||| There are no guarantees whatsoever as to the string structure and contents.
 221 | export
 222 | Primitive dtype => Show (Tag $ Tensor shape dtype) where
 223 |   show (MkTagT x) =
 224 |     let (env, MkTensor x) = runState empty x
 225 |      in show (MkFn 0 [] [TensorType shape dtype] [x] env)
 226 |
 227 | ||| Positive infinity.
 228 | export
 229 | inf : Tensor [] F64
 230 |
 231 | ||| NaN (not a number).
 232 | export
 233 | nan : Tensor [] F64
 234 |
 235 | namespace Primitive
 236 |   ||| Ordered element types.
 237 |   public export
 238 |   interface Primitive.Eq dtype => Primitive.Ord dtype where
 239 |     ||| Compares less than or equal to any other value (except NaN).
 240 |     min : Tensor [] dtype
 241 |
 242 |     ||| Compares greater than or equal to any other value (except NaN).
 243 |     max : Tensor [] dtype
 244 |
 245 |   export
 246 |   Primitive.Ord U32 where
 247 |     min = MkTensor $ V 0 $ MinValue U32
 248 |     max = MkTensor $ V 0 $ MaxValue U32
 249 |
 250 |   export
 251 |   Primitive.Ord S32 where
 252 |     min = MkTensor $ V 0 $ MinValue S32
 253 |     max = MkTensor $ V 0 $ MaxValue S32
 254 |
 255 |   export
 256 |   Primitive.Ord U64 where
 257 |     min = MkTensor $ V 0 $ MinValue U64
 258 |     max = MkTensor $ V 0 $ MaxValue U64
 259 |
 260 |   export
 261 |   Primitive.Ord F64 where
 262 |     min = tensor $ Scalar $ -1.0 / 0.0
 263 |     max = tensor $ Scalar $ 1.0 / 0.0
 264 |
 265 | ||| The most negative possible finite float, approx. -1.8e308
 266 | export
 267 | minFinite : Tensor [] F64
 268 | minFinite = MkTensor $ V 0 $ MinFiniteFloat
 269 |
 270 | ||| The most positive possible finite float, approx. 1.8e308
 271 | export
 272 | maxFinite : Tensor [] F64
 273 | maxFinite = MkTensor $ V 0 $ MaxFiniteFloat
 274 |
 275 | ||| Cast the element type. For example, `castDtype (tensor {dtype = S32} [1, -2])` is
 276 | ||| `tensor {dtype = F64} [1.0, -2.0]`.
 277 | export
 278 | castDtype : Primitive.Integral a => Tensor shape a -> Tensor shape F64
 279 | castDtype $ MkTensor {shape} x = MkTensor $ V 0 $ Convert {dtype = F64} shape x
 280 |
 281 | fn0 : Primitive a => Tag (Tensor s a) -> Tag $ Fn 0
 282 | fn0 f = MkTagT $ do
 283 |   addr <- reserve
 284 |
 285 |   let MkTagT res = f
 286 |       (env, MkTensor res) = runState (emptyFrom !get) res
 287 |       f = MkFn addr [] [TensorType s a] [res] env
 288 |
 289 |   updateCounterFrom env
 290 |   pure f
 291 |
 292 | fn1 : Primitive a => Primitive a' => {s : _} -> (Tensor s a -> Tag $ Tensor s' a') -> Tag $ Fn 1
 293 | fn1 f = MkTagT $ do
 294 |   addr <- reserve
 295 |
 296 |   let MkTagT res = f (MkTensor $ V 0 $ BoundSet addr)
 297 |       (env, MkTensor res) = runState (emptyFrom !get) res
 298 |       f = MkFn addr [TensorType s a] [TensorType s' a'] [res] env
 299 |
 300 |   updateCounterFrom env
 301 |   pure f
 302 |
 303 | fn2 :
 304 |   Primitive a => Primitive a' => Primitive a'' => {s, s' : _} ->
 305 |   (Tensor s a -> Tensor s' a' -> Tag $ Tensor s'' a'') ->
 306 |   Tag $ Fn 2
 307 | fn2 f = MkTagT $ do
 308 |   addr <- reserve
 309 |
 310 |   let MkTagT res = f (MkTensor $ V 0 $ BoundSet addr) (MkTensor $ V 1 $ BoundSet addr)
 311 |       (env, MkTensor res) = runState (emptyFrom !get) res
 312 |       f = MkFn addr [TensorType s a, TensorType s' a'] [TensorType s'' a''] [res] env
 313 |
 314 |   updateCounterFrom env
 315 |   pure f
 316 |
 317 | fn22 :
 318 |   Primitive a => Primitive a' => Primitive a'' => Primitive a''' => {s, s' : _} ->
 319 |   (Tensor s a -> Tensor s' a' -> Tag $ (Tensor s'' a'', Tensor s''' a''')) ->
 320 |   Tag $ Fn 2
 321 | fn22 f = MkTagT $ do
 322 |   addr <- reserve
 323 |
 324 |   let MkTagT res = f (MkTensor $ V 0 $ BoundSet addr) (MkTensor $ V 1 $ BoundSet addr)
 325 |       (env, (MkTensor res0, MkTensor res1)) = runState (emptyFrom !get) res
 326 |       resTys = [TensorType s'' a'', TensorType s''' a''']
 327 |       f = MkFn addr [TensorType s a, TensorType s' a'] resTys [res0, res1] env
 328 |
 329 |   updateCounterFrom env
 330 |   pure f
 331 |
 332 | ||| Reverse-mode automatic differentiation.
 333 | |||
 334 | ||| `grad` can be applied repeatedly to obtain higher derivatives, though we do not yet support
 335 | ||| derivatives of vector-valued functions.
 336 | |||
 337 | ||| For example, for
 338 | ||| ```
 339 | ||| f : Tensor [2] F64 -> Tag $ Tensor [] F64
 340 | ||| f x = do
 341 | |||   x <- tag x
 342 | |||   let (x0, x1) = (slice [at 0] x, slice [at 1] x)
 343 | |||   pure $ x0 / x1
 344 | ||| ```
 345 | ||| `grad f (tensor [3.0, 2.0])` produces `tensor [0.5, -0.75]`.
 346 | |||
 347 | ||| **Warning:** `grad` is experimental, and only implemented for a subset of the tensor API.
 348 | export partial
 349 | grad : (Tensor shape F64 -> Tag $ Tensor [] F64) -> Tensor shape F64 -> Tag $ Tensor shape F64
 350 | grad f (MkTensor x) = pure $ MkTensor $ V 0 $ Grad shape !(fn1 f) x
 351 |
 352 | ||| Reshape a `Tensor`. For example, `reshape {to = [2, 1]} (tensor [3, 4])` is
 353 | ||| `tensor [[3], [4]]`. The output can have a different rank to the input.
 354 | export
 355 | reshape :
 356 |   Primitive dtype =>
 357 |   {to : _} ->
 358 |   {auto 0 sizesEqual : product from = product to} ->
 359 |   Tensor from dtype ->
 360 |   Tensor to dtype
 361 | reshape $ MkTensor {shape} x = MkTensor $ V 0 $ Reshape {dtype} to x
 362 |
 363 | ||| Add a dimension of length one at the specified `axis`. The new dimension will be at the
 364 | ||| specified `axis` in the new `Tensor` (as opposed to the original `Tensor`). For example,
 365 | ||| `expand 1 $ tensor [[1, 2], [3, 4], [5, 6]]` is `tensor [[[1, 2]], [[3, 4]], [[5, 6]]]`.
 366 | export
 367 | expand :
 368 |   Primitive dtype =>
 369 |   (axis : Nat) ->
 370 |   {auto 0 inBounds : axis `LTE` length shape} ->
 371 |   Tensor shape dtype ->
 372 |   Tensor (insertAt axis 1 shape) dtype
 373 | expand axis $ MkTensor {shape = _} x = MkTensor $ V 0 $ Reshape {dtype} (insertAt axis 1 shape) x
 374 |
 375 | namespace Squeezable
 376 |   ||| A `Squeezable from to` constitutes proof that the shape `from` can be squeezed to the
 377 |   ||| shape `to`. Squeezing is the process of removing any number of dimensions of length one.
 378 |   public export
 379 |   data Squeezable : (0 from : Shape) -> (0 to : Shape) -> Type where
 380 |     ||| Proof that a shape can be squeezed to itself. For example:
 381 |     |||
 382 |     ||| [] to []
 383 |     ||| [3, 4] to [3, 4]
 384 |     Same : Squeezable x x
 385 |
 386 |     ||| Proof that any dimensions (including those of length 1) can be preserved in the process of
 387 |     ||| squeezing. For example:
 388 |     |||
 389 |     ||| ...
 390 |     Match : Squeezable from to -> Squeezable (x :: from) (x :: to)
 391 |
 392 |     ||| Proof that any dimensions of length one can be squeezed out. For example:
 393 |     |||
 394 |     ||| [1, 3, 1, 1, 4] to [3, 4]
 395 |     Nest : Squeezable from to -> Squeezable (1 :: from) to
 396 |
 397 | ||| Remove dimensions of length one from a `Tensor` such that it has the desired shape. For example:
 398 | |||
 399 | ||| ```
 400 | ||| x : Tensor [2, 1, 3, 1] S32
 401 | ||| x = tensor [[[[4], [5], [6]]],
 402 | |||             [[[7], [8], [9]]]]
 403 | |||
 404 | ||| y : Tensor [2, 1, 3] S32
 405 | ||| y = squeeze x
 406 | ||| ```
 407 | ||| is
 408 | ||| ```
 409 | ||| y : Tensor [2, 1, 3] S32
 410 | ||| y = tensor [[[4, 5, 6]],
 411 | |||             [[7, 8, 9]]]
 412 | ||| ```
 413 | export
 414 | squeeze :
 415 |   Primitive dtype =>
 416 |   {to : _} ->
 417 |   {auto 0 shapesSqueezable : Squeezable from to} ->
 418 |   Tensor from dtype ->
 419 |   Tensor to dtype
 420 | squeeze $ MkTensor {shape} x = MkTensor $ V 0 $ Reshape {dtype} to x
 421 |
 422 | ||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for
 423 | ||| details.
 424 | export
 425 | data SliceOrIndex : Nat -> Type where
 426 |   Slice :
 427 |     (from, to : Nat) ->
 428 |     {size : _} ->
 429 |     {auto 0 fromTo : from + size = to} ->
 430 |     {auto 0 inDim : LTE to d} ->
 431 |     SliceOrIndex d
 432 |   Index : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
 433 |   DynamicSlice : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
 434 |   DynamicIndex : Tensor [] U64 -> SliceOrIndex d
 435 |
 436 | ||| Index at `idx`. See `slice` for details.
 437 | public export
 438 | at : (idx : Nat) -> {auto 0 inDim : LT idx d} -> SliceOrIndex d
 439 | at = Index
 440 |
 441 | namespace Dynamic
 442 |   ||| Index at the specified index. See `slice` for details.
 443 |   public export
 444 |   at : Tensor [] U64 -> SliceOrIndex d
 445 |   at = DynamicIndex
 446 |
 447 | ||| Slice from `from` (inclusive) to `to` (exclusive). See `slice` for details.
 448 | public export
 449 | (.to) :
 450 |   (from, to : Nat) ->
 451 |   {size : _} ->
 452 |   {auto 0 fromTo : from + size = to} ->
 453 |   {auto 0 inDim : LTE to d} ->
 454 |   SliceOrIndex d
 455 | (.to) = Slice
 456 |
 457 | ||| Slice `size` elements starting at the specified scalar `U64` index. See `slice` for details.
 458 | public export
 459 | (.size) : Tensor [] U64 -> (size : Nat) -> {auto 0 inDim : LTE size d} -> SliceOrIndex d
 460 | (.size) = DynamicSlice
 461 |
 462 | ||| Slice across all indices along an axis. See `slice` for details.
 463 | public export
 464 | all : {d : _} -> SliceOrIndex d
 465 | all = Slice 0 @{%search} @{reflexive {ty = Nat}} d
 466 |
 467 | ||| A `MultiSlice shape` is a valid multi-dimensional slice into a tensor with shape `shape`.
 468 | ||| See `slice` for details.
 469 | public export
 470 | data MultiSlice : Shape -> Type where
 471 |   Nil : MultiSlice ds
 472 |   (::) : SliceOrIndex d -> MultiSlice ds -> MultiSlice (d :: ds)
 473 |
 474 | namespace MultiSlice
 475 |   ||| The shape of a tensor produced by slicing with the specified multi-dimensional slice. See
 476 |   ||| `Tensor.slice` for details.
 477 |   public export
 478 |   slice : {shape : _} -> MultiSlice shape -> Shape
 479 |   slice {shape} [] = shape
 480 |   slice {shape = (_ :: _)} (Slice {size} _ _ :: xs) = size :: slice xs
 481 |   slice {shape = (_ :: _)} (Index _ :: xs) = slice xs
 482 |   slice {shape = (_ :: _)} (DynamicSlice _ size :: xs) = size :: slice xs
 483 |   slice {shape = (_ :: _)} (DynamicIndex _ :: xs) = slice xs
 484 |
 485 | ||| Slice or index `Tensor` axes. Each axis can be sliced or indexed, and this can be done with
 486 | ||| either static (`Nat`) or dynamic (scalar `U64`) indices.
 487 | |||
 488 | ||| **Static indices**
 489 | |||
 490 | ||| Static indices are `Nat`s. For example, for
 491 | ||| ```
 492 | ||| x : Tensor [5, 6] S32
 493 | ||| x = tensor [[ 0,  1,  2,  3,  4,  5],
 494 | |||             [ 6,  7,  8,  9, 10, 11],
 495 | |||             [12, 13, 14, 15, 16, 17],
 496 | |||             [18, 19, 20, 21, 22, 23],
 497 | |||             [24, 25, 26, 27, 28, 29]]
 498 | ||| ```
 499 | ||| we can index as `slice [at 1] x` to get
 500 | ||| ```
 501 | ||| x : Tensor [6] S32
 502 | ||| x = tensor [6, 7, 8, 9, 10, 11]
 503 | ||| ```
 504 | ||| or we can slice as `slice [2.to 4] x` to get
 505 | ||| ```
 506 | ||| x : Tensor [2, 6] S32
 507 | ||| x = tensor [[12, 13, 14, 15, 16, 17],
 508 | |||             [18, 19, 20, 21, 22, 23]]
 509 | ||| ```
 510 | ||| Note that in `2.to 4`, the 2 is inclusive, and the 4 exclusive, so we return indices 2 and 3.
 511 | |||
 512 | ||| **Dynamic indices**
 513 | |||
 514 | ||| Dynamic indices are scalar `U64` values, and the API works slightly differently because we
 515 | ||| can't know the value of dynamic indices until the graph is executed. For indexing, with scalar
 516 | ||| `U64` index `i` in `slice [at i] x`, `i` is clamped to be a valid index into that dimension.
 517 | ||| For example, for `i = tensor 1`, `slice [at i] x` is
 518 | ||| ```
 519 | ||| x : Tensor [6] S32
 520 | ||| x = tensor [6, 7, 8, 9, 10, 11]
 521 | ||| ```
 522 | ||| as in the static case. However, for `i = tensor 10`, `slice [at i] x` returns the last row
 523 | ||| ```
 524 | ||| x : Tensor [6] S32
 525 | ||| x = tensor [24, 25, 26, 27, 28, 29]
 526 | ||| ```
 527 | ||| We can also slice by specifying a scalar `U64` start index, and a static size, as
 528 | ||| `slice [i.size 2] x` with `i = tensor 2` to get
 529 | ||| ```
 530 | ||| x : Tensor [2, 6] S32
 531 | ||| x = tensor [[12, 13, 14, 15, 16, 17],
 532 | |||             [18, 19, 20, 21, 22, 23]]
 533 | ||| ```
 534 | ||| For a given slice `size`, the dynamic start index is clamped such that we always get `size`
 535 | ||| elements along that axis. For example, `slice [i.size 2] x` with `i = tensor 4` is
 536 | ||| ```
 537 | ||| x : Tensor [2, 6] S32
 538 | ||| x = tensor [[18, 19, 20, 21, 22, 23],
 539 | |||             [24, 25, 26, 27, 28, 29]]
 540 | ||| ```
 541 | ||| which starts at index 3 rather than index 4.
 542 | |||
 543 | ||| **Mixed static, dynamic, slicing and indexing**
 544 | |||
 545 | ||| Each axis can only be sliced or indexed, and must use only static or dynamic indices. However,
 546 | ||| across axes, we can mix these four arbitrarily. For example, with `slice [2.to 4, at 1] x` to
 547 | ||| get
 548 | ||| ```
 549 | ||| x : Tensor [2] S32
 550 | ||| x = tensor [13, 19]
 551 | ||| ```
 552 | ||| or with `i = tensor 2` in `slice [at 1, i.size 2] x` to get
 553 | ||| ```
 554 | ||| x : Tensor [2] S32
 555 | ||| x = tensor [7, 8]
 556 | ||| ```
 557 | |||
 558 | ||| Slices and indices apply to the leading axes of the tensor. For trailing axes omitted from the
 559 | ||| multi-dimensional slice, the whole of the axis is returned. If we want to slice or index over
 560 | ||| later axes and retain all indices in a leading axis, we can use the convenience function `all`,
 561 | ||| as `slice [all, at 3] x` to get
 562 | ||| ```
 563 | ||| x : Tensor [5] S32
 564 | ||| x = tensor [[3], [9], [15], [21], [27]]
 565 | ||| ```
 566 | ||| This is exactly the same as the more manual `slice [0.to 5, at 3] x` and
 567 | ||| `slice [(tensor 0).size 5, at 3] x`.
 568 | |||
 569 | ||| @at The multi-dimensional slices and indices at which to slice the tensor.
 570 | export
 571 | slice :
 572 |   Primitive dtype =>
 573 |   (at : MultiSlice shape) ->
 574 |   Tensor shape dtype ->
 575 |   Tensor (slice at) dtype
 576 | slice at $ MkTensor x = MkTensor $ V 0 $
 577 |   let x = V 0 $ Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) x
 578 |       -- we shortcut DynamicSlice to allow autodiff for static slicing
 579 |       x = if isDynamic at then V 0 $ DynamicSlice (dynStarts [] at) (mapd size id at) x else x
 580 |    in Reshape {dtype} (MultiSlice.slice at) x
 581 |
 582 |       where
 583 |       mapd : ((Nat -> a) -> {d : Nat} -> SliceOrIndex d -> a) ->
 584 |              (Nat -> a) ->
 585 |              {shape : Shape} ->
 586 |              MultiSlice shape ->
 587 |              List a
 588 |       mapd _ dflt {shape} [] = Prelude.map dflt shape
 589 |       mapd f dflt (x :: xs) = f dflt x :: mapd f dflt xs
 590 |
 591 |       start : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
 592 |       start _ (Slice from _) = from
 593 |       start _ (Index idx) = idx
 594 |       start f {d} _ = f d
 595 |
 596 |       stop : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
 597 |       stop _ (Slice _ to) = to
 598 |       stop _ (Index idx) = S idx
 599 |       stop f {d} _ = f d
 600 |
 601 |       size : (Nat -> Nat) -> {d : Nat} -> SliceOrIndex d -> Nat
 602 |       size _ (Slice {size = size'} _ _) = size'
 603 |       size _ (Index _) = 1
 604 |       size _ (DynamicSlice _ size') = size'
 605 |       size _ (DynamicIndex _) = 1
 606 |
 607 |       zero : Value
 608 |       zero = V 0 $ Lit {shape = []} {dtype = U64} 0
 609 |
 610 |       isDynamic : {shape : _} -> MultiSlice shape -> Bool
 611 |       isDynamic [] = False
 612 |       isDynamic {shape = (_ :: _)} (DynamicSlice _ _ :: _) = True
 613 |       isDynamic {shape = (_ :: _)} (DynamicIndex _ :: _) = True
 614 |       isDynamic (_ :: ds) = isDynamic ds
 615 |
 616 |       dynStarts : List Value -> {shape : _} -> MultiSlice shape -> List Value
 617 |       dynStarts idxs {shape} [] = replicate (length shape) zero ++ idxs
 618 |       dynStarts idxs (DynamicSlice (MkTensor i) _ :: ds) = i :: dynStarts idxs ds
 619 |       dynStarts idxs (DynamicIndex (MkTensor i) :: ds) = i :: dynStarts idxs ds
 620 |       dynStarts idxs (_ :: ds) = zero :: dynStarts idxs ds
 621 |
 622 | ||| Concatenate two `Tensor`s along the specified `axis`. For example,
 623 | ||| `concat 0 (tensor [[1, 2], [3, 4]]) (tensor [[5, 6]])` and
 624 | ||| `concat 1 (tensor [[3], [6]]) (tensor [[4, 5], [7, 8]])` are both
 625 | ||| `tensor [[1, 2], [3, 4], [5, 6]]`.
 626 | export
 627 | concat :
 628 |   Primitive dtype =>
 629 |   (axis : Nat) ->
 630 |   Tensor s dtype ->
 631 |   Tensor s' dtype ->
 632 |   {auto 0 inBounds : (InBounds axis s, InBounds axis s')} ->
 633 |   {auto 0 shapesConcatenable : deleteAt axis s = deleteAt axis s'} ->
 634 |   Tensor (replaceAt axis (index axis s + index axis s') s) dtype
 635 | concat axis (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ Concat axis [x, x']
 636 |
 637 | ||| Transpose a matrix. For example, `(tensor [[1, 2], [3, 4]]).T` is `tensor [[1, 3], [2, 4]]`.
 638 | export
 639 | (.T) : Tensor [m, n] dtype -> Tensor [n, m] dtype
 640 | (MkTensor x).T = MkTensor $ V 0 $ Transpose [1, 0] x
 641 |
 642 | ||| Transpose axes of a tensor. This is a more general version of `(.T)`, in which you can
 643 | ||| transpose any number of axes in a tensor of arbitrary rank. The i'th axis in the resulting
 644 | ||| tensor corresponds to the `index i ordering`'th axis in the input tensor. For example, for
 645 | ||| ```
 646 | ||| x : Tensor [2, 3, 4] S32
 647 | ||| x = tensor [[[ 0,  1,  2,  3],
 648 | |||              [ 4,  5,  6,  7],
 649 | |||              [ 8,  9, 10, 11]],
 650 | |||             [[12, 13, 14, 15],
 651 | |||              [16, 17, 18, 19],
 652 | |||              [20, 21, 22, 23]]]
 653 | ||| ```
 654 | ||| `transpose [0, 2, 1] x` is
 655 | ||| ```
 656 | ||| x : Tensor [2, 4, 3] S32
 657 | ||| x = tensor [[[ 0,  4,  8],
 658 | |||              [ 1,  5,  9],
 659 | |||              [ 2,  6, 10],
 660 | |||              [ 3,  7, 11]],
 661 | |||             [[12, 16, 20],
 662 | |||              [13, 17, 21],
 663 | |||              [14, 18, 22],
 664 | |||              [15, 19, 23]]]
 665 | ||| ```
 666 | ||| `transpose [2, 0, 1] x` is
 667 | ||| ```
 668 | ||| x : Tensor [4, 2, 3] S32
 669 | ||| x = tensor [[[ 0,  4,  8],
 670 | |||              [12, 16, 20]],
 671 | |||             [[ 1,  5,  9],
 672 | |||              [13, 17, 21]],
 673 | |||             [[ 2,  6, 10],
 674 | |||              [14, 18, 22]],
 675 | |||             [[ 3,  7, 11],
 676 | |||              [15, 19, 23]]]
 677 | ||| ```
 678 | |||
 679 | ||| In order to see what effect transposing a tensor has, it can help to bear in mind the following:
 680 | ||| * if an element can be found with `slice [at 3, at 4, at 5] x` in the original tensor,
 681 | |||   that same element can instead be found with `slice [at 5, at 3, at 4]` given a
 682 | |||   `transpose [2, 0, 1]`. That is, transposing axes re-orders indices when indexing.
 683 | ||| * with `transpose [2, 0, 1]`, traversing the first axis in the result is equivalent to
 684 | |||   traversing the last axis in the input. Similarly, traversing the last axis in the result is
 685 | |||   equivalent to traversing the second axis in the input.
 686 | export
 687 | transpose :
 688 |   (ordering : List Nat) ->
 689 |   Tensor shape dtype ->
 690 |   {auto 0 lengths : length ordering = length shape} ->
 691 |   {auto 0 axesUnique : unique ordering = True} ->
 692 |   {auto 0 inBounds : All (flip InBounds shape) ordering} ->
 693 |   Tensor (multiIndex ordering shape) dtype
 694 | transpose ordering $ MkTensor x = MkTensor $ V 0 $ Transpose ordering x
 695 |
 696 | ||| A `DimBroadcastable from to` proves that a dimension of size `from` can be broadcast to a
 697 | ||| dimension of size `to`.
 698 | public export
 699 | data DimBroadcastable : (0 from : Nat) -> (0 to : Nat) -> Type where
 700 |   ||| Proof that any dimension can be broadcast to itself. For example in shapes `[2, 3]` to
 701 |   ||| `[2, 3]`.
 702 |   Same : DimBroadcastable x x
 703 |
 704 |   ||| Proof that a dimension of length one can be broadcast to any size. For example in shapes
 705 |   ||| `[2, 1]` to `[2, 3]`
 706 |   Stack : DimBroadcastable 1 _
 707 |
 708 |   ||| Proof that any dimension can be broadcast to zero. For example in shapes `[2, 3]` to `[2, 0]`.
 709 |   Zero : DimBroadcastable _ 0
 710 |
 711 | namespace Broadcastable
 712 |   ||| A `Broadcastable from to` constitutes proof that the shape `from` can be broadcast to the
 713 |   ||| shape `to`.
 714 |   public export
 715 |   data Broadcastable : (0 from : Shape) -> (0 to : Shape) -> Type where
 716 |     ||| Proof that a shape can be broadcast to itself. For example:
 717 |     |||
 718 |     ||| [] to []
 719 |     ||| [3, 4] to [3, 4]
 720 |     |||
 721 |     ||| Implementation note: we could have used `Broadcast [] []`, which would have resulted in more
 722 |     ||| atomic constructors for `Broadcastable`, but the author guesses that this implementation helps
 723 |     ||| the type checker avoid applications of `Match`.
 724 |     Same : Broadcastable x x
 725 |
 726 |     ||| Proof that a dimension of size `f` can be broadcast to size `t` if these dimensions
 727 |     ||| are `DimBroadcastable f t`. For example:
 728 |     |||
 729 |     ||| [2, 3] to [2, 3]
 730 |     ||| [2, 1] to [2, 3]
 731 |     ||| [2, 1] to [2, 0]
 732 |     Match : forall from, to .
 733 |             {auto 0 ranksEq : length from = length to} ->
 734 |             {auto 0 dimBroadcastable : DimBroadcastable f t} ->
 735 |             Broadcastable from to ->
 736 |             Broadcastable (f :: from) (t :: to)
 737 |
 738 |     ||| Proof that broadcasting can add outer dimensions i.e. nesting. For example:
 739 |     |||
 740 |     ||| [3] to [1, 3]
 741 |     ||| [3] to [5, 3]
 742 |     Nest : Broadcastable f t -> Broadcastable f (_ :: t)
 743 |
 744 | ||| A shape can be extended with any number of leading dimensions.
 745 | |||
 746 | ||| @leading The leading dimensions.
 747 | export
 748 | broadcastableByLeading : (leading : List Nat) -> Broadcastable shape (leading ++ shape)
 749 | broadcastableByLeading [] = Same
 750 | broadcastableByLeading (l :: ls) = Nest (broadcastableByLeading ls)
 751 |
 752 | ||| A scalar can be broadcast to any shape.
 753 | %hint
 754 | export
 755 | scalarToAnyOk : (to : Shape) -> Broadcastable [] to
 756 | scalarToAnyOk to = rewrite sym $ appendNilRightNeutral to in broadcastableByLeading to
 757 |
 758 | ||| Broadcast a `Tensor` to a new compatible shape. For example,
 759 | ||| ```
 760 | ||| x : Tensor [2, 3] S32
 761 | ||| x = broadcast (tensor [4, 5, 6])
 762 | ||| ```
 763 | ||| is
 764 | ||| ```
 765 | ||| x : Tensor [2, 3] S32
 766 | ||| x = tensor [[4, 5, 6], [4, 5, 6]]
 767 | ||| ```
 768 | export
 769 | broadcast :
 770 |   Primitive dtype =>
 771 |   {to : _} ->
 772 |   {auto shapesOK : Broadcastable from to} ->
 773 |   Tensor from dtype ->
 774 |   Tensor to dtype
 775 | broadcast $ MkTensor {shape = _} x = MkTensor $ V 0 $ Broadcast {dtype} from to x
 776 |
 777 | ||| A `Tensor` where every element has the specified value. For example,
 778 | ||| ```
 779 | ||| fives : Tensor [2, 3] S32
 780 | ||| fives = fill 5
 781 | ||| ```
 782 | ||| is
 783 | ||| ```
 784 | ||| fives : Tensor [2, 3] S32
 785 | ||| fives = tensor [[5, 5, 5],
 786 | |||                 [5, 5, 5]]
 787 | ||| ```
 788 | export
 789 | fill : PrimitiveRW dtype ty => {shape : _} -> ty -> Tensor shape dtype
 790 | fill x = broadcast {shapesOK = scalarToAnyOk shape} (tensor (Scalar x))
 791 |
 792 | ||| A constant where values increment from zero along the specified `axis`. For example,
 793 | ||| ```
 794 | ||| x : Tensor [3, 5] S32
 795 | ||| x = iota 1
 796 | ||| ```
 797 | ||| is the same as
 798 | ||| ```
 799 | ||| x : Tensor [3, 5] S32
 800 | ||| x = tensor [[0, 1, 2, 3, 4],
 801 | |||             [0, 1, 2, 3, 4],
 802 | |||             [0, 1, 2, 3, 4]]
 803 | ||| ```
 804 | ||| and
 805 | ||| ```
 806 | ||| x : Tensor [3, 5] S32
 807 | ||| x = iota 0
 808 | ||| ```
 809 | ||| is the same as
 810 | ||| ```
 811 | ||| x : Tensor [3, 5] S32
 812 | ||| x = tensor [[0, 0, 0, 0, 0],
 813 | |||             [1, 1, 1, 1, 1],
 814 | |||             [2, 2, 2, 2, 2]]
 815 | ||| ```
 816 | export
 817 | iota : Primitive.Num dtype =>
 818 |        {shape : _} ->
 819 |        (axis : Nat) ->
 820 |        {auto 0 inBounds : InBounds axis shape} ->
 821 |        Tensor shape dtype
 822 | iota dimension = MkTensor $ V 0 $ Iota shape {dtype} dimension
 823 |
 824 | ||| A while-loop operating on a single tensor.
 825 | |||
 826 | ||| `while1` iteratively checks if the tensor satisfies `condition`, and if it does, updates it with
 827 | ||| `body`.
 828 | |||
 829 | ||| **Note:** For the XLA compiler, there are scoping restrictions on functions passed to `while1`.
 830 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 831 | |||
 832 | ||| @condition The guard condition for each iteration.
 833 | ||| @body The update step.
 834 | ||| @initial The initial tensor.
 835 | export covering
 836 | while1 :
 837 |   Primitive dtype =>
 838 |   (condition : Tensor shape dtype -> Tag $ Tensor [] PRED) ->
 839 |   (body : Tensor shape dtype -> Tag $ Tensor shape dtype) ->
 840 |   (initial : Tensor shape dtype) ->
 841 |   Tag $ Tensor shape dtype
 842 | while1 condition body (MkTensor i0) =
 843 |   pure $ MkTensor $ V 0 $ While !(fn1 condition) !(fn1 body) [i0]
 844 |
 845 | ||| A while-loop operating on two tensors.
 846 | |||
 847 | ||| `while2` iteratively checks if the tensors together satisfy `condition`, and if so, updates them
 848 | ||| with `body`.
 849 | |||
 850 | ||| **Note:** For the XLA compiler, there are scoping restrictions on functions passed to `while2`.
 851 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 852 | |||
 853 | ||| @condition The guard condition for each iteration.
 854 | ||| @body The update step.
 855 | ||| @initial One initial tensor.
 856 | ||| @initial' The other initial tensor.
 857 | export covering
 858 | while2 :
 859 |   Primitive a => Primitive a' =>
 860 |   (condition : Tensor s a -> Tensor s' a' -> Tag $ Tensor [] PRED) ->
 861 |   (body : Tensor s a -> Tensor s' a' -> Tag (Tensor s a, Tensor s' a')) ->
 862 |   (initial : Tensor s a) -> (initial' : Tensor s' a') ->
 863 |   Tag (Tensor s a, Tensor s' a')
 864 | while2 condition body (MkTensor i) (MkTensor i') = do
 865 |   res <- tag $ While !(fn2 condition) !(fn22 body) [i, i']
 866 |   pure (MkTensor $ V 0 res, MkTensor $ V 1 res)
 867 |
 868 | ||| Lift a unary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
 869 | ||| For example,
 870 | ||| ```
 871 | ||| recip : Tensor [] F64 -> Tag $ Tensor [] F64
 872 | ||| recip x = pure $ 1.0 / x
 873 | ||| ```
 874 | ||| can be lifted to an element-wise reciprocal function as `map recip (tensor [-2, 0.4])`,
 875 | ||| which produces `tensor [-0.5, 2.5]`.
 876 | |||
 877 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the function passed to `map`.
 878 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 879 | export
 880 | map : (Primitive a, Primitive b) =>
 881 |       (Tensor [] a -> Tag $ Tensor [] b) ->
 882 |       Tensor shape a -> Tag $ Tensor shape b
 883 | map f $ MkTensor {shape = _} x =
 884 |   pure $ MkTensor $ V 0 $ Map !(fn1 f) [x] (TensorType shape b) (range $ length shape)
 885 |
 886 | ||| Lift a binary function on scalars to an element-wise function on `Tensor`s of arbitrary shape.
 887 | ||| For example,
 888 | ||| ```
 889 | ||| addRecip : Tensor [] F64 -> Tensor [] F64 -> Tag $ Tensor [] F64
 890 | ||| addRecip x y = pure $ x + 1.0 / y
 891 | ||| ```
 892 | ||| can be lifted to an element-wise function as
 893 | ||| `map2 addRecip (tensor [3.0, -3.0]) (tensor [-2.0, 0.4])`, which produces `tensor [2.5, -0.5]`.
 894 | |||
 895 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the function passed to `map2`.
 896 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 897 | export
 898 | map2 :
 899 |   (Primitive a, Primitive b, Primitive c) =>
 900 |   (Tensor [] a -> Tensor [] b -> Tag $ Tensor [] c) ->
 901 |   Tensor shape a -> Tensor shape b -> Tag $ Tensor shape c
 902 | map2 f (MkTensor {shape = _} x) (MkTensor x') =
 903 |   pure $ MkTensor $ V 0 $ Map !(fn2 f) [x, x'] (TensorType shape c) (range $ length shape)
 904 |
 905 | ||| Reduce elements along one `axis` of a `Tensor` according to a specified `reducer` `Monoid`.
 906 | ||| For example, if `x = tensor [[0, 1, 2], [3, 4, 5]]`, then reduce @{Sum} 0 x` produces
 907 | ||| `tensor [3, 5, 7]`, and `reduce @{Sum} 1 x` produces `tensor [3, 12]`.
 908 | |||
 909 | ||| **Note:** `Semigroup` doesn't use `Tag`, which limits the functions that can be used in
 910 | ||| `reduce`. However, the most commonly used semigroups don't need `Tag`, including `Sum`,
 911 | ||| `Prod`, `Min` and `Max`, so for ergonomics, we have opted to use `Monoid` as is. We can
 912 | ||| provide an overloaded variant if requested.
 913 | |||
 914 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the monoid's semigroup.
 915 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 916 | |||
 917 | ||| @reducer How to reduce elements along the given `axis`.
 918 | ||| @axis The axis along which to reduce elements.
 919 | export
 920 | reduce :
 921 |   (reducer : Monoid (Tensor [] dtype)) =>
 922 |   Primitive dtype =>
 923 |   (axes : List Nat) ->
 924 |   {auto 0 axesUnique : Sorted LT axes} ->
 925 |   {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
 926 |   Tensor shape dtype ->
 927 |   Tag $ Tensor (deleteAt axes shape) dtype
 928 | reduce axes $ MkTensor x = do
 929 |   let semigroup : Monoid a -> Semigroup a
 930 |       semigroup _ = %search
 931 |
 932 |   g <- fn2 (pure .: (<+>) @{semigroup reducer})
 933 |   let MkTensor neutral' = neutral @{reducer}
 934 |   pure $ MkTensor $ V 0 $ Reduce g [neutral'] axes [x]
 935 |
 936 | ||| Sort the elements of a `Tensor` along a specified `dimension` according to a scalar-wise
 937 | ||| ordering. For sorting function `f`, elements are sorted such that for consecutive sorted
 938 | ||| elements `a` and `b`, either `f a b` is true, or `f a b` *and* `f b a` are false.
 939 | |||
 940 | ||| **Note:** Sorting is not stable, meaning elements that compare equal according the ordering may
 941 | ||| be sorted in a different order to the order they appear in the input.
 942 | |||
 943 | ||| **Note:** `sort` is limited to use comparison function without `Tag`. However, since the most
 944 | ||| commonly-used functions, including (>), (<), (>=), and (<=), don't use `Tag`, we have opted to
 945 | ||| omit it for ergonomics. We can trivially provide an overloaded variant if requested.
 946 | |||
 947 | ||| For example, for `x = tensor [[1, 6, 4], [3, 2, 5]]`, `sort (<) 0 x` produces
 948 | ||| `tensor [[1, 2, 4], [3, 6, 5]]`, while `sort (<) 1 x` produces
 949 | ||| `tensor [[1, 4, 6], [2, 3, 5]]`.
 950 | |||
 951 | ||| **Note:** For the XLA compiler, there are scoping restrictions on the function passed to `sort`.
 952 | ||| See the tutorial _Nuisances in the Tensor API_ for details.
 953 | export
 954 | sort :
 955 |   Primitive dtype =>
 956 |   (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
 957 |   (dimension : Nat) ->
 958 |   Tensor shape dtype ->
 959 |   {auto 0 dimInBounds : InBounds dimension shape} ->
 960 |   Tag $ Tensor shape dtype
 961 | sort comp dimension $ MkTensor x =
 962 |   pure $ MkTensor $ V 0 $ Sort !(fn2 $ pure .: comp) dimension False x
 963 |
 964 | ||| Reverse elements along the specified axes. For example, for
 965 | ||| ```
 966 | ||| x : Tensor [2, 3] S32
 967 | ||| x = tensor [[-2, -1,  0],
 968 | |||             [ 1,  2,  3]]
 969 | ||| ```
 970 | ||| `reverse [0] x` is
 971 | ||| ```
 972 | ||| x : Tensor [2, 3] S32
 973 | ||| x = tensor [[ 1,  2,  3],
 974 | |||             [-2, -1,  0]]
 975 | ||| ```
 976 | ||| `reverse [1] x` is
 977 | ||| ```
 978 | ||| x : Tensor [2, 3] S32
 979 | ||| x = tensor [[ 0, -1, -2],
 980 | |||             [ 3,  2,  1]]
 981 | ||| ```
 982 | ||| and `reverse [0, 1] x` is
 983 | ||| ```
 984 | ||| x : Tensor [2, 3] S32
 985 | ||| x = tensor [[ 3,  2,  1],
 986 | |||             [ 0, -1, -2]]
 987 | ||| ```
 988 | |||
 989 | ||| **Note:** This function requires `axes` is ordered simply so that elements are unique.
 990 | ||| The ordering itself is irrelevant to the implementation, but ensures uniqueness without using
 991 | ||| proofs of contradiction that can be difficult for Idris to construct.
 992 | export
 993 | reverse :
 994 |   (axes : List Nat) ->
 995 |   {auto 0 axesUnique : Sorted LT axes} ->
 996 |   {auto 0 axesInBounds : All (flip InBounds shape) axes} ->
 997 |   Tensor shape dtype ->
 998 |   Tensor shape dtype
 999 | reverse axes $ MkTensor x = MkTensor $ V 0 $ Reverse axes x
1000 |
1001 | ewUnary : UnaryOp -> Tensor s a -> Tensor s a'
1002 | ewUnary op $ MkTensor x = MkTensor $ V 0 $ UnaryElementwise op x
1003 |
1004 | ewBinary : BinaryOp -> Tensor s a -> Tensor s a' -> Tensor s a''
1005 | ewBinary op (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise op x x'
1006 |
1007 | ||| Element-wise equality. For example, `tensor [1, 2] /= tensor [1, 3]` is
1008 | ||| `tensor [True, False]`.
1009 | export
1010 | (==) : Primitive.Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1011 | (==) = ewBinary $ Compare Eq
1012 |
1013 | ||| Element-wise inequality. For example, `tensor [1, 2] /= tensor [1, 3]` is
1014 | ||| `tensor [False, True]`.
1015 | export
1016 | (/=) : Primitive.Eq dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1017 | (/=) = ewBinary $ Compare Ne
1018 |
1019 | ||| Element-wise less than. For example, `tensor [1, 2, 3] < tensor [2, 2, 2]` is
1020 | ||| `tensor [True, False, False]`.
1021 | export
1022 | (<) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1023 | (<) = ewBinary $ Compare Lt
1024 |
1025 | ||| Element-wise greater than. For example, `tensor [1, 2, 3] > tensor [2, 2, 2]` is
1026 | ||| `tensor [False, False, True]`.
1027 | export
1028 | (>) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1029 | (>) = ewBinary $ Compare Gt
1030 |
1031 | ||| Element-wise less than or equal. For example, `tensor [1, 2, 3] <= tensor [2, 2, 2]`
1032 | ||| is `tensor [True, True, False]`.
1033 | export
1034 | (<=) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1035 | (<=) = ewBinary $ Compare Le
1036 |
1037 | ||| Element-wise greater than or equal. For example,
1038 | ||| `tensor [1, 2, 3] >= tensor [2, 2, 2]` is `tensor [False, True, True]`.
1039 | export
1040 | (>=) : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape PRED
1041 | (>=) = ewBinary $ Compare Ge
1042 |
1043 | ||| Element-wise boolean and. For example,
1044 | ||| `tensor [True, True, False, False] && tensor [True, False, True, False]` is
1045 | ||| `tensor [True, False, False, False]`.
1046 | export
1047 | (&&) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1048 | (&&) = ewBinary And
1049 |
1050 | namespace Semigroup
1051 |   export
1052 |   [All] Semigroup (Tensor shape PRED) where
1053 |     (<+>) = (&&)
1054 |
1055 | namespace Monoid
1056 |   export
1057 |   [All] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.All where
1058 |     neutral = fill True
1059 |
1060 | ||| Element-wise boolean or. For example,
1061 | ||| `tensor [True, True, False, False] || tensor [True, False, True, False]` is
1062 | ||| `tensor [True, True, True, False]`.
1063 | export
1064 | (||) : Tensor shape PRED -> Tensor shape PRED -> Tensor shape PRED
1065 | (||) = ewBinary Or
1066 |
1067 | namespace Semigroup
1068 |   export
1069 |   [Any] Semigroup (Tensor shape PRED) where
1070 |     (<+>) = (||)
1071 |
1072 | namespace Monoid
1073 |   export
1074 |   [Any] {shape : _} -> Monoid (Tensor shape PRED) using Tensor.Semigroup.Any where
1075 |     neutral = fill False
1076 |
1077 | ||| Element-wise boolean negation. For example, `not (tensor [True, False])` is
1078 | ||| `tensor [False, True]`.
1079 | export
1080 | not : Tensor shape PRED -> Tensor shape PRED
1081 | not = ewUnary Not
1082 |
1083 | ||| Choose elements from two `Tensor`s based on a `Tensor` of predicates. For each element in the
1084 | ||| predicates, the output will use the corresponding element from `onTrue` if the element is
1085 | ||| truthy, else the element from `onFalse`. For example, for
1086 | ||| ```
1087 | ||| preds : Tensor [3] PRED
1088 | ||| preds = tensor [False, True, False]
1089 | |||
1090 | ||| onTrue : Tensor [3] S32
1091 | ||| onTrue = tensor [1, 2, 3]
1092 | |||
1093 | ||| onFalse : Tensor [3] S32
1094 | ||| onFalse = tensor [4, 5, 6]
1095 | ||| ```
1096 | ||| `select preds onTrue onFalse` is `tensor [4, 2, 6]`.
1097 | |||
1098 | ||| @onTrue The elements to choose where the predicate elements are truthy.
1099 | ||| @onFalse The elements to choose where the predicate elements are falsy.
1100 | export
1101 | select :
1102 |   Primitive dtype =>
1103 |   Tensor shape PRED ->
1104 |   (onTrue, onFalse : Tensor shape dtype) ->
1105 |   Tensor shape dtype
1106 | select (MkTensor p) (MkTensor t) (MkTensor f) = MkTensor $ V 0 $ Select p t f
1107 |
1108 | ||| Use a scalar predicate to evaluate one of two branches. If the predicate is truthy,
1109 | ||| evaluate `onTrue`, else `onFalse`. Each branch is evaluated lazily; only one will be
1110 | ||| evaluated.
1111 | |||
1112 | ||| For example, for
1113 | ||| ```
1114 | ||| f : Tensor [] F64 -> Tag $ Tensor [] F64
1115 | ||| f x = if_ (x < 1.0) (pure $ cos x) (do x <- tag x; x * x)
1116 | ||| ```
1117 | ||| `f 0.0` produces `1.0`, and `f 4.0` produces `8.0`.
1118 | |||
1119 | ||| **Note:** For the XLA compiler, there are scoping restrictions on branches used in `if_`, as
1120 | ||| StableHLO gives branches their own scope. See the tutorial _Nuisances in the Tensor API_
1121 | ||| for details.
1122 | |||
1123 | ||| @onTrue The branch to evaluate if the predicate is truthy.
1124 | ||| @onFalse The branch to evaluate if the predicate is falsy.
1125 | export
1126 | if_ :
1127 |   Primitive dtype =>
1128 |   {shape : _} ->
1129 |   Tensor [] PRED ->
1130 |   (onTrue, onFalse : Tag $ Tensor shape dtype) ->
1131 |   Tag $ Tensor shape dtype
1132 | if_ (MkTensor pred) onTrue onFalse =
1133 |   pure $ MkTensor $ V 0 $ If (TensorType shape dtype) pred !(fn0 onTrue) !(fn0 onFalse)
1134 |
1135 | ||| The identity tensor, with inferred shape and element type. For example,
1136 | ||| ```
1137 | ||| x : Tensor [2, 2] S32
1138 | ||| x = identity
1139 | ||| ```
1140 | ||| is
1141 | ||| ```
1142 | ||| x : Tensor [2, 2] S32
1143 | ||| x = tensor [[1, 0],
1144 | |||             [0, 1]]
1145 | ||| ```
1146 | export
1147 | identity : Primitive.Num dtype => {n : _} -> Tensor [n, n] dtype
1148 | identity =
1149 |   let MkTensor x = iota 0 {shape = [n, n], dtype = U64} == iota 1
1150 |    in MkTensor $ V 0 $ Convert {dtype} [n, n] x
1151 |
1152 | -- see https://www.python.org/dev/peps/pep-0465/#precedence-and-associativity
1153 | export infixl 9 @@
1154 |
1155 | namespace Vector
1156 |   ||| Vector dot product with a tensor of any rank. The vector dot product is with the first axis of
1157 |   ||| the right-hand side tensor. For example `tensor [0, 1, 2] @@ tensor [-1, -3, -1]` is
1158 |   ||| `-1`.
1159 |   export
1160 |   (@@) : Primitive.Num dtype => Tensor [S m] dtype -> Tensor [S m] dtype -> Tensor [] dtype
1161 |   (MkTensor x) @@ (MkTensor x') =
1162 |     MkTensor $ V 0 $ DotGeneral [] [] [0] [0] (TensorType [] dtype) x x'
1163 |
1164 | namespace Matrix
1165 |   ||| Matrix multiplication with a matrix or vector. Contraction is along the last axis of the first
1166 |   ||| and the first axis of the last. For example,
1167 |   ||| ```
1168 |   ||| x : Tensor [2, 3] S32
1169 |   ||| x = tensor [[-1, -2, -3],
1170 |   |||             [ 0,  1,  2]]
1171 |   |||
1172 |   ||| y : Tensor [3, 1] S32
1173 |   ||| y = tensor [[4, 0, 5]]
1174 |   |||
1175 |   ||| z : Tensor [2, 1] S32
1176 |   ||| z = x @@ y
1177 |   ||| ```
1178 |   ||| is
1179 |   ||| ```
1180 |   ||| z : Tensor [2, 1] S32
1181 |   ||| z = tensor [-19, 10]
1182 |   ||| ```
1183 |   export
1184 |   (@@) : (Primitive dtype, Primitive.Num dtype) =>
1185 |          Tensor [n, S m] dtype ->
1186 |          Tensor (S m :: tl) dtype ->
1187 |          {auto 0 vectorTail : length tl `LTE` 1} ->
1188 |          Tensor (n :: tl) dtype
1189 |   (MkTensor x) @@ (MkTensor x') =
1190 |     MkTensor $ V 0 $ DotGeneral [] [] [1] [0] (TensorType (n :: tl) dtype) x x'
1191 |
1192 | ||| The output shape of a `dotGeneral` operation.
1193 | public export
1194 | contract : (lBatch, rBatch, lContract, rContract : List Nat) ->
1195 |            (ls, rs : Shape) ->
1196 |            {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1197 |            {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1198 |            {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1199 |            {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1200 |            Shape
1201 | contract lBatch rBatch lContract rContract ls rs =
1202 |   let lResultDims = deleteAt {inBounds = lInBoundsBatch ++ lInBoundsContract}
1203 |                              (lBatch ++ lContract) ls
1204 |       rResultDims = deleteAt {inBounds = rInBoundsBatch ++ rInBoundsContract}
1205 |                              (rBatch ++ rContract) rs
1206 |    in multiIndex lBatch ls ++ lResultDims ++ rResultDims
1207 |
1208 | ||| Matrix multiplication.
1209 | |||
1210 | ||| This is a much more general version of `(@@)`, in which you can specify any number of batch
1211 | ||| and contracting axes. Matrix multiplication is done over each contracting axis.
1212 | ||| The operation is vectorized over batch axes. For each contracting axis on the left-hand
1213 | ||| operand, there is one contracting axis on the right-hand operand. These can be different axes
1214 | ||| in each operand. The same is true for each batch axis.
1215 | |||
1216 | ||| For example, we can vectorize over a typical rank-two matrix multiplication as follows: given
1217 | ||| two inputs tensors
1218 | ||| ```
1219 | ||| let x : Tensor [3, 4, 5, 6] F64
1220 | |||     y : Tensor [3, 4, 6, 7] F64
1221 | ||| ```
1222 | ||| we do
1223 | ||| ```
1224 | ||| let z : Tensor [3, 4, 5, 7] F64 = dotGeneral [0, 1] [0, 1] [3] [2] x y
1225 | ||| ```
1226 | ||| Here, we vectorized over the first two axes `[0, 1]`, and do standard matrix multiplication
1227 | ||| over the remaining axes by specifying the axes 3 and 2 respectively as contracting axes. Notice
1228 | ||| how the batch axes appear once each at the start of the output shape, and the contracting axis
1229 | ||| disappears. Remaining axes appear in order from left to right.
1230 | |||
1231 | ||| Note this API is somewhat of a quickfix to bring general matrix multiplication to the tensor
1232 | |||   API. It is not thoroughly tested. Expect it to change in the future.
1233 | export
1234 | dotGeneral :
1235 |   Primitive.Num dtype =>
1236 |   (lBatch, rBatch, lContract, rContract : List Nat) ->
1237 |   {auto 0 lUnique : unique (lBatch ++ lContract) = True} ->
1238 |   {auto 0 rUnique : unique (rBatch ++ rContract) = True} ->
1239 |   {auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
1240 |   {auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
1241 |   {auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
1242 |   {auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
1243 |   {auto 0 batchDimsEq : multiIndex lBatch ls = multiIndex rBatch rs} ->
1244 |   {auto 0 contractDimsEq : multiIndex lContract ls = multiIndex rContract rs} ->
1245 |   Tensor ls dtype ->
1246 |   Tensor rs dtype ->
1247 |   Tensor (contract lBatch rBatch lContract rContract ls rs) dtype
1248 | dotGeneral lb rb lc rc (MkTensor x) (MkTensor y) =
1249 |   let resultType = TensorType (contract lb rb lc rc ls rs) dtype
1250 |    in MkTensor $ V 0 $ DotGeneral lb rb lc rc resultType x y
1251 |
1252 | ||| Element-wise addition. For example, `tensor [1, 2] + tensor [3, 4]` is
1253 | ||| `tensor [4, 6]`.
1254 | export
1255 | (+) : Primitive.Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1256 | (+) = ewBinary Add
1257 |
1258 | namespace Semigroup
1259 |   export
1260 |   [Sum] Primitive.Num dtype => Semigroup (Tensor shape dtype) where
1261 |     (<+>) = (+)
1262 |
1263 | namespace Monoid
1264 |   export
1265 |   [Sum] {shape : _} ->
1266 |         Prelude.Num a =>
1267 |         PrimitiveRW dtype a =>
1268 |         Primitive.Num dtype =>
1269 |     Monoid (Tensor shape dtype) using Semigroup.Sum where
1270 |       neutral = fill 0
1271 |
1272 | ||| Element-wise negation. For example, `- tensor [1, -2]` is `tensor [-1, 2]`.
1273 | export
1274 | negate : Primitive.Neg dtype => Tensor shape dtype -> Tensor shape dtype
1275 | negate $ MkTensor i = MkTensor $ V 0 $ UnaryElementwise Neg i
1276 |
1277 | ||| Element-wise subtraction. For example, `tensor [3, 4] - tensor [4, 2]` is
1278 | ||| `tensor [-1, 2]`.
1279 | export
1280 | (-) : Primitive.Neg dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1281 | (-) = ewBinary Sub
1282 |
1283 | ||| Element-wise multiplication. For example, `tensor [2, 3] * tensor [4, 5]` is
1284 | ||| `tensor [8, 15]`.
1285 | export
1286 | (*) : Primitive.Num dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1287 | (*) = ewBinary Mul
1288 |
1289 | namespace Scalarwise
1290 |   ||| Multiplication by a scalar. For example, `tensor 2 * tensor [3, 5]` is
1291 |   ||| `tensor [6, 10]`.
1292 |   |||
1293 |   ||| The RHS is required to be non-scalar simply to avoid ambiguities with element-wise `(*)`.
1294 |   export
1295 |   (*) : Primitive.Num dtype => Tensor [] dtype -> Tensor (d :: ds) dtype -> Tensor (d :: ds) dtype
1296 |   l * r =
1297 |     let MkTensor {shape = _ :: _} _ = r
1298 |      in broadcast {shapesOK = scalarToAnyOk (d :: ds)} l * r
1299 |
1300 | namespace Semigroup
1301 |   export
1302 |   [Prod] Primitive.Num dtype => Semigroup (Tensor shape dtype) where
1303 |     (<+>) = (*)
1304 |
1305 | namespace Monoid
1306 |   export
1307 |   [Prod] {shape : _} ->
1308 |          Prelude.Num a =>
1309 |          PrimitiveRW dtype a =>
1310 |          Primitive.Num dtype =>
1311 |     Monoid (Tensor shape dtype) using Semigroup.Prod where
1312 |       neutral = fill 1
1313 |
1314 | ||| Element-wise floating point division. For example, `tensor [2, 3] / tensor [4, 5]` is
1315 | ||| `tensor [0.5, 0.6]`.
1316 | export
1317 | (/) : Primitive.Fractional dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1318 | (/) = ewBinary Div
1319 |
1320 | namespace Scalarwise
1321 |   ||| Floating point division by a scalar. For example, `tensor [3.4, -5.6] / tensor 2` is
1322 |   ||| `tensor [1.7, -2.8]`.
1323 |   |||
1324 |   ||| The LHS is required to be non-scalar simply to avoid ambiguities with element-wise `(/)`.
1325 |   export
1326 |   (/) : Primitive.Fractional dtype =>
1327 |         Tensor (d :: ds) dtype ->
1328 |         Tensor [] dtype ->
1329 |         Tensor (d :: ds) dtype
1330 |   l / r =
1331 |     let MkTensor {shape = _ :: _} _ = l
1332 |      in l / broadcast {shapesOK = scalarToAnyOk (d :: ds)} r
1333 |
1334 | inf = tensor 1.0 / tensor 0.0
1335 | nan = tensor 0.0 / tensor 0.0
1336 |
1337 | ||| Element-wise division of natural numbers. For example,
1338 | ||| `div (tensor [Scalar 13, Scalar 8]) [3, 4]` is `tensor [4, 2]`.
1339 | export
1340 | div : Tensor shape U64 ->
1341 |       (denom : Literal shape Nat) ->
1342 |       {auto 0 isSucc : All IsSucc denom} ->
1343 |       Tensor shape U64
1344 | div x y with (x)
1345 |   _ | (MkTensor {shape = _} _) = ewBinary Div x (tensor {dtype = U64} y)
1346 |
1347 | namespace Integral
1348 |   ||| Element-wise remainder for natural numbers. For example,
1349 |   ||| `rem (tensor [Scalar 13, Scalar 8]) [3, 4]` is `tensor [1, 0]`.
1350 |   export
1351 |   rem : Tensor shape U64 ->
1352 |         (denom : Literal shape Nat) ->
1353 |         {auto 0 isSucc : All IsSucc denom} ->
1354 |         Tensor shape U64
1355 |   rem x y with (x)
1356 |     _ | (MkTensor {shape = _} _) = ewBinary Rem x (tensor {dtype = U64} y)
1357 |
1358 | export infixr 9 ^
1359 |
1360 | ||| Each element in `base` raised to the power of the corresponding element in `exponent`.
1361 | ||| example, `tensor [2, 25, -9] ^ tensor [3, -0.5, 0.5]` is `tensor [8, 0.2, nan]`.
1362 | |||
1363 | ||| Note: The behaviour of this function is not well-defined at negative or positive infinity, or
1364 | |||   NaN.
1365 | |||
1366 | ||| Note: The first root is used.
1367 | export
1368 | (^) : Tensor shape F64 -> Tensor shape F64 -> Tensor shape F64
1369 | (^) = ewBinary Pow
1370 |
1371 | (>>) : Tensor shape U64 -> Tensor shape U64 -> Tensor shape U64
1372 | (>>) = ewBinary ShiftRightLogical
1373 |
1374 | ||| Element-wise absolute value. For example, `abs (tensor [-2, 3])` is `tensor [2, 3]`.
1375 | export
1376 | abs : Primitive.Abs dtype => Tensor shape dtype -> Tensor shape dtype
1377 | abs = ewUnary Abs
1378 |
1379 | ||| The element-wise natural exponential. For example, `exp (tensor [-1, 0, 2])` is
1380 | ||| `tensor [1 / euler, 1, pow euler 2]`.
1381 | export
1382 | exp : Tensor shape F64 -> Tensor shape F64
1383 | exp = ewUnary Exp
1384 |
1385 | ||| The element-wise floor function. For example,
1386 | ||| `floor (tensor [-1.6, -1.5, -1.4, -1.0, 1.0, 1.4, 1.5, 1.6])` is
1387 | ||| `tensor [-2.0, -2.0, -2.0, -1.0, 1.0, 1.0, 1.0, 1.0]`.
1388 | export
1389 | floor : Tensor shape F64 -> Tensor shape F64
1390 | floor = ewUnary Floor
1391 |
1392 | ||| The element-wise ceiling function. For example,
1393 | ||| `ceil (tensor [-1.6, -1.5, -1.4, -1.0, 1.0, 1.4, 1.5, 1.6])` is
1394 | ||| `tensor [-1.0, -1.0, -1.0, -1.0, 1.0, 2.0, 2.0, 2.0]`.
1395 | export
1396 | ceil : Tensor shape F64 -> Tensor shape F64
1397 | ceil = ewUnary Ceil
1398 |
1399 | ||| The element-wise natural logarithm. Negative inputs yield NaN output. For example,
1400 | ||| `log (tensor [1 / euler, 1, euler * euler])` is `tensor [-1, 0, 2]`.
1401 | export
1402 | log : Tensor shape F64 -> Tensor shape F64
1403 | log = ewUnary Log
1404 |
1405 | ||| The element-wise logistic function equivalent to `1 / 1 + exp (-x)`.
1406 | export
1407 | logistic : Tensor shape F64 -> Tensor shape F64
1408 | logistic = ewUnary Logistic
1409 |
1410 | ||| The element-wise sine.
1411 | export
1412 | sin : Tensor shape F64 -> Tensor shape F64
1413 | sin = ewUnary Sin
1414 |
1415 | ||| The element-wise cosine.
1416 | export
1417 | cos : Tensor shape F64 -> Tensor shape F64
1418 | cos = ewUnary Cos
1419 |
1420 | ||| The element-wise tangent.
1421 | export
1422 | tan : Tensor shape F64 -> Tensor shape F64
1423 | tan = ewUnary Tan
1424 |
1425 | ||| The element-wise inverse sine.
1426 | export
1427 | asin : Tensor shape F64 -> Tensor shape F64
1428 | asin = ewUnary Asin
1429 |
1430 | ||| The element-wise inverse cosine.
1431 | export
1432 | acos : Tensor shape F64 -> Tensor shape F64
1433 | acos = ewUnary Acos
1434 |
1435 | ||| The element-wise inverse tangent.
1436 | export
1437 | atan : Tensor shape F64 -> Tensor shape F64
1438 | atan = ewUnary Atan
1439 |
1440 | ||| The element-wise hyperbolic sine.
1441 | export
1442 | sinh : Tensor shape F64 -> Tensor shape F64
1443 | sinh = ewUnary Sinh
1444 |
1445 | ||| The element-wise hyperbolic cosine.
1446 | export
1447 | cosh : Tensor shape F64 -> Tensor shape F64
1448 | cosh = ewUnary Cosh
1449 |
1450 | ||| The element-wise hyperbolic tangent.
1451 | export
1452 | tanh : Tensor shape F64 -> Tensor shape F64
1453 | tanh = ewUnary Tanh
1454 |
1455 | ||| The element-wise inverse hyperbolic sine.
1456 | export
1457 | asinh : Tensor shape F64 -> Tensor shape F64
1458 | asinh = ewUnary Asinh
1459 |
1460 | ||| The element-wise inverse hyperbolic cosine.
1461 | export
1462 | acosh : Tensor shape F64 -> Tensor shape F64
1463 | acosh = ewUnary Acosh
1464 |
1465 | ||| The element-wise inverse hyperbolic tangent.
1466 | export
1467 | atanh : Tensor shape F64 -> Tensor shape F64
1468 | atanh = ewUnary Atanh
1469 |
1470 | ||| An approximation to the element-wise error function.
1471 | export
1472 | erf : Tensor shape F64 -> Tensor shape F64
1473 | erf = ewUnary Erf
1474 |
1475 | erfInv : Tensor shape F64 -> Tensor shape F64
1476 | erfInv = ewUnary ErfInv
1477 |
1478 | ||| The element-wise square. For example, `square (tensor [-2, 0, 3])`
1479 | ||| is `tensor [4, 0, 9]`.
1480 | export
1481 | square : Tensor shape F64 -> Tensor shape F64
1482 | square = ewUnary Square
1483 |
1484 | ||| The element-wise square root. The first root is used. Negative inputs yield NaN output.
1485 | ||| For example, `sqrt (tensor [0, 9])` is `tensor [0, 3]`.
1486 | export
1487 | sqrt : Tensor shape F64 -> Tensor shape F64
1488 | sqrt = ewUnary Sqrt
1489 |
1490 | ||| The element-wise minimum of the first argument compared to the second. For example,
1491 | ||| `min (tensor [-3, -1, 3]) (tensor [-1, 0, 1])` is `tensor [-3, -1, 1]`.
1492 | export
1493 | min : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1494 | min (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise Min x x'
1495 |
1496 | namespace Semigroup
1497 |   export
1498 |   [Min] {shape : _} -> Primitive.Ord dtype => Semigroup (Tensor shape dtype) where
1499 |     (<+>) = min
1500 |
1501 | namespace Monoid
1502 |   export
1503 |   [Min] {shape : _} -> Primitive.Ord dtype => Monoid (Tensor shape dtype) using Semigroup.Min where
1504 |     neutral = broadcast max
1505 |
1506 | ||| The element-wise maximum of the first argument compared to the second. For example,
1507 | ||| `max (tensor [-3, -1, 3]) (tensor [-1, 0, 1])` is `tensor [-1, 0, 3]`.
1508 | export
1509 | max : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Tensor shape dtype
1510 | max (MkTensor x) (MkTensor x') = MkTensor $ V 0 $ BinaryElementwise Max x x'
1511 |
1512 | namespace Semigroup
1513 |   export
1514 |   [Max] Primitive.Ord dtype => Semigroup (Tensor shape dtype) where
1515 |     (<+>) = max
1516 |
1517 | namespace Monoid
1518 |   export
1519 |   [Max] {shape : _} -> Primitive.Ord dtype => Monoid (Tensor shape dtype) using Semigroup.Max where
1520 |     neutral = broadcast min
1521 |
1522 | ||| The diagonal of a matrix as a vector. For example, for
1523 | ||| ```
1524 | ||| x : Tensor [3, 3] S32
1525 | ||| x = tensor [[0, 1, 2],
1526 | |||             [3, 4, 5],
1527 | |||             [6, 7, 8]]
1528 | ||| ```
1529 | ||| `diag x` is `tensor [0, 4, 8]`.
1530 | export
1531 | diag :
1532 |   Primitive.Num dtype =>
1533 |   PrimitiveRW dtype ty =>
1534 |   Prelude.Num ty =>
1535 |   Tensor [n, n] dtype ->
1536 |   Tensor [n] dtype
1537 | diag {n = 0} x@(MkTensor {shape = [0, 0]} _) = reshape x
1538 | diag {n = S n} x@(MkTensor {shape = [S n, S n]} _) = (x * identity) @@ fill 1
1539 |
1540 | argmxx :
1541 |   Primitive.Ord dtype =>
1542 |   (Tensor [] dtype -> Tensor [] dtype -> Tensor [] PRED) ->
1543 |   Tensor [] dtype ->
1544 |   Tensor [S n] dtype ->
1545 |   Tag $ Tensor [] U64
1546 | argmxx cmp (MkTensor bound) x@(MkTensor {shape = _} _) = do
1547 |   let MkTensor idxs : Tensor [S n] U64 = iota 0
1548 |       MkTensor x = x
1549 |       MkTensor zero = tensor {dtype = U64} $ Scalar 0
1550 |
1551 |       mon :
1552 |         Tensor [] dtype -> Tensor [] U64 ->
1553 |         Tensor [] dtype -> Tensor [] U64 ->
1554 |         Tag (Tensor [] dtype, Tensor [] U64)
1555 |       mon x y x' y' =
1556 |         let useNext = x == x && (cmp x' x || x' /= x')
1557 |          in pure (select useNext x' x, select useNext y' y)
1558 |
1559 |   MkTagT $ do
1560 |     addr <- reserve
1561 |
1562 |     let MkTagT res = mon
1563 |           (MkTensor $ V 0 $ BoundSet addr)
1564 |           (MkTensor $ V 1 $ BoundSet addr)
1565 |           (MkTensor $ V 2 $ BoundSet addr)
1566 |           (MkTensor $ V 3 $ BoundSet addr)
1567 |         (env, (MkTensor m, MkTensor i)) = runState (emptyFrom !get) res
1568 |         argTys = [TensorType [] dtype, TensorType [] U64]
1569 |         f = MkFn addr (argTys ++ argTys) argTys [m, i] env
1570 |
1571 |     updateCounterFrom env
1572 |     pure $ MkTensor $ V 1 $ Reduce f [bound, zero] [0] [x, idxs]
1573 |
1574 | ||| The first index of the maximum value in a vector. For example,
1575 | ||| `argmax (tensor [-1, 3, -2, -2, 3])` produces `tensor 1`. If the vector contains NaN values,
1576 | ||| `argmax` returns the index of the first NaN.
1577 | export
1578 | argmax : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64
1579 | argmax = argmxx (>) min
1580 |
1581 | ||| The first index of the minimum value in a vector. For example,
1582 | ||| `argmin (tensor [-1, 3, -2, -2, 3])` produces `tensor 2`. If the vector contains NaN values,
1583 | ||| `argmin` returns the index of the first NaN.
1584 | export
1585 | argmin : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64
1586 | argmin = argmxx (<) max
1587 |
1588 | ||| Represents the upper- or lower-triangular component of a matrix.
1589 | public export
1590 | data Triangle = Upper | Lower
1591 |
1592 | ||| Get the upper- or lower-triangular component of a matrix, always including the matrix diagonal.
1593 | ||| Remaining elements will be zero. For example, for
1594 | ||| ```
1595 | ||| x : Tensor [3, 3] S32
1596 | ||| x = tensor [[1, 2, 3],
1597 | |||             [4, 5, 6],
1598 | |||             [7, 8, 9]]
1599 | ||| ```
1600 | ||| `triangle Lower x` produces
1601 | ||| ```
1602 | ||| x : Tensor [3, 3] S32
1603 | ||| x = tensor [[1, 0, 0],
1604 | |||             [4, 5, 0],
1605 | |||             [7, 8, 9]]
1606 | ||| ```
1607 | export
1608 | triangle :
1609 |   PrimitiveRW dtype ty =>
1610 |   Prelude.Num ty =>
1611 |   Triangle ->
1612 |   Tensor [n, n] dtype ->
1613 |   Tag $ Tensor [n, n] dtype
1614 | triangle tri (MkTensor x) = do
1615 |   let range : Tensor [n * n] U64 = iota 0
1616 |   indices <- tag $ reshape {to = [n, n], sizesEqual = productSquare n} range
1617 |   let op = case tri of
1618 |         Upper => Tensor.(>)
1619 |         Lower => Tensor.(<)
1620 |   pure $ select (op indices indices.T) (fill $ fromInteger 0) (MkTensor x)
1621 |
1622 |   where
1623 |
1624 |   productSquare : (m : Nat) -> product (the Shape [m * m]) = product (the Shape [m, m])
1625 |   productSquare m =
1626 |     Calc $
1627 |       |~ (m * m) + 0
1628 |       ~~ m * m ... plusZeroRightNeutral (m * m)
1629 |       ~~ (m + 0) * m ..< cong (* m) (plusZeroRightNeutral m)
1630 |
1631 | ||| Cholesky decomposition. Computes the lower triangular matrix `L` from the symmetric, positive
1632 | ||| semi-definite matrix `X` s.t. `X = L @@ L.T`. Values will be NaN if the input matrix is not
1633 | ||| positive semi-definite. The remaining matrix components - those not in the lower triangle or
1634 | ||| diagonal - will always be zero.
1635 | export
1636 | cholesky : Tensor [S n, S n] F64 -> Tag $ Tensor [S n, S n] F64
1637 | cholesky $ MkTensor x = triangle Lower (MkTensor $ V 0 $ Cholesky x)
1638 |
1639 | export infix 9 |\, \|
1640 |
1641 | namespace Matrix
1642 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is a lower-triangular matrix.
1643 |   ||| `a` is given by the lower-triangular elements of the first argument. Values in the
1644 |   ||| upper-triangular part are ignored. If `a` is lower-triangular already,
1645 |   ||| this is written `a |\ b`.
1646 |   |||
1647 |   ||| The operator is shaped like the lower-triangular portion of a matrix to signal that it uses
1648 |   ||| this portion of its argument. This is in contrast to `(\|)`.
1649 |   export
1650 |   (|\) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1651 |   (MkTensor a) |\ (MkTensor b) = MkTensor $ V 0 $ TriangularSolve a b True
1652 |
1653 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular
1654 |   ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the
1655 |   ||| lower-triangular part are ignored. If `a` is upper-triangular already, this is written
1656 |   ||| `a \| b`.
1657 |   |||
1658 |   ||| The operator is shaped like the upper-triangular portion of a matrix to signal that it uses
1659 |   ||| this portion of its argument. This is in contrast to `(|\)`.
1660 |   export
1661 |   (\|) : Tensor [m, m] F64 -> Tensor [m, n] F64 -> Tensor [m, n] F64
1662 |   (MkTensor a) \| (MkTensor b) = MkTensor $ V 0 $ TriangularSolve a b False
1663 |
1664 | namespace Vector
1665 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is a lower-triangular matrix.
1666 |   ||| `a` is given by the lower-triangular elements of the first argument. Values in the
1667 |   ||| upper-triangular part are ignored. If `a` is lower-triangular already,
1668 |   ||| this is written `a |\ b`.
1669 |   |||
1670 |   ||| The operator is shaped like the lower-triangular portion of a matrix to signal that it uses
1671 |   ||| this portion of its argument. This is in contrast to `(\|)`.
1672 |   export
1673 |   (|\) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1674 |   a |\ b = let (MkTensor {shape = [_]} _) = b in squeeze (a |\ expand 1 b)
1675 |
1676 |   ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular
1677 |   ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the
1678 |   ||| lower-triangular part are ignored. If `a` is upper-triangular already, this is written
1679 |   ||| `a \| b`.
1680 |   |||
1681 |   ||| The operator is shaped like the upper-triangular portion of a matrix to signal that it uses
1682 |   ||| this portion of its argument. This is in contrast to `(|\)`.
1683 |   export
1684 |   (\|) : Tensor [m, m] F64 -> Tensor [m] F64 -> Tensor [m] F64
1685 |   a \| b = let (MkTensor {shape = [_]} _) = b in squeeze (a \| expand 1 b)
1686 |
1687 | ||| Sum the elements along the diagonal of the input. For example,
1688 | ||| `trace (tensor [[-1, 5], [1, 4]])` produces `3`.
1689 | export
1690 | trace : (Primitive.Num dtype, Prelude.Num a) =>
1691 |         PrimitiveRW dtype a =>
1692 |         Tensor [S n, S n] dtype ->
1693 |         Tag $ Tensor [] dtype
1694 | trace x with (x)
1695 |   _ | MkTensor {shape = [_, _]} _ = reduce @{Sum} [0, 1] $ x * identity
1696 |
1697 | ||| A `Rand a` produces a pseudo-random value of type `a` from a `Tensor [2] U64` state.
1698 | ||| The state is updated every time a new value is generated.
1699 | public export 0
1700 | Rand : Type -> Type
1701 | Rand = StateT (Tensor [2] U64) Tag
1702 |
1703 | ||| Generate independent and identically distributed (IID) uniform samples.
1704 | |||
1705 | ||| The generated samples are a deterministic function of the input key and state, but may vary
1706 | ||| between PJRT plugin and library version.
1707 | |||
1708 | ||| Example usage, multiplying two uniform samples
1709 | ||| ```
1710 | ||| x : Tag $ Tensor [3] U64
1711 | ||| x = let seed = tensor [1, 1] in evalStateT seed [| rng * rng |]
1712 | ||| ```
1713 | export
1714 | rng : {shape : _} -> Rand $ Tensor shape U64
1715 | rng = ST $ \(MkTensor state) => do
1716 |   res <- tag $ Rng state (TensorType shape U64)
1717 |   pure (MkTensor $ V 0 res, MkTensor $ V 1 res)
1718 |
1719 | ||| Generate independent and identically distributed (IID) from the uniform distribution U(0, 1).
1720 | |||
1721 | ||| The generated samples are a deterministic function of the input key and state, but may vary
1722 | ||| between PJRT plugin and library version.
1723 | |||
1724 | ||| Example usage, multiplying two uniform samples
1725 | ||| ```
1726 | ||| x : Rand $ Tensor [3] F64
1727 | ||| x = [| uniform * uniform |]
1728 | ||| ```
1729 | export
1730 | uniform : {shape : _} -> Rand $ Tensor shape F64
1731 | uniform =
1732 |   let numMantissaBits = 52
1733 |       scale = broadcast $ 2.0 ^ tensor (Scalar $ - cast {to = Double} numMantissaBits)
1734 |       shift = fill $ 64 `minus` numMantissaBits
1735 |    in rng {shape} <&> \x => castDtype (x >> shift) * scale
1736 |
1737 | ||| Generate independent and identically distributed (IID) samples from the standard normal
1738 | ||| distribution N(0, 1).
1739 | |||
1740 | ||| The generated samples are a deterministic function of the input key and state, but may vary
1741 | ||| between PJRT plugin and library version.
1742 | |||
1743 | ||| Example usage, multiplying two normal samples
1744 | ||| ```
1745 | ||| x : Rand $ Tensor [3] F64
1746 | ||| x = [| normal * normal |]
1747 | ||| ```
1748 | export
1749 | normal : {shape : _} -> Rand $ Tensor shape F64
1750 | normal = uniform <&> \x => sqrt (broadcast 2.0) * erfInv (broadcast 2.0 * x - broadcast 1.0)
1751 |