0 | module NN.Architectures.Utils
 1 |
 2 | import Data.Para
 3 | import Data.Tensor
 4 |
 5 | ||| Batching only works simply when we have a non-dependent Para
 6 | public export
 7 | paraMapFirstAxis : {c : Axis} ->
 8 |   {cs : TensorShape rank} -> {ds : TensorShape rank'} ->
 9 |   NewAxisConsistent c cs => NewAxisConsistent c ds =>
10 |   Num a =>
11 |   (pf : Tensor cs a -\-> Tensor ds a) ->
12 |   (nonDep : IsNotDependent pf) =>
13 |   Tensor (c :: cs) a -\-> Tensor (c :: ds) a
14 | paraMapFirstAxis (MkPara (const pType) f) {nonDep = MkNonDep pType f} = MkPara
15 |   (\_ => pType) (\(t ** p=> flip (curry f) p <-$> t)