0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.IR
 18 |
 19 | import Control.Monad.State
 20 | import Data.Primitives.Interpolation
 21 | import public Compiler.Stablehlo.Dialect.StablehloEnums
 22 |
 23 | import Derive.Prelude
 24 | import Language.Reflection
 25 |
 26 | import Compiler.LiteralRW
 27 | import Compiler.Xla.XlaData
 28 | import Literal
 29 | import Primitive
 30 | import Types
 31 | import Util
 32 |
 33 | %language ElabReflection
 34 |
 35 | Show a => Interpolation (List a) where
 36 |   interpolate = show
 37 |
 38 | public export
 39 | data ValueType : Type where
 40 |   TensorType : Shape -> (0 dtype : Type) -> Primitive dtype => ValueType
 41 |
 42 | covering
 43 | Show ValueType where
 44 |   show (TensorType shape dtype) = "\{shape} \{repr {dtype}}"
 45 |
 46 | public export
 47 | data Op : Type
 48 |
 49 | public export
 50 | data Value = V Nat Op
 51 |
 52 | -- we use `List (Nat, Op)` for O(1) append (all we do when building the graph is append)
 53 | -- we can't use `(Nat, List Op)`, or even better `(n ** Vect n Op)`, because we don't handle
 54 | -- scoping properly so node pointers aren't contiguous and don't match list indices
 55 | export
 56 | data Env = MkEnv Nat (List (Nat, Op))
 57 |
 58 | export
 59 | empty : Env
 60 | empty = MkEnv 1 []  -- root function takes 0
 61 |
 62 | export
 63 | emptyFrom : Env -> Env
 64 | emptyFrom (MkEnv n _) = MkEnv n []
 65 |
 66 | export
 67 | updateCounterFrom : Env -> State Env ()
 68 | updateCounterFrom (MkEnv n _) = do
 69 |   MkEnv _ xs <- get
 70 |   put $ MkEnv n xs
 71 |
 72 | export
 73 | toList : Env -> List (Nat, Op)
 74 | toList (MkEnv _ env) = reverse env
 75 |
 76 | export
 77 | counter : Env -> Nat
 78 | counter (MkEnv c _) = c
 79 |
 80 | public export
 81 | record Fn (arity : Nat) where
 82 |   constructor MkFn
 83 |   tag : Nat
 84 |   paramTypes : Vect arity ValueType
 85 |   resultTypes : Vect resultCount ValueType
 86 |   results : Vect resultCount Value
 87 |   env : Env
 88 |
 89 | public export
 90 | data BinaryOp =
 91 |   Compare ComparisonDirection | And | Or | Add | Sub | Mul | Div | Rem | Pow | Min | Max
 92 |   | ShiftRightLogical
 93 |
 94 | %runElab derive "ComparisonDirection" [Show]
 95 | %runElab derive "BinaryOp" [Show]
 96 |
 97 | public export
 98 | data UnaryOp =
 99 |     Not | Neg | Ceil | Floor | Abs | Log | Exp | Logistic | Sqrt | Sin | Cos | Tan | Tanh
