0 | {--
 1 | Copyright (C) 2021  Joel Berkeley
 2 |
 3 | This program is free software: you can redistribute it and/or modify
 4 | it under the terms of the GNU Affero General Public License as published
 5 | by the Free Software Foundation, either version 3 of the License, or
 6 | (at your option) any later version.
 7 |
 8 | This program is distributed in the hope that it will be useful,
 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 | GNU Affero General Public License for more details.
12 |
13 | You should have received a copy of the GNU Affero General Public License
14 | along with this program.  If not, see <https://www.gnu.org/licenses/>.
15 | --}
16 | ||| Kernel functions, particularly for use in Gaussian processes.
17 | module Model.Kernel
18 |
19 | import Tensor
20 | import Data.Nat
21 |
22 | ||| A `Kernel` function maps pairs of points in a feature space to the covariance between those two
23 | ||| points in some target space.
24 | |||
25 | ||| @features The shape of the feature domain.
26 | public export 0
27 | Kernel : (0 features : Shape) -> Type
28 | Kernel features =
29 |   {sk, sk' : _} ->
30 |   Tensor (sk :: features) F64 ->
31 |   Tensor (sk' :: features) F64 ->
32 |   Tag $ Tensor [sk, sk'] F64
33 |
34 | scaledL2Norm :
35 |   Tensor [] F64 ->
36 |   {d, n, n' : _} ->
37 |   Tensor [n, S d] F64 ->
38 |   Tensor [n', S d] F64 ->
39 |   Tag $ Tensor [n, n'] F64
40 | scaledL2Norm len x x' =
41 |   let xs = broadcast {to = [n, n', S d]} $ expand 1 x
42 |    in reduce @{Sum} [2] $ ((xs - broadcast (expand 0 x')) / len) ^ fill 2.0
43 |
44 | ||| The radial basis function, or squared exponential kernel. This is a stationary kernel with form
45 | |||
46 | ||| (\mathbf x_i, \mathbf x_j) \mapsto \exp \left(- \frac{r^2}{2l^2} \right)
47 | |||
48 | ||| where `r^2 = (\mathbf x_i - \mathbf x_j)^ \intercal (\mathbf x_i - \mathbf x_j)` and the
49 | ||| length scale `l > 0`.
50 | |||
51 | ||| Two points that are close in feature space will be more tightly correlated than points that
52 | ||| are further apart. The distance over which the correlation reduces is given by the length
53 | ||| scale `l`. Smaller length scales result in faster-varying target values.
54 | |||
55 | ||| @lengthScale The length scale `l`.
56 | export
57 | rbf : (lengthScale : Tensor [] F64) -> {d : _} -> Kernel [S d]
58 | rbf lengthScale x x' = pure $ exp (- !(scaledL2Norm lengthScale x x') / 2.0)
59 |
60 | ||| The Matern kernel for parameter 5/2. This is a stationary kernel with form
61 | |||
62 | ||| (\mathbf x_i, \mathbf x_j) \mapsto \sigma^2 \left(
63 | |||   1 + \frac{\sqrt{5}r}{l} + \frac{5 r^2}{3 l^2}
64 | ||| \right) \exp \left( -\frac{\sqrt{5}r}{l} \right)
65 | |||
66 | ||| where `r^2 = (\mathbf x_i - \mathbf x_j)^ \intercal (\mathbf x_i - \mathbf x_j)` and the
67 | ||| length scale `l > 0`.
68 | |||
69 | ||| @amplitude The amplitude `\sigma`.
70 | ||| @length_scale The length scale `l`.
71 | export
72 | matern52 :
73 |   (amplitude : Tensor [] F64) -> (length_scale : Tensor [] F64) -> {d : _} -> Kernel [S d]
74 | matern52 amp len x x' = do
75 |   d2 <- tag $ 5.0 * !(scaledL2Norm len x x')
76 |   d <- tag $ d2 ^ fill 0.5
77 |   pure $ (amp ^ 2.0) * (d2 / 3.0 + d + fill 1.0) * exp (- d)
78 |