0 | module Language.Reflection.Unify.WithCompiler
  1 |
  2 | import Control.Monad.Either
  3 | import Control.Monad.Writer
  4 | import Control.Monad.Identity
  5 | import Data.DPair
  6 | import Data.Fin.Set
  7 | import Data.Vect
  8 | import Data.Vect.Quantifiers
  9 | import Data.SnocVect
 10 | import Data.SortedMap
 11 | import Decidable.Equality
 12 | import Language.Reflection
 13 | import Language.Reflection.Expr
 14 | import Language.Reflection.Logging
 15 | import Language.Reflection.Syntax
 16 | import Language.Reflection.Unify.Interface
 17 | import Language.Reflection.VarSubst
 18 |
 19 | %default total
 20 |
 21 | ||| Generate free variable name to index mapping
 22 | genNameToId :
 23 |   {freeVars : Nat} ->
 24 |   Vect freeVars FVData ->
 25 |   SortedMap Name $ Fin freeVars
 26 | genNameToId fvs =
 27 |   foldl (\acc, (i, fv) => insert fv.name i acc) empty (zip (allFins freeVars) fvs)
 28 |
 29 | ||| Generate free variable hole to index mapping
 30 | genHoleToId :
 31 |   {freeVars : Nat} ->
 32 |   Vect freeVars FVData ->
 33 |   SortedMap String $ Fin freeVars
 34 | genHoleToId fvs =
 35 |   foldl (\acc, (i, fv) => insert fv.holeName i acc) empty (zip (allFins freeVars) fvs)
 36 |
 37 | aMHImpl :
 38 |   {0 freeVars : Nat} ->
 39 |   MonadWriter (FinSet freeVars) m =>
 40 |   SortedMap String (Fin freeVars) ->
 41 |   TTImp ->
 42 |   m TTImp
 43 | aMHImpl h2Id h = do
 44 |   let IHole _ s = h
 45 |   | _ => pure h
 46 |   let Just id = lookup s h2Id
 47 |   | _ => pure h
 48 |   writer (h, insert id empty)
 49 |
 50 | ||| Generate a set of free variables whose holes appear in a TTImp
 51 | allMatchingHoles :
 52 |   {0 freeVars : Nat} ->
 53 |   SortedMap String (Fin freeVars) ->
 54 |   TTImp ->
 55 |   FinSet freeVars
 56 | allMatchingHoles h2Id t = execWriter $ mapMTTImp (aMHImpl h2Id) t
 57 |
 58 | fromPiInfo : Lazy t -> PiInfo t -> t
 59 | fromPiInfo _ (DefImplicit x) = x
 60 | fromPiInfo x _ = x
 61 |
 62 | ||| Generate a dependency map from unification output and hole-to-index mapping
 63 | genDeps :
 64 |   {freeVars : Nat} ->
 65 |   Vect freeVars FVData ->
 66 |   SortedMap String (Fin freeVars) ->
 67 |   Vect freeVars $ FVDeps freeVars
 68 | genDeps fvs h2Id =
 69 |   map
 70 |     (\fv =>
 71 |       MkFVDeps
 72 |         (allMatchingHoles h2Id fv.type)
 73 |         (fromMaybe empty $ allMatchingHoles h2Id <$> fv.value)
 74 |         (fromPiInfo empty $ allMatchingHoles h2Id <$> fv.piInfo)
 75 |     )
 76 |     fvs
 77 |
 78 | ||| Find free variables without value
 79 | genEmpties :
 80 |   {freeVars : Nat} ->
 81 |   Vect freeVars FVData ->
 82 |   FinSet freeVars
 83 | genEmpties fvs = foldl genEmpties' empty $ zip (allFins freeVars) fvs
 84 |   where
 85 |     genEmpties' : FinSet fv -> (Fin fv, FVData) -> FinSet fv
 86 |     genEmpties' set (i, fv) =
 87 |       if isNothing fv.value
 88 |          then insert i set
 89 |          else set
 90 |
 91 | ||| Generate a dependency graph based on free variable data
 92 | genDG :
 93 |   {freeVars : Nat} ->
 94 |   Vect freeVars FVData ->
 95 |   DependencyGraph
 96 | genDG fvs = do
 97 |   let h2Id = genHoleToId fvs
 98 |   MkDG freeVars fvs (genDeps fvs h2Id) (genEmpties fvs) (genNameToId fvs) h2Id
 99 |
