0 | module Data.Layout
 1 |
 2 | import Data.Fin.Split
 3 | import Data.List.Quantifiers
 4 | import Language.Reflection
 5 | import Derive.Prelude
 6 | import Misc
 7 |
 8 | %language ElabReflection
 9 |
10 | ||| We often deal with the 'logical' representation of tensors, but for
11 | ||| performance characteristics we need to be cognisant of how these tensors
12 | ||| are stored in the physical memory, which is in 1D linear order.
13 | ||| There are two options: row-major and column-major format
14 | ||| NumPy, PyTorch, TensorFlow and JAX use row-major indexing
15 | ||| The idea is that once linearised, with:
16 | ||| - row-major the last index of the array varies fastest
17 | ||| - column-major the first index of the array varies fastest
18 | public export
19 | data LayoutOrder = RowMajor
20 |                  | ColumnMajor
21 |
22 | %runElab derive "LayoutOrder" [Eq, Show]
23 | %name LayoutOrder lo
24 |
25 |
26 | ||| Following most popular conventions, we use row-major ordering by default
27 | public export
28 | DefaultLayoutOrder : LayoutOrder
29 | DefaultLayoutOrder = RowMajor
30 |
31 | ||| Layout-aware version of splitProd from Data.Fin.Split.
32 | ||| 
33 | ||| Row-major: index k in a m×n matrix maps to (k/n, k%n)
34 | |||   - goes through all columns before moving to next row
35 | ||| Column-major: index k maps to (k%m, k/m)  
36 | |||   - goes through all rows before moving to next column
37 | |||
38 | ||| For a 2×3 matrix:
39 | |||   Row-major order:    (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
40 | |||   Column-major order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2)
41 | public export
42 | splitFinProd : {m, n : Nat} ->
43 |   LayoutOrder ->
44 |   Fin (m * n) ->
45 |   (Fin m, Fin n)
46 | splitFinProd RowMajor p = splitProd p
47 | splitFinProd ColumnMajor p = swap (splitProd {m=n} {n=m}
48 |   (replace {p = Fin} (multCommutative m n) p))
49 |
50 | ||| Layout-aware version of indexProd from Data.Fin.Split.
51 | ||| Inverse of splitFinProd: given (row, col) indices, compute linear index.
52 | |||
53 | ||| Row-major: linear index = row * n + col
54 | ||| Column-major: linear index = col * m + row
55 | |||
56 | ||| For a 2×3 matrix with (row=1, col=2):
57 | |||   Row-major:    1 * 3 + 2 = 5
58 | |||   Column-major: 2 * 2 + 1 = 5
59 | public export
60 | indexFinProd : {m, n : Nat} ->
61 |   LayoutOrder ->
62 |   Fin m ->
63 |   Fin n ->
64 |   Fin (m * n)
65 | indexFinProd RowMajor row col = indexProd row col
66 | indexFinProd ColumnMajor row col = 
67 |   replace {p = Fin} (sym $ multCommutative m n) (indexProd {m=n} {n=m} col row)
68 |
69 | ||| Like `splitFinProd`, but here the order is fixed for us by dependency
70 | public export
71 | splitFinProdDep : {n : Nat} -> (content : Fin n -> Nat) ->
72 |   Fin (sum content) -> (i : Fin n ** Fin (content i))
73 | splitFinProdDep {n = 0} content x = ?shouldBeImpossibleToReach
74 | splitFinProdDep {n = (S k)} content x = case splitSum x of
75 |   Left y => (FZ ** y)
76 |   Right y => let (i ** j= splitFinProdDep (content . FS) y in (FS i ** j)