0 | {--
  1 | Copyright (C) 2022  Joel Berkeley
  2 |
  3 | This program is free software: you can redistribute it and/or modify
  4 | it under the terms of the GNU Affero General Public License as published
  5 | by the Free Software Foundation, either version 3 of the License, or
  6 | (at your option) any later version.
  7 |
  8 | This program is distributed in the hope that it will be useful,
  9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 11 | GNU Affero General Public License for more details.
 12 |
 13 | You should have received a copy of the GNU Affero General Public License
 14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
 15 | --}
 16 | ||| For internal spidr use only.
 17 | module Compiler.IR
 18 |
 19 | import Control.Monad.State
 20 | import Data.Primitives.Interpolation
 21 | import public Compiler.Stablehlo.Dialect.StablehloEnums
 22 |
 23 | import Derive.Prelude
 24 | import Language.Reflection
 25 |
 26 | import Compiler.LiteralRW
 27 | import Compiler.Xla.XlaData
 28 | import Literal
 29 | import DType
 30 | import Types
 31 | import Util
 32 |
 33 | %language ElabReflection
 34 |
 35 | Show a => Interpolation (List a) where
 36 |   interpolate = show
 37 |
 38 | public export
 39 | data ValueType = TensorType Shape DType
 40 |
 41 | Show ValueType where
 42 |   show (TensorType shape dtype) = "\{show shape} \{show dtype}"
 43 |
 44 | public export
 45 | data Op : Type
 46 |
 47 | public export
 48 | data Value = V Nat Op
 49 |
 50 | -- we use `List (Nat, Op)` for O(1) append (all we do when building the graph is append)
 51 | -- we can't use `(Nat, List Op)`, or even better `(n ** Vect n Op)`, because we don't handle
 52 | -- scoping properly so node pointers aren't contiguous and don't match list indices
 53 | export
 54 | data Env = MkEnv Nat (List (Nat, Op))
 55 |
 56 | export
 57 | empty : Env
 58 | empty = MkEnv 1 []  -- root function takes 0
 59 |
 60 | export
 61 | emptyFrom : Env -> Env
 62 | emptyFrom (MkEnv n _) = MkEnv n []
 63 |
 64 | export
 65 | updateCounterFrom : Env -> State Env ()
 66 | updateCounterFrom (MkEnv n _) = do
 67 |   MkEnv _ xs <- get
 68 |   put $ MkEnv n xs
 69 |
 70 | export
 71 | toList : Env -> List (Nat, Op)
 72 | toList (MkEnv _ env) = reverse env
 73 |
 74 | export
 75 | counter : Env -> Nat
 76 | counter (MkEnv c _) = c
 77 |
 78 | public export
 79 | record Fn (arity : Nat) where
 80 |   constructor MkFn
 81 |   tag : Nat
 82 |   paramTypes : Vect arity ValueType
 83 |   resultTypes : Vect resultCount ValueType
 84 |   results : Vect resultCount Value
 85 |   env : Env
 86 |
 87 | public export
 88 | data BinaryOp =
 89 |   Compare ComparisonDirection | And | Or | Add | Sub | Mul | Div | Rem | Pow | Min | Max
 90 |   | ShiftRightLogical
 91 |
 92 | %runElab derive "ComparisonDirection" [Show]
 93 | %runElab derive "BinaryOp" [Show]
 94 |
 95 | public export
 96 | data UnaryOp =
 97 |     Not | Neg | Ceil | Floor | Abs | Log | Exp | Logistic | Sqrt | Sin | Cos | Tan | Tanh
 98 |   | Erf | ErfInv | Square | Asin | Acos | Atan | Sinh | Cosh | Asinh | Acosh | Atanh
 99 |
