30 | interface Distribution (0 dist : (0 event : Shape) -> (0 dim : Nat) -> Type) where
32 | mean : dist event dim -> Tag $
Tensor (dim :: event) F64
35 | cov : dist event dim -> Tag $
Tensor (dim :: dim :: event) F64
39 | variance : {event : _} -> Distribution dist => dist event 1 -> Tag $
Tensor (1 :: event) F64
40 | variance dist = squeeze {from = (1 :: 1 :: event)} <$> cov dist
50 | interface Distribution dist =>
51 | ClosedFormDistribution (0 event : Shape)
52 | (0 dist : (0 event : Shape) -> (0 dim : Nat) -> Type) where
54 | pdf : dist event (S d) -> Tensor (S d :: event) F64 -> Tag $
Tensor [] F64
58 | cdf : dist event (S d) -> Tensor (S d :: event) F64 -> Tag $
Tensor [] F64
65 | data Gaussian : (0 event : Shape) -> (0 dim : Nat) -> Type where
68 | MkGaussian : {d : Nat} -> (mean : Tensor (S d :: event) F64) ->
69 | (cov : Tensor (S d :: S d :: event) F64) ->
70 | Gaussian event (S d)
73 | Taggable (Gaussian event dim) where
74 | tag (MkGaussian mean cov) = [| MkGaussian (tag mean) (tag cov) |]
77 | Distribution Gaussian where
78 | mean (MkGaussian mean' _) = pure mean'
79 | cov (MkGaussian _ cov') = pure cov'
83 | ClosedFormDistribution [1] Gaussian where
84 | pdf (MkGaussian {d} mean cov) x = do
85 | cholCov <- tag !(cholesky $
squeeze {to = [S d, S d]} cov)
86 | tri <- tag $
cholCov |\ squeeze (x - mean)
87 | let exponent = - tri @@ tri / 2.0
88 | covSqrtDet <- reduce @{Prod} [0] (diag cholCov)
89 | let denominator = fromDouble (pow (2.0 * pi) (cast (S d) / 2.0)) * covSqrtDet
90 | pure (exp exponent / denominator)
92 | cdf (MkGaussian {d = S _} _ _) _ =
93 | assert_total $
idris_crash "CDF not implemented for multivariate Gaussian"
94 | cdf (MkGaussian {d = 0} mean cov) x =
95 | pure $
(1.0 + erf (squeeze (x - mean) / (sqrt (squeeze cov * 2.0)))) / 2.0