0 | module Compiler.Interpreter.VMCode
  1 |
  2 | import Core.Primitives
  3 | import Core.Value
  4 |
  5 | import Compiler.Common
  6 | import Compiler.VMCode
  7 |
  8 | import Idris.Syntax
  9 |
 10 | import Data.IOArray
 11 | import Data.Vect
 12 | import Libraries.Data.NameMap
 13 |
 14 | public export
 15 | data Object : Type where
 16 |     Closure : (predMissing : Nat) -> (args : SnocList Object) -> Name -> Object
 17 |     Constructor : (tag : Either Int Name) -> (args : List Object) -> Object
 18 |     Const : Constant -> Object
 19 |     Null : Object
 20 |
 21 | showType : Object -> String
 22 | showType (Closure {}) = "Closure"
 23 | showType (Constructor {}) = "Constructor"
 24 | showType (Const {}) = "Constant"
 25 | showType Null = "Null"
 26 |
 27 | mutual
 28 |     showSep : Nat -> List Object -> String
 29 |     showSep k [] = ""
 30 |     showSep k [o] = showDepth k o
 31 |     showSep k (o :: os) = showDepth k o ++ ", " ++ showSep k os
 32 |
 33 |     showDepth : Nat -> Object -> String
 34 |     showDepth (S k) (Closure mis args fn) = show fn ++ "-" ++ show mis ++ "(" ++ showSep k (args <>> []) ++ ")"
 35 |     showDepth (S k) (Constructor (Left t) args) = "tag" ++ show t ++ "(" ++ showSep k args ++ ")"
 36 |     showDepth (S k) (Const c) = show c
 37 |     showDepth _ obj = showType obj
 38 |
 39 | Show Object where
 40 |     show = showDepth 5
 41 |
 42 | data State : Type where
 43 | record InterpState where
 44 |     constructor MkInterpState
 45 |     defs : NameMap VMDef
 46 |     locals : IOArray Object
 47 |     returnObj : Maybe Object
 48 |
 49 | initInterpState : List (Name, VMDef) -> Core InterpState
 50 | initInterpState defsList = do
 51 |     let defs = fromList defsList
 52 |     locals <- coreLift $ newArray 0
 53 |     let returnObj = Nothing
 54 |     pure $ MkInterpState defs locals returnObj
 55 |
 56 | 0
 57 | Stack : Type
 58 | Stack = List Name
 59 |
 60 | interpError : Ref State InterpState => Stack -> String -> Core a
 61 | interpError stk msg = do
 62 |     MkInterpState _ ls ret <- get State
 63 |     lsList <- coreLift $ toList ls
 64 |     throw $ InternalError $ "Interpreter Error in " ++ show (take 10 stk) ++ ":\n" ++ msg
 65 |         ++ "\n\nlocals:\n" ++ showWithIndex lsList
 66 |         ++ "\nreturn:\n  " ++ show ret
 67 |   where
 68 |     showWithIndex : forall a. {default 0 idx : Nat} -> Show a => List a -> String
 69 |     showWithIndex {idx} [] = ""
 70 |     showWithIndex {idx} (x :: xs) = "  " ++ show idx ++ ": " ++ show x ++ "\n" ++ showWithIndex {idx = S idx} xs
 71 |
 72 | getReg : Ref State InterpState => Stack -> Reg -> Core Object
 73 | getReg stk (Loc i) = do
 74 |     ls <- locals <$> get State
 75 |     objm <- coreLift $ readArray ls i
 76 |     case objm of
 77 |         Just obj => pure obj
 78 |         Nothing =>
 79 |             interpError stk $ "Missing local " ++ show i
 80 | getReg stk RVal = do
 81 |     objm <- returnObj <$> get State
 82 |     case objm of
 83 |         Just obj => pure obj
 84 |         Nothing => interpError stk "Missing returnObj val"
 85 | getReg stk Discard = pure Null
 86 |
 87 | setReg : Ref State InterpState => Stack -> Reg -> Object -> Core ()
 88 | setReg stk RVal obj = update State $ { returnObj := Just obj }
 89 | setReg stk (Loc i) obj = do
 90 |     ls <- locals <$> get State
 91 |     when (i >= max ls) $ interpError stk $ "Attempt to set register: " ++ show i ++ ", size of locals: " ++ show (max ls)
 92 |     coreLift_ $ writeArray ls i obj
 93 | setReg stk Discard _ = pure ()
 94 |
 95 | saveLocals : Ref State InterpState => Core a -> Core a
 96 | saveLocals act = do
 97 |     st <- get State
 98 |     x <- act
 99 |     put State st