100 |   | Erf | ErfInv | Square | Asin | Acos | Atan | Sinh | Cosh | Asinh | Acosh | Atanh
101 |
102 | %runElab derive "UnaryOp" [Show]
103 |
104 | public export
105 | data Op : Type where
106 |   BoundSet : Nat -> Op
107 |
108 |   Lit : PrimitiveRW dtype ty => {shape : _} -> Literal shape ty -> Op
109 |   Grad : Shape -> Fn 1 -> Value -> Op
110 |   MinValue : (0 dtype : _) -> Primitive dtype => Op
111 |   MaxValue : (0 dtype : _) -> Primitive dtype => Op
112 |   MinFiniteFloat : Op
113 |   MaxFiniteFloat : Op
114 |   Iota : Primitive dtype => (shape : Shape) -> (axis : Nat) -> Op
115 |   BitCastConvert : Primitive dtype => Shape -> Value -> Op
116 |   Convert : Primitive dtype => Shape -> Value -> Op
117 |   Reshape : Primitive dtype => Shape -> Value -> Op
118 |   Slice : (starts, stops, strides : List Nat) -> Value -> Op
119 |   DynamicSlice : (starts : List Value) -> (sizes : List Nat) -> Value -> Op
120 |   Concat : (axis : Nat) -> Vect (S n) Value -> Op
121 |   Transpose : (ordering : List Nat) -> Value -> Op
122 |   Broadcast : Primitive dtype => (from, to : Shape) -> Value -> Op
123 |   Map : Fn arity -> Vect arity Value -> (resultType : ValueType) -> Shape -> Op
124 |   Reduce : Fn (n + n) -> (inits : Vect n Value) -> (axes : List Nat) -> Vect n Value -> Op
125 |   Sort : Fn 2 -> (axis : Nat) -> (isStable : Bool) -> Value -> Op
126 |   Reverse : (axes : List Nat) -> Value -> Op
127 |   BinaryElementwise : BinaryOp -> Value -> Value -> Op
128 |   UnaryElementwise : UnaryOp -> Value -> Op
129 |   Select : (predicate, onTrue, onFalse : Value) -> Op
130 |   While : (condition, body : Fn n) -> (init : Vect n Value) -> Op
131 |   If : (resultType : ValueType) -> (predicate : Value) -> (onTrue, onFalse : Fn 0) -> Op
132 |   DotGeneral :
133 |     (lBatch, lContract, rBatch, rContract: List Nat) ->
134 |     (resultType : ValueType) ->
135 |     Value ->
136 |     Value ->
137 |     Op
138 |   Cholesky : Value -> Op
139 |   TriangularSolve : Value -> Value -> (isLower : Bool) -> Op
140 |   Rng : (state : Value) -> (resultType : ValueType) -> Op
141 |
142 | export
143 | tagOp : Monad m => Op -> StateT Env m Op
144 | tagOp expr = do
145 |   MkEnv next env <- get
146 |   put $ MkEnv (S next) ((next, expr) :: env)
147 |   pure (BoundSet next)
148 |
149 | export
150 | reserve : State Env Nat
151 | reserve = do
152 |   MkEnv next env <- get
153 |   put $ MkEnv (S next) env
154 |   pure next
155 |
156 | covering
157 | showOp : Nat -> Op -> String
158 |
159 | covering
160 | showValue : Nat -> Value -> String
161 | showValue indent (V idx op) = "(\{showOp indent op}):\{show idx}"
162 |
163 | covering
164 | showValueList : Nat -> List Value -> String
165 | showValueList indent xs = "[" ++ joinBy ", " (toList $ map (showValue indent) xs) ++ "]"
166 |
167 | covering
168 | showEnv : Nat -> Env -> String
169 | showEnv indent (MkEnv max env) = joinBy "\n" $ assert_total $ map fmt (reverse env)
170 |
171 |   where
172 |
173 |   fmt : (Nat, Op) -> String
174 |   fmt (n, x) =
175 |     let sep = replicate (4 + length (show max) `minus` length (show n)) ' '
176 |      in "\{replicate indent ' '}\{show n}\{sep}\{showOp indent x}"
177 |
178 | covering
179 | showFn : Nat -> Fn arity -> String
180 | showFn indent (MkFn parameterSetTag paramTypes resultTypes results env@(MkEnv _ env')) =
181 |   let params = "\{show parameterSetTag} \{show paramTypes}"
182 |       res = "\{showValueList (indent + 2) $ toList results}" in
183 |   case env' of
184 |     [] => "\{params} => \{res}"
185 |     _  =>
186 |       "\{params} => \{res} with vars {\n\{showEnv (indent + 4) env}\n\{replicate (indent + 2) ' '}}"
187 |
188 | export Show (Fn arity) where show = assert_total $ showFn 0
189 |
190 | showOp indent (Lit {shape, dtype} x) = "Lit \{shape} \{repr {dtype}}"
191 | showOp indent (BoundSet k) = "Bound \{k}"
192 | showOp indent (Grad _ op x) = "Grad {op = \{showFn indent op}} \{showValue indent x}"
193 | showOp index (MinValue dtype) = "MinValue \{repr {dtype}}"
194 | showOp index (MaxValue dtype) = "MaxValue \{repr {dtype}}"
195 | showOp index MinFiniteFloat = "MinFiniteFloat"
196 | showOp index MaxFiniteFloat = "MaxFiniteFloat"
197 | showOp indent (Iota {dtype} shape axis) =
198 |   "Iota {shape = \{show shape}, dtype = \{repr {dtype}}, axis = \{axis}}"
199 | showOp indent (Convert {dtype} shape x) =
200 |   "Convert {dtype = \{repr {dtype}}} \{showValue indent x}"
201 | showOp indent (BitCastConvert {dtype} shape x) =
202 |   "BitCastConvert {dtype = \{repr {dtype}}} \{showValue indent x}"
203 | showOp indent (Reshape to x) = "Reshape {to = \{to}} \{showValue indent x}"
204 | showOp indent (Slice starts stops strides x) =
205 |   "Slice {starts = \{starts}, stops = \{stops}, strides = \{strides}} \{showValue indent x}"
206 | showOp indent (DynamicSlice starts sizes x) =
207 |   "DynamicSlice {starts = \{showValueList indent starts}, sizes = \{sizes}} \{showValue indent x}"
208 | showOp indent (Concat axis xs) = "Concat {axis = \{axis}} \{showValueList indent $ toList xs}"
209 | showOp indent (Transpose ordering x) = "Transpose {ordering = \{ordering}} \{showValue indent x}"
210 | showOp indent (Broadcast from to x) =
211 |   "Broadcast {from = \{from}, to = \{to}} \{showValue indent x}"
212 | showOp indent (Map f xs _ _) = "Map {f = \{showFn indent f}} \{showValueList indent $ toList xs}"
213 | showOp indent (Reduce op neutrals axes xs) =
214 |   "Reduce {op = \{showFn indent op}, inits = \{showValueList indent $ toList neutrals}," ++
215 |     " axes = \{axes}} \{showValueList indent $ toList xs}"
216 | showOp indent (Sort f axis _ xs) =
217 |   "Sort {f = \{showFn indent f}, axis = \{axis}} \{showValue indent xs}"
218 | showOp indent (Reverse axes x) = "Reverse \{axes} \{showValue indent x}"
219 | showOp indent (BinaryElementwise op x y) = "\{show op} \{showValue indent x} \{showValue indent y}"
220 | showOp indent (UnaryElementwise op x) = "\{show op} \{showValue indent x}"
221 | showOp indent (Select p t f) =
222 |   "Select {predicate = \{showValue indent p}, onTrue = \{showValue indent t}," ++
223 |     " onFalse = \{showValue indent f}}"
224 | showOp indent (While c b is) =
225 |   "While {condition = \{showFn indent c}, body = \{showFn indent b}," ++
226 |     " initials = \{showValueList indent $ toList is}}"
227 | showOp indent (If _ p ft ff) =
228 |   "If {predicate = \{showValue indent p}, onTrue = \{showFn indent ft}," ++
229 |     " onFalse = \{showFn indent ff}}"
230 | showOp indent (DotGeneral lBatch lContract rBatch rContract _ x y) =
231 |   "DotGeneral {lBatch = \{lBatch}, lContract = \{lContract}," ++
232 |     " rBatch = \{rBatch}, rContract = \{rContract}} \{showValue indent x} \{showValue indent y}"
233 | showOp indent (Cholesky x) = "Cholesky \{showValue indent x}"
234 | showOp indent (TriangularSolve x y isLower) =
235 |   "TriangularSolve {isLower = \{show isLower}} \{showValue indent x} \{showValue indent y}"
236 | showOp indent (Rng state shape) = "Rng {state = \{showValue indent state}, shape = \{show shape}}"
237 |
238 | export Show Op where show = assert_total $ showOp 0
239 |
240 | export
241 | Show Value where
242 |   show (V pos x) = "V \{show pos} \{show x}"
243 |