0 | module Derive.Ord
  1 |
  2 | import Derive.Eq
  3 | import Language.Reflection.Util
  4 |
  5 | %default total
  6 |
  7 | --------------------------------------------------------------------------------
  8 | --          Claims
  9 | --------------------------------------------------------------------------------
 10 |
 11 | ||| Top-level function declaration implementing the ordering test for
 12 | ||| the given data type.
 13 | export
 14 | ordClaim : Visibility -> (fun : Name) -> (p : ParamTypeInfo) -> Decl
 15 | ordClaim vis fun p =
 16 |   let arg := p.applied
 17 |       tpe := piAll `(~(arg) -> ~(arg) -> Ordering) (allImplicits p "Ord")
 18 |    in simpleClaim vis fun tpe
 19 |
 20 | ||| Top-level declaration implementing the `Ord` interface for
 21 | ||| the given data type.
 22 | export
 23 | ordImplClaim : Visibility -> (impl : Name) -> (p : ParamTypeInfo) -> Decl
 24 | ordImplClaim v impl p = implClaimVis v impl (implType "Ord" p)
 25 |
 26 | --------------------------------------------------------------------------------
 27 | --          Definitions
 28 | --------------------------------------------------------------------------------
 29 |
 30 | export
 31 | ordImplDef : (fun, impl : Name) -> Decl
 32 | ordImplDef fun impl =
 33 |   def impl [patClause (var impl) (var "mkOrd" `app` var fun)]
 34 |
 35 | ordEnumDef : (impl, ci : Name) -> Decl
 36 | ordEnumDef i c =
 37 |   def i [patClause (var i) `(mkOrd $ \x,y => compare (~(var c) x) (~(var c) y))]
 38 |
 39 | -- Generates the right-hand side of the ordering test on a single
 40 | -- pair of (identical) data constructors based on the given list of
 41 | -- comparisons.
 42 | rhs : SnocList TTImp -> TTImp
 43 | rhs [<]       = `(EQ)
 44 | rhs (sx :< x) = foldr (\e,acc => `(case ~(e) of {EQ => ~(acc); o => o})) x sx
 45 |
 46 | -- catch-all pattern clause for data types with more than
 47 | -- one data constructor
 48 | catchAll : (ci : Name) -> (fun : Name) -> TypeInfo -> List Clause
 49 | catchAll ci fun ti =
 50 |   let civ      := var ci
 51 |   in if length ti.cons > 1
 52 |        then [patClause `(~(var fun) x y) `(compare (~(civ) x) (~(civ) y))]
 53 |        else []
 54 |
 55 | parameters (nms : List Name)
 56 |   arg : BoundArg 2 Regular -> TTImp
 57 |   arg (BA g [x,y] _) = assertIfRec nms g.type `(compare ~(var x) ~(var y))
 58 |
 59 |   ||| Generates pattern match clauses for the constructors of
 60 |   ||| the given data type. `fun` is the name of the function we implement.
 61 |   ||| This is either a local function definition in case of a
 62 |   ||| custom derivation, or the name of a top-level function.
 63 |   export
 64 |   ordClauses : (ci, fun : Name) -> (t : TypeInfo) -> List Clause
 65 |   ordClauses ci fun ti = map clause ti.cons ++ catchAll ci fun ti
 66 |
 67 |    where
 68 |      clause : Con ti.arty ti.args -> Clause
 69 |      clause = accumArgs2 regular (\x,y => `(~(var fun) ~(x) ~(y))) rhs arg
 70 |
 71 |   ||| Definition of a (local or top-level) function implementing
 72 |   ||| the ordering check for the given data type.
 73 |   export
 74 |   ordDef : (ci, fun : Name) -> TypeInfo -> Decl
 75 |   ordDef ci fun ti = def fun (ordClauses ci fun ti)
 76 |
 77 | --------------------------------------------------------------------------------
 78 | --          Deriving
 79 | --------------------------------------------------------------------------------
 80 |
 81 | ||| Generate declarations and implementations for `Ord` for a given data type.
 82 | export
 83 | OrdVis : Visibility -> List Name -> ParamTypeInfo -> Res (List TopLevel)
 84 | OrdVis vis nms p = case isEnum p.info of
 85 |   True  =>
 86 |     let impl := implName p "Ord"
 87 |         ci   := conIndexName p
 88 |      in Right [ TL (ordImplClaim vis impl p) (ordEnumDef impl ci) ]
 89 |   False =>
 90 |     let ci   := conIndexName p
 91 |         pre  := if length p.cons > 1 then ConIndexVis vis nms p else Right []
 92 |         fun  := funName p "ord"
 93 |         impl := implName p "Ord"
 94 |      in sequenceJoin
 95 |           [ pre
 96 |           , Right
 97 |               [ TL (ordClaim vis fun p) (ordDef nms ci fun p.info)
 98 |               , TL (ordImplClaim vis impl p) (ordImplDef fun impl)
 99 |               ]
100 |           ]
101 |
102 | ||| Alias for `OrdVis Public`
103 | export %inline
104 | Ord : List Name -> ParamTypeInfo -> Res (List TopLevel)
105 | Ord = OrdVis Public
106 |