0 | module Data.NumIdr.Transform.Transform
3 | import Data.NumIdr.Interfaces
4 | import Data.NumIdr.Array
5 | import Data.NumIdr.Vector
6 | import Data.NumIdr.Matrix
7 | import Data.NumIdr.Homogeneous
8 | import Data.NumIdr.Transform.Point
22 | TransType = (Fin 4, Bool)
26 | TAffine, TIsometry, TRigid, TTranslation,
27 | TLinear, TOrthonormal, TRotation, TTrivial : TransType
29 | TIsometry = (2, True)
31 | TTranslation = (0, True)
32 | TLinear = (3, False)
33 | TOrthonormal = (2, False)
34 | TRotation = (1, False)
35 | TTrivial = (0, False)
47 | (:<) : TransType -> TransType -> Bool
48 | (xn, xb) :< (yn, yb) = (xn <= yn) && (xb <= yb)
53 | transMult : TransType -> TransType -> TransType
54 | transMult (xn, xb) (yn, yb) = (max xn yn, xb || yb)
59 | linearizeType : TransType -> TransType
60 | linearizeType = mapSnd (const False)
65 | delinearizeType : TransType -> TransType
66 | delinearizeType = mapSnd (const True)
81 | data Transform : TransType -> Nat -> Type -> Type where
82 | MkTrans : (ty : TransType) -> HMatrix' n a -> Transform ty n a
88 | unsafeMkTrans : {ty : _} -> HMatrix' n a -> Transform ty n a
89 | unsafeMkTrans = MkTrans _
94 | getHMatrix : Transform ty n a -> HMatrix' n a
95 | getHMatrix (MkTrans _ mat) = mat
100 | getMatrix : Transform ty n a -> Matrix' n a
101 | getMatrix (MkTrans _ mat) = getMatrix mat
104 | convertRep : (rep : Rep) -> RepConstraint rep a => Transform ty n a -> Transform ty n a
105 | convertRep rep (MkTrans _ mat) = MkTrans _ (convertRep rep mat)
110 | linearize : Num a => Transform ty n a -> Transform (linearizeType ty) n a
111 | linearize {n} (MkTrans _ mat) with (viewShape mat)
112 | _ | Shape [S n,S n] = MkTrans _ (hmatrix (getMatrix mat) (zeros _))
121 | setTranslation : Num a => Vector n a -> Transform ty n a
122 | -> Transform (delinearizeType ty) n a
123 | setTranslation v (MkTrans _ mat) = MkTrans _ (hmatrix (getMatrix mat) v)
128 | applyInv : FieldCmp a => Transform ty n a -> Vector n a -> Vector n a
129 | applyInv tr v = assert_total $
case solve (getMatrix tr) v of Just v' => v'
133 | applyInv : FieldCmp a => Transform ty n a -> Point n a -> Point n a
134 | applyInv (MkTrans _ mat) p = assert_total $
135 | case solve (getMatrix mat) (zipWith (-) (toVector p) (getTranslationVector mat)) of
136 | Just v => fromVector v
144 | mulPoint : Num a => HMatrix' n a -> Point n a -> Point n a
145 | mulPoint mat p = fromVector $
zipWith (+) (getMatrix mat *. toVector p)
146 | (getTranslationVector mat)
148 | mulVector : Num a => HMatrix' n a -> Vector n a -> Vector n a
149 | mulVector mat v = getMatrix mat *. v
152 | Num a => Mult (Transform ty n a) (Point n a) (Point n a) where
153 | MkTrans _ mat *. p = mulPoint mat p
156 | Num a => Mult (Transform ty n a) (Vector n a) (Vector n a) where
157 | MkTrans _ mat *. v = mulVector mat v
161 | Num a => Mult (Transform t1 n a) (Transform t2 n a) (Transform (transMult t1 t2) n a) where
162 | MkTrans _ m1 *. MkTrans _ m2 = MkTrans _ (m1 *. m2)
165 | [TransformMult'] Num a => Mult (Transform ty n a) (Transform ty n a) (Transform ty n a) where
166 | MkTrans _ m1 *. MkTrans _ m2 = MkTrans _ (m1 *. m2)
170 | {n,ty : _} -> Num a => MultMonoid (Transform ty n a) using TransformMult' where
171 | identity = MkTrans ty identity
174 | {n,ty : _} -> FieldCmp a => MultGroup (Transform ty n a) where
175 | inverse (MkTrans _ mat) = MkTrans _ (inverse mat)
179 | {t2 : _} -> So (t1 :< t2) => Cast a b => Cast (Transform t1 n a) (Transform t2 n b) where
180 | cast (MkTrans t1 mat) = MkTrans t2 (cast mat)