0 | module TTImp.Elab.Record
  1 |
  2 | import Core.Env
  3 | import Core.Metadata
  4 | import Core.Unify
  5 | import Core.Value
  6 |
  7 | import Idris.REPL.Opts
  8 | import Idris.Syntax
  9 |
 10 | import TTImp.Elab.Check
 11 | import TTImp.Elab.Delayed
 12 | import TTImp.TTImp
 13 |
 14 | import Data.SortedSet
 15 |
 16 | %default covering
 17 |
 18 | getRecordType : Env Term vars -> NF vars -> Maybe Name
 19 | getRecordType env (NTCon _ n _ _) = Just n
 20 | getRecordType env _ = Nothing
 21 |
 22 | getNames : {auto c : Ref Ctxt Defs} -> Defs -> ClosedNF -> Core $ SortedSet Name
 23 | getNames defs (NApp _ hd args)
 24 |     = do eargs <- traverse (evalClosure defs . snd) args
 25 |          pure $ nheadNames hd `union` concat !(traverse (getNames defs) eargs)
 26 |   where
 27 |     nheadNames : NHead Scope.empty -> SortedSet Name
 28 |     nheadNames (NRef Bound n) = singleton n
 29 |     nheadNames _ = empty
 30 | getNames defs (NDCon _ _ _ _ args)
 31 |     = do eargs <- traverse (evalClosure defs . snd) args
 32 |          pure $ concat !(traverse (getNames defs) eargs)
 33 | getNames defs (NTCon _ _ _ args)
 34 |   = do eargs <- traverse (evalClosure defs . snd) args
 35 |        pure $ concat !(traverse (getNames defs) eargs)
 36 | getNames defs (NDelayed _ _ tm) = getNames defs tm
 37 | getNames {} = pure empty
 38 |
 39 | data Rec : Type where
 40 |      Field : Maybe Name -> -- implicit argument name, if any
 41 |              String -> RawImp -> Rec -- field name on left, value on right
 42 |      Constr : Maybe Name -> -- implicit argument name, if any
 43 |               Name -> List (String, Rec) -> Rec
 44 |
 45 | covering
 46 | Show Rec where
 47 |   show (Field mn n ty)
 48 |       = "Field " ++ show mn ++ "; " ++ show n ++ " : " ++ show ty
 49 |   show (Constr mn n args)
 50 |       = "Constr " ++ show mn ++ " " ++ show n ++ " " ++ show args
 51 |
 52 | toLHS' : FC -> Rec -> (Maybe Name, RawImp)
 53 | toLHS' loc (Field mn@(Just _) n _)
 54 |     = (mn, IAs loc (virtualiseFC loc) UseRight (UN $ Basic n) (Implicit loc True))
 55 | toLHS' loc (Field mn n _) = (mn, IBindVar (virtualiseFC loc) (UN $ Basic n))
 56 | toLHS' loc (Constr mn con args)
 57 |     = let args' = map (toLHS' loc . snd) args in
 58 |           (mn, gapply (IVar loc con) args')
 59 |
 60 | toLHS : FC -> Rec -> RawImp
 61 | toLHS fc r = snd (toLHS' fc r)
 62 |
 63 | toRHS' : FC -> Rec -> (Maybe Name, RawImp)
 64 | toRHS' loc (Field mn _ val) = (mn, val)
 65 | toRHS' loc (Constr mn con args)
 66 |     = let args' = map (toRHS' loc . snd) args in
 67 |           (mn, gapply (IVar loc con) args')
 68 |
 69 | toRHS : FC -> Rec -> RawImp
 70 | toRHS fc r = snd (toRHS' fc r)
 71 |
 72 | findConName : Defs -> Name -> Core (Maybe Name)
 73 | findConName defs tyn
 74 |     = case !(lookupDefExact tyn (gamma defs)) of
 75 |            Just (TCon _ _ _ _ _ (Just [con]) _) => pure (Just con)
 76 |            _ => pure Nothing
 77 |
 78 | findFieldsAndTypeArgs : {auto c : Ref Ctxt Defs} ->
 79 |                         Defs -> Name ->
 80 |                         Core $ Maybe (List (String, Maybe Name, Maybe Name), SortedSet Name)
 81 | findFieldsAndTypeArgs defs con
 82 |     = case !(lookupTyExact con (gamma defs)) of
 83 |            Just t => pure (Just !(getExpNames empty [] !(nf defs Env.empty t)))
 84 |            _ => pure Nothing
 85 |   where
 86 |     getExpNames : SortedSet Name ->
 87 |                   List (String, Maybe Name, Maybe Name) ->
 88 |                   ClosedNF ->
 89 |                   Core (List (String, Maybe Name, Maybe Name), SortedSet Name)
 90 |     getExpNames names expNames (NBind fc x (Pi _ _ p ty) sc)
 91 |         = do let imp = case p of
 92 |                             Explicit => Nothing
 93 |                             _ => Just x
 94 |              nfty <- evalClosure defs ty
 95 |              let names = !(getNames defs nfty) `union` names
 96 |              let expNames = (nameRoot x, imp, getRecordType Env.empty nfty) :: expNames
 97 |              getExpNames names expNames !(sc defs (toClosure defaultOpts Env.empty (Ref fc Bound x)))
 98 |     getExpNames names expNames nfty = pure (reverse expNames, (!(getNames defs nfty) `union` names))
 99 |
100 | genFieldName : {auto u : Ref UST UState} ->
101 |                String -> Core String
102 | genFieldName root
103 |     = do ust <- get UST
104 |          put UST ({ nextName $= (+1) } ust)
105 |          pure (root ++ show (nextName ust))
106 |
107 | -- There's probably a generic version of this in the prelude isn't
108 | -- there?
109 | replace : String -> Rec ->
110 |           List (String, Rec) -> List (String, Rec)
111 | replace k v [] = []
112 | replace k v ((k', v') :: vs)
113 |     = if k == k'
114 |          then ((k, v) :: vs)
115 |          else ((k', v') :: replace k v vs)
116 |
117 | findPath : {auto c : Ref Ctxt Defs} ->
118 |            {auto u : Ref UST UState} ->
119 |            FC -> List String -> List String -> Maybe Name ->
120 |            (String -> RawImp) ->
121 |            Rec -> Core Rec
122 | findPath loc [] full tyn val (Field mn lhs _) = pure (Field mn lhs (val lhs))
123 | findPath loc [] full tyn val rec
124 |    = throw (IncompatibleFieldUpdate loc full)
125 | findPath loc (p :: ps) full Nothing val (Field mn n v)
126 |    = throw (NotRecordField loc p Nothing)
127 | findPath loc (p :: ps) full (Just tyn) val (Field mn n v)
128 |    = do defs <- get Ctxt
129 |         Just con <- findConName defs tyn
130 |              | Nothing => throw (NotRecordType loc tyn)
131 |         Just (fs, tyArgs) <- findFieldsAndTypeArgs defs con
132 |              | Nothing => throw (NotRecordType loc tyn)
133 |         args <- mkArgs fs tyArgs
134 |         let rec' = Constr mn con args
135 |         findPath loc (p :: ps) full (Just tyn) val rec'
136 |   where
137 |     mkArgs : List (String, Maybe Name, Maybe Name) ->
138 |              SortedSet Name ->
139 |              Core (List (String, Rec))
140 |     mkArgs [] _ = pure []
141 |     mkArgs ((p, imp, _) :: ps) tyArgs
142 |         = do fldn <- genFieldName p
143 |              args' <- mkArgs ps tyArgs
144 |              -- If other types depend on that implicit argument, leave it as _ by default
145 |              let arg = case (flip contains tyArgs) <$> imp of
146 |                   Just True => Implicit loc False
147 |                   _ => IVar (virtualiseFC loc) (UN $ Basic fldn)
148 |              pure ((p, Field imp fldn arg) :: args')
149 |
150 | findPath loc (p :: ps) full tyn val (Constr mn con args)
151 |    = do let Just prec = lookup p args
152 |                  | Nothing => throw (NotRecordField loc p tyn)
153 |         defs <- get Ctxt
154 |         Just (fs, _) <- findFieldsAndTypeArgs defs con
155 |              | Nothing => pure (Constr mn con args)
156 |         let Just (imp, mfty) = lookup p fs
157 |                  | Nothing => throw (NotRecordField loc p tyn)
158 |         prec' <- findPath loc ps full mfty val prec
159 |         pure (Constr mn con (replace p prec' args))
160 |
161 | getSides : {auto c : Ref Ctxt Defs} ->
162 |            {auto u : Ref UST UState} ->
163 |            FC -> IFieldUpdate -> Name -> RawImp -> Rec ->
164 |            Core Rec
165 | getSides loc (ISetField path val) tyn orig rec
166 |    -- update 'rec' so that 'path' is accessible on the lhs and rhs,
167 |    -- then set the path on the rhs to 'val'
168 |    = findPath loc path path (Just tyn) (const val) rec
169 | getSides loc (ISetFieldApp path val) tyn orig rec
170 |    = findPath loc path path (Just tyn)
171 |       (\n => apply val [IVar (virtualiseFC loc) (UN $ Basic n)]) rec
172 |
173 | getAllSides : {auto c : Ref Ctxt Defs} ->
174 |               {auto u : Ref UST UState} ->
175 |               FC -> List IFieldUpdate -> Name ->
176 |               RawImp -> Rec ->
177 |               Core Rec
178 | getAllSides loc [] tyn orig rec = pure rec
179 | getAllSides loc (u :: upds) tyn orig rec
180 |     = getAllSides loc upds tyn orig !(getSides loc u tyn orig rec)
181 |
182 | checkForDuplicates :
183 |   List IFieldUpdate ->
184 |   (seen, dups : SortedSet (List String)) ->
185 |   SortedSet (List String)
186 | checkForDuplicates [] seen dups = dups
187 | checkForDuplicates (x :: xs) seen dups
188 |   = let path = getFieldUpdatePath x
189 |         dups = ifThenElse (contains path seen) (insert path dups) dups
190 |     in checkForDuplicates xs (insert path seen) dups
191 |
192 | -- Convert the collection of high level field accesses into a case expression
193 | -- which does the updates all in one go
194 | export
195 | recUpdate : {vars : _} ->
196 |             {auto c : Ref Ctxt Defs} ->
197 |             {auto u : Ref UST UState} ->
198 |             RigCount -> ElabInfo -> FC ->
199 |             NestedNames vars -> Env Term vars ->
200 |             List IFieldUpdate ->
201 |             (rec : RawImp) -> (grecty : Glued vars) ->
202 |             Core RawImp
203 | recUpdate rigc elabinfo iloc nest env flds rec grecty
204 |       = do let dups = checkForDuplicates flds empty empty
205 |            unless (null dups) $
206 |              throw (DuplicatedRecordUpdatePath iloc $ Prelude.toList dups)
207 |            defs <- get Ctxt
208 |            rectynf <- getNF grecty
209 |            let Just rectyn = getRecordType env rectynf
210 |                     | Nothing => throw (RecordTypeNeeded iloc env)
211 |            fldn <- genFieldName "__fld"
212 |            sides <- getAllSides iloc flds rectyn rec
213 |                                 (Field Nothing fldn (IVar vloc (UN $ Basic fldn)))
214 |            pure $ ICase vloc [] rec (Implicit vloc False) [mkClause sides]
215 |   where
216 |     vloc : FC
217 |     vloc = virtualiseFC iloc
218 |
219 |     mkClause : Rec -> ImpClause
220 |     mkClause rec = PatClause vloc (toLHS vloc rec) (toRHS vloc rec)
221 |
222 | needType : Error -> Bool
223 | needType (RecordTypeNeeded {}) = True
224 | needType (InType _ _ err) = needType err
225 | needType (InCon _ err) = needType err
226 | needType (InLHS _ _ err) = needType err
227 | needType (InRHS _ _ err) = needType err
228 | needType (WhenUnifying _ _ _ _ _ err) = needType err
229 | needType _ = False
230 |
231 | export
232 | checkUpdate : {vars : _} ->
233 |               {auto c : Ref Ctxt Defs} ->
234 |               {auto m : Ref MD Metadata} ->
235 |               {auto u : Ref UST UState} ->
236 |               {auto e : Ref EST (EState vars)} ->
237 |               {auto s : Ref Syn SyntaxInfo} ->
238 |               {auto o : Ref ROpts REPLOpts} ->
239 |               RigCount -> ElabInfo ->
240 |               NestedNames vars -> Env Term vars ->
241 |               FC -> List IFieldUpdate -> RawImp -> Maybe (Glued vars) ->
242 |               Core (Term vars, Glued vars)
243 | checkUpdate rig elabinfo nest env fc upds rec expected
244 |     = do recty <- case expected of
245 |                        Just ret => pure ret
246 |                        _ => do (_, ty) <- checkImp rig elabinfo
247 |                                                    nest env rec Nothing
248 |                                pure ty
249 |          let solvemode = case elabMode elabinfo of
250 |                               InLHS c => inLHS
251 |                               _ => inTerm
252 |          delayOnFailure fc rig env (Just recty) needType RecordUpdate $
253 |            \delayed =>
254 |              do solveConstraints solvemode Normal
255 |                 exp <- getTerm recty
256 |                 -- We can't just use the old NF on the second attempt,
257 |                 -- because we might know more now, so recalculate it
258 |                 let recty' = if delayed
259 |                                 then gnf env exp
260 |                                 else recty
261 |                 logGlueNF "elab.record" 5 (show delayed ++ " record type " ++ show rec) env recty'
262 |                 rcase <- recUpdate rig elabinfo fc nest env upds rec recty'
263 |                 log "elab.record" 5 $ "Record update: " ++ show rcase
264 |                 check rig elabinfo nest env rcase expected
265 |