0 | module Compiler.Opts.Identity
  1 |
  2 | import Core.CompileExpr
  3 | import Core.Context.Log
  4 | import Data.Vect
  5 |
  6 | import Libraries.Data.List.SizeOf
  7 |
  8 | makeArgs : (args : Scope) -> List (Var (args ++ vars))
  9 | makeArgs args = embed @{ListFreelyEmbeddable} (Var.allVars args)
 10 |
 11 | parameters (fn1 : Name) (idIdx : Nat)
 12 |   mutual
 13 |     -- special case for matching on 'Nat'-shaped things
 14 |     isUnsucc : Var vars -> CExp vars -> Maybe (Constant, Var (x :: vars))
 15 |     isUnsucc var (COp _ (Sub _) [CLocal _ p, CPrimVal _ c]) =
 16 |         if var == MkVar p
 17 |             then Just (c, first)
 18 |             else Nothing
 19 |     isUnsucc _ _ = Nothing
 20 |
 21 |     unsuccIdentity : Constant -> Var vars -> CExp vars -> Bool
 22 |     unsuccIdentity c1 var (COp _ (Add _) [exp, CPrimVal _ c2]) = c1 == c2 && cexpIdentity var Nothing Nothing exp
 23 |     unsuccIdentity _ _ _ = False
 24 |
 25 |     -- does the CExp evaluate to the var, the constructor or the constant?
 26 |     cexpIdentity : Var vars -> Maybe (Name, List (Var vars)) -> Maybe Constant -> CExp vars -> Bool
 27 |     cexpIdentity var _ _ (CLocal fc p) = var == MkVar p
 28 |     cexpIdentity var _ _ (CRef {}) = False
 29 |     cexpIdentity var _ _ (CLam {}) = False
 30 |     cexpIdentity var con const (CLet _ _ NotInline val sc) = False
 31 |     cexpIdentity var con const (CLet _ _ _ val sc) = (case isUnsucc var val of
 32 |         Just (c, var') => unsuccIdentity c var' sc
 33 |         Nothing => False)
 34 |         || cexpIdentity
 35 |             (weaken var)
 36 |             (map (map (map weaken)) con)
 37 |             const
 38 |             sc
 39 |     cexpIdentity var con const (CApp _ (CRef _ fn2) as) = -- special case for self-recursive functions
 40 |         fn1 == fn2 &&
 41 |         case getAt idIdx as of
 42 |             Just exp => cexpIdentity var con const exp
 43 |             Nothing => False
 44 |     cexpIdentity _ _ _ (CApp {}) = False
 45 |     cexpIdentity var (Just (con1, as1)) const (CCon _ con2 _ _ as2) =
 46 |         con1 == con2
 47 |         && eqArgs as1 as2
 48 |       where
 49 |         eqArgs : List (Var vars) -> List (CExp vars) -> Bool
 50 |         eqArgs [] [] = True
 51 |         eqArgs (v :: vs) (a :: as) = cexpIdentity v Nothing Nothing a && eqArgs vs as
 52 |         eqArgs _ _ = False
 53 |     cexpIdentity var Nothing const (CCon {}) = False
 54 |     -- special case for integerToNat, see unsuccIdentity for a easier to read
 55 |     -- version that works when the let hasn't been inlined.
 56 |     -- integerToNat : (x : Integer) -> {auto 0 _ : (x >= 0) === True} -> Nat
 57 |     -- integerToNat x = case x of
 58 |     --                       0 => Z
 59 |     --                       _ => S $ integerToNat (x - 1)
 60 |     cexpIdentity var _ _ (COp _ (Add _) [a1, a2]) = case a2 of
 61 |         CPrimVal _ c1 => case a1 of
 62 |             CApp _ (CRef _ fn2) as =>
 63 |                 fn1 == fn2
 64 |                 && (case getAt idIdx as of
 65 |                     Just (COp _ (Sub _) [a3, (CPrimVal _ c2)]) =>
 66 |                         c1 == c2 && cexpIdentity var Nothing Nothing a3
 67 |                     _ => False)
 68 |             _ => False
 69 |         _ => False
 70 |     cexpIdentity var _ _ (COp {}) = False
 71 |     cexpIdentity var _ _ (CExtPrim {}) = False
 72 |     cexpIdentity var _ _ (CForce {}) = False
 73 |     cexpIdentity var _ _ (CDelay {}) = False
 74 |     cexpIdentity var con const (CConCase _ sc xs x) =
 75 |         cexpIdentity var Nothing Nothing sc
 76 |         && all altEq xs
 77 |         && maybeVarEq var con const x
 78 |       where
 79 |
 80 |         altEq : CConAlt vars -> Bool
 81 |         altEq (MkConAlt y _ _ args exp) =
 82 |             cexpIdentity
 83 |                 (weakenNs (mkSizeOf args) var)
 84 |                 (Just (y, makeArgs args))
 85 |                 const
 86 |                 exp
 87 |     cexpIdentity var con const (CConstCase fc sc xs x) =
 88 |         cexpIdentity var Nothing Nothing sc
 89 |         && all altEq xs
 90 |         && maybeVarEq var con const x
 91 |     where
 92 |         altEq : CConstAlt vars -> Bool
 93 |         altEq (MkConstAlt c exp) = cexpIdentity var con (Just c) exp
 94 |     cexpIdentity _ _ (Just c1) (CPrimVal _ c2) = c1 == c2
 95 |     cexpIdentity _ _ Nothing (CPrimVal {}) = False
 96 |     cexpIdentity _ _ _ (CErased _) = False
 97 |     cexpIdentity _ _ _ (CCrash {}) = False
 98 |
 99 |     -- fused `all (cexpIdentity var con const)`
100 |     maybeVarEq : Var vars -> Maybe (Name, List (Var vars)) -> Maybe Constant -> Maybe (CExp vars) -> Bool
101 |     maybeVarEq _ _ _ Nothing = True
102 |     maybeVarEq var con const (Just exp) = cexpIdentity var con const exp
103 |
104 | checkIdentity : (fullName : Name) -> List (Var vars) -> CExp vars -> Nat -> Maybe Nat
105 | checkIdentity _ [] _ _ = Nothing
106 | checkIdentity fn (v :: vs) exp idx = if cexpIdentity fn idx v Nothing Nothing exp
107 |     then Just idx
108 |     else checkIdentity fn vs exp (S idx)
109 |
110 | calcIdentity : (fullName : Name) -> CDef -> Maybe Nat
111 | calcIdentity fn (MkFun args exp) = checkIdentity fn (Var.allVars args) exp Z
112 | calcIdentity _ _ = Nothing
113 |
114 | getArg : FC -> Nat -> (args : Scope) -> Maybe (CExp args)
115 | getArg _ _ [] = Nothing
116 | getArg fc Z (a :: _) = Just $ CLocal fc First
117 | getArg fc (S k) (_ :: as) = weaken <$> getArg fc k as
118 |
119 | idCDef : Nat -> CDef -> Maybe CDef
120 | idCDef idx (MkFun args exp) = MkFun args <$> getArg (getFC exp) idx args
121 | idCDef _ def = Just def
122 |
123 | export
124 | rewriteIdentityFlag : Ref Ctxt Defs => Name -> Core ()
125 | rewriteIdentityFlag fn = do
126 |     defs <- get Ctxt
127 |     Just (fnIdx, gdef) <- lookupCtxtExactI fn defs.gamma
128 |         | Nothing => pure ()
129 |     let Just flg@(Identity idx) = find isId gdef.flags
130 |         | _ => pure ()
131 |     log "compiler.identity" 5 $ "found identity flag for: "
132 |                               ++ show !(getFullName fn) ++ ", " ++ show idx
133 |                               ++ "\n\told def: " ++ show gdef.compexpr
134 |     let Just cdef = the _ $ gdef.compexpr >>= idCDef idx
135 |         | Nothing => pure ()
136 |     log "compiler.identity" 5 $ "\tnew def: " ++ show cdef
137 |     unsetFlag EmptyFC (Resolved fnIdx) flg -- other optimisations might mess with argument counts
138 |     setFlag EmptyFC (Resolved fnIdx) Inline
139 |     setCompiled (Resolved fnIdx) cdef
140 |   where
141 |     isId : DefFlag -> Bool
142 |     isId (Identity _) = True
143 |     isId _ = False
144 |
145 | export
146 | setIdentity : Ref Ctxt Defs => Name -> Core ()
147 | setIdentity fn = do
148 |     defs <- get Ctxt
149 |     Just (fnIdx, gdef) <- lookupCtxtExactI fn defs.gamma
150 |         | Nothing => pure ()
151 |     let Just idx = the _ $ gdef.compexpr >>= calcIdentity fn
152 |         | Nothing => pure ()
153 |     setFlag EmptyFC (Resolved fnIdx) (Identity idx)
154 |     rewriteIdentityFlag (Resolved fnIdx)
155 |