0 | module Compiler.VMCode
  1 |
  2 | import Compiler.ANF
  3 |
  4 | import Core.CompileExpr
  5 | import Core.TT
  6 |
  7 | import Libraries.Data.IntMap
  8 | import Data.List
  9 | import Data.Vect
 10 |
 11 | %default covering
 12 |
 13 | public export
 14 | data Reg : Type where
 15 |      RVal : Reg
 16 |      Loc : Int -> Reg
 17 |      Discard : Reg
 18 |
 19 | -- VM instructions - first Reg is where the result goes, unless stated
 20 | -- otherwise.
 21 |
 22 | -- As long as you have a representation of closures, and an 'apply' function
 23 | -- which adds an argument and evaluates if it's fully applied, then you can
 24 | -- translate this directly to a target language program.
 25 | public export
 26 | data VMInst : Type where
 27 |      DECLARE : Reg -> VMInst
 28 |      START : VMInst -- start of the main body of the function
 29 |      ASSIGN : Reg -> Reg -> VMInst
 30 |      MKCON : Reg -> (tag : Either Int Name) -> (args : List Reg) -> VMInst
 31 |      MKCLOSURE : Reg -> Name -> (missing : Nat) -> (args : List Reg) -> VMInst
 32 |      MKCONSTANT : Reg -> Constant -> VMInst
 33 |
 34 |      APPLY : Reg -> (f : Reg) -> (a : Reg) -> VMInst
 35 |      CALL : Reg -> (tailpos : Bool) -> Name -> (args : List Reg) -> VMInst
 36 |      OP : {0 arity : Nat} -> Reg -> PrimFn arity -> Vect arity Reg -> VMInst
 37 |        --  ^ we explicitly bind arity here to silence the warnings it is shadowing
 38 |        -- an existing global definition
 39 |      EXTPRIM : Reg -> Name -> List Reg -> VMInst
 40 |
 41 |      CASE : Reg -> -- scrutinee
 42 |             (alts : List (Either Int Name, List VMInst)) -> -- based on constructor tag
 43 |             (def : Maybe (List VMInst)) ->
 44 |             VMInst
 45 |      CONSTCASE : Reg -> -- scrutinee
 46 |                  (alts : List (Constant, List VMInst)) ->
 47 |                  (def : Maybe (List VMInst)) ->
 48 |                  VMInst
 49 |      PROJECT : Reg -> (value : Reg) -> (pos : Int) -> VMInst
 50 |      NULL : Reg -> VMInst
 51 |
 52 |      ERROR : String -> VMInst
 53 |
 54 | public export
 55 | data VMDef : Type where
 56 |      MkVMFun : (args : List Int) -> List VMInst -> VMDef
 57 |      MkVMForeign : (ccs : List String) -> (fargs : List CFType) ->
 58 |                    CFType -> VMDef
 59 |      MkVMError : List VMInst -> VMDef
 60 |
 61 | export
 62 | Show Reg where
 63 |   show RVal = "RVAL"
 64 |   show (Loc i) = "v" ++ show i
 65 |   show Discard = "DISCARD"
 66 |
 67 | export
 68 | covering
 69 | Show VMInst where
 70 |   show (DECLARE r) = "DECLARE " ++ show r
 71 |   show START = "START"
 72 |   show (ASSIGN r v) = show r ++ " := " ++ show v
 73 |   show (MKCON r t args)
 74 |       = show r ++ " := MKCON " ++ show t ++ " (" ++
 75 |                   showSep ", " (map show args) ++ ")"
 76 |   show (MKCLOSURE r n m args)
 77 |       = show r ++ " := MKCLOSURE " ++ show n ++ " " ++ show m ++ " (" ++
 78 |                   showSep ", " (map show args) ++ ")"
 79 |   show (MKCONSTANT r c) = show r ++ " := MKCONSTANT " ++ show c
 80 |   show (APPLY r f a) = show r ++ " := " ++ show f ++ " @ " ++ show a
 81 |   show (CALL r t n args)
 82 |       = show r ++ " := " ++ (if t then "TAILCALL " else "CALL ") ++
 83 |         show n ++ "(" ++ showSep ", " (map show args) ++ ")"
 84 |   show (OP r op args)
 85 |       = show r ++ " := " ++ "OP " ++
 86 |         show op ++ "(" ++ showSep ", " (map show (toList args)) ++ ")"
 87 |   show (EXTPRIM r n args)
 88 |       = show r ++ " := " ++ "EXTPRIM " ++
 89 |         show n ++ "(" ++ showSep ", " (map show args) ++ ")"
 90 |
 91 |   show (CASE scr alts def)
 92 |       = "CASE " ++ show scr ++ " " ++ show alts ++ " {default: " ++ show def ++ "}"
 93 |   show (CONSTCASE scr alts def)
 94 |       = "CASE " ++ show scr ++ " " ++ show alts ++ " {default: " ++ show def ++ "}"
 95 |
 96 |   show (PROJECT r val pos)
 97 |       = show r ++ " := PROJECT(" ++ show val ++ ", " ++ show pos ++ ")"
 98 |   show (NULL r) = show r ++ " := NULL"
 99 |   show (ERROR str) = "ERROR " ++ show str
