0 | module Data.Layout
 1 |
 2 | import Data.Fin.Split
 3 | import Data.List.Quantifiers
 4 | import Language.Reflection
 5 | import Derive.Prelude
 6 |
 7 | %language ElabReflection
 8 |
 9 | ||| We often deal with the 'logical' representation of tensors, but for
10 | ||| performance characteristics we need to be cognisant of how these tensors
11 | ||| are stored in the physical memory, which is in 1D linear order.
12 | ||| There are two options: row-major and column-major format
13 | ||| NumPy, PyTorch, TensorFlow and JAX use row-major indexing
14 | ||| The idea is that once linearised, with:
15 | ||| - row-major the last index of the array varies fastest
16 | ||| - column-major the first index of the array varies fastest
17 | public export
18 | data LayoutOrder = RowMajor
19 |                  | ColumnMajor
20 |
21 | %runElab derive "LayoutOrder" [Eq, Show]
22 | %name LayoutOrder lo
23 |
24 |
25 | ||| Following most popular conventions, we use row-major ordering by default
26 | public export
27 | DefaultLayoutOrder : LayoutOrder
28 | DefaultLayoutOrder = RowMajor
29 |
30 | ||| Layout-aware version of splitProd from Data.Fin.Split.
31 | ||| 
32 | ||| Row-major: index k in a m×n matrix maps to (k/n, k%n)
33 | |||   - goes through all columns before moving to next row
34 | ||| Column-major: index k maps to (k%m, k/m)  
35 | |||   - goes through all rows before moving to next column
36 | |||
37 | ||| For a 2×3 matrix:
38 | |||   Row-major order:    (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
39 | |||   Column-major order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2)
40 | public export
41 | splitFinProd : {m, n : Nat} ->
42 |   LayoutOrder ->
43 |   Fin (m * n) ->
44 |   (Fin m, Fin n)
45 | splitFinProd RowMajor p = splitProd p
46 | splitFinProd ColumnMajor p = swap (splitProd {m=n} {n=m}
47 |   (replace {p = Fin} (multCommutative m n) p))
48 |
49 | ||| Layout-aware version of indexProd from Data.Fin.Split.
50 | ||| Inverse of splitFinProd: given (row, col) indices, compute linear index.
51 | |||
52 | ||| Row-major: linear index = row * n + col
53 | ||| Column-major: linear index = col * m + row
54 | |||
55 | ||| For a 2×3 matrix with (row=1, col=2):
56 | |||   Row-major:    1 * 3 + 2 = 5
57 | |||   Column-major: 2 * 2 + 1 = 5
58 | public export
59 | indexFinProd : {m, n : Nat} ->
60 |   LayoutOrder ->
61 |   Fin m ->
62 |   Fin n ->
63 |   Fin (m * n)
64 | indexFinProd RowMajor row col = indexProd row col
65 | indexFinProd ColumnMajor row col = 
66 |   replace {p = Fin} (sym $ multCommutative m n) (indexProd {m=n} {n=m} col row)