0 | module Derive.HDecEq
 1 |
 2 | import public Decidable.HDecEq
 3 | import Language.Reflection.Util
 4 |
 5 | %default total
 6 |
 7 | --------------------------------------------------------------------------------
 8 | --          Claims
 9 | --------------------------------------------------------------------------------
10 |
11 | ||| Top-level declaration implementing the equality test for
12 | ||| the given data type.
13 | export
14 | hdeceqClaim : Visibility -> (fun : Name) -> (p : ParamTypeInfo) -> Decl
15 | hdeceqClaim vis fun p =
16 |   let arg := var p.info.name
17 |       tpe := `((x : ~(arg)) -> (y : ~(arg)) -> Maybe0 (x === y))
18 |    in simpleClaim vis fun tpe
19 |
20 | ||| Top-level declaration implementing the `Eq` interface for
21 | ||| the given data type.
22 | export
23 | hdeceqImplClaim : Visibility -> (impl : Name) -> (p : ParamTypeInfo) -> Decl
24 | hdeceqImplClaim vis impl p = implClaimVis vis impl (implType "HDecEq" p)
25 |
26 | --------------------------------------------------------------------------------
27 | --          Definitions
28 | --------------------------------------------------------------------------------
29 |
30 | hdeceqImplDef : (fun, impl : Name) -> Decl
31 | hdeceqImplDef fun impl =
32 |   def impl [patClause (var impl) (var "MkHDecEq" `app` var fun)]
33 |
34 | -- catch-all pattern clause for data types with more than
35 | -- one data constructor
36 | catchAll : (fun : Name) -> TypeInfo -> List Clause
37 | catchAll fun ti =
38 |   if length ti.cons > 1
39 |      then [patClause `(~(var fun) _ _) `(Nothing0)]
40 |      else []
41 |
42 | ||| Generates pattern match clauses for the constructors of
43 | ||| the given data type. `fun` is the name of the function we implement.
44 | ||| This is either a local function definition in case of a
45 | ||| custom derivation, or the name of a top-level function.
46 | export
47 | hdeceqClauses : (fun : Name) -> TypeInfo -> List Clause
48 | hdeceqClauses fun ti = map clause ti.cons ++ catchAll fun ti
49 |
50 |  where
51 |    clause : Con ti.arty ti.args -> Clause
52 |    clause c =
53 |      let v := var c.name
54 |       in patClause `(~(var fun) ~(v) ~(v)) `(Just0 Refl)
55 |
56 | ||| Definition of a (local or top-level) function implementing
57 | ||| the equality check for the given data type.
58 | export
59 | hdeceqDef : Name -> TypeInfo -> Decl
60 | hdeceqDef fun ti = def fun (hdeceqClauses fun ti)
61 |
62 | --------------------------------------------------------------------------------
63 | --          Deriving
64 | --------------------------------------------------------------------------------
65 |
66 | export
67 | failEnum : Res a
68 | failEnum = Left "Interface HDecEq can currently only be derived for enumerations"
69 |
70 | ||| Generate declarations and implementations for `Eq` for a given data type.
71 | export
72 | HDecEqVis : Visibility -> List Name -> ParamTypeInfo -> Res (List TopLevel)
73 | HDecEqVis vis nms p = case isEnum p.info of
74 |   True  =>
75 |     let fun  := funName p "hdecEq"
76 |         impl := implName p "HDecEq"
77 |      in Right
78 |           [ TL (hdeceqClaim vis fun p) (hdeceqDef fun p.info)
79 |           , TL (hdeceqImplClaim vis impl p) (hdeceqImplDef fun impl)
80 |           ]
81 |   False => failEnum
82 |
83 | ||| Alias for `EqVis Public`
84 | export %inline
85 | HDecEq : List Name -> ParamTypeInfo -> Res (List TopLevel)
86 | HDecEq = HDecEqVis Public
87 |