_gpr.py
21.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
"""Gaussian processes regression. """
# Authors: Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>
# Modified by: Pete Green <p.l.green@liverpool.ac.uk>
# License: BSD 3 clause
import warnings
from operator import itemgetter
import numpy as np
from scipy.linalg import cholesky, cho_solve, solve_triangular
import scipy.optimize
from ..base import BaseEstimator, RegressorMixin, clone
from ..base import MultiOutputMixin
from .kernels import RBF, ConstantKernel as C
from ..utils import check_random_state
from ..utils.validation import check_array
from ..utils.optimize import _check_optimize_result
from ..utils.validation import _deprecate_positional_args
class GaussianProcessRegressor(MultiOutputMixin,
RegressorMixin, BaseEstimator):
"""Gaussian process regression (GPR).
The implementation is based on Algorithm 2.1 of Gaussian Processes
for Machine Learning (GPML) by Rasmussen and Williams.
In addition to standard scikit-learn estimator API,
GaussianProcessRegressor:
* allows prediction without prior fitting (based on the GP prior)
* provides an additional method sample_y(X), which evaluates samples
drawn from the GPR (prior or posterior) at given inputs
* exposes a method log_marginal_likelihood(theta), which can be used
externally for other ways of selecting hyperparameters, e.g., via
Markov chain Monte Carlo.
Read more in the :ref:`User Guide <gaussian_process>`.
.. versionadded:: 0.18
Parameters
----------
kernel : kernel instance, default=None
The kernel specifying the covariance function of the GP. If None is
passed, the kernel "1.0 * RBF(1.0)" is used as default. Note that
the kernel's hyperparameters are optimized during fitting.
alpha : float or array-like of shape (n_samples), default=1e-10
Value added to the diagonal of the kernel matrix during fitting.
Larger values correspond to increased noise level in the observations.
This can also prevent a potential numerical issue during fitting, by
ensuring that the calculated values form a positive definite matrix.
If an array is passed, it must have the same number of entries as the
data used for fitting and is used as datapoint-dependent noise level.
Note that this is equivalent to adding a WhiteKernel with c=alpha.
Allowing to specify the noise level directly as a parameter is mainly
for convenience and for consistency with Ridge.
optimizer : "fmin_l_bfgs_b" or callable, default="fmin_l_bfgs_b"
Can either be one of the internally supported optimizers for optimizing
the kernel's parameters, specified by a string, or an externally
defined optimizer passed as a callable. If a callable is passed, it
must have the signature::
def optimizer(obj_func, initial_theta, bounds):
# * 'obj_func' is the objective function to be minimized, which
# takes the hyperparameters theta as parameter and an
# optional flag eval_gradient, which determines if the
# gradient is returned additionally to the function value
# * 'initial_theta': the initial value for theta, which can be
# used by local optimizers
# * 'bounds': the bounds on the values of theta
....
# Returned are the best found hyperparameters theta and
# the corresponding value of the target function.
return theta_opt, func_min
Per default, the 'L-BGFS-B' algorithm from scipy.optimize.minimize
is used. If None is passed, the kernel's parameters are kept fixed.
Available internal optimizers are::
'fmin_l_bfgs_b'
n_restarts_optimizer : int, default=0
The number of restarts of the optimizer for finding the kernel's
parameters which maximize the log-marginal likelihood. The first run
of the optimizer is performed from the kernel's initial parameters,
the remaining ones (if any) from thetas sampled log-uniform randomly
from the space of allowed theta-values. If greater than 0, all bounds
must be finite. Note that n_restarts_optimizer == 0 implies that one
run is performed.
normalize_y : boolean, optional (default: False)
Whether the target values y are normalized, the mean and variance of
the target values are set equal to 0 and 1 respectively. This is
recommended for cases where zero-mean, unit-variance priors are used.
Note that, in this implementation, the normalisation is reversed
before the GP predictions are reported.
.. versionchanged:: 0.23
copy_X_train : bool, default=True
If True, a persistent copy of the training data is stored in the
object. Otherwise, just a reference to the training data is stored,
which might cause predictions to change if the data is modified
externally.
random_state : int or RandomState, default=None
Determines random number generation used to initialize the centers.
Pass an int for reproducible results across multiple function calls.
See :term: `Glossary <random_state>`.
Attributes
----------
X_train_ : array-like of shape (n_samples, n_features) or list of object
Feature vectors or other representations of training data (also
required for prediction).
y_train_ : array-like of shape (n_samples,) or (n_samples, n_targets)
Target values in training data (also required for prediction)
kernel_ : kernel instance
The kernel used for prediction. The structure of the kernel is the
same as the one passed as parameter but with optimized hyperparameters
L_ : array-like of shape (n_samples, n_samples)
Lower-triangular Cholesky decomposition of the kernel in ``X_train_``
alpha_ : array-like of shape (n_samples,)
Dual coefficients of training data points in kernel space
log_marginal_likelihood_value_ : float
The log-marginal-likelihood of ``self.kernel_.theta``
Examples
--------
>>> from sklearn.datasets import make_friedman2
>>> from sklearn.gaussian_process import GaussianProcessRegressor
>>> from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
>>> X, y = make_friedman2(n_samples=500, noise=0, random_state=0)
>>> kernel = DotProduct() + WhiteKernel()
>>> gpr = GaussianProcessRegressor(kernel=kernel,
... random_state=0).fit(X, y)
>>> gpr.score(X, y)
0.3680...
>>> gpr.predict(X[:2,:], return_std=True)
(array([653.0..., 592.1...]), array([316.6..., 316.6...]))
"""
@_deprecate_positional_args
def __init__(self, kernel=None, *, alpha=1e-10,
optimizer="fmin_l_bfgs_b", n_restarts_optimizer=0,
normalize_y=False, copy_X_train=True, random_state=None):
self.kernel = kernel
self.alpha = alpha
self.optimizer = optimizer
self.n_restarts_optimizer = n_restarts_optimizer
self.normalize_y = normalize_y
self.copy_X_train = copy_X_train
self.random_state = random_state
def fit(self, X, y):
"""Fit Gaussian process regression model.
Parameters
----------
X : array-like of shape (n_samples, n_features) or list of object
Feature vectors or other representations of training data.
y : array-like of shape (n_samples,) or (n_samples, n_targets)
Target values
Returns
-------
self : returns an instance of self.
"""
if self.kernel is None: # Use an RBF kernel as default
self.kernel_ = C(1.0, constant_value_bounds="fixed") \
* RBF(1.0, length_scale_bounds="fixed")
else:
self.kernel_ = clone(self.kernel)
self._rng = check_random_state(self.random_state)
if self.kernel_.requires_vector_input:
X, y = self._validate_data(X, y, multi_output=True, y_numeric=True,
ensure_2d=True, dtype="numeric")
else:
X, y = self._validate_data(X, y, multi_output=True, y_numeric=True,
ensure_2d=False, dtype=None)
# Normalize target value
if self.normalize_y:
self._y_train_mean = np.mean(y, axis=0)
self._y_train_std = np.std(y, axis=0)
# Remove mean and make unit variance
y = (y - self._y_train_mean) / self._y_train_std
else:
self._y_train_mean = np.zeros(1)
self._y_train_std = 1
if np.iterable(self.alpha) \
and self.alpha.shape[0] != y.shape[0]:
if self.alpha.shape[0] == 1:
self.alpha = self.alpha[0]
else:
raise ValueError("alpha must be a scalar or an array"
" with same number of entries as y.(%d != %d)"
% (self.alpha.shape[0], y.shape[0]))
self.X_train_ = np.copy(X) if self.copy_X_train else X
self.y_train_ = np.copy(y) if self.copy_X_train else y
if self.optimizer is not None and self.kernel_.n_dims > 0:
# Choose hyperparameters based on maximizing the log-marginal
# likelihood (potentially starting from several initial values)
def obj_func(theta, eval_gradient=True):
if eval_gradient:
lml, grad = self.log_marginal_likelihood(
theta, eval_gradient=True, clone_kernel=False)
return -lml, -grad
else:
return -self.log_marginal_likelihood(theta,
clone_kernel=False)
# First optimize starting from theta specified in kernel
optima = [(self._constrained_optimization(obj_func,
self.kernel_.theta,
self.kernel_.bounds))]
# Additional runs are performed from log-uniform chosen initial
# theta
if self.n_restarts_optimizer > 0:
if not np.isfinite(self.kernel_.bounds).all():
raise ValueError(
"Multiple optimizer restarts (n_restarts_optimizer>0) "
"requires that all bounds are finite.")
bounds = self.kernel_.bounds
for iteration in range(self.n_restarts_optimizer):
theta_initial = \
self._rng.uniform(bounds[:, 0], bounds[:, 1])
optima.append(
self._constrained_optimization(obj_func, theta_initial,
bounds))
# Select result from run with minimal (negative) log-marginal
# likelihood
lml_values = list(map(itemgetter(1), optima))
self.kernel_.theta = optima[np.argmin(lml_values)][0]
self.log_marginal_likelihood_value_ = -np.min(lml_values)
else:
self.log_marginal_likelihood_value_ = \
self.log_marginal_likelihood(self.kernel_.theta,
clone_kernel=False)
# Precompute quantities required for predictions which are independent
# of actual query points
K = self.kernel_(self.X_train_)
K[np.diag_indices_from(K)] += self.alpha
try:
self.L_ = cholesky(K, lower=True) # Line 2
# self.L_ changed, self._K_inv needs to be recomputed
self._K_inv = None
except np.linalg.LinAlgError as exc:
exc.args = ("The kernel, %s, is not returning a "
"positive definite matrix. Try gradually "
"increasing the 'alpha' parameter of your "
"GaussianProcessRegressor estimator."
% self.kernel_,) + exc.args
raise
self.alpha_ = cho_solve((self.L_, True), self.y_train_) # Line 3
return self
def predict(self, X, return_std=False, return_cov=False):
"""Predict using the Gaussian process regression model
We can also predict based on an unfitted model by using the GP prior.
In addition to the mean of the predictive distribution, also its
standard deviation (return_std=True) or covariance (return_cov=True).
Note that at most one of the two can be requested.
Parameters
----------
X : array-like of shape (n_samples, n_features) or list of object
Query points where the GP is evaluated.
return_std : bool, default=False
If True, the standard-deviation of the predictive distribution at
the query points is returned along with the mean.
return_cov : bool, default=False
If True, the covariance of the joint predictive distribution at
the query points is returned along with the mean
Returns
-------
y_mean : ndarray of shape (n_samples, [n_output_dims])
Mean of predictive distribution a query points
y_std : ndarray of shape (n_samples,), optional
Standard deviation of predictive distribution at query points.
Only returned when `return_std` is True.
y_cov : ndarray of shape (n_samples, n_samples), optional
Covariance of joint predictive distribution a query points.
Only returned when `return_cov` is True.
"""
if return_std and return_cov:
raise RuntimeError(
"Not returning standard deviation of predictions when "
"returning full covariance.")
if self.kernel is None or self.kernel.requires_vector_input:
X = check_array(X, ensure_2d=True, dtype="numeric")
else:
X = check_array(X, ensure_2d=False, dtype=None)
if not hasattr(self, "X_train_"): # Unfitted;predict based on GP prior
if self.kernel is None:
kernel = (C(1.0, constant_value_bounds="fixed") *
RBF(1.0, length_scale_bounds="fixed"))
else:
kernel = self.kernel
y_mean = np.zeros(X.shape[0])
if return_cov:
y_cov = kernel(X)
return y_mean, y_cov
elif return_std:
y_var = kernel.diag(X)
return y_mean, np.sqrt(y_var)
else:
return y_mean
else: # Predict based on GP posterior
K_trans = self.kernel_(X, self.X_train_)
y_mean = K_trans.dot(self.alpha_) # Line 4 (y_mean = f_star)
# undo normalisation
y_mean = self._y_train_std * y_mean + self._y_train_mean
if return_cov:
v = cho_solve((self.L_, True), K_trans.T) # Line 5
y_cov = self.kernel_(X) - K_trans.dot(v) # Line 6
# undo normalisation
y_cov = y_cov * self._y_train_std**2
return y_mean, y_cov
elif return_std:
# cache result of K_inv computation
if self._K_inv is None:
# compute inverse K_inv of K based on its Cholesky
# decomposition L and its inverse L_inv
L_inv = solve_triangular(self.L_.T,
np.eye(self.L_.shape[0]))
self._K_inv = L_inv.dot(L_inv.T)
# Compute variance of predictive distribution
y_var = self.kernel_.diag(X)
y_var -= np.einsum("ij,ij->i",
np.dot(K_trans, self._K_inv), K_trans)
# Check if any of the variances is negative because of
# numerical issues. If yes: set the variance to 0.
y_var_negative = y_var < 0
if np.any(y_var_negative):
warnings.warn("Predicted variances smaller than 0. "
"Setting those variances to 0.")
y_var[y_var_negative] = 0.0
# undo normalisation
y_var = y_var * self._y_train_std**2
return y_mean, np.sqrt(y_var)
else:
return y_mean
def sample_y(self, X, n_samples=1, random_state=0):
"""Draw samples from Gaussian process and evaluate at X.
Parameters
----------
X : array-like of shape (n_samples, n_features) or list of object
Query points where the GP is evaluated.
n_samples : int, default=1
The number of samples drawn from the Gaussian process
random_state : int, RandomState, default=0
Determines random number generation to randomly draw samples.
Pass an int for reproducible results across multiple function
calls.
See :term: `Glossary <random_state>`.
Returns
-------
y_samples : ndarray of shape (n_samples_X, [n_output_dims], n_samples)
Values of n_samples samples drawn from Gaussian process and
evaluated at query points.
"""
rng = check_random_state(random_state)
y_mean, y_cov = self.predict(X, return_cov=True)
if y_mean.ndim == 1:
y_samples = rng.multivariate_normal(y_mean, y_cov, n_samples).T
else:
y_samples = \
[rng.multivariate_normal(y_mean[:, i], y_cov,
n_samples).T[:, np.newaxis]
for i in range(y_mean.shape[1])]
y_samples = np.hstack(y_samples)
return y_samples
def log_marginal_likelihood(self, theta=None, eval_gradient=False,
clone_kernel=True):
"""Returns log-marginal likelihood of theta for training data.
Parameters
----------
theta : array-like of shape (n_kernel_params,) default=None
Kernel hyperparameters for which the log-marginal likelihood is
evaluated. If None, the precomputed log_marginal_likelihood
of ``self.kernel_.theta`` is returned.
eval_gradient : bool, default=False
If True, the gradient of the log-marginal likelihood with respect
to the kernel hyperparameters at position theta is returned
additionally. If True, theta must not be None.
clone_kernel : bool, default=True
If True, the kernel attribute is copied. If False, the kernel
attribute is modified, but may result in a performance improvement.
Returns
-------
log_likelihood : float
Log-marginal likelihood of theta for training data.
log_likelihood_gradient : ndarray of shape (n_kernel_params,), optional
Gradient of the log-marginal likelihood with respect to the kernel
hyperparameters at position theta.
Only returned when eval_gradient is True.
"""
if theta is None:
if eval_gradient:
raise ValueError(
"Gradient can only be evaluated for theta!=None")
return self.log_marginal_likelihood_value_
if clone_kernel:
kernel = self.kernel_.clone_with_theta(theta)
else:
kernel = self.kernel_
kernel.theta = theta
if eval_gradient:
K, K_gradient = kernel(self.X_train_, eval_gradient=True)
else:
K = kernel(self.X_train_)
K[np.diag_indices_from(K)] += self.alpha
try:
L = cholesky(K, lower=True) # Line 2
except np.linalg.LinAlgError:
return (-np.inf, np.zeros_like(theta)) \
if eval_gradient else -np.inf
# Support multi-dimensional output of self.y_train_
y_train = self.y_train_
if y_train.ndim == 1:
y_train = y_train[:, np.newaxis]
alpha = cho_solve((L, True), y_train) # Line 3
# Compute log-likelihood (compare line 7)
log_likelihood_dims = -0.5 * np.einsum("ik,ik->k", y_train, alpha)
log_likelihood_dims -= np.log(np.diag(L)).sum()
log_likelihood_dims -= K.shape[0] / 2 * np.log(2 * np.pi)
log_likelihood = log_likelihood_dims.sum(-1) # sum over dimensions
if eval_gradient: # compare Equation 5.9 from GPML
tmp = np.einsum("ik,jk->ijk", alpha, alpha) # k: output-dimension
tmp -= cho_solve((L, True), np.eye(K.shape[0]))[:, :, np.newaxis]
# Compute "0.5 * trace(tmp.dot(K_gradient))" without
# constructing the full matrix tmp.dot(K_gradient) since only
# its diagonal is required
log_likelihood_gradient_dims = \
0.5 * np.einsum("ijl,jik->kl", tmp, K_gradient)
log_likelihood_gradient = log_likelihood_gradient_dims.sum(-1)
if eval_gradient:
return log_likelihood, log_likelihood_gradient
else:
return log_likelihood
def _constrained_optimization(self, obj_func, initial_theta, bounds):
if self.optimizer == "fmin_l_bfgs_b":
opt_res = scipy.optimize.minimize(
obj_func, initial_theta, method="L-BFGS-B", jac=True,
bounds=bounds)
_check_optimize_result("lbfgs", opt_res)
theta_opt, func_min = opt_res.x, opt_res.fun
elif callable(self.optimizer):
theta_opt, func_min = \
self.optimizer(obj_func, initial_theta, bounds=bounds)
else:
raise ValueError("Unknown optimizer %s." % self.optimizer)
return theta_opt, func_min
def _more_tags(self):
return {'requires_fit': False}