100 | %runElab derive "UnaryOp" [Show]
101 |
102 | public export
103 | data Op : Type where
104 |   BoundSet : Nat -> Op
105 |
106 |   Lit : (shape : Shape) -> (dtype : DType) -> Literal shape (idrisType dtype) -> Op
107 |   Grad : Shape -> Fn 1 -> Value -> Op
108 |   MinValue : DType -> Op
109 |   MaxValue : DType -> Op
110 |   MinFiniteFloat : Op
111 |   MaxFiniteFloat : Op
112 |   Iota : (shape : Shape) -> DType -> (axis : Nat) -> Op
113 |   BitCastConvert : DType -> Shape -> Value -> Op
114 |   Convert : DType -> Shape -> Value -> Op
115 |   Reshape : DType -> Shape -> Value -> Op
116 |   Slice : (starts, stops, strides : List Nat) -> Value -> Op
117 |   DynamicSlice : (starts : List Value) -> (sizes : List Nat) -> Value -> Op
118 |   Concat : (axis : Nat) -> Vect (S n) Value -> Op
119 |   Transpose : (ordering : List Nat) -> Value -> Op
120 |   Broadcast : DType -> (from, to : Shape) -> Value -> Op
121 |   Map : Fn arity -> Vect arity Value -> (resultType : ValueType) -> Shape -> Op
122 |   Reduce : Fn (n + n) -> (inits : Vect n Value) -> (axes : List Nat) -> Vect n Value -> Op
123 |   Sort : Fn 2 -> (axis : Nat) -> (isStable : Bool) -> Value -> Op
124 |   Reverse : (axes : List Nat) -> Value -> Op
125 |   BinaryElementwise : BinaryOp -> Value -> Value -> Op
126 |   UnaryElementwise : UnaryOp -> Value -> Op
127 |   Select : (predicate, onTrue, onFalse : Value) -> Op
128 |   While : (condition, body : Fn n) -> (init : Vect n Value) -> Op
129 |   If : (resultType : ValueType) -> (predicate : Value) -> (onTrue, onFalse : Fn 0) -> Op
130 |   DotGeneral :
131 |     (lBatch, lContract, rBatch, rContract: List Nat) ->
132 |     (resultType : ValueType) ->
133 |     Value ->
134 |     Value ->
135 |     Op
136 |   Cholesky : Value -> Op
137 |   TriangularSolve : Value -> Value -> (isLower : Bool) -> Op
138 |   Rng : (state : Value) -> (resultType : ValueType) -> Op
139 |
140 | export
141 | tagOp : Monad m => Op -> StateT Env m Op
142 | tagOp expr = do
143 |   MkEnv next env <- get
144 |   put $ MkEnv (S next) ((next, expr) :: env)
145 |   pure (BoundSet next)
146 |
147 | export
148 | reserve : State Env Nat
149 | reserve = do
150 |   MkEnv next env <- get
151 |   put $ MkEnv (S next) env
152 |   pure next
153 |
154 | covering
155 | showOp : Nat -> Op -> String
156 |
157 | covering
158 | showValue : Nat -> Value -> String
159 | showValue indent (V idx op) = "(\{showOp indent op}):\{show idx}"
160 |
161 | covering
162 | showValueList : Nat -> List Value -> String
163 | showValueList indent xs = "[" ++ joinBy ", " (toList $ map (showValue indent) xs) ++ "]"
164 |
165 | covering
166 | showEnv : Nat -> Env -> String
167 | showEnv indent (MkEnv max env) = joinBy "\n" $ assert_total $ map fmt (reverse env)
168 |
169 |   where
170 |
171 |   fmt : (Nat, Op) -> String
172 |   fmt (n, x) =
173 |     let sep = replicate (4 + length (show max) `minus` length (show n)) ' '
174 |      in "\{replicate indent ' '}\{show n}\{sep}\{showOp indent x}"
175 |
176 | covering
177 | showFn : Nat -> Fn arity -> String
178 | showFn indent (MkFn parameterSetTag paramTypes resultTypes results env@(MkEnv _ env')) =
179 |   let params = "\{show parameterSetTag} \{show paramTypes}"
180 |       res = "\{showValueList (indent + 2) $ toList results}" in
181 |   case env' of
182 |     [] => "\{params} => \{res}"
183 |     _  =>
184 |       "\{params} => \{res} with vars {\n\{showEnv (indent + 4) env}\n\{replicate (indent + 2) ' '}}"
185 |
186 | export Show (Fn arity) where show = assert_total $ showFn 0
187 |
188 | showOp indent (Lit shape dtype x) = "Lit \{shape} \{show dtype}"
189 | showOp indent (BoundSet k) = "Bound \{k}"
190 | showOp indent (Grad _ op x) = "Grad {op = \{showFn indent op}} \{showValue indent x}"
191 | showOp index (MinValue dtype) = "MinValue \{show dtype}"
192 | showOp index (MaxValue dtype) = "MaxValue \{show dtype}"
193 | showOp index MinFiniteFloat = "MinFiniteFloat"
194 | showOp index MaxFiniteFloat = "MaxFiniteFloat"
195 | showOp indent (Iota dtype shape axis) =
196 |   "Iota {shape = \{show shape}, dtype = \{show dtype}, axis = \{axis}}"
197 | showOp indent (Convert dtype shape x) =
198 |   "Convert {dtype = \{show dtype}} \{showValue indent x}"
199 | showOp indent (BitCastConvert dtype shape x) =
200 |   "BitCastConvert {dtype = \{show dtype}} \{showValue indent x}"
201 | showOp indent (Reshape _ to x) = "Reshape {to = \{to}} \{showValue indent x}"
202 | showOp indent (Slice starts stops strides x) =
203 |   "Slice {starts = \{starts}, stops = \{stops}, strides = \{strides}} \{showValue indent x}"
204 | showOp indent (DynamicSlice starts sizes x) =
205 |   "DynamicSlice {starts = \{showValueList indent starts}, sizes = \{sizes}} \{showValue indent x}"
206 | showOp indent (Concat axis xs) = "Concat {axis = \{axis}} \{showValueList indent $ toList xs}"
207 | showOp indent (Transpose ordering x) = "Transpose {ordering = \{ordering}} \{showValue indent x}"
208 | showOp indent (Broadcast _ from to x) =
209 |   "Broadcast {from = \{from}, to = \{to}} \{showValue indent x}"
210 | showOp indent (Map f xs _ _) = "Map {f = \{showFn indent f}} \{showValueList indent $ toList xs}"
211 | showOp indent (Reduce op neutrals axes xs) =
212 |   "Reduce {op = \{showFn indent op}, inits = \{showValueList indent $ toList neutrals}," ++
213 |     " axes = \{axes}} \{showValueList indent $ toList xs}"
214 | showOp indent (Sort f axis _ xs) =
215 |   "Sort {f = \{showFn indent f}, axis = \{axis}} \{showValue indent xs}"
216 | showOp indent (Reverse axes x) = "Reverse \{axes} \{showValue indent x}"
217 | showOp indent (BinaryElementwise op x y) = "\{show op} \{showValue indent x} \{showValue indent y}"
218 | showOp indent (UnaryElementwise op x) = "\{show op} \{showValue indent x}"
219 | showOp indent (Select p t f) =
220 |   "Select {predicate = \{showValue indent p}, onTrue = \{showValue indent t}," ++
221 |     " onFalse = \{showValue indent f}}"
222 | showOp indent (While c b is) =
223 |   "While {condition = \{showFn indent c}, body = \{showFn indent b}," ++
224 |     " initials = \{showValueList indent $ toList is}}"
225 | showOp indent (If _ p ft ff) =
226 |   "If {predicate = \{showValue indent p}, onTrue = \{showFn indent ft}," ++
227 |     " onFalse = \{showFn indent ff}}"
228 | showOp indent (DotGeneral lBatch lContract rBatch rContract _ x y) =
229 |   "DotGeneral {lBatch = \{lBatch}, lContract = \{lContract}," ++
230 |     " rBatch = \{rBatch}, rContract = \{rContract}} \{showValue indent x} \{showValue indent y}"
231 | showOp indent (Cholesky x) = "Cholesky \{showValue indent x}"
232 | showOp indent (TriangularSolve x y isLower) =
233 |   "TriangularSolve {isLower = \{show isLower}} \{showValue indent x} \{showValue indent y}"
234 | showOp indent (Rng state shape) = "Rng {state = \{showValue indent state}, shape = \{show shape}}"
235 |
236 | export Show Op where show = assert_total $ showOp 0
237 |
238 | export
239 | Show Value where
240 |   show (V pos x) = "V \{show pos} \{show x}"
241 |