100 |     pure x
101 |
102 | total
103 | indexMaybe : List a -> Int -> Maybe a
104 | indexMaybe [] _ = Nothing
105 | indexMaybe (x :: xs) idx = if idx <= 0 then Just x else indexMaybe xs (idx - 1)
106 |
107 | callPrim : Ref State InterpState => Stack -> PrimFn ar -> Vect ar Object -> Core Object
108 | callPrim stk BelieveMe [_, _, obj] = pure obj
109 | callPrim stk fn args = case the (Either Object (Vect ar Constant)) $ traverse getConst args of
110 |     Right args' => case getOp {vars=Scope.empty} fn (NPrimVal EmptyFC <$> args') of
111 |         Just (NPrimVal _ res) => pure $ Const res
112 |         _ => interpError stk $ "OP: Error calling " ++ show (opName fn) ++ " with operands: " ++ show args'
113 |     Left obj => interpError stk $ "OP: Expected Constant, found " ++ showType obj
114 |   where
115 |     getConst : Object -> Either Object Constant
116 |     getConst (Const c) = Right c
117 |     getConst o = Left o
118 |
119 | NS_UN : Namespace -> String -> Name
120 | NS_UN ns un = NS ns (UN $ Basic un)
121 |
122 | argError : Ref State InterpState => Stack -> Vect h Object -> Core a
123 | argError stk obj = interpError stk $ "Unexpected arguments: " ++ foldMap ((" " ++) . showDepth 1) obj
124 |
125 | unit : Object
126 | unit = Const (I 0)
127 |
128 | ioRes : Object -> Object
129 | ioRes obj = obj -- ioRes is a newtype -- Constructor (Left 0) [Const WorldVal, obj]
130 |
131 | -- TODO: add more?
132 | knownForeign : NameMap (ar ** (Ref State InterpState => Stack -> Vect ar Object -> Core Object))
133 | knownForeign = fromList
134 |     [ (NS_UN ioNS "prim__putChar", (2 ** prim_putChar))
135 |     , (NS_UN ioNS "prim__getChar", (1 ** prim_getChar))
136 |     , (NS_UN ioNS "prim__getStr", (1 ** prim_getStr))
137 |     , (NS_UN ioNS "prim__putStr", (2 ** prim_putStr))
138 |     ]
139 |   where
140 |     -- %MkWorld should not be matched on
141 |     -- however a value of type %World should only be %MkWorld or and erased value
142 |     world : Ref State InterpState => Stack -> Object -> Core ()
143 |     world stk Null = pure ()
144 |     world stk (Const WorldVal) = pure ()
145 |     world stk o = interpError stk $ "expected %MkWorld or Null, got \{show o}"
146 |
147 |     prim_putChar : Ref State InterpState => Stack -> Vect 2 Object -> Core Object
148 |     prim_putChar stk [Const (Ch c), w] = world stk w *> (ioRes unit <$ coreLift_ (putChar c))
149 |     prim_putChar stk as = argError stk as
150 |
151 |     prim_getChar : Ref State InterpState => Stack -> Vect 1 Object -> Core Object
152 |     prim_getChar stk [w] = world stk w *> (ioRes . Const . Ch <$> coreLift getChar)
153 |     prim_getChar stk as = argError stk as
154 |
155 |     prim_getStr : Ref State InterpState => Stack -> Vect 1 Object -> Core Object
156 |     prim_getStr stk [w] = world stk w *> (ioRes . Const . Str <$> coreLift getLine)
157 |     prim_getStr stk as = argError stk as
158 |
159 |     prim_putStr : Ref State InterpState => Stack -> Vect 2 Object -> Core Object
160 |     prim_putStr stk [Const (Str s), w] = world stk w *> (ioRes unit <$ coreLift_ (putStr s))
161 |     prim_putStr stk as = argError stk as
162 |
163 | knownExtern : NameMap (ar ** (Ref State InterpState => Stack -> Vect ar Object -> Core Object))
164 | knownExtern = empty
165 |
166 | beginFunction : Ref State InterpState => List (Int, Object) -> List VMInst -> Int -> Core (List VMInst)
167 | beginFunction args (DECLARE (Loc i) :: is) maxLoc = beginFunction args is (Prelude.max i maxLoc)
168 | beginFunction args (DECLARE _ :: is) maxLoc = beginFunction args is maxLoc
169 | beginFunction args (START :: is) maxLoc = do
170 |     locals <- coreLift $ newArray (maxLoc + 1)
171 |     traverse_ (\(idx, arg) => coreLift $ writeArray locals idx arg) args
172 |     update State $ { locals := locals, returnObj := Nothing }
173 |     pure is
174 | beginFunction args is maxLoc = pure is
175 |
176 | parameters {auto c : Ref Ctxt Defs}
177 |   mutual
178 |     step : Stack -> Ref State InterpState => VMInst -> Core ()
179 |     step stk (DECLARE _) = pure ()
180 |     step stk START = pure ()
181 |     step stk (ASSIGN target val) = do
182 |         valObj <- getReg stk val
183 |         setReg stk target valObj
184 |     step stk (MKCON target tag args) = do
185 |         argObjs <- traverse (getReg stk) args
186 |         setReg stk target $ Constructor tag argObjs
187 |     step stk (MKCLOSURE target fn missing args) = do
188 |         argObjs <- traverse (getReg stk) args
189 |         setReg stk target $ Closure (pred missing) ([<] <>< argObjs) fn
190 |     step stk (MKCONSTANT target c) = setReg stk target $ Const c
191 |     step stk (APPLY target fn arg) = do
192 |         fnObj <- getReg stk fn
193 |         argObj <- getReg stk arg
194 |         case fnObj of
195 |             Closure Z args fn => do
196 |                 res <- callFunc stk fn (args <>> [argObj])
197 |                 setReg stk target res
198 |             Closure (S k) args fn => setReg stk target $ Closure k (args :< argObj) fn
199 |             obj => interpError stk $ "APPLY: While applying " ++ show fn ++ ", expected closure, found: " ++ show obj
200 |     step stk (CALL target _ fn args) = do
201 |         argObjs <- traverse (getReg stk) args
202 |         res <- callFunc stk fn argObjs
203 |         setReg stk target res
204 |     step stk (OP target fn args) = do
205 |         argObjs <- traverseVect (getReg stk) args
206 |         res <- callPrim stk fn argObjs
207 |         setReg stk target res
208 |     step stk (EXTPRIM target fn args) = case lookup fn knownExtern of
209 |         Nothing => interpError stk $ "EXTPRIM: Unkown foreign function: " ++ show fn
210 |         Just (ar ** op=> case toVect ar args of
211 |             Nothing => interpError stk $ "EXTPRIM: Wrong number of arguments, found: " ++ show (length args) ++ ", expected: " ++ show ar
212 |             Just argsVect => do
213 |                 argObjs <- traverseVect (getReg stk) argsVect
214 |                 res <- op stk argObjs
215 |                 setReg stk target res
216 |     step stk (CASE sc alts def) = do
217 |         scObj <- getReg stk sc
218 |         case scObj of
219 |             Constructor tag _ => matchCon stk tag alts def
220 |             _ => interpError stk $ "CASE: Expected Constructor, found " ++ showType scObj
221 |       where
222 |         matchCon : Stack -> Either Int Name -> List (Either Int Name, List VMInst) -> Maybe (List VMInst) -> Core ()
223 |         matchCon stk tag [] Nothing = interpError stk "CASE: Missing matching alternative or default"
224 |         matchCon stk tag [] (Just is) = traverse_ (step stk) is
225 |         matchCon stk tag ((t, is) :: alts) def =
226 |             if tag == t
227 |                 then traverse_ (step stk) is
228 |                 else matchCon stk tag alts def
229 |     step stk (CONSTCASE sc alts def) = do
230 |         scObj <- getReg stk sc
231 |         case scObj of
232 |             Const c => matchConst stk c alts def
233 |             _ => interpError stk $ "CONSTCASE: Expected Constant, found " ++ showType scObj
234 |       where
235 |         matchConst : Stack -> Constant -> List (Constant, List VMInst) -> Maybe (List VMInst) -> Core ()
236 |         matchConst stk c [] Nothing = interpError stk "CONSTCASE: Missing matching alternative or default"
237 |         matchConst stk c [] (Just is) = traverse_ (step stk) is
238 |         matchConst stk c ((c', is) :: alts) def =
239 |             if c == c'
240 |                 then traverse_ (step stk) is
241 |                 else matchConst stk c alts def
242 |     step stk (PROJECT target val idx) = do
243 |         valObj <- getReg stk val
244 |         case valObj of
245 |             Constructor _ args => case indexMaybe args idx of
246 |                 Nothing => interpError stk
247 |                     $ "PROJECT: Unable to project index " ++ show idx
248 |                     ++ ", missing arguments for constructor:\n" ++ show valObj
249 |                 Just arg => setReg stk target arg
250 |             _ => interpError stk $ "PROJECT: Expected Constructor, found " ++ showType valObj
251 |     step stk (NULL reg) = setReg stk reg Null
252 |     step stk (ERROR msg) = interpError stk $ "ERROR: " ++ msg
253 |
254 |     callFunc : Ref State InterpState => Stack -> Name -> List Object -> Core Object
255 |     callFunc stk fn args = saveLocals $ do
256 |         logCallStack <- logging "compiler.interpreter" 25
257 |         let ind = if logCallStack then pack $ '|' <$ stk else ""
258 |         when logCallStack $ coreLift $ putStrLn $ ind ++ "Calling " ++ show fn ++ " with args: " ++ show args
259 |         let stk' = fn :: stk
260 |         defs <- defs <$> get State
261 |         res <- case lookup fn defs of
262 |             Nothing => interpError stk $ "Undefined function: " ++ show fn
263 |             Just (MkVMFun as is) => do
264 |                 when (length as /= length args) $ interpError stk
265 |                     $ "Unexpected argument count during function call, expected: "
266 |                     ++ show (length as) ++ ", found: " ++ show (length args)
267 |                 is' <- beginFunction (zip as args) is (foldl max (-1) as)
268 |                 traverse_ (step stk') is'
269 |                 getReg stk' RVal
270 |             Just (MkVMForeign {}) => case lookup fn knownForeign of
271 |                 Nothing => interpError stk $ "Unkown foreign function: " ++ show fn
272 |                 Just (ar ** op=> case toVect ar args of
273 |                     Nothing => interpError stk $ "Wrong number of arguments, found: " ++ show (length args) ++ ", expected: " ++ show ar
274 |                     Just argsVect => op stk argsVect
275 |             Just (MkVMError is) => do
276 |                 traverse_ (step stk') is
277 |                 getReg stk' RVal
278 |         when logCallStack $ coreLift $ putStrLn $ ind ++ "Result: " ++ show res
279 |         pure res
280 |
281 | compileExpr :
282 |   Ref Ctxt Defs ->
283 |   Ref Syn SyntaxInfo ->
284 |   String -> String -> ClosedTerm -> String -> Core (Maybe String)
285 | compileExpr _ _ _ _ _ _ = throw (InternalError "compile not implemeted for vmcode-interp")
286 |
287 | executeExpr :
288 |   Ref Ctxt Defs ->
289 |   Ref Syn SyntaxInfo ->
290 |   String -> ClosedTerm -> Core ()
291 | executeExpr c s _ tm = do
292 |     cdata <- getCompileData False VMCode tm
293 |     st <- newRef State !(initInterpState cdata.vmcode)
294 |     ignore $ callFunc [] (MN "__mainExpression" 0) []
295 |
296 | export
297 | codegenVMCodeInterp : Codegen
298 | codegenVMCodeInterp = MkCG compileExpr executeExpr Nothing Nothing
299 |