... | ... |
@@ -7,6 +7,12 @@ from . import runner |
7 | 7 |
|
8 | 8 |
from .learner import (Learner1D, Learner2D, AverageLearner, |
9 | 9 |
BalancingLearner, DataSaver, IntegratorLearner) |
10 |
+try: |
|
11 |
+ # Only available if 'scikit-optimize' is installed |
|
12 |
+ from .learner import SKOptLearner |
|
13 |
+except ImportError: |
|
14 |
+ pass |
|
15 |
+ |
|
10 | 16 |
from .runner import Runner, BlockingRunner |
11 | 17 |
from . import version |
12 | 18 |
|
... | ... |
@@ -6,3 +6,9 @@ from .learner1D import Learner1D |
6 | 6 |
from .learner2D import Learner2D |
7 | 7 |
from .integrator_learner import IntegratorLearner |
8 | 8 |
from .data_saver import DataSaver |
9 |
+ |
|
10 |
+try: |
|
11 |
+ # Only available if 'scikit-optimize' is installed |
|
12 |
+ from .skopt_learner import SKOptLearner |
|
13 |
+except ImportError: |
|
14 |
+ pass |
9 | 15 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,92 @@ |
1 |
+# -*- coding: utf-8 -*- |
|
2 |
+from copy import deepcopy |
|
3 |
+import heapq |
|
4 |
+import itertools |
|
5 |
+import math |
|
6 |
+ |
|
7 |
+import numpy as np |
|
8 |
+import sortedcontainers |
|
9 |
+ |
|
10 |
+from ..notebook_integration import ensure_holoviews |
|
11 |
+from .base_learner import BaseLearner |
|
12 |
+ |
|
13 |
+from skopt import Optimizer |
|
14 |
+ |
|
15 |
+ |
|
16 |
+class SKOptLearner(Optimizer, BaseLearner): |
|
17 |
+ """Learn a function minimum using 'skopt.Optimizer'. |
|
18 |
+ |
|
19 |
+ This is an 'Optimizer' from 'scikit-optimize', |
|
20 |
+ with the necessary methods added to make it conform |
|
21 |
+ to the 'adaptive' learner interface. |
|
22 |
+ |
|
23 |
+ Parameters |
|
24 |
+ ---------- |
|
25 |
+ function : callable |
|
26 |
+ The function to learn. |
|
27 |
+ **kwargs : |
|
28 |
+ Arguments to pass to 'skopt.Optimizer'. |
|
29 |
+ """ |
|
30 |
+ |
|
31 |
+ def __init__(self, function, **kwargs): |
|
32 |
+ self.function = function |
|
33 |
+ super().__init__(**kwargs) |
|
34 |
+ |
|
35 |
+ def add_point(self, x, y): |
|
36 |
+ if y is not None: |
|
37 |
+ # 'skopt.Optimizer' takes care of points we |
|
38 |
+ # have not got results for. |
|
39 |
+ self.tell([x], y) |
|
40 |
+ |
|
41 |
+ def remove_unfinished(self): |
|
42 |
+ pass |
|
43 |
+ |
|
44 |
+ def loss(self, real=True): |
|
45 |
+ if not self.models: |
|
46 |
+ return np.inf |
|
47 |
+ else: |
|
48 |
+ model = self.models[-1] |
|
49 |
+ # Return the in-sample error (i.e. test the model |
|
50 |
+ # with the training data). This is not the best |
|
51 |
+ # estimator of loss, but it is the cheapest. |
|
52 |
+ return 1 / model.score(self.Xi, self.yi) |
|
53 |
+ |
|
54 |
+ def choose_points(self, n, add_data=True): |
|
55 |
+ points = self.ask(n) |
|
56 |
+ if self.space.n_dims > 1: |
|
57 |
+ return points, [np.inf] * n |
|
58 |
+ else: |
|
59 |
+ return [p[0] for p in points], [np.inf] * n |
|
60 |
+ |
|
61 |
+ @property |
|
62 |
+ def npoints(self): |
|
63 |
+ return len(self.Xi) |
|
64 |
+ |
|
65 |
+ def plot(self): |
|
66 |
+ hv = ensure_holoviews() |
|
67 |
+ if self.space.n_dims > 1: |
|
68 |
+ raise ValueError('Can only plot 1D functions') |
|
69 |
+ bounds = self.space.bounds[0] |
|
70 |
+ if not self.Xi: |
|
71 |
+ p = hv.Scatter([]) * hv.Area([]) |
|
72 |
+ else: |
|
73 |
+ scatter = hv.Scatter(([p[0] for p in self.Xi], self.yi)) |
|
74 |
+ if self.models: |
|
75 |
+ # Plot 95% confidence interval as colored area around points |
|
76 |
+ model = self.models[-1] |
|
77 |
+ xs = np.linspace(*bounds, 201) |
|
78 |
+ y_pred, sigma = model.predict(np.atleast_2d(xs).transpose(), |
|
79 |
+ return_std=True) |
|
80 |
+ area = hv.Area( |
|
81 |
+ (xs, y_pred - 1.96 * sigma, y_pred + 1.96 * sigma), |
|
82 |
+ vdims=['y', 'y2'], |
|
83 |
+ ).opts(style=dict(alpha=0.5, line_alpha=0)) |
|
84 |
+ else: |
|
85 |
+ area = hv.Area([]) |
|
86 |
+ p = scatter * area |
|
87 |
+ |
|
88 |
+ # Plot with 5% empty margins such that the boundary points are visible |
|
89 |
+ margin = 0.05 * (bounds[1] - bounds[0]) |
|
90 |
+ plot_bounds = (bounds[0] - margin, bounds[1] + margin) |
|
91 |
+ |
|
92 |
+ return p.redim(x=dict(range=plot_bounds)) |