0 | module Deriving.DepTyCheck.Gen.ConsRecs
  1 |
  2 | import public Data.Alternative
  3 | import public Data.Fuel
  4 | import public Data.List.Ex
  5 | import public Data.List.Map
  6 | import public Data.Nat1
  7 | import public Data.SortedMap
  8 | import public Data.SortedMap.Extra
  9 | import public Data.SortedSet
 10 | import public Data.SortedSet.Extra
 11 |
 12 | import public Deriving.DepTyCheck.Gen.Signature
 13 | import public Deriving.DepTyCheck.Gen.Tuning
 14 |
 15 | import public Language.Reflection.Compat.TypeInfo
 16 | import public Language.Reflection.Logging
 17 |
 18 | import public Syntax.IHateParens.Function
 19 |
 20 | %default total
 21 |
 22 | ----------------------------------
 23 | --- Constructors recursiveness ---
 24 | ----------------------------------
 25 |
 26 | ||| Weight info of recursive constructors
 27 | public export
 28 | data RecWeightInfo : Type where
 29 |   SpendingFuel : ((leftFuelVarName : Name) -> TTImp) -> RecWeightInfo
 30 |   StructurallyDecreasing : (decrTy : TypeInfo) -> (wExpr : TTImp) -> RecWeightInfo
 31 |
 32 | public export
 33 | record ConWeightInfo where
 34 |   constructor MkConWeightInfo
 35 |   ||| Either a constant (for non-recursive) or a function returning weight info (for recursive)
 36 |   weight : Either Nat1 RecWeightInfo
 37 |
 38 | liftWeight1 : TTImp
 39 | liftWeight1 = `(Data.Nat1.one)
 40 |
 41 | export
 42 | reflectNat1 : Nat1 -> TTImp
 43 | reflectNat1 $ FromNat 1 = liftWeight1
 44 | reflectNat1 $ FromNat n = `(fromInteger ~(primVal $ BI $ cast n))
 45 |
 46 | export
 47 | isWeight1 : TTImp -> Bool
 48 | isWeight1 = (== liftWeight1)
 49 |
 50 | public export
 51 | weightExpr : ConWeightInfo -> Either TTImp ((leftFuelVarName : Name) -> TTImp)
 52 | weightExpr $ MkConWeightInfo $ Left n = Left $ reflectNat1 n
 53 | weightExpr $ MkConWeightInfo $ Right $ StructurallyDecreasing {wExpr, _} = Left wExpr
 54 | weightExpr $ MkConWeightInfo $ Right $ SpendingFuel e = Right e
 55 |
 56 | export
 57 | usedWeightFun : ConWeightInfo -> Maybe TypeInfo
 58 | usedWeightFun $ MkConWeightInfo $ Right $ StructurallyDecreasing {decrTy, _} = Just decrTy
 59 | usedWeightFun $ MkConWeightInfo $ Right $ SpendingFuel _ = Nothing
 60 | usedWeightFun $ MkConWeightInfo $ Left _ = Nothing
 61 |
 62 | record ConRec where
 63 |   constructor MkConRec
 64 |   constr    : Con
 65 |   conWeight : Either Nat1 (TTImp -> TTImp, SortedSet $ Fin constr.args.length)
 66 |            -- ^^^^^^                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 67 |            --    |                                   \- directly recursive args
 68 |            --    \- `Left` for non-recursive, `Right` for recursive constructor
 69 |
 70 | -- determine if this type is a nat-or-list-like data, i.e. one which we can measure for the probability
 71 | weightableTy : List ConRec -> Bool
 72 | weightableTy = any weightableCon where
 73 |   weightableCon : ConRec -> Bool
 74 |   weightableCon $ MkConRec _ $ Right (_, dra) = not $ null dra
 75 |   weightableCon $ MkConRec _ $ Left _         = False
 76 |
 77 | record TyConsRec where
 78 |   constructor MkTyConsRec
 79 |   typeInfo         : TypeInfo
 80 |   weightableTyArgs : SortedMap (Fin typeInfo.args.length) (TypeInfo, Name)
 81 |   constructors     : List ConRec
 82 |
 83 | export
 84 | record ConsRecs where
 85 |   constructor MkConsRecs
 86 |   consRecs : SortedMap Name TyConsRec
 87 |
 88 | Semigroup ConsRecs where
 89 |   MkConsRecs cw <+> MkConsRecs cw' = MkConsRecs $ cw `mergeLeft` cw'
 90 |
 91 | -------------------------------------
 92 | --- Getting (deriving) `ConsRecs` ---
 93 | -------------------------------------
 94 |
 95 | -- This is a workaround of some bad and not yet understood behaviour, leading to both compile- and runtime errors
 96 | removeNamedApps, workaroundFromNat : TTImp -> TTImp
 97 | removeNamedApps = mapTTImp $ \case INamedApp _ lhs _ _ => lhse => e
 98 | workaroundFromNat = mapTTImp $ \e => case fst $ unAppAny e of IVar _ `{Data.Nat1.FromNat} => removeNamedApps e_ => e
 99 |
100 | weightFunName : TypeInfo -> Name
101 | weightFunName ty = fromString "weight^\{show ty.name}"
102 |
103 | -- this is a workaround for Idris compiler bug #2983
104 | export
105 | interimNamesWrapper : Name -> Name
106 | interimNamesWrapper n = UN $ Basic "inter^<\{show n}>"
107 |
108 | -- This function is moved out from `getConsRecs` to reduce the closure of the returned function
109 | deriveW : TyConsRec -> Maybe (Decl, Decl)
110 | deriveW $ MkTyConsRec ty _ cons = do
111 |   guard $ weightableTy cons -- continue only when this type has structurally decreasing argument
112 |   let weightFunName = weightFunName ty
113 |
114 |   let inTyArg = arg $ foldl (\f, n => namedApp f n $ var n) .| var ty.name .| mapMaybe name ty.args
115 |   let funSig = export' weightFunName $ piAll `(Data.Nat1.Nat1) $ map {piInfo := ImplicitArg} ty.args ++ [inTyArg]
116 |
117 |   let wClauses = cons <&> \(MkConRec con e) => do
118 |     let wArgs = either (const empty) snd e
119 |     let lhsArgs : List (_, _) = mapI con.args $ \idx, arg => appArg arg <$> if contains idx wArgs && arg.count == MW
120 |                                   then let bindName = UN $ Basic "arg^\{show idx}" in (Just bindName, bindVar bindName)
121 |                                   else (Nothing, implicitTrue)
122 |     let callSelfOn : Name -> TTImp
123 |         callSelfOn n = var weightFunName .$ var n
124 |     patClause (var weightFunName .$ (reAppAny .| var con.name .| snd <$> lhsArgs)) $ case mapMaybe fst lhsArgs of
125 |       []  => liftWeight1
126 |       [x] => `(succ ~(callSelfOn x))
127 |       xs  => `(succ $ Prelude.concat @{Maximum} ~(liftList' $ xs <&> callSelfOn))
128 |
129 |   pure (funSig, def weightFunName wClauses)
130 |
131 | getAppVar : TTImp -> Maybe Name
132 | getAppVar e = case fst $ unAppAny e of IVar _ n => Just n_ => Nothing
133 |
134 | -- TODO to think of better placement for this function; this anyway is intended to be called from the derived code.
135 | public export
136 | leftDepth : Fuel -> Nat1
137 | leftDepth = go 1 where
138 |   go : Nat1 -> Fuel -> Nat1
139 |   go n Dry      = n
140 |   go n (More x) = go (succ n) x
141 |
142 | -- This function is moved out from `getConsRecs` to reduce the closure of the returned function
143 | finCR : NamesInfoInTypes =>
144 |         (tyCR : TyConsRec) ->
145 |         (givenTyArgs : SortedSet $ Fin tyCR.typeInfo.args.length) ->
146 |         List (Con, ConWeightInfo)
147 | finCR (MkTyConsRec ti wTyArgs cons) givenTyArgs = do
148 |   let wTyArgs = wTyArgs `intersectionMap` givenTyArgs
149 |   cons <&> \(MkConRec con e) => (con,) $ MkConWeightInfo $ e <&> \(wMod, directRecConArgs) => do
150 |     let conRetTyArgs = snd $ unAppAny con.type
151 |     let directRecConArgArgs = flip mapMaybe con.args $ \conArg => case unAppAny conArg.type of (conArgTy, conArgArgs) => do
152 |                                 toMaybe (getAppVar conArgTy == Just ti.name) conArgArgs
153 |     -- default behaviour, spend fuel, weight proportional to fuel
154 |     fromMaybe (SpendingFuel $ wMod . app `(Deriving.DepTyCheck.Gen.ConsRecs.leftDepth) . var) $ do
155 |     -- work only with given args
156 |     -- fail-fast if no given weightable args
157 |     guard $ not $ null wTyArgs
158 |     -- If for any weightable type argument (in `wTyArgs`) there exists a directly recursive constructor arg (in `directRecConArgs`) that has
159 |     -- this type argument strictly decreasing, we consider this constructor to be non-fuel-spending.
160 |     let conArgNames = SortedSet.fromList $ mapMaybe name con.args
161 |     (decrTy, weightExpr) <- foldAlt' wTyArgs.asList $ \(wTyArg, weightTy, weightArgName) => map (weightTy,) $ do
162 |       let wTyArg = finToNat wTyArg
163 |       conRetTyArg <- getExpr <$> getAt wTyArg conRetTyArgs
164 |       guard $ isJust $ lookupCon =<< getAppVar conRetTyArg
165 |       let freeNamesLessThanOrig = allVarNames' conRetTyArg `intersection` conArgNames
166 |       foldAlt' directRecConArgArgs $ \conArgArgs => do
167 |         getAt wTyArg conArgArgs >>= getAppVar . getExpr >>= \arg => toMaybe .| contains arg freeNamesLessThanOrig .|
168 |           var (weightFunName weightTy) .$ var (interimNamesWrapper weightArgName)
169 |     pure $ StructurallyDecreasing decrTy $ wMod weightExpr
170 |
171 | weightableTyArgs : (consRecs : SortedMap Name (TypeInfo, List ConRec)) -> (ti : TypeInfo) -> SortedMap (Fin ti.args.length) (TypeInfo, Name)
172 | weightableTyArgs consRecs ti = fromList $ flip List.mapMaybe ti.args.withIdx $ \(idx, ar) =>
173 |   getAppVar ar.type >>= lookup' consRecs >>= \(wti, cons) => guard (weightableTy cons) >> (idx, wti,) <$> ar.name
174 |
175 | -- Builds `ConsRecs` only for the given types, assuming that given `NamesInfoInTypes` contains info for them and their dependencies
176 | getConsRecsFor : NamesInfoInTypes => Elaboration m => (desiredTypes : ListMap Name TypeInfo) -> m ConsRecs
177 | getConsRecsFor desiredTypes = do
178 |   consRecs <- for (toSortedMap desiredTypes) $ \targetType => logBounds DetailedTrace "deptycheck.derive.consRec" [targetType] $ do
179 |     crsForTy <- for targetType.cons $ \con => do
180 |       tuneImpl <- search $ ProbabilityTuning con.name
181 |       w : Either Nat1 (TTImp -> TTImp, SortedSet $ Fin con.args.length) <- case isRecursive {containingType=Just targetType} con of
182 |         --             ^^^^^^^^^^^^^^  ^^^^^^^^^^^^^^^ <- set of directly recursive constructor arguments
183 |         --                    \------ Modifier of the standard weight expression
184 |         False => pure $ Left $ maybe one (\impl => tuneWeight @{impl} one) tuneImpl
185 |         True  => Right <$> do
186 |           fuelWeightExpr <- case tuneImpl of
187 |             Nothing   => pure id
188 |             Just impl => quote (tuneWeight @{impl}) <&> \wm, expr => workaroundFromNat $ wm `applySyn` expr
189 |           let directlyRecArgs : List $ Fin con.args.length := flip mapMaybe con.args.withIdx $ \idxarg => do
190 |             argTy <- getAppVar (snd idxarg).type
191 |             whenT .| argTy == targetType.name .| fst idxarg
192 |           when (not $ null directlyRecArgs) $
193 |             logPoint FineDetails "deptycheck.derive.consRec" [targetType, con]
194 |               "- directly recursive args: \{show $ finToNat <$> directlyRecArgs}"
195 |           pure (fuelWeightExpr, fromList directlyRecArgs)
196 |       pure $ MkConRec con w
197 |     pure (targetType, crsForTy)
198 |   let 0 _ : SortedMap Name (TypeInfo, List ConRec) := consRecs
199 |
200 |   pure $ MkConsRecs $ mapWithKey' consRecs $ \tyName, (ti, cons) => do
201 |     MkTyConsRec ti (weightableTyArgs consRecs ti) cons
202 |
203 | export
204 | getConsRecs : NamesInfoInTypes => Elaboration m => m ConsRecs
205 | getConsRecs = getConsRecsFor knownTypes
206 |
207 | export
208 | lookupConsWithWeight : ConsRecs => NamesInfoInTypes => GenSignature -> Maybe $ List (Con, ConWeightInfo)
209 | lookupConsWithWeight @{MkConsRecs crs} sig = do
210 |   cr <- lookup sig.targetType.name crs
211 |   let Yes prf = decEq cr.typeInfo.args.length sig.targetType.args.length | No _ => Nothing
212 |   pure $ finCR cr $ rewrite prf in sig.givenParams
213 |
214 | export
215 | deriveWeightingFun : ConsRecs => TypeInfo -> Maybe (Decl, Decl)
216 | deriveWeightingFun @{MkConsRecs crs} ti = lookup ti.name crs >>= deriveW
217 |
218 | export
219 | isTypeKnown : ConsRecs => TypeInfo -> Bool
220 | isTypeKnown @{MkConsRecs crs} ti = isJust $ lookup ti.name crs
221 |
222 | -- Having a `ConsRecs` being built from the given `NamesInfoInTypes`,
223 | -- it'll get the updated `NamesInfoInTypes` and a `ConsRecs` equivalent to those being built from this `NamesInfoInTypes`, but more effective.
224 | export
225 | updateNamesAndConsRecs : NamesInfoInTypes => ConsRecs => Elaboration m => List TypeInfo -> m (NamesInfoInTypes, ConsRecs)
226 | updateNamesAndConsRecs @{niit} @{crs} tis = do
227 |   newNiit <- logBounds Trace "deptycheck.derive.namesInfo.update" [] $ enrichNamesInfoInTypes tis niit
228 |   newCr <- logBounds Trace "deptycheck.derive.consRec.update" [] $ map (crs <+>) $ getConsRecsFor @{newNiit} $ fromList $ tis <&> \ti => (ti.name, ti)
229 |   pure (newNiit, newCr)
230 |