100 |
101 | export
102 | covering
103 | Show VMDef where
104 |   show (MkVMFun args body) = show args ++ ": " ++ show body
105 |   show (MkVMForeign ccs args ret)
106 |       = "Foreign call " ++ show ccs ++ " " ++
107 |         show args ++ " " ++ show ret
108 |   show (MkVMError err) = "Error: " ++ show err
109 |
110 | toReg : AVar -> Reg
111 | toReg (ALocal i) = Loc i
112 | toReg ANull = Discard
113 |
114 | projectArgs : Int -> Int -> (used : IntMap ()) -> (args : List Int) -> List VMInst
115 | projectArgs scr i used [] = []
116 | projectArgs scr i used (arg :: args)
117 |     = case lookup arg used of
118 |            Just _ => PROJECT (Loc arg) (Loc scr) i :: projectArgs scr (i + 1) used args
119 |            Nothing => projectArgs scr (i + 1) used args
120 |
121 | collectReg : Reg -> IntMap ()
122 | collectReg (Loc i) = singleton i ()
123 | collectReg _ = empty
124 |
125 | collectUsed : VMInst -> IntMap ()
126 | collectUsed (DECLARE reg) = collectReg reg
127 | collectUsed START = empty
128 | collectUsed (ASSIGN _ val) = collectReg val
129 | collectUsed (MKCON _ _ args) = foldMap collectReg args
130 | collectUsed (MKCLOSURE _ _ _ args) = foldMap collectReg args
131 | collectUsed (MKCONSTANT {}) = empty
132 | collectUsed (APPLY _ fn arg) = collectReg fn <+> collectReg arg
133 | collectUsed (CALL _ _ _ args) = foldMap collectReg args
134 | collectUsed (OP _ _ args) = foldMap collectReg args
135 | collectUsed (EXTPRIM _ _ args) = foldMap collectReg args
136 | collectUsed (CASE sc is mdef)
137 |     = collectReg sc
138 |       <+> foldMap (foldMap collectUsed . snd) is
139 |       <+> maybe empty (foldMap collectUsed) mdef
140 | collectUsed (CONSTCASE sc is mdef)
141 |     = collectReg sc
142 |       <+> foldMap (foldMap collectUsed . snd) is
143 |       <+> maybe empty (foldMap collectUsed) mdef
144 | collectUsed (PROJECT _ val _) = collectReg val
145 | collectUsed (NULL _) = empty
146 | collectUsed (ERROR _) = empty
147 |
148 | toVM : (tailpos : Bool) -> (target : Reg) -> ANF -> List VMInst
149 | toVM t Discard _ = []
150 | toVM t res (AV fc (ALocal i))
151 |     = [ASSIGN res (Loc i)]
152 | toVM t res (AAppName fc _ n args)
153 |     = [CALL res t n (map toReg args)]
154 | toVM t res (AUnderApp fc n m args)
155 |     = [MKCLOSURE res n m (map toReg args)]
156 | toVM t res (AApp fc _ f a)
157 |     = [APPLY res (toReg f) (toReg a)]
158 | toVM t res (ALet fc var val body)
159 |     = toVM False (Loc var) val ++ toVM t res body
160 | toVM t res (ACon fc n ci (Just tag) args)
161 |     = [MKCON res (Left tag) (map toReg args)]
162 | toVM t res (ACon fc n ci Nothing args)
163 |     = [MKCON res (Right n) (map toReg args)]
164 | toVM t res (AOp fc _ op args)
165 |     = [OP res op (map toReg args)]
166 | toVM t res (AExtPrim fc _ p args)
167 |     = [EXTPRIM res p (map toReg args)]
168 | toVM t res (AConCase fc (ALocal scr) [MkAConAlt n ci mt args code] Nothing) -- exactly one alternative, so skip matching
169 |     = let body = toVM t res code
170 |           used = foldMap collectUsed body
171 |        in projectArgs scr 0 used args ++ body
172 | toVM t res (AConCase fc (ALocal scr) alts def)
173 |     = [CASE (Loc scr) (map toVMConAlt alts) (map (toVM t res) def)]
174 |   where
175 |     toVMConAlt : AConAlt -> (Either Int Name, List VMInst)
176 |     toVMConAlt (MkAConAlt n ci tag args code)
177 |        = let body = toVM t res code
178 |              used = foldMap collectUsed body
179 |           in (maybe (Right n) Left tag, projectArgs scr 0 used args ++ body)
180 | toVM t res (AConstCase fc (ALocal scr) alts def)
181 |     = [CONSTCASE (Loc scr) (map toVMConstAlt alts) (map (toVM t res) def)]
182 |   where
183 |     toVMConstAlt : AConstAlt -> (Constant, List VMInst)
184 |     toVMConstAlt (MkAConstAlt c code)
185 |         = (c, toVM t res code)
186 | toVM t res (APrimVal fc c)
187 |     = [MKCONSTANT res c]
188 | toVM t res (AErased fc)
189 |     = [NULL res]
190 | toVM t res (ACrash fc err)
191 |     = [ERROR err]
192 | toVM t res _
193 |     = [NULL res]
194 |
195 | findVars : VMInst -> List Int
196 | findVars (ASSIGN (Loc r) _) = [r]
197 | findVars (MKCON (Loc r) _ _) = [r]
198 | findVars (MKCLOSURE (Loc r) _ _ _) = [r]
199 | findVars (MKCONSTANT (Loc r) _) = [r]
200 | findVars (APPLY (Loc r) _ _) = [r]
201 | findVars (CALL (Loc r) _ _ _) = [r]
202 | findVars (OP (Loc r) _ _) = [r]
203 | findVars (EXTPRIM (Loc r) _ _) = [r]
204 | findVars (CASE _ alts d)
205 |     = foldMap findVarAlt alts ++ fromMaybe [] (map (foldMap findVars) d)
206 |   where
207 |     findVarAlt : (Either Int Name, List VMInst) -> List Int
208 |     findVarAlt (t, code) = foldMap findVars code
209 | findVars (CONSTCASE _ alts d)
210 |     = foldMap findConstVarAlt alts ++ fromMaybe [] (map (foldMap findVars) d)
211 |   where
212 |     findConstVarAlt : (Constant, List VMInst) -> List Int
213 |     findConstVarAlt (t, code) = foldMap findVars code
214 | findVars (PROJECT (Loc r) _ _) = [r]
215 | findVars _ = []
216 |
217 | declareVars : List Int -> List VMInst -> List VMInst
218 | declareVars got code
219 |     = let vs = foldMap findVars code in
220 |           declareAll got vs
221 |   where
222 |     declareAll : List Int -> List Int -> List VMInst
223 |     declareAll got [] = START :: code
224 |     declareAll got (i :: is)
225 |         = if i `elem` got
226 |              then declareAll got is
227 |              else DECLARE (Loc i) :: declareAll (i :: got) is
228 |
229 | export
230 | toVMDef : ANFDef -> Maybe VMDef
231 | toVMDef (MkAFun args body)
232 |     = Just $ MkVMFun args (declareVars args (toVM True RVal body))
233 | toVMDef (MkAForeign ccs cargs ret)
234 |     = Just $ MkVMForeign ccs cargs ret
235 | toVMDef (MkAError body)
236 |     = Just $ MkVMError (declareVars [] (toVM True RVal body))
237 | toVMDef _ = Nothing
238 |
239 | export
240 | allDefs : List (Name, ANFDef) -> List (Name, VMDef)
241 | allDefs = mapMaybe (\ (n, d) => do d' <- toVMDef dpure (n, d'))
242 |