100 | ||| Find all free variables that can be substituted
101 | canSub :
102 |   (dg : DependencyGraph) ->
103 |   FinSet dg.freeVars
104 | canSub dg =
105 |   flip difference dg.empties $
106 |     foldl
107 |       (\s, (i, deps) =>
108 |         if (flip difference dg.empties (mergeDeps deps)) == empty
109 |            then insert i s
110 |            else s)
111 |       Fin.Set.empty
112 |       $ zip (allFins dg.freeVars) dg.fvDeps
113 |
114 | subMatchingHolesImpl :
115 |   (dg : DependencyGraph) ->
116 |   FinSet dg.freeVars ->
117 |   TTImp ->
118 |   TTImp
119 | subMatchingHolesImpl dg fbs ih@(IHole _ h) =
120 |   case lookup h dg.holeToId of
121 |     Just id =>
122 |       if contains id fbs
123 |         then
124 |           let fv = index id dg.fvData
125 |           in fromMaybe ih fv.value
126 |         else ih
127 |     Nothing => ih
128 | subMatchingHolesImpl _ _ t = t
129 |
130 | ||| Substitute all holes matching free variables in set with their values
131 | subMatchingHoles :
132 |   (dg : DependencyGraph) ->
133 |   FinSet dg.freeVars ->
134 |   TTImp ->
135 |   TTImp
136 | subMatchingHoles dg fbs = mapTTImp $ subMatchingHolesImpl dg fbs
137 |
138 | ||| Substitute all free variables in set within other free variable's data
139 | fvSubMatching :
140 |   (dg: DependencyGraph) ->
141 |   FinSet dg.freeVars ->
142 |   FVData ->
143 |   FVData
144 | fvSubMatching dg canSub =
145 |   { type $= subMatchingHoles dg canSub
146 |   , value $= map $ subMatchingHoles dg canSub
147 |   , piInfo $= map $ subMatchingHoles dg canSub
148 |   }
149 |
150 | valDepsOfVar : (dg : DependencyGraph) -> Fin dg.freeVars -> FinSet dg.freeVars
151 | valDepsOfVar dg id = valueDeps $ index id dg.fvDeps
152 |
153 | valDepsOfVars : (dg : DependencyGraph) -> FinSet dg.freeVars -> FinSet dg.freeVars
154 | valDepsOfVars dg vars = foldl (\a, b => union (valDepsOfVar dg b) a) empty (toList vars)
155 |
156 | fvSubMatching' :
157 |   (dg: DependencyGraph) ->
158 |   FinSet dg.freeVars ->
159 |   (FVData, FVDeps dg.freeVars) ->
160 |   (FVData, FVDeps dg.freeVars)
161 | fvSubMatching' dg canSub (fvData, MkFVDeps tyDeps valDeps piInfoDeps) = do
162 |   let canSubTy = intersection tyDeps canSub
163 |   let canSubVal = intersection valDeps canSub
164 |   let canSubPiInfo = intersection piInfoDeps canSub
165 |   let tyAddDeps = valDepsOfVars dg canSubTy
166 |   let valAddDeps = valDepsOfVars dg canSubVal
167 |   let piInfoAddDeps = valDepsOfVars dg canSubPiInfo
168 |   let newTyDeps = union tyAddDeps (difference tyDeps canSubTy)
169 |   let newValDeps = union valAddDeps (difference valDeps canSubVal)
170 |   let newPiInfoDeps = union piInfoAddDeps (difference piInfoDeps canSubPiInfo)
171 |   (fvSubMatching dg canSub fvData, MkFVDeps newTyDeps newValDeps newPiInfoDeps)
172 |
173 |
174 | ||| Substitute all free variables in set within dependency graph
175 | doSub :
176 |   (dg : DependencyGraph) ->
177 |   FinSet dg.freeVars ->
178 |   DependencyGraph
179 | doSub dg canSub = do
180 |   let (newFvData, newFvDeps) = unzip $ fvSubMatching' dg canSub <$> zip dg.fvData dg.fvDeps
181 |   ({fvData := newFvData, fvDeps := newFvDeps} dg)
182 |
183 | subEmptiesTImpl : (dg : DependencyGraph) -> TTImp -> TTImp
184 | subEmptiesTImpl dg t@(IHole _ h) = do
185 |   let Just id = lookup h dg.holeToId
186 |   | _ => t
187 |   if contains id dg.empties
188 |     then
189 |       let fv = index id dg.fvData
190 |       in IVar EmptyFC fv.name
191 |     else t
192 | subEmptiesTImpl dg t = t
193 |
194 | subEmptiesT : DependencyGraph -> TTImp -> TTImp
195 | subEmptiesT dg = mapTTImp $ subEmptiesTImpl dg
196 |
197 | subEmptiesFV :
198 |   (dg: DependencyGraph) ->
199 |   FVData ->
200 |   FVData
201 | subEmptiesFV dg  =
202 |   { type $= subEmptiesT dg
203 |   , value $= map $ subEmptiesT dg
204 |   }
205 |
206 | ||| Substitute holes of empty free variables with their names
207 | subEmpties :
208 |   (dg : DependencyGraph) ->
209 |   DependencyGraph
210 | subEmpties dg = {fvData $= map $ subEmptiesFV dg} dg
211 |
212 | ||| Solve dependency graph
213 | solveDG :
214 |   Monad m =>
215 |   (dg : DependencyGraph) ->
216 |   m DependencyGraph
217 | solveDG dg = do
218 |   let cs = canSub dg
219 |   let False = null cs
220 |   | _ => pure dg
221 |   ds <- pure $ doSub dg cs
222 |   -- DG <= DS because cs is non-empty, and every doSub may shrink the set of possibly substitutable variables
223 |   -- If doSub can't shrink it, the dependency graph stays the same
224 |   if ds == dg
225 |      then pure dg
226 |      else solveDG $ assert_smaller dg ds
227 |
228 | ArgDPair : Type
229 | ArgDPair = Subset Arg IsNamedArg
230 |
231 | ||| Generate hole name for free variables
232 | genHoleNames :
233 |   Elaboration m =>
234 |   SnocVect l ArgDPair ->
235 |   m $ (SortedMap Name String, SnocVect l String)
236 | genHoleNames [<] = pure (empty, [<])
237 | genHoleNames (xs :< (Element fv isNamed)) = do
238 |   let n = argName fv
239 |   gs <- genSym $ show n
240 |   (others, others') <- genHoleNames xs
241 |   pure $ (insert n (show gs) others, others' :< show gs)
242 |
243 | ||| Build up dependent pair type for typechecking
244 | buildUpDPair : SnocVect l ArgDPair -> TTImp -> TTImp
245 | buildUpDPair [<] t = t
246 | buildUpDPair (xs :< (Element fv isNamed)) t =
247 |   buildUpDPair xs
248 |     `(Builtin.DPair.DPair
249 |       ~(fv.type)
250 |       ~(ILam EmptyFC MW ExplicitArg (Just $ argName fv) fv.type t))
251 |
252 | ||| Build up dependent pair value for typechecking
253 | buildUpTarget : SnocVect l (String, ArgDPair) -> TTImp -> TTImp
254 | buildUpTarget [<] t = t
255 | buildUpTarget (xs :< (s, _)) t =
256 |   buildUpTarget xs `((~(IHole EmptyFC s) ** ~t))
257 |
258 | extractFVData :
259 |   Elaboration m =>
260 |   MonadError (Maybe UnificationError) m =>
261 |   (t : Type) ->
262 |   t ->
263 |   Vect l ArgDPair ->
264 |   Vect l String ->
265 |   m $ Vect l (Name, TTImp, Maybe TTImp)
266 | extractFVData t v ((Element fv isNamed) :: xs) (hn :: hns) = do
267 |   case t of
268 |     DPair myTy dNext => do
269 |       let (vv ** vRest= v
270 |       quoteV <- quote vv
271 |       quoteT <- quote myTy
272 |       rest <- extractFVData (dNext vv) vRest xs hns
273 |       let retVal =
274 |         case quoteV of
275 |             IHole _ hh =>
276 |               if hh == hn then Nothing else Just quoteV
277 |             qv => Just qv
278 |       pure $ (argName fv, quoteT, retVal) :: rest
279 |     _ => do
280 |       qT <- quote t
281 |       throwError $ Just $ ExtractionError qT
282 | extractFVData t v [] [] = do
283 |   qT <- quote t
284 |   qV <- quote v
285 |   case t of
286 |     Equal x y =>
287 |       case qV of
288 |         INamedApp _ (INamedApp _ `(Builtin.Refl) _ _) _ _ => pure ()
289 |         _ => throwError $ Nothing
290 |     _ => throwError $ Just $ InternalError "DPairs don't correspond to each other. Should never occur."
291 |   pure []
292 |
293 | ||| Run unification
294 | unify' :
295 |   Elaboration m =>
296 |   MonadError (Maybe UnificationError) m =>
297 |   UnificationTask ->
298 |   m $ DependencyGraph
299 | unify' task = do
300 |   let 0 allFVsNamed = task.lhsAreNamed ++ task.rhsAreNamed
301 |   let allFreeVars = pushIn (task.lhsFreeVars ++ task.rhsFreeVars) allFVsNamed
302 |   let snocLFV : SnocVect task.lfv _ =
303 |     cast $ pushIn task.lhsFreeVars task.lhsAreNamed
304 |   let snocRFV : SnocVect task.rfv _ =
305 |     cast $ pushIn task.rhsFreeVars task.rhsAreNamed
306 |   (lhsNMap, lhsNames) <- genHoleNames snocLFV
307 |   (rhsNMap, rhsNames) <- genHoleNames snocRFV
308 |   let hole2N = IHole EmptyFC <$> mergeLeft lhsNMap rhsNMap
309 |   let allNames = lhsNames ++ rhsNames
310 |   -- Assemble the type, the value of which is all our free variables + proof of equality
311 |   let checkTargetType =
312 |     buildUpDPair snocLFV $
313 |       buildUpDPair snocRFV `(~(task.lhsExpr) ~=~ ~(task.rhsExpr))
314 |   -- Assemble the value (holes + Refl)
315 |   let checkTarget =
316 |     buildUpTarget (zip lhsNames snocLFV) $
317 |       buildUpTarget (zip rhsNames snocRFV) `(Refl)
318 |   logPoint DetailedDebug "unifyWithCompiler" [] "Target type: \{show checkTargetType}"
319 |   logPoint DetailedDebug "unifyWithCompiler" [] "Target value: \{show checkTarget}"
320 |   -- Instantiate target type
321 |   Just checkTargetType' : Maybe Type <-
322 |     try (Just <$> check checkTargetType) (pure Nothing)
323 |   | _ => throwError $ Just $ TargetTypeError checkTargetType
324 |   -- Run unification
325 |   Just checkTarget' : Maybe checkTargetType' <-
326 |     try (Just <$> check checkTarget) (pure Nothing)
327 |   | _ => throwError $ Just NoUnificationError
328 |   ctQuote <- quote checkTarget'
329 |   logPoint DetailedDebug "unifyWithCompiler" [] "Target value after quoting: \{show ctQuote}"
330 |   let vectNames = cast allNames
331 |   -- Extract unification results
332 |   uniResults <-
333 |     extractFVData checkTargetType' checkTarget' allFreeVars vectNames
334 |   logPoint DetailedDebug "unifyWithCompiler" [] "Raw unification results: \{show uniResults}"
335 |   -- Generate dependency graph
336 |   let allZipped = zip vectNames $ zip (task.lhsFreeVars ++ task.rhsFreeVars) uniResults
337 |   let dg = genDG $ makeFVData <$> allZipped
338 |   let dg = {fvData $= map {piInfo $= map $ substituteVariables hole2N}} dg
339 |   logPoint DetailedDebug "unifyWithCompiler" [] "Initial DG: \{show dg}"
340 |   let dg = subEmpties dg
341 |   logPoint DetailedDebug "unifyWithCompiler" [] "DG after subEmpties: \{show dg}"
342 |   solved <- solveDG dg
343 |   logPoint DetailedDebug "unifyWithCompiler" [] "Solved DG: \{show solved}"
344 |   pure solved
345 |
346 | ||| Run unification in a try block
347 | export
348 | unifyWithCompiler :
349 |   Elaboration m =>
350 |   MonadError (Maybe UnificationError) m =>
351 |   UnificationTask ->
352 |   m $ UnificationResult
353 | unifyWithCompiler task = do
354 |   let ret = runEitherT {m=Elab} {e=Maybe UnificationError} $ unify' task
355 |   let err = pure {f=Elab} $ Left $ Just CatastrophicError
356 |   rr <- try ret err
357 |   dg <- liftEither rr
358 |   ur <- pure $ finalizeDG task dg
359 |   logPoint DetailedDebug "unifyWithCompiler" [] "Unification result: \{show ur}"
360 |   pure ur
361 |
362 | ||| Run unification in a try block
363 | export
364 | unifyWithCompiler' :
365 |   Elaboration m =>
366 |   MonadError (Maybe UnificationError) m =>
367 |   UnificationTask ->
368 |   m $ UnificationResult
369 | unifyWithCompiler' task = do
370 |   dg <- unify' task
371 |   pure $ finalizeDG task dg
372 |
373 | export
374 | [UnifyWithCompiler]
375 | Elaboration m => CanUnify m where
376 |   unify = map cast . runEitherT {m} {e=Maybe UnificationError} . unifyWithCompiler
377 |