0 | {--
 1 | Copyright (C) 2021  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| Probability distributions.
17 | module Distribution
18 |
19 | import Data.Nat
20 | import Tensor
21 | import Constants
22 |
23 | ||| A joint, or multivariate distribution over a tensor of floating point values, where the first
24 | ||| two central moments (mean and covariance) are known. Every sub-event is assumed to have the
25 | ||| same shape.
26 | |||
27 | ||| @dist Constructs the distribution from the shape of each sub-event and the number of events in
28 | |||   the distribution.
29 | public export
30 | interface Distribution (0 dist : (0 event : Shape) -> (0 dim : Nat) -> Type) where
31 |   ||| The mean of the distribution.
32 |   mean : dist event dim -> Tag $ Tensor (dim :: event) F64
33 |
34 |   ||| The covariance, or correlation, between sub-events.
35 |   cov : dist event dim -> Tag $ Tensor (dim :: dim :: event) F64
36 |
37 | ||| The variance of a single random variable.
38 | export
39 | variance : {event : _} -> Distribution dist => dist event 1 -> Tag $ Tensor (1 :: event) F64
40 | variance dist = squeeze {from = (1 :: 1 :: event)} <$> cov dist
41 |
42 | ||| A joint, or multivariate distribution over a tensor of floating point values, where the density
43 | ||| function and corresponding cumulative density function are known (either analytically or via
44 | ||| approximation). Every sub-event is assumed to have the same shape.
45 | |||
46 | ||| @event The shape of each sub-event.
47 | ||| @dist Constructs the distribution from the shape of each sub-event and the number of events in
48 | |||   the distribution.
49 | public export
50 | interface Distribution dist  =>
51 |   ClosedFormDistribution (0 event : Shape)
52 |     (0 dist : (0 event : Shape) -> (0 dim : Nat) -> Type) where
53 |       ||| The probability density function of the distribution at the specified point.
54 |       pdf : dist event (S d) -> Tensor (S d :: event) F64 -> Tag $ Tensor [] F64
55 |
56 |       ||| The cumulative distribution function of the distribution at the specified point (that is,
57 |       ||| the probability the random variable takes a value less than or equal to the given point).
58 |       cdf : dist event (S d) -> Tensor (S d :: event) F64 -> Tag $ Tensor [] F64
59 |
60 | ||| A joint Gaussian distribution.
61 | |||
62 | ||| @event The shape of each sub-event.
63 | ||| @dim The number of sub-events.
64 | public export
65 | data Gaussian : (0 event : Shape) -> (0 dim : Nat) -> Type where
66 |   ||| @mean The mean of the events.
67 |   ||| @cov The covariance between events.
68 |   MkGaussian : {d : Nat} -> (mean : Tensor (S d :: event) F64) ->
69 |                (cov : Tensor (S d :: S d :: event) F64) ->
70 |                Gaussian event (S d)
71 |
72 | export
73 | Taggable (Gaussian event dim) where
74 |   tag (MkGaussian mean cov) = [| MkGaussian (tag mean) (tag cov) |]
75 |
76 | export
77 | Distribution Gaussian where
78 |   mean (MkGaussian mean' _) = pure mean'
79 |   cov (MkGaussian _ cov') = pure cov'
80 |
81 | ||| **NOTE** `cdf` is implemented only for univariate `Gaussian`.
82 | export
83 | ClosedFormDistribution [1] Gaussian where
84 |   pdf (MkGaussian {d} mean cov) x = do
85 |     cholCov <- tag !(cholesky $ squeeze {to = [S d, S d]} cov)
86 |     tri <- tag $ cholCov |\ squeeze (x - mean)
87 |     let exponent = - tri @@ tri / 2.0
88 |     covSqrtDet <- reduce @{Prod} [0] (diag cholCov)
89 |     let denominator = fromDouble (pow (2.0 * pi) (cast (S d) / 2.0)) * covSqrtDet
90 |     pure (exp exponent / denominator)
91 |
92 |   cdf (MkGaussian {d = S _} _ _) _ =
93 |     assert_total $ idris_crash "CDF not implemented for multivariate Gaussian"
94 |   cdf (MkGaussian {d = 0} mean cov) x =
95 |     pure $ (1.0 + erf (squeeze (x - mean) / (sqrt (squeeze cov * 2.0)))) / 2.0
96 |