0 | module Derive.Eq
  1 |
  2 | import Language.Reflection.Util
  3 |
  4 | %default total
  5 |
  6 | --------------------------------------------------------------------------------
  7 | --          Constructor Index
  8 | --------------------------------------------------------------------------------
  9 |
 10 | export
 11 | conIndexName : Named a => a -> Name
 12 | conIndexName v = funName v "conIndex"
 13 |
 14 | ||| Type used to represent the index of a data constructor.
 15 | export
 16 | conIndexTypes : Nat -> (Bits32 -> TTImp, TTImp)
 17 | conIndexTypes n =
 18 |   let f := primVal . PrT
 19 |    in if      n < 256     then (primVal . B8 . cast, f Bits8Type)
 20 |       else if n < 0x20000 then (primVal . B16 . cast, f Bits16Type)
 21 |       else                     (primVal . B32, f Bits32Type)
 22 |
 23 | ||| Clauses returning the index for each constructor in the given
 24 | ||| list.
 25 | export
 26 | conIndexClauses : Named a => Name -> List a -> List Clause
 27 | conIndexClauses n ns = go 0 (fst $ conIndexTypes $ length ns) ns
 28 |
 29 |   where
 30 |     go : Bits32 -> (Bits32 -> TTImp) -> List a -> List Clause
 31 |     go _  _ []        = []
 32 |     go ix f (c :: cs) =
 33 |       patClause(var n `app` bindAny c) (f ix) :: go (ix + 1) f cs
 34 |
 35 | ||| Declaration of a function returning the constructor index
 36 | ||| for a value of the given data type.
 37 | export
 38 | conIndexClaim : Visibility -> (fun : Name) -> (t : TypeInfo) -> Decl
 39 | conIndexClaim vis fun t =
 40 |   let tpe := snd (conIndexTypes $ length t.cons)
 41 |       arg := t.applied
 42 |    in simpleClaim vis fun $ piAll `(~(arg) -> ~(tpe)) (t.implicits)
 43 |
 44 | ||| Definition of a function returning the constructor index
 45 | ||| for a value of the given data type.
 46 | export
 47 | conIndexDef : (fun : Name) -> TypeInfo -> Decl
 48 | conIndexDef fun t = def fun $ conIndexClauses fun t.cons
 49 |
 50 | ||| For the given data type, creates a function for returning
 51 | ||| a 0-based index for each constructor.
 52 | |||
 53 | ||| For instance, for `Either a b = Left a | Right b` this creates
 54 | ||| declarations as follows:
 55 | |||
 56 | ||| ```idris
 57 | ||| conIndexEither : Either a b -> Bits8
 58 | ||| conIndexEither (Left {})  = 0
 59 | ||| conIndexEither (Right {}) = 1
 60 | ||| ```
 61 | |||
 62 | ||| This function is useful in several situations: When deriving
 63 | ||| `Ord` for a sum type with more than one data constructors, we
 64 | ||| can use the constructor index to compare values created from
 65 | ||| distinct constructors. This allows us to only use a linear number
 66 | ||| of pattern matches to implement the ordering.
 67 | |||
 68 | ||| For enum types (all data constructors have only erased arguments - if any),
 69 | ||| there are even greater benefits: `conIndex` is
 70 | ||| the identity function at runtime, being completely eliminated during
 71 | ||| code generations. This allows us to get `Eq` and `Ord` implementations for
 72 | ||| enum types, which run in O(1)!
 73 | export
 74 | ConIndexVis : Visibility -> List Name -> ParamTypeInfo -> Res (List TopLevel)
 75 | ConIndexVis vis _ t =
 76 |   let fun := conIndexName t
 77 |    in Right [ TL (conIndexClaim vis fun t.info) (conIndexDef fun t.info) ]
 78 |
 79 | ||| Alias for `ConIndexVis Public`
 80 | export %inline
 81 | ConIndex : List Name -> ParamTypeInfo -> Res (List TopLevel)
 82 | ConIndex = ConIndexVis Public
 83 |
 84 | --------------------------------------------------------------------------------
 85 | --          Claims
 86 | --------------------------------------------------------------------------------
 87 |
 88 | ||| Top-level declaration implementing the equality test for
 89 | ||| the given data type.
 90 | export
 91 | eqClaim : Visibility -> (fun : Name) -> (p : ParamTypeInfo) -> Decl
 92 | eqClaim vis fun p =
 93 |   let arg := p.applied
 94 |       tpe := piAll `(~(arg) -> ~(arg) -> Bool) (allImplicits p "Eq")
 95 |    in simpleClaim vis fun tpe
 96 |
 97 | ||| Top-level declaration implementing the `Eq` interface for
 98 | ||| the given data type.
 99 | export
100 | eqImplClaim : Visibility -> (impl : Name) -> (p : ParamTypeInfo) -> Decl
101 | eqImplClaim vis impl p = implClaimVis vis impl (implType "Eq" p)
102 |
103 | --------------------------------------------------------------------------------
104 | --          Definitions
105 | --------------------------------------------------------------------------------
106 |
107 | eqImplDef : (fun, impl : Name) -> Decl
108 | eqImplDef fun impl = def impl [patClause (var impl) (var "mkEq" `app` var fun)]
109 |
110 | eqEnumDef : (impl, ci : Name) -> Decl
111 | eqEnumDef i c =
112 |   def i [patClause (var i) `(mkEq $ \x,y => ~(var c) x == ~(var c) y)]
113 |
114 | -- catch-all pattern clause for data types with more than
115 | -- one data constructor
116 | catchAll : (fun : Name) -> TypeInfo -> List Clause
117 | catchAll fun ti =
118 |   if length ti.cons > 1
119 |      then [patClause `(~(var fun) _ _) `(False)]
120 |      else []
121 |
122 | -- accumulate right-hand side of a single pattern clause
123 | rhs : SnocList TTImp -> TTImp
124 | rhs Lin       = `(True)
125 | rhs (sx :< x) = foldr (\e,acc => `(~(e) && ~(acc))) x sx
126 |
127 | parameters (nms : List Name)
128 |   arg : BoundArg 2 Regular -> TTImp
129 |   arg (BA g [x,y] _) = assertIfRec nms g.type `(~(var x) == ~(var y))
130 |
131 |   ||| Generates pattern match clauses for the constructors of
132 |   ||| the given data type. `fun` is the name of the function we implement.
133 |   ||| This is either a local function definition in case of a
134 |   ||| custom derivation, or the name of a top-level function.
135 |   export
136 |   eqClauses : (fun : Name) -> TypeInfo -> List Clause
137 |   eqClauses fun ti = map clause ti.cons ++ catchAll fun ti
138 |
139 |    where
140 |      clause : Con ti.arty ti.args -> Clause
141 |      clause = accumArgs2 regular (\x,y => `(~(var fun) ~(x) ~(y))) rhs arg
142 |
143 |   ||| Definition of a (local or top-level) function implementing
144 |   ||| the equality check for the given data type.
145 |   export
146 |   eqDef : Name -> TypeInfo -> Decl
147 |   eqDef fun ti = def fun (eqClauses fun ti)
148 |
149 | --------------------------------------------------------------------------------
150 | --          Deriving
151 | --------------------------------------------------------------------------------
152 |
153 | ||| Derive an implementation of `Eq a` for a custom data type `a`.
154 | |||
155 | ||| Note: This is mainly to be used for indexed data types. Consider using
156 | |||       `derive` together with `Derive.Eq.Eq` for parameterized data types.
157 | export %macro
158 | deriveEq : Elab (Eq f)
159 | deriveEq = do
160 |   Just tpe <- goal
161 |     | Nothing => fail "Can't infer goal"
162 |   let Just (resTpe, nm) := extractResult tpe
163 |     | Nothing => fail "Invalid goal type: \{show tpe}"
164 |   ti <- getInfo' nm
165 |
166 |   let impl :=
167 |            lam (lambdaArg {a = Name} "x") $
168 |            lam (lambdaArg {a = Name} "y") $
169 |            iCase `(MkPair x y) implicitFalse (eqClauses [ti.name] "MkPair" ti)
170 |
171 |   logMsg "derive.definitions" 1 $ show impl
172 |   check $ var "mkEq" `app` impl
173 |
174 | ||| Generate declarations and implementations for `Eq` for a given data type.
175 | export
176 | EqVis : Visibility -> List Name -> ParamTypeInfo -> Res (List TopLevel)
177 | EqVis vis nms p = case isEnum p.info of
178 |   True  =>
179 |     let impl := implName p "Eq"
180 |         ci   := conIndexName p
181 |      in sequenceJoin
182 |           [ ConIndexVis vis nms p
183 |           , Right [ TL (eqImplClaim vis impl p) (eqEnumDef impl ci) ]
184 |           ]
185 |   False =>
186 |     let fun  := funName p "eq"
187 |         impl := implName p "Eq"
188 |      in Right
189 |           [ TL (eqClaim vis fun p) (eqDef nms fun p.info)
190 |           , TL (eqImplClaim vis impl p) (eqImplDef fun impl)
191 |           ]
192 |
193 | ||| Alias for `EqVis Public`
194 | export %inline
195 | Eq : List Name -> ParamTypeInfo -> Res (List TopLevel)
196 | Eq = EqVis Public
197 |