0 | module Data.NumIdr.Transform.Transform
  1 |
  2 | import Data.Vect
  3 | import Data.NumIdr.Interfaces
  4 | import Data.NumIdr.Array
  5 | import Data.NumIdr.Vector
  6 | import Data.NumIdr.Matrix
  7 | import Data.NumIdr.Homogeneous
  8 | import Data.NumIdr.Transform.Point
  9 |
 10 | %default total
 11 |
 12 |
 13 | --------------------------------------------------------------------------------
 14 | -- Transformation Types
 15 | --------------------------------------------------------------------------------
 16 |
 17 |
 18 | ||| A transform type encodes the properties of a transform. There are 8 transform
 19 | ||| types, and together with the coersion relation `(:<)` they form a semilattice.
 20 | export
 21 | TransType : Type
 22 | TransType = (Fin 4, Bool)
 23 |
 24 | namespace TransType
 25 |   public export
 26 |   TAffine, TIsometry, TRigid, TTranslation,
 27 |     TLinear, TOrthonormal, TRotation, TTrivial : TransType
 28 |   TAffine = (3, True)
 29 |   TIsometry = (2, True)
 30 |   TRigid = (1, True)
 31 |   TTranslation = (0, True)
 32 |   TLinear = (3, False)
 33 |   TOrthonormal = (2, False)
 34 |   TRotation = (1, False)
 35 |   TTrivial = (0, False)
 36 |
 37 |
 38 | --------------------------------------------------------------------------------
 39 | -- Transformation type operations
 40 | --------------------------------------------------------------------------------
 41 |
 42 |
 43 | ||| Coersion relation for transform types.
 44 | ||| `a :< b` is `True` if and only if any transform of type `a` can be cast into
 45 | ||| a transform of type `b`.
 46 | public export
 47 | (:<) : TransType -> TransType -> Bool
 48 | (xn, xb) :< (yn, yb) = (xn <= yn) && (xb <= yb)
 49 |
 50 | ||| Return the type of transform resulting from multiplying transforms of
 51 | ||| the two input types.
 52 | public export
 53 | transMult : TransType -> TransType -> TransType
 54 | transMult (xn, xb) (yn, yb) = (max xn yn, xb || yb)
 55 |
 56 | ||| Return the linearized transform type, i.e. the transform type resulting
 57 | ||| from removing the translation component of the original transform.
 58 | public export
 59 | linearizeType : TransType -> TransType
 60 | linearizeType = mapSnd (const False)
 61 |
 62 | ||| Return the delinearized transform type, i.e. the transform type resulting
 63 | ||| from adding a translation to the original transform.
 64 | public export
 65 | delinearizeType : TransType -> TransType
 66 | delinearizeType = mapSnd (const True)
 67 |
 68 |
 69 | ||| A transform is a wrapper for a homogeneous matrix subject to certain
 70 | ||| restrictions, such as a rotation, an isometry, or a rigid transform.
 71 | ||| The restriction on the transform is encoded by the transform's *type*.
 72 | |||
 73 | ||| Transforms have special behavior over matrices when it comes to multiplication.
 74 | ||| When a transform is applied to a vector, only the linear part of the transform
 75 | ||| is applied, as if `linearize` were called on the transform before the operation.
 76 | |||
 77 | ||| In order for non-linear transformations to be used, the transform should be
 78 | ||| applied to the special wrapper type `Point`. This separates the concepts of
 79 | ||| point and vector, which is often useful when working with affine maps.
 80 | export
 81 | data Transform : TransType -> Nat -> Type -> Type where
 82 |   MkTrans : (ty : TransType) -> HMatrix' n a -> Transform ty n a
 83 |
 84 | %name Transform t
 85 |
 86 |
 87 | export
 88 | unsafeMkTrans : {ty : _} -> HMatrix' n a -> Transform ty n a
 89 | unsafeMkTrans = MkTrans _
 90 |
 91 |
 92 | ||| Unwrap the inner homogeneous matrix from a transform.
 93 | export
 94 | getHMatrix : Transform ty n a -> HMatrix' n a
 95 | getHMatrix (MkTrans _ mat) = mat
 96 |
 97 | ||| Unwrap the inner matrix from a transform, ignoring the translation component
 98 | ||| if one exists.
 99 | export
100 | getMatrix : Transform ty n a -> Matrix' n a
101 | getMatrix (MkTrans _ mat) = getMatrix mat
102 |
103 | export
104 | convertRep : (rep : Rep) -> RepConstraint rep a => Transform ty n a -> Transform ty n a
105 | convertRep rep (MkTrans _ mat) = MkTrans _ (convertRep rep mat)
106 |
107 | ||| Linearize a transform by removing its translation component.
108 | ||| If the transform is already linear, then this function does nothing.
109 | export
110 | linearize : Num a => Transform ty n a -> Transform (linearizeType ty) n a
111 | linearize {n} (MkTrans _ mat) with (viewShape mat)
112 |   _ | Shape [S n,S n] = MkTrans _ (hmatrix (getMatrix mat) (zeros _))
113 |
114 | ||| Set the translation component of a transform.
115 | |||
116 | ||| `setTranslation v tr == translate v *. linearize tr`
117 | |||
118 | ||| If `tr` is linear:
119 | ||| `setTranslation v tr == translate v *. tr`
120 | export
121 | setTranslation : Num a => Vector n a -> Transform ty n a
122 |                   -> Transform (delinearizeType ty) n a
123 | setTranslation v (MkTrans _ mat) = MkTrans _ (hmatrix (getMatrix mat) v)
124 |
125 |
126 | namespace Vector
127 |   export
128 |   applyInv : FieldCmp a => Transform ty n a -> Vector n a -> Vector n a
129 |   applyInv tr v = assert_total $ case solve (getMatrix tr) v of Just v' => v'
130 |
131 | namespace Point
132 |   export
133 |   applyInv : FieldCmp a => Transform ty n a -> Point n a -> Point n a
134 |   applyInv (MkTrans _ mat) p = assert_total $
135 |     case solve (getMatrix mat) (zipWith (-) (toVector p) (getTranslationVector mat)) of
136 |       Just v => fromVector v
137 |
138 |
139 | --------------------------------------------------------------------------------
140 | -- Interface implementations
141 | --------------------------------------------------------------------------------
142 |
143 |
144 | mulPoint : Num a => HMatrix' n a -> Point n a -> Point n a
145 | mulPoint mat p = fromVector $ zipWith (+) (getMatrix mat *. toVector p)
146 |                                           (getTranslationVector mat)
147 |
148 | mulVector : Num a => HMatrix' n a -> Vector n a -> Vector n a
149 | mulVector mat v = getMatrix mat *. v
150 |
151 | export
152 | Num a => Mult (Transform ty n a) (Point n a) (Point n a) where
153 |   MkTrans _ mat *. p = mulPoint mat p
154 |
155 | export
156 | Num a => Mult (Transform ty n a) (Vector n a) (Vector n a) where
157 |   MkTrans _ mat *. v = mulVector mat v
158 |
159 |
160 | export
161 | Num a => Mult (Transform t1 n a) (Transform t2 n a) (Transform (transMult t1 t2) n a) where
162 |   MkTrans _ m1 *. MkTrans _ m2 = MkTrans _ (m1 *. m2)
163 |
164 | export
165 | [TransformMult'] Num a => Mult (Transform ty n a) (Transform ty n a) (Transform ty n a) where
166 |   MkTrans _ m1 *. MkTrans _ m2 = MkTrans _ (m1 *. m2)
167 |
168 |
169 | export
170 | {n,ty : _} -> Num a => MultMonoid (Transform ty n a) using TransformMult' where
171 |   identity = MkTrans ty identity
172 |
173 | export
174 | {n,ty : _} -> FieldCmp a => MultGroup (Transform ty n a) where
175 |   inverse (MkTrans _ mat) = MkTrans _ (inverse mat)
176 |
177 |
178 | export
179 | {t2 : _} -> So (t1 :< t2) => Cast a b => Cast (Transform t1 n a) (Transform t2 n b) where
180 |   cast (MkTrans t1 mat) = MkTrans t2 (cast mat)